mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
82 Commits
feat/llm-s
...
feat/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9bae7aafd | ||
|
|
48be8fb6cc | ||
|
|
dd02a9ac9d | ||
|
|
b203139356 | ||
|
|
c479fcf251 | ||
|
|
d7c3e54eaa | ||
|
|
d5fe50e471 | ||
|
|
205535c8e9 | ||
|
|
e9aedc701c | ||
|
|
cf464d252d | ||
|
|
5e09ac696c | ||
|
|
ba9357da96 | ||
|
|
c6fb879cea | ||
|
|
e2cb7006c4 | ||
|
|
3737e0b087 | ||
|
|
a1158cc946 | ||
|
|
404f8a790c | ||
|
|
35a008af18 | ||
|
|
79bf590576 | ||
|
|
61e39bccdf | ||
|
|
6b7dfee88b | ||
|
|
21412a8c55 | ||
|
|
239e40c8d5 | ||
|
|
1ce2c7f3e8 | ||
|
|
de750a67ec | ||
|
|
8e6ea4d117 | ||
|
|
ef188564f3 | ||
|
|
413271eaa6 | ||
|
|
eb1ce3dd6b | ||
|
|
18e4f42c3c | ||
|
|
e0e92921b5 | ||
|
|
94e22ba0fd | ||
|
|
67eefd0ba1 | ||
|
|
bf031af7b1 | ||
|
|
617611ee22 | ||
|
|
d43b884c2a | ||
|
|
80f5ee1eb2 | ||
|
|
ee30497237 | ||
|
|
be964c78ec | ||
|
|
2543162dec | ||
|
|
3136eb8e4b | ||
|
|
7b6523e54d | ||
|
|
30c051d485 | ||
|
|
f191d372f0 | ||
|
|
cb69cb2d64 | ||
|
|
5d9c67e97e | ||
|
|
0ba37592f7 | ||
|
|
e0e8667a0b | ||
|
|
2157d9e17e | ||
|
|
62e7fa1f63 | ||
|
|
0ac7366cdc | ||
|
|
c768d97637 | ||
|
|
9bd8e62702 | ||
|
|
9a3acdcff8 | ||
|
|
93c1ee225e | ||
|
|
1e32175cdc | ||
|
|
00d9f037b5 | ||
|
|
44a2eca449 | ||
|
|
20df6e9c00 | ||
|
|
7ba3e599d2 | ||
|
|
4247a6b807 | ||
|
|
775dc47abe | ||
|
|
da9269ca97 | ||
|
|
d2e3744ca3 | ||
|
|
3914cf07e7 | ||
|
|
1e7418095f | ||
|
|
efe5db38ee | ||
|
|
523efbfea5 | ||
|
|
b96ecd072a | ||
|
|
28ffe7e3db | ||
|
|
721294948c | ||
|
|
b287aaccec | ||
|
|
bbc6efd773 | ||
|
|
dc9c5a4bc7 | ||
|
|
e90c532c3a | ||
|
|
397e2a8522 | ||
|
|
8f547e6340 | ||
|
|
defd5520ea | ||
|
|
b6b608219a | ||
|
|
22a1bc337f | ||
|
|
caa179a1d3 | ||
|
|
e8e47aee21 |
@@ -6,7 +6,7 @@
|
||||
|
||||
本指南和 Dify 一样在不断完善中。如果有任何滞后于项目实际情况的地方,恳请谅解,我们也欢迎任何改进建议。
|
||||
|
||||
关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。社区同时也遵循[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。同时也请遵循社区[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
## 开始之前
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
<a href="https://docs.dify.ai">Documentation</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Enterprise inquiry</a>
|
||||
<a href="https://dify.ai/pricing">Dify edition overview</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">الاستضافة الذاتية</a> ·
|
||||
<a href="https://docs.dify.ai">التوثيق</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">استفسار الشركات (للإنجليزية فقط)</a>
|
||||
<a href="https://dify.ai/pricing">نظرة عامة على منتجات Dify</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://cloud.dify.ai">ডিফাই ক্লাউড</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">সেল্ফ-হোস্টিং</a> ·
|
||||
<a href="https://docs.dify.ai">ডকুমেন্টেশন</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">ব্যাবসায়িক অনুসন্ধান</a>
|
||||
<a href="https://dify.ai/pricing">Dify পণ্যের রূপভেদ</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify 云服务</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">自托管</a> ·
|
||||
<a href="https://docs.dify.ai">文档</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">(需用英文)常见问题解答 / 联系团队</a>
|
||||
<a href="https://dify.ai/pricing">Dify 产品形态总览</a>
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Selbstgehostetes</a> ·
|
||||
<a href="https://docs.dify.ai">Dokumentation</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Anfrage an Unternehmen</a>
|
||||
<a href="https://dify.ai/pricing">Überblick über die Dify-Produkte</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Auto-alojamiento</a> ·
|
||||
<a href="https://docs.dify.ai">Documentación</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Consultas empresariales (en inglés)</a>
|
||||
<a href="https://dify.ai/pricing">Resumen de las ediciones de Dify</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Auto-hébergement</a> ·
|
||||
<a href="https://docs.dify.ai">Documentation</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Demande d’entreprise (en anglais seulement)</a>
|
||||
<a href="https://dify.ai/pricing">Présentation des différentes offres Dify</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">セルフホスティング</a> ·
|
||||
<a href="https://docs.dify.ai">ドキュメント</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">企業のお問い合わせ(英語のみ)</a>
|
||||
<a href="https://dify.ai/pricing">Difyの各種エディションについて</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
<a href="https://docs.dify.ai">Documentation</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Commercial enquiries</a>
|
||||
<a href="https://dify.ai/pricing">Dify product editions</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify 클라우드</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">셀프-호스팅</a> ·
|
||||
<a href="https://docs.dify.ai">문서</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">기업 문의 (영어만 가능)</a>
|
||||
<a href="https://dify.ai/pricing">Dify 제품 에디션 안내</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Auto-hospedagem</a> ·
|
||||
<a href="https://docs.dify.ai">Documentação</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Consultas empresariais</a>
|
||||
<a href="https://dify.ai/pricing">Visão geral das edições do Dify</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Samostojno gostovanje</a> ·
|
||||
<a href="https://docs.dify.ai">Dokumentacija</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Povpraševanje za podjetja</a>
|
||||
<a href="https://dify.ai/pricing">Pregled ponudb izdelkov Dify</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Bulut</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Kendi Sunucunuzda Barındırma</a> ·
|
||||
<a href="https://docs.dify.ai">Dokümantasyon</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Yalnızca İngilizce: Kurumsal Sorgulama</a>
|
||||
<a href="https://dify.ai/pricing">Dify ürün seçeneklerine genel bakış</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify 雲端服務</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">自行託管</a> ·
|
||||
<a href="https://docs.dify.ai">說明文件</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">企業諮詢</a>
|
||||
<a href="https://dify.ai/pricing">產品方案概覽</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Tự triển khai</a> ·
|
||||
<a href="https://docs.dify.ai">Tài liệu</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">Yêu cầu doanh nghiệp</a>
|
||||
<a href="https://dify.ai/pricing">Tổng quan các lựa chọn sản phẩm Dify</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -482,4 +482,7 @@ OTEL_MAX_QUEUE_SIZE=2048
|
||||
OTEL_MAX_EXPORT_BATCH_SIZE=512
|
||||
OTEL_METRIC_EXPORT_INTERVAL=60000
|
||||
OTEL_BATCH_EXPORT_TIMEOUT=10000
|
||||
OTEL_METRIC_EXPORT_TIMEOUT=30000
|
||||
OTEL_METRIC_EXPORT_TIMEOUT=30000
|
||||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
|
||||
@@ -52,6 +52,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_otel,
|
||||
ext_otel_patch,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
ext_repositories,
|
||||
@@ -84,6 +85,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_proxy_fix,
|
||||
ext_blueprints,
|
||||
ext_commands,
|
||||
ext_otel_patch, # Apply patch before initializing OpenTelemetry
|
||||
ext_otel,
|
||||
]
|
||||
for ext in extensions:
|
||||
|
||||
@@ -13,6 +13,7 @@ from .observability import ObservabilityConfig
|
||||
from .packaging import PackagingInfo
|
||||
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
|
||||
from .remote_settings_sources.apollo import ApolloSettingsSource
|
||||
from .remote_settings_sources.nacos import NacosSettingsSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,6 +35,8 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
match remote_source_name:
|
||||
case RemoteSettingsSourceName.APOLLO:
|
||||
remote_source = ApolloSettingsSource(current_state)
|
||||
case RemoteSettingsSourceName.NACOS:
|
||||
remote_source = NacosSettingsSource(current_state)
|
||||
case _:
|
||||
logger.warning(f"Unsupported remote source: {remote_source_name}")
|
||||
return {}
|
||||
|
||||
@@ -22,6 +22,7 @@ from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
from .vdb.myscale_config import MyScaleConfig
|
||||
@@ -263,6 +264,7 @@ class MiddlewareConfig(
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
|
||||
25
api/configs/middleware/vdb/huawei_cloud_config.py
Normal file
25
api/configs/middleware/vdb/huawei_cloud_config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HuaweiCloudConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Huawei cloud search service
|
||||
"""
|
||||
|
||||
HUAWEI_CLOUD_HOSTS: Optional[str] = Field(
|
||||
description="Hostname or IP address of the Huawei cloud search service instance",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_CLOUD_USER: Optional[str] = Field(
|
||||
description="Username for authenticating with Huawei cloud search service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_CLOUD_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with Huawei cloud search service",
|
||||
default=None,
|
||||
)
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.2.0",
|
||||
default="1.3.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -3,3 +3,4 @@ from enum import StrEnum
|
||||
|
||||
class RemoteSettingsSourceName(StrEnum):
|
||||
APOLLO = "apollo"
|
||||
NACOS = "nacos"
|
||||
|
||||
52
api/configs/remote_settings_sources/nacos/__init__.py
Normal file
52
api/configs/remote_settings_sources/nacos/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .http_request import NacosHttpClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||
|
||||
from .utils import _parse_config
|
||||
|
||||
|
||||
class NacosSettingsSource(RemoteSettingsSource):
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
self.configs = configs
|
||||
self.remote_configs: dict[str, Any] = {}
|
||||
self.async_init()
|
||||
|
||||
def async_init(self):
|
||||
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
|
||||
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
|
||||
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
|
||||
|
||||
params = {"dataId": data_id, "group": group, "tenant": tenant}
|
||||
try:
|
||||
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
||||
self.remote_configs = self._parse_config(content)
|
||||
except Exception as e:
|
||||
logger.exception("[get-access-token] exception occurred")
|
||||
raise
|
||||
|
||||
def _parse_config(self, content: str) -> dict:
|
||||
if not content:
|
||||
return {}
|
||||
try:
|
||||
return _parse_config(self, content)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse config: {e}")
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
if field_value is None:
|
||||
return None, field_name, False
|
||||
|
||||
return field_value, field_name, False
|
||||
83
api/configs/remote_settings_sources/nacos/http_request.py
Normal file
83
api/configs/remote_settings_sources/nacos/http_request.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NacosHttpClient:
|
||||
def __init__(self):
|
||||
self.username = os.getenv("DIFY_ENV_NACOS_USERNAME")
|
||||
self.password = os.getenv("DIFY_ENV_NACOS_PASSWORD")
|
||||
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
|
||||
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
|
||||
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
|
||||
self.token = None
|
||||
self.token_ttl = 18000
|
||||
self.token_expire_time: float = 0
|
||||
|
||||
def http_request(self, url, method="GET", headers=None, params=None):
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers, params, module="config"):
|
||||
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
|
||||
|
||||
if module == "login":
|
||||
return
|
||||
|
||||
ts = str(int(time.time() * 1000))
|
||||
|
||||
if self.ak and self.sk:
|
||||
sign_str = self.get_sign_str(params["group"], params["tenant"], ts)
|
||||
headers["Spas-AccessKey"] = self.ak
|
||||
headers["Spas-Signature"] = self.__do_sign(sign_str, self.sk)
|
||||
headers["timeStamp"] = ts
|
||||
if self.username and self.password:
|
||||
self.get_access_token(force_refresh=False)
|
||||
params["accessToken"] = self.token
|
||||
|
||||
def __do_sign(self, sign_str, sk):
|
||||
return (
|
||||
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
|
||||
def get_sign_str(self, group, tenant, ts):
|
||||
sign_str = ""
|
||||
if tenant:
|
||||
sign_str = tenant + "+"
|
||||
if group:
|
||||
sign_str = sign_str + group + "+"
|
||||
if sign_str:
|
||||
sign_str += ts
|
||||
return sign_str
|
||||
|
||||
def get_access_token(self, force_refresh=False):
|
||||
current_time = time.time()
|
||||
if self.token and not force_refresh and self.token_expire_time > current_time:
|
||||
return self.token
|
||||
|
||||
params = {"username": self.username, "password": self.password}
|
||||
url = "http://" + self.server + "/nacos/v1/auth/login"
|
||||
try:
|
||||
resp = requests.request("POST", url, headers=None, params=params)
|
||||
resp.raise_for_status()
|
||||
response_data = resp.json()
|
||||
self.token = response_data.get("accessToken")
|
||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||
self.token_expire_time = current_time + self.token_ttl - 10
|
||||
except Exception as e:
|
||||
logger.exception("[get-access-token] exception occur")
|
||||
raise
|
||||
31
api/configs/remote_settings_sources/nacos/utils.py
Normal file
31
api/configs/remote_settings_sources/nacos/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
def _parse_config(self, content: str) -> dict[str, str]:
|
||||
config: dict[str, str] = {}
|
||||
if not content:
|
||||
return config
|
||||
|
||||
for line in content.splitlines():
|
||||
cleaned_line = line.strip()
|
||||
if not cleaned_line or cleaned_line.startswith(("#", "!")):
|
||||
continue
|
||||
|
||||
separator_index = -1
|
||||
for i, c in enumerate(cleaned_line):
|
||||
if c in ("=", ":") and (i == 0 or cleaned_line[i - 1] != "\\"):
|
||||
separator_index = i
|
||||
break
|
||||
|
||||
if separator_index == -1:
|
||||
continue
|
||||
|
||||
key = cleaned_line[:separator_index].strip()
|
||||
raw_value = cleaned_line[separator_index + 1 :].strip()
|
||||
|
||||
try:
|
||||
decoded_value = bytes(raw_value, "utf-8").decode("unicode_escape")
|
||||
decoded_value = decoded_value.replace(r"\=", "=").replace(r"\:", ":")
|
||||
except UnicodeDecodeError:
|
||||
decoded_value = raw_value
|
||||
|
||||
config[key] = decoded_value
|
||||
|
||||
return config
|
||||
@@ -80,8 +80,6 @@ class ChatMessageTextApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, location="json")
|
||||
|
||||
@@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource):
|
||||
return code_result
|
||||
|
||||
|
||||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return structured_output
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
|
||||
|
||||
@@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if "code" in request.args:
|
||||
code = request.args.get("code")
|
||||
code = request.args.get("code", "")
|
||||
if not code:
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
|
||||
@@ -16,7 +16,7 @@ from controllers.console.auth.error import (
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
@@ -30,6 +30,7 @@ from services.feature_service import FeatureService
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@@ -62,6 +63,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@@ -86,12 +88,21 @@ class ForgotPasswordCheckApi(Resource):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email")}
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
@@ -107,6 +118,9 @@ class ForgotPasswordResetApi(Resource):
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
@@ -22,7 +22,7 @@ from controllers.console.error import (
|
||||
EmailSendIpLimitError,
|
||||
NotAllowedCreateWorkspace,
|
||||
)
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
@@ -38,6 +38,7 @@ class LoginApi(Resource):
|
||||
"""Resource for user login."""
|
||||
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -110,6 +111,7 @@ class LogoutApi(Resource):
|
||||
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
|
||||
@@ -664,6 +664,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
):
|
||||
return {
|
||||
@@ -710,6 +711,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -5,6 +5,7 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.manager.exc import PluginPermissionDeniedError
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
@@ -28,15 +29,18 @@ class EndpointCreateApi(Resource):
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
class EndpointListApi(Resource):
|
||||
|
||||
@@ -210,3 +210,16 @@ def enterprise_license_required(view):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_email_password_login:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# otherwise, return 403
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -9,8 +11,8 @@ from controllers.files.error import UnsupportedFileTypeError
|
||||
from controllers.inner_api.plugin.wraps import get_user
|
||||
from controllers.service_api.app.error import FileTooLargeError
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from fields.file_fields import file_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class PluginUploadFileApi(Resource):
|
||||
@@ -51,19 +53,26 @@ class PluginUploadFileApi(Resource):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=file.read(),
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
file_binary=file.read(),
|
||||
mimetype=mimetype,
|
||||
user=user,
|
||||
source=None,
|
||||
filename=filename,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
extension = guess_extension(tool_file.mimetype) or ".bin"
|
||||
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
|
||||
tool_file.mime_type = mimetype
|
||||
tool_file.extension = extension
|
||||
tool_file.preview_url = preview_url
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
return tool_file, 201
|
||||
|
||||
|
||||
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
|
||||
|
||||
@@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
|
||||
@@ -21,14 +21,13 @@ from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
@@ -501,7 +500,7 @@ class BaseAgentRunner(AppRunner):
|
||||
)
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file in file_objs:
|
||||
prompt_message_contents.append(
|
||||
|
||||
@@ -5,12 +5,11 @@ from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
@@ -40,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
# get image detail config
|
||||
|
||||
@@ -15,14 +15,13 @@ from core.model_runtime.entities import (
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
@@ -395,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
# get image detail config
|
||||
|
||||
@@ -6,7 +6,6 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import DatasetRetrieverResource
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
@@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
|
||||
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self._message_id,
|
||||
position=item.get("position") or 0,
|
||||
dataset_id=item.get("dataset_id"),
|
||||
dataset_name=item.get("dataset_name"),
|
||||
document_id=item.get("document_id"),
|
||||
document_name=item.get("document_name"),
|
||||
data_source_type=item.get("data_source_type"),
|
||||
segment_id=item.get("segment_id"),
|
||||
score=item.get("score") if "score" in item else None,
|
||||
hit_count=item.get("hit_count") if "hit_count" in item else None,
|
||||
word_count=item.get("word_count") if "word_count" in item else None,
|
||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
||||
content=item.get("content"),
|
||||
retriever_from=item.get("retriever_from"),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
@@ -7,9 +7,9 @@ from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
@@ -43,7 +43,7 @@ def to_prompt_message_content(
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> MultiModalPromptMessageContent:
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
@@ -58,7 +58,7 @@ def to_prompt_message_content(
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
|
||||
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.llm_generator.prompts import (
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
@@ -340,3 +341,37 @@ class LLMGenerator:
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=instruction),
|
||||
]
|
||||
model_parameters = model_config.get("model_parameters", {})
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
generated_json_schema = cast(str, response.message.content)
|
||||
return {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}}
|
||||
|
||||
You just need to generate the output
|
||||
""" # noqa: E501
|
||||
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE = """
|
||||
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
|
||||
|
||||
## Instructions:
|
||||
|
||||
1. Analyze the user's description of their data needs
|
||||
2. Identify each property that should be included in the schema
|
||||
3. Determine the appropriate data type for each property
|
||||
4. Decide which properties should be required
|
||||
5. Generate a complete JSON Schema with proper syntax
|
||||
6. Include appropriate constraints when specified (min/max values, patterns, formats)
|
||||
7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting.
|
||||
8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly.
|
||||
|
||||
## Examples:
|
||||
|
||||
### Example 1:
|
||||
**User Input:** I need name and age
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"age": { "type": "number" }
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
|
||||
### Example 2:
|
||||
**User Input:** I want to store information about books including title, author, publication year and optional page count
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": { "type": "string" },
|
||||
"author": { "type": "string" },
|
||||
"publicationYear": { "type": "integer" },
|
||||
"pageCount": { "type": "integer" }
|
||||
},
|
||||
"required": ["title", "author", "publicationYear"]
|
||||
}
|
||||
|
||||
### Example 3:
|
||||
**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18)
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"minLength": 8
|
||||
},
|
||||
"age": {
|
||||
"type": "integer",
|
||||
"minimum": 18
|
||||
}
|
||||
},
|
||||
"required": ["email", "password", "age"]
|
||||
}
|
||||
|
||||
### Example 4:
|
||||
**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist.
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"properties": {
|
||||
"songs": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "string"
|
||||
},
|
||||
"aritst": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"id",
|
||||
"duration",
|
||||
"aritst"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"songs"
|
||||
]
|
||||
}
|
||||
|
||||
Now, generate a JSON Schema based on my description
|
||||
""" # noqa: E501
|
||||
|
||||
0
api/core/memory/__init__.py
Normal file
0
api/core/memory/__init__.py
Normal file
64
api/core/memory/base_memory.py
Normal file
64
api/core/memory/base_memory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
|
||||
class BaseMemory:
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
200
api/core/memory/model_context_memory.py
Normal file
200
api/core/memory/model_context_memory.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base_memory import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import LLMMemoryType
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun
|
||||
|
||||
|
||||
class ModelContextMemory(BaseMemory):
|
||||
def __init__(self, conversation: Conversation, node_id: str, model_instance: ModelInstance) -> None:
|
||||
self.conversation = conversation
|
||||
self.node_id = node_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
thread_messages = list(reversed(self._fetch_thread_messages(message_limit)))
|
||||
if not thread_messages:
|
||||
return []
|
||||
# Get all required workflow_run_ids
|
||||
workflow_run_ids = [msg.workflow_run_id for msg in thread_messages]
|
||||
|
||||
# Batch query all related WorkflowNodeExecution records
|
||||
node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.workflow_run_id.in_(workflow_run_ids),
|
||||
WorkflowNodeExecution.node_id == self.node_id,
|
||||
WorkflowNodeExecution.status.in_(
|
||||
[WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION]
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create mapping from workflow_run_id to node_execution
|
||||
node_execution_map = {ne.workflow_run_id: ne for ne in node_executions}
|
||||
|
||||
# Get the last node_execution
|
||||
last_node_execution = node_execution_map.get(thread_messages[-1].workflow_run_id)
|
||||
prompt_messages = self._get_prompt_messages_in_process_data(last_node_execution)
|
||||
|
||||
# Batch query all message-related files
|
||||
message_ids = [msg.id for msg in thread_messages]
|
||||
all_files = db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_ids)).all()
|
||||
|
||||
# Create mapping from message_id to files
|
||||
files_map = {}
|
||||
for file in all_files:
|
||||
if file.message_id not in files_map:
|
||||
files_map[file.message_id] = []
|
||||
files_map[file.message_id].append(file)
|
||||
|
||||
for message in thread_messages:
|
||||
files = files_map.get(message.id, [])
|
||||
node_execution = node_execution_map.get(message.workflow_run_id)
|
||||
if node_execution and files:
|
||||
file_objs, detail = self._handle_file(message, files)
|
||||
if file_objs:
|
||||
outputs = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else ""
|
||||
if not outputs:
|
||||
continue
|
||||
if outputs not in [prompt.content for prompt in prompt_messages]:
|
||||
continue
|
||||
outputs_index = [prompt.content for prompt in prompt_messages].index(outputs)
|
||||
prompt_index = outputs_index - 1
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
content = cast(str, prompt_messages[prompt_index].content)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=content))
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_messages[prompt_index].content = prompt_message_contents
|
||||
return prompt_messages
|
||||
|
||||
def _get_prompt_messages_in_process_data(
|
||||
self,
|
||||
node_execution: WorkflowNodeExecution,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get prompt messages in process data.
|
||||
:param node_execution: node execution
|
||||
:return: prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if not node_execution.process_data:
|
||||
return []
|
||||
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
if process_data.get("memory_type", "") != LLMMemoryType.INDEPENDENT:
|
||||
return []
|
||||
prompts = process_data.get("prompts", [])
|
||||
for prompt in prompts:
|
||||
prompt_content = prompt.get("text", "")
|
||||
if prompt.get("role", "") == "user":
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_content))
|
||||
elif prompt.get("role", "") == "assistant":
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt_content))
|
||||
output = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else ""
|
||||
prompt_messages.append(AssistantPromptMessage(content=output))
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return prompt_messages
|
||||
|
||||
def _fetch_thread_messages(self, message_limit: int | None = None) -> list[Message]:
|
||||
"""
|
||||
Fetch thread messages.
|
||||
:param message_limit: message limit
|
||||
:return: thread messages
|
||||
"""
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id,
|
||||
Message.parent_message_id,
|
||||
Message.answer_tokens,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
|
||||
# fetch the thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
if not thread_messages:
|
||||
return []
|
||||
return thread_messages
|
||||
|
||||
def _handle_file(self, message: Message, files: list[MessageFile]):
|
||||
"""
|
||||
Handle file for memory.
|
||||
:param message: message
|
||||
:param files: files
|
||||
:return: file objects and detail
|
||||
"""
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
else:
|
||||
if message.workflow_run_id:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||
|
||||
if workflow_run and workflow_run.workflow:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict, is_vision=False
|
||||
)
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
app_record = self.conversation.app
|
||||
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
return file_objs, detail
|
||||
@@ -3,16 +3,16 @@ from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base_memory import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@@ -20,7 +20,7 @@ from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
@@ -100,7 +100,7 @@ class TokenBufferMemory:
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
@@ -129,44 +129,3 @@ class TokenBufferMemory:
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Optional
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
@@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum):
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
pass
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
@@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
|
||||
@@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
@@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
@@ -110,12 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
Union[
|
||||
TextPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
AudioPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
@@ -124,7 +131,7 @@ class PromptMessage(BaseModel):
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
@@ -135,6 +142,16 @@ class PromptMessage(BaseModel):
|
||||
"""
|
||||
return not self.content
|
||||
|
||||
@field_serializer("content")
|
||||
def serialize_content(
|
||||
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
|
||||
) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]:
|
||||
if content is None or isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
|
||||
return content
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ from decimal import Decimal
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
@@ -85,6 +85,7 @@ class ModelFeature(Enum):
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
STRUCTURED_OUTPUT = "structured-output"
|
||||
|
||||
|
||||
class DefaultParameterName(StrEnum):
|
||||
@@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel):
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
pricing: Optional[PriceConfig] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model(self):
|
||||
supported_schema_keys = ["json_schema"]
|
||||
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||
if not schema_key:
|
||||
return self
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||
return self
|
||||
|
||||
|
||||
class ModelUsage(BaseModel):
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _gen_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||
):
|
||||
"""
|
||||
Merge incremental tool call updates into existing tool calls.
|
||||
|
||||
:param new_tool_calls: List of new tool call deltas to be merged.
|
||||
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||
"""
|
||||
|
||||
def get_tool_call(tool_call_id: str):
|
||||
"""
|
||||
Get or create a tool call by ID
|
||||
|
||||
:param tool_call_id: tool call ID
|
||||
:return: existing or new tool call
|
||||
"""
|
||||
if not tool_call_id:
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
|
||||
if _tool_call is None:
|
||||
_tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(_tool_call)
|
||||
|
||||
return _tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# generate ID for tool calls with function name but no ID to track them
|
||||
if new_tool_call.function.name and not new_tool_call.id:
|
||||
new_tool_call.id = _gen_tool_call_id()
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.id)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
Model class for large language model.
|
||||
@@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
|
||||
system_fingerprint = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next(
|
||||
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
|
||||
)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in result:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
increase_tool_call(chunk.delta.message.tool_calls)
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
|
||||
@@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
:param query: str
|
||||
:return: dict
|
||||
"""
|
||||
# FIXME(-LAN-): Avoid import service into core
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = ParameterExtractorNodeData(
|
||||
@@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
:param query: str
|
||||
:return: dict
|
||||
"""
|
||||
# FIXME(-LAN-): Avoid import service into core
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = QuestionClassifierNodeData(
|
||||
|
||||
@@ -9,13 +9,12 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
@@ -125,7 +124,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
@@ -201,7 +200,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
if files and query is not None:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -24,6 +25,11 @@ class CompletionModelPromptTemplate(BaseModel):
|
||||
edition_type: Optional[Literal["basic", "jinja2"]] = None
|
||||
|
||||
|
||||
class LLMMemoryType(str, Enum):
|
||||
INDEPENDENT = "independent"
|
||||
GLOBAL = "global"
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
@@ -48,3 +54,4 @@ class MemoryConfig(BaseModel):
|
||||
role_prefix: Optional[RolePrefix] = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: Optional[str] = None
|
||||
type: LLMMemoryType = LLMMemoryType.GLOBAL
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
@@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from constants import UUID_NIL
|
||||
from models.model import Message
|
||||
|
||||
|
||||
def extract_thread_messages(messages: list[Any]):
|
||||
def extract_thread_messages(messages: list[Message]) -> list[Message]:
|
||||
thread_messages = []
|
||||
next_message = None
|
||||
|
||||
|
||||
@@ -124,6 +124,15 @@ class ProviderManager:
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
||||
# Ensure that both the original provider name and its ModelProviderID string representation
|
||||
# are present in the dictionary to handle cases where either form might be used
|
||||
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
|
||||
provider_id = ModelProviderID(provider_name)
|
||||
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
|
||||
# Add the ModelProviderID string representation if it's not already present
|
||||
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
|
||||
provider_name_to_preferred_model_provider_records_dict[provider_name]
|
||||
)
|
||||
|
||||
# Get All provider model settings
|
||||
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
|
||||
@@ -497,8 +506,8 @@ class ProviderManager:
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
|
||||
) -> dict[str, list]:
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
) -> dict[str, list[Provider]]:
|
||||
"""
|
||||
Initialize trial provider records if not exists.
|
||||
|
||||
@@ -532,7 +541,7 @@ class ProviderManager:
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||
provider_record = Provider(
|
||||
new_provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
@@ -542,11 +551,12 @@ class ProviderManager:
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.add(new_provider_record)
|
||||
db.session.commit()
|
||||
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider_record = (
|
||||
existed_provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
@@ -556,11 +566,14 @@ class ProviderManager:
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if provider_record and not provider_record.is_valid:
|
||||
provider_record.is_valid = True
|
||||
if not existed_provider_record:
|
||||
continue
|
||||
|
||||
if not existed_provider_record.is_valid:
|
||||
existed_provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
|
||||
provider_name_to_provider_records_dict[provider_name].append(provider_record)
|
||||
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
|
||||
0
api/core/rag/datasource/vdb/huawei/__init__.py
Normal file
0
api/core/rag/datasource/vdb/huawei/__init__.py
Normal file
215
api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
Normal file
215
api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_ssl_context() -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
return ssl_context
|
||||
|
||||
|
||||
class HuaweiCloudVectorConfig(BaseModel):
|
||||
hosts: str
|
||||
username: str | None
|
||||
password: str | None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["hosts"]:
|
||||
raise ValueError("config HOSTS is required")
|
||||
return values
|
||||
|
||||
def to_elasticsearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts.split(","),
|
||||
"verify_certs": False,
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 30000,
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 10,
|
||||
}
|
||||
if self.username and self.password:
|
||||
params["basic_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
|
||||
class HuaweiCloudVector(BaseVector):
|
||||
def __init__(self, index_name: str, config: HuaweiCloudVectorConfig):
|
||||
super().__init__(index_name.lower())
|
||||
self._client = Elasticsearch(**config.to_elasticsearch_params())
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.HUAWEI_CLOUD
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
self._client.index(
|
||||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] or None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
return uuids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.indices.delete(index=self._collection_name)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
|
||||
query = {
|
||||
"size": top_k,
|
||||
"query": {
|
||||
"vector": {
|
||||
Field.VECTOR.value: {
|
||||
"vector": query_vector,
|
||||
"topk": top_k,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
results = self._client.search(index=self._collection_name, body=query)
|
||||
|
||||
docs_and_scores = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
dim = len(embeddings[0])
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "vector",
|
||||
"dimension": dim,
|
||||
"indexing": True,
|
||||
"algorithm": "GRAPH",
|
||||
"metric": "cosine",
|
||||
"neighbors": 32,
|
||||
"efc": 128,
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
settings = {"index.vector": True}
|
||||
self._client.indices.create(index=self._collection_name, mappings=mappings, settings=settings)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class HuaweiCloudVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HuaweiCloudVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HUAWEI_CLOUD, collection_name))
|
||||
|
||||
return HuaweiCloudVector(
|
||||
index_name=collection_name,
|
||||
config=HuaweiCloudVectorConfig(
|
||||
hosts=dify_config.HUAWEI_CLOUD_HOSTS or "http://localhost:9200",
|
||||
username=dify_config.HUAWEI_CLOUD_USER,
|
||||
password=dify_config.HUAWEI_CLOUD_PASSWORD,
|
||||
),
|
||||
)
|
||||
@@ -2,12 +2,12 @@ import array
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
import oracledb
|
||||
from oracledb.connection import Connection
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@@ -70,6 +70,7 @@ class OracleVector(BaseVector):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
self.config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ORACLE
|
||||
@@ -107,16 +108,19 @@ class OracleVector(BaseVector):
|
||||
outconverter=self.numpy_converter_out,
|
||||
)
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
|
||||
return connection
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
pool_params = {
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"dsn": config.dsn,
|
||||
"min": 1,
|
||||
"max": 50,
|
||||
"max": 5,
|
||||
"increment": 1,
|
||||
}
|
||||
|
||||
if config.is_autonomous:
|
||||
pool_params.update(
|
||||
{
|
||||
@@ -125,22 +129,8 @@ class OracleVector(BaseVector):
|
||||
"wallet_password": config.wallet_password,
|
||||
}
|
||||
)
|
||||
|
||||
return oracledb.create_pool(**pool_params)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.acquire()
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
@@ -162,41 +152,68 @@ class OracleVector(BaseVector):
|
||||
numpy.array(embeddings[i]),
|
||||
)
|
||||
)
|
||||
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
with self._get_cursor() as cur:
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
)
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
# with conn.cursor() as cur:
|
||||
# cur.executemany(
|
||||
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
# )
|
||||
# conn.commit()
|
||||
for value in values:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute(
|
||||
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
|
||||
VALUES (:1, :2, :3, :4)""",
|
||||
value,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
conn.close()
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
return cur.fetchone() is not None
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
||||
return cur.fetchone() is not None
|
||||
conn.close()
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
self.pool.release(connection=conn)
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
@@ -205,20 +222,25 @@ class OracleVector(BaseVector):
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
with self._get_connection() as conn:
|
||||
conn.inputtypehandler = self.input_type_handler
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||
AS distance FROM {self.table_name}
|
||||
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -228,7 +250,7 @@ class OracleVector(BaseVector):
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
# score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
@@ -239,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
@@ -260,30 +282,35 @@ class OracleVector(BaseVector):
|
||||
for token in all_tokens:
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name}"
|
||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||
f"order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"""select meta, text, embedding FROM {self.table_name}
|
||||
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
||||
order by score(1) desc fetch first {top_k} rows only""",
|
||||
kk=" ACCUM ".join(entities),
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
else:
|
||||
return [Document(page_content="", metadata={})]
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
@@ -293,11 +320,14 @@ class OracleVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
class OracleVectorFactory(AbstractVectorFactory):
|
||||
|
||||
@@ -156,6 +156,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
|
||||
|
||||
return TableStoreVectorFactory
|
||||
case VectorType.HUAWEI_CLOUD:
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
|
||||
|
||||
return HuaweiCloudVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@@ -26,3 +26,4 @@ class VectorType(StrEnum):
|
||||
OCEANBASE = "oceanbase"
|
||||
OPENGAUSS = "opengauss"
|
||||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
|
||||
@@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
image_map[rel.target_part] = f""
|
||||
|
||||
return image_map
|
||||
|
||||
|
||||
@@ -869,7 +869,9 @@ class DatasetRetrieval:
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
|
||||
logical_operator=metadata_filtering_conditions.logical_operator
|
||||
if metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
conditions=conditions,
|
||||
)
|
||||
elif metadata_filtering_mode == "manual":
|
||||
@@ -891,10 +893,10 @@ class DatasetRetrieval:
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if metadata_filtering_conditions.logical_operator == "or": # type: ignore
|
||||
document_query = document_query.filter(or_(*filters))
|
||||
else:
|
||||
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
|
||||
document_query = document_query.filter(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.filter(or_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
|
||||
@@ -17,7 +17,7 @@ RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||
|
||||
# Repository type literals
|
||||
RepositoryType = Literal["workflow_node_execution"]
|
||||
_RepositoryType = Literal["workflow_node_execution"]
|
||||
|
||||
|
||||
class RepositoryFactory:
|
||||
@@ -32,7 +32,7 @@ class RepositoryFactory:
|
||||
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||
|
||||
@classmethod
|
||||
def _register_factory(cls, repository_type: RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for a specific repository type.
|
||||
This is a private method and should not be called directly.
|
||||
@@ -44,7 +44,7 @@ class RepositoryFactory:
|
||||
cls._factory_functions[repository_type] = factory_func
|
||||
|
||||
@classmethod
|
||||
def _create_repository(cls, repository_type: RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Create a new repository instance with the provided parameters.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
@@ -86,3 +86,12 @@ class WorkflowNodeExecutionRepository(Protocol):
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
||||
|
||||
This method is intended to be used for bulk deletion operations, such as removing
|
||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -94,7 +94,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
context_list.append(source)
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from core.variables.segments import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
@@ -251,7 +251,12 @@ class AgentNode(ToolNode):
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
@@ -348,3 +353,10 @@ class AgentNode(ToolNode):
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features:
|
||||
if feature.value not in AgentOldVersionModelFeatures:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
@@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData):
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(Enum):
|
||||
"""
|
||||
Enum class for old SDK version llm feature.
|
||||
"""
|
||||
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
@@ -349,7 +349,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
conditions=conditions,
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
@@ -380,7 +382,10 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
): # type: ignore
|
||||
document_query = document_query.filter(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.filter(or_(*filters))
|
||||
|
||||
@@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData):
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: dict | None = None
|
||||
structured_output_enabled: bool = False
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
@@ -11,6 +13,7 @@ from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.memory.model_context_memory import ModelContextMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
@@ -22,16 +25,22 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, LLMMemoryType, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
@@ -57,6 +66,12 @@ from core.workflow.nodes.event import (
|
||||
RunRetrieverResourceEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.utils.structured_output.entities import (
|
||||
ResponseFormat,
|
||||
SpecialModelType,
|
||||
SupportStructuredOutputStatus,
|
||||
)
|
||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@@ -92,6 +107,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||
return None
|
||||
return self._parse_structured_output(text)
|
||||
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
@@ -130,7 +151,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
context = event.context
|
||||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs["#context#"] = context
|
||||
|
||||
@@ -171,6 +191,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
),
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
"memory_type": self.node_data.memory.type if self.node_data.memory else None,
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
@@ -192,7 +213,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
structured_output = process_structured_output(result_text)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -513,7 +536,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
@@ -527,10 +555,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
def _fetch_memory(
|
||||
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
) -> Optional[TokenBufferMemory | ModelContextMemory]:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
@@ -549,7 +576,15 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
memory = (
|
||||
TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
if node_data_memory.type == LLMMemoryType.GLOBAL
|
||||
else ModelContextMemory(
|
||||
conversation=conversation,
|
||||
node_id=self.node_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
@@ -559,7 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
memory: TokenBufferMemory | ModelContextMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -568,8 +603,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
# FIXME: fix the type error cause prompt_messages is type quick a few times
|
||||
prompt_messages: list[Any] = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
@@ -631,12 +665,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
prompt_content_type = type(prompt_content)
|
||||
if prompt_content_type == str:
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if "#histories#" in content_item.data:
|
||||
@@ -649,9 +685,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if prompt_content_type == str:
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
@@ -681,7 +718,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content = []
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
@@ -724,10 +761,29 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
|
||||
structured_output: dict[str, Any] | list[Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
return structured_output
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
@@ -926,8 +982,170 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema = self._fetch_structured_output_schema()
|
||||
schema_json = self._prepare_schema_for_model(schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = schema.copy()
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in self.node_data.model.name:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
"""
|
||||
model_name = self.node_data.model.name
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||
)
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
model_credentials = model_instance.credentials
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_schema
|
||||
|
||||
def _fetch_structured_output_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch the structured output schema from the node data.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The structured output schema
|
||||
"""
|
||||
if not self.node_data.structured_output:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
|
||||
if not structured_output_schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
|
||||
try:
|
||||
schema = json.loads(structured_output_schema)
|
||||
if not isinstance(schema, dict):
|
||||
raise LLMNodeError("structured_output_schema must be a JSON object")
|
||||
return schema
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
|
||||
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
||||
"""
|
||||
Check if the current model supports structured output.
|
||||
|
||||
Returns:
|
||||
SupportStructuredOutput: The support status of structured output
|
||||
"""
|
||||
# Early return if structured output is disabled
|
||||
if (
|
||||
not isinstance(self.node_data, LLMNodeData)
|
||||
or not self.node_data.structured_output_enabled
|
||||
or not self.node_data.structured_output
|
||||
):
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
# Get model schema and check if it exists
|
||||
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
||||
if not model_schema:
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
|
||||
# Check if model supports structured output feature
|
||||
return (
|
||||
SupportStructuredOutputStatus.SUPPORTED
|
||||
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||
)
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=contents)
|
||||
@@ -992,7 +1210,7 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: TokenBufferMemory | ModelContextMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
@@ -1009,7 +1227,7 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: TokenBufferMemory | ModelContextMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
@@ -1064,3 +1282,49 @@ def _handle_completion_template(
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
|
||||
24
api/core/workflow/utils/structured_output/entities.py
Normal file
24
api/core/workflow/utils/structured_output/entities.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class SupportStructuredOutputStatus(StrEnum):
|
||||
"""Constants for structured output support status"""
|
||||
|
||||
SUPPORTED = "supported"
|
||||
UNSUPPORTED = "unsupported"
|
||||
DISABLED = "disabled"
|
||||
17
api/core/workflow/utils/structured_output/prompt.py
Normal file
17
api/core/workflow/utils/structured_output/prompt.py
Normal file
@@ -0,0 +1,17 @@
|
||||
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
constraints:
|
||||
- You must output in JSON format.
|
||||
- Do not output boolean value, use string type instead.
|
||||
- Do not output integer or float value, use number type instead.
|
||||
eg:
|
||||
Here is the JSON schema:
|
||||
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||
|
||||
Here is the user's question:
|
||||
My name is John Doe and I am 30 years old.
|
||||
|
||||
output:
|
||||
{"name": "John Doe", "age": 30}
|
||||
Here is the JSON schema:
|
||||
{{schema}}
|
||||
""" # noqa: E501
|
||||
@@ -26,9 +26,12 @@ def init_app(app: DifyApp):
|
||||
|
||||
# Always add StreamHandler to log to console
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.addFilter(RequestIdFilter())
|
||||
log_handlers.append(sh)
|
||||
|
||||
# Apply RequestIdFilter to all handlers
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(RequestIdFilter())
|
||||
|
||||
logging.basicConfig(
|
||||
level=dify_config.LOG_LEVEL,
|
||||
format=dify_config.LOG_FORMAT,
|
||||
|
||||
@@ -14,7 +14,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from opentelemetry.metrics import get_meter_provider, set_meter_provider
|
||||
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3Format
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
@@ -112,6 +112,11 @@ def is_celery_worker():
|
||||
|
||||
|
||||
def init_flask_instrumentor(app: DifyApp):
|
||||
meter = get_meter("http_metrics", version=dify_config.CURRENT_VERSION)
|
||||
_http_response_counter = meter.create_counter(
|
||||
"http.server.response.count", description="Total number of HTTP responses by status code", unit="{response}"
|
||||
)
|
||||
|
||||
def response_hook(span: Span, status: str, response_headers: list):
|
||||
if span and span.is_recording():
|
||||
if status.startswith("2"):
|
||||
@@ -119,6 +124,11 @@ def init_flask_instrumentor(app: DifyApp):
|
||||
else:
|
||||
span.set_status(StatusCode.ERROR, status)
|
||||
|
||||
status = status.split(" ")[0]
|
||||
status_code = int(status)
|
||||
status_class = f"{status_code // 100}xx"
|
||||
_http_response_counter.add(1, {"status_code": status_code, "status_class": status_class})
|
||||
|
||||
instrumentor = FlaskInstrumentor()
|
||||
if dify_config.DEBUG:
|
||||
logging.info("Initializing Flask instrumentor")
|
||||
|
||||
63
api/extensions/ext_otel_patch.py
Normal file
63
api/extensions/ext_otel_patch.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Patch for OpenTelemetry context detach method to handle None tokens gracefully.
|
||||
|
||||
This patch addresses the issue where OpenTelemetry's context.detach() method raises a TypeError
|
||||
when called with a None token. The error occurs in the contextvars_context.py file where it tries
|
||||
to call reset() on a None token.
|
||||
|
||||
Related GitHub issue: https://github.com/langgenius/dify/issues/18496
|
||||
|
||||
Error being fixed:
|
||||
```
|
||||
Traceback (most recent call last):
|
||||
File "opentelemetry/context/__init__.py", line 154, in detach
|
||||
_RUNTIME_CONTEXT.detach(token)
|
||||
File "opentelemetry/context/contextvars_context.py", line 50, in detach
|
||||
self._current_context.reset(token) # type: ignore
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
TypeError: expected an instance of Token, got None
|
||||
```
|
||||
|
||||
Instead of modifying the third-party package directly, this patch monkey-patches the
|
||||
context.detach method to gracefully handle None tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
from opentelemetry import context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Store the original detach method
|
||||
original_detach = context.detach
|
||||
|
||||
|
||||
# Create a patched version that handles None tokens
|
||||
@wraps(original_detach)
|
||||
def patched_detach(token):
|
||||
"""
|
||||
A patched version of context.detach that handles None tokens gracefully.
|
||||
"""
|
||||
if token is None:
|
||||
logger.debug("Attempted to detach a None token, skipping")
|
||||
return
|
||||
|
||||
return original_detach(token)
|
||||
|
||||
|
||||
def is_enabled():
|
||||
"""
|
||||
Check if the extension is enabled.
|
||||
Always enable this patch to prevent errors even when OpenTelemetry is disabled.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
def init_app(app):
|
||||
"""
|
||||
Initialize the OpenTelemetry context patch.
|
||||
"""
|
||||
# Replace the original detach method with our patched version
|
||||
context.detach = patched_detach
|
||||
logger.info("OpenTelemetry context.detach patched to handle None tokens")
|
||||
@@ -19,6 +19,7 @@ file_fields = {
|
||||
"mime_type": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"preview_url": fields.String,
|
||||
}
|
||||
|
||||
remote_file_info_fields = {
|
||||
|
||||
@@ -3,8 +3,8 @@ import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
@@ -13,9 +13,6 @@ from services.plugin.plugin_service import PluginService
|
||||
if TYPE_CHECKING:
|
||||
from models.workflow import Workflow
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin # type: ignore
|
||||
@@ -1091,12 +1088,7 @@ class Message(db.Model): # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def retriever_resources(self):
|
||||
return (
|
||||
db.session.query(DatasetRetrieverResource)
|
||||
.filter(DatasetRetrieverResource.message_id == self.id)
|
||||
.order_by(DatasetRetrieverResource.position.asc())
|
||||
.all()
|
||||
)
|
||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
||||
@property
|
||||
def message_files(self):
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, func
|
||||
@@ -245,6 +243,13 @@ class Workflow(Base):
|
||||
|
||||
@property
|
||||
def tool_published(self) -> bool:
|
||||
"""
|
||||
DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
|
||||
It only checks if there's a WorkflowToolProvider for the app, not if this specific workflow version
|
||||
is the one being used by the tool.
|
||||
|
||||
For accurate checking, use a direct query with tenant_id, app_id, and version.
|
||||
"""
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
||||
return (
|
||||
@@ -601,6 +606,17 @@ class WorkflowNodeExecution(Base):
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_run_node_status_idx",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
"status",
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_run_status_idx",
|
||||
"workflow_run_id",
|
||||
"status",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
@@ -630,6 +646,7 @@ class WorkflowNodeExecution(Base):
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
@@ -637,6 +654,7 @@ class WorkflowNodeExecution(Base):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
@property
|
||||
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
"gunicorn~=23.0.0",
|
||||
"httpx[socks]~=0.27.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.41.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"mailchimp-transactional~=1.0.50",
|
||||
@@ -163,10 +164,7 @@ storage = [
|
||||
############################################################
|
||||
# [ Tools ] dependency group
|
||||
############################################################
|
||||
tools = [
|
||||
"cloudscraper~=1.2.71",
|
||||
"nltk~=3.9.1",
|
||||
]
|
||||
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
|
||||
############################################################
|
||||
# [ VDB ] dependency group
|
||||
@@ -180,7 +178,7 @@ vdb = [
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==2.4.0",
|
||||
"oracledb~=2.2.1",
|
||||
"oracledb==3.0.0",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -36,9 +36,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory)
|
||||
else:
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
@@ -168,3 +172,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
|
||||
session.merge(execution)
|
||||
session.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
|
||||
|
||||
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||
and app_id (if provided) associated with this repository instance.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(
|
||||
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||
)
|
||||
|
||||
@@ -407,10 +407,8 @@ class AccountService:
|
||||
|
||||
raise PasswordResetRateLimitExceededError()
|
||||
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="reset_password", additional_data={"code": code}
|
||||
)
|
||||
code, token = cls.generate_reset_password_token(account_email, account)
|
||||
|
||||
send_reset_password_mail_task.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
@@ -419,6 +417,22 @@ class AccountService:
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def generate_reset_password_token(
|
||||
cls,
|
||||
email: str,
|
||||
account: Optional[Account] = None,
|
||||
code: Optional[str] = None,
|
||||
additional_data: dict[str, Any] = {},
|
||||
):
|
||||
if not code:
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
additional_data["code"] = code
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="reset_password", additional_data=additional_data
|
||||
)
|
||||
return code, token
|
||||
|
||||
@classmethod
|
||||
def revoke_reset_password_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "reset_password")
|
||||
|
||||
@@ -40,7 +40,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.1.5"
|
||||
CURRENT_DSL_VERSION = "0.2.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from configs import dify_config
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
@@ -111,6 +112,8 @@ class DependenciesAnalysisService:
|
||||
Generate the latest version of dependencies
|
||||
"""
|
||||
dependencies = list(set(dependencies))
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
return []
|
||||
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
|
||||
return [
|
||||
PluginDependency(
|
||||
|
||||
@@ -2,13 +2,14 @@ import threading
|
||||
from typing import Optional
|
||||
|
||||
import contexts
|
||||
from core.repository import RepositoryFactory
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
@@ -127,17 +128,17 @@ class WorkflowRunService:
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecution.app_id == app_model.id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
||||
)
|
||||
.order_by(WorkflowNodeExecution.index.desc())
|
||||
.all()
|
||||
# Use the repository to get the node executions
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind(),
|
||||
}
|
||||
)
|
||||
|
||||
return node_executions
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||
|
||||
return list(node_executions)
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repository import RepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -27,6 +28,7 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
@@ -282,8 +284,15 @@ class WorkflowService:
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
# Use the repository to save the workflow node execution
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind(),
|
||||
}
|
||||
)
|
||||
repository.save(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@@ -515,8 +524,19 @@ class WorkflowService:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
|
||||
|
||||
# Check if this workflow is published as a tool
|
||||
if workflow.tool_published:
|
||||
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
||||
# Check if there's a tool provider using this specific workflow version
|
||||
tool_provider = (
|
||||
session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
WorkflowToolProvider.tenant_id == workflow.tenant_id,
|
||||
WorkflowToolProvider.app_id == workflow.app_id,
|
||||
WorkflowToolProvider.version == workflow.version,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_provider:
|
||||
# Cannot delete a workflow that's published as a tool
|
||||
raise WorkflowInUseError("Cannot delete workflow that is published as a tool")
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from celery import shared_task # type: ignore
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.repository import RepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import AppDatasetJoin
|
||||
from models.model import (
|
||||
@@ -30,7 +31,7 @@ from models.model import (
|
||||
)
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.web import PinnedConversation, SavedMessage
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
|
||||
|
||||
|
||||
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
||||
@@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
||||
|
||||
|
||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
def del_workflow_node_execution(workflow_node_execution_id: str):
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_workflow_node_execution,
|
||||
"workflow node execution",
|
||||
# Create a repository instance for WorkflowNodeExecution
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"session_factory": db.session.get_bind(),
|
||||
}
|
||||
)
|
||||
|
||||
# Use the clear method to delete all records for this tenant_id and app_id
|
||||
repository.clear()
|
||||
|
||||
logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green"))
|
||||
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(workflow_app_log_id: str):
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from api.core.rag.datasource.vdb.field import Field
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
|
||||
class MockIndicesClient:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create(self, index, mappings, settings):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def refresh(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def delete(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, index):
|
||||
return True
|
||||
|
||||
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.indices = MockIndicesClient()
|
||||
|
||||
def index(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, **kwargs):
|
||||
return True
|
||||
|
||||
def delete(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def search(self, **kwargs):
|
||||
return {
|
||||
"took": 1,
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: "abcdef",
|
||||
Field.VECTOR.value: [1, 2],
|
||||
Field.METADATA_KEY.value: {},
|
||||
},
|
||||
"_score": 1.0,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: "123456",
|
||||
Field.VECTOR.value: [2, 2],
|
||||
Field.METADATA_KEY.value: {},
|
||||
},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: "a1b2c3",
|
||||
Field.VECTOR.value: [3, 2],
|
||||
Field.METADATA_KEY.value: {},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_client_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Elasticsearch, "__init__", MockClient.__init__)
|
||||
monkeypatch.setattr(Elasticsearch, "index", MockClient.index)
|
||||
monkeypatch.setattr(Elasticsearch, "exists", MockClient.exists)
|
||||
monkeypatch.setattr(Elasticsearch, "delete", MockClient.delete)
|
||||
monkeypatch.setattr(Elasticsearch, "search", MockClient.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
0
api/tests/integration_tests/vdb/huawei/__init__.py
Normal file
0
api/tests/integration_tests/vdb/huawei/__init__.py
Normal file
28
api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py
Normal file
28
api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
|
||||
from tests.integration_tests.vdb.__mock.huaweicloudvectordb import setup_client_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class HuaweiCloudVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = HuaweiCloudVector(
|
||||
"dify",
|
||||
HuaweiCloudVectorConfig(
|
||||
hosts="https://127.0.0.1:9200",
|
||||
username="dify",
|
||||
password="dify",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 3
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 3
|
||||
|
||||
|
||||
def test_huawei_cloud_vector(setup_mock_redis, setup_client_mock):
|
||||
HuaweiCloudVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,99 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||
|
||||
ToolCall = AssistantPromptMessage.ToolCall
|
||||
|
||||
# CASE 1: Single tool call
|
||||
INPUTS_CASE_1 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_1 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||
INPUTS_CASE_2 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_2 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||
INPUTS_CASE_3 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_3 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 4: Tool call sequences with no IDs
|
||||
INPUTS_CASE_4 = [
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_4 = [
|
||||
ToolCall(
|
||||
id="RANDOM_ID_1",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||
),
|
||||
ToolCall(
|
||||
id="RANDOM_ID_2",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||
actual = []
|
||||
_increase_tool_call(inputs, actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test__increase_tool_call():
|
||||
# case 1:
|
||||
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||
|
||||
# case 2:
|
||||
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||
|
||||
# case 3:
|
||||
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||
|
||||
# case 4:
|
||||
mock_id_generator = MagicMock()
|
||||
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
||||
27
api/tests/unit_tests/core/prompt/test_prompt_message.py
Normal file
27
api/tests/unit_tests/core/prompt/test_prompt_message.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_build_prompt_message_with_prompt_message_contents():
|
||||
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
|
||||
assert isinstance(prompt.content, list)
|
||||
assert isinstance(prompt.content[0], TextPromptMessageContent)
|
||||
assert prompt.content[0].data == "Hello, World!"
|
||||
|
||||
|
||||
def test_dump_prompt_message():
|
||||
example_url = "https://example.com/image.jpg"
|
||||
prompt = UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(
|
||||
url=example_url,
|
||||
format="jpeg",
|
||||
mime_type="image/jpeg",
|
||||
)
|
||||
]
|
||||
)
|
||||
data = prompt.model_dump()
|
||||
assert data["content"][0].get("url") == example_url
|
||||
@@ -152,3 +152,27 @@ def test_update(repository, session):
|
||||
|
||||
# Assert session.merge was called
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_clear(repository, session, mocker: MockerFixture):
|
||||
"""Test clear method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_delete.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Mock the execute result with rowcount
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.rowcount = 5 # Simulate 5 records deleted
|
||||
session_obj.execute.return_value = mock_result
|
||||
|
||||
# Call method
|
||||
repository.clear()
|
||||
|
||||
# Assert delete was called with correct parameters
|
||||
mock_delete.assert_called_once_with(WorkflowNodeExecution)
|
||||
mock_stmt.where.assert_called()
|
||||
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
@@ -40,6 +40,10 @@ def workflow_setup():
|
||||
|
||||
def test_delete_workflow_success(workflow_setup):
|
||||
# Setup mocks
|
||||
|
||||
# Mock the tool provider query to return None (not published as a tool)
|
||||
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
) # Return workflow first, then None for app
|
||||
@@ -97,7 +101,12 @@ def test_delete_workflow_in_use_by_app_error(workflow_setup):
|
||||
|
||||
def test_delete_workflow_published_as_tool_error(workflow_setup):
|
||||
# Setup mocks
|
||||
workflow_setup["workflow"].tool_published = True
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
||||
# Mock the tool provider query
|
||||
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
|
||||
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
) # Return workflow first, then None for app
|
||||
|
||||
40
api/uv.lock
generated
40
api/uv.lock
generated
@@ -1,5 +1,4 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.11, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'",
|
||||
@@ -1178,6 +1177,7 @@ dependencies = [
|
||||
{ name = "gunicorn" },
|
||||
{ name = "httpx", extra = ["socks"] },
|
||||
{ name = "jieba" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "langfuse" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "mailchimp-transactional" },
|
||||
@@ -1346,6 +1346,7 @@ requires-dist = [
|
||||
{ name = "gunicorn", specifier = "~=23.0.0" },
|
||||
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
|
||||
{ name = "jieba", specifier = "==0.42.1" },
|
||||
{ name = "json-repair", specifier = ">=0.41.1" },
|
||||
{ name = "langfuse", specifier = "~=2.51.3" },
|
||||
{ name = "langsmith", specifier = "~=0.1.77" },
|
||||
{ name = "mailchimp-transactional", specifier = "~=1.0.50" },
|
||||
@@ -1470,7 +1471,7 @@ vdb = [
|
||||
{ name = "couchbase", specifier = "~=4.3.0" },
|
||||
{ name = "elasticsearch", specifier = "==8.14.0" },
|
||||
{ name = "opensearch-py", specifier = "==2.4.0" },
|
||||
{ name = "oracledb", specifier = "~=2.2.1" },
|
||||
{ name = "oracledb", specifier = "==3.0.0" },
|
||||
{ name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" },
|
||||
{ name = "pgvector", specifier = "==0.2.5" },
|
||||
{ name = "pymilvus", specifier = "~=2.5.0" },
|
||||
@@ -2524,6 +2525,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json-repair"
|
||||
version = "0.41.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonpath-python"
|
||||
version = "1.0.6"
|
||||
@@ -3590,23 +3600,23 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "oracledb"
|
||||
version = "2.2.1"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/36/fb/3fbacb351833dd794abb184303a5761c4bb33df9d770fd15d01ead2ff738/oracledb-2.2.1.tar.gz", hash = "sha256:8464c6f0295f3318daf6c2c72c83c2dcbc37e13f8fd44e3e39ff8665f442d6b6", size = 580818 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/74/b7/a4238295944670fb8cc50a8cc082e0af5a0440bfb1c2bac2b18429c0a579/oracledb-2.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb6d9a4d7400398b22edb9431334f9add884dec9877fd9c4ae531e1ccc6ee1fd", size = 3551303 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/5f/98481d44976cd2b3086361f2d50026066b24090b0e6cd1f2a12c824e9717/oracledb-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07757c240afbb4f28112a6affc2c5e4e34b8a92e5bb9af81a40fba398da2b028", size = 12258455 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/54/06b2540286e2b63f60877d6f3c6c40747e216b6eeda0756260e194897076/oracledb-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daec72f853c47179e98493e9b732909d96d495bdceb521c5973a3940d28142", size = 12317476 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/1a/67814439a4e24df83281a72cb0ba433d6b74e1bff52a9975b87a725bcba5/oracledb-2.2.1-cp311-cp311-win32.whl", hash = "sha256:fec5318d1e0ada7e4674574cb6c8d1665398e8b9c02982279107212f05df1660", size = 1369368 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/b8/b2a8f0607be17f58ec6689ad5fd15c2956f4996c64547325e96439570edf/oracledb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5134dccb5a11bc755abf02fd49be6dc8141dfcae4b650b55d40509323d00b5c2", size = 1655035 },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/5b/2fff762243030f31a6b1561fc8eeb142e69ba6ebd3e7fbe4a2c82f0eb6f0/oracledb-2.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ac5716bc9a48247fdf563f5f4ec097f5c9f074a60fd130cdfe16699208ca29b5", size = 3583960 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/88/34117ae830e7338af7c0481f1c0fc6eda44d558e12f9203b45b491e53071/oracledb-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c150bddb882b7c73fb462aa2d698744da76c363e404570ed11d05b65811d96c3", size = 11749006 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/58/bac788f18c21f727955652fe238de2d24a12c2b455ed4db18a6d23ff781e/oracledb-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193e1888411bc21187ade4b16b76820bd1e8f216e25602f6cd0a97d45723c1dc", size = 11950663 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/e2/005f66ae919c6f7c73e06863256cf43aa844330e2dc61a5f9779ae44a801/oracledb-2.2.1-cp312-cp312-win32.whl", hash = "sha256:44a960f8bbb0711af222e0a9690e037b6a2a382e0559ae8eeb9cfafe26c7a3bc", size = 1324255 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/25/759eb2143134513382e66d874c4aacfd691dec3fef7141170cfa6c1b154f/oracledb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:470136add32f0d0084225c793f12a52b61b52c3dc00c9cd388ec6a3db3a7643e", size = 1613047 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461 },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046 },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640 },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949 },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4074,6 +4084,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -15,3 +15,4 @@ pytest api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/couchbase \
|
||||
api/tests/integration_tests/vdb/oceanbase \
|
||||
api/tests/integration_tests/vdb/tidb_vector \
|
||||
api/tests/integration_tests/vdb/huawei \
|
||||
|
||||
@@ -574,6 +574,11 @@ OPENGAUSS_MIN_CONNECTION=1
|
||||
OPENGAUSS_MAX_CONNECTION=5
|
||||
OPENGAUSS_ENABLE_PQ=false
|
||||
|
||||
# huawei cloud search service vector configurations, only available when VECTOR_STORE is `huawei_cloud`
|
||||
HUAWEI_CLOUD_HOSTS=https://127.0.0.1:9200
|
||||
HUAWEI_CLOUD_USER=admin
|
||||
HUAWEI_CLOUD_PASSWORD=admin
|
||||
|
||||
# Upstash Vector configuration, only available when VECTOR_STORE is `upstash`
|
||||
UPSTASH_VECTOR_URL=https://xxx-vector.upstash.io
|
||||
UPSTASH_VECTOR_TOKEN=dify
|
||||
@@ -1063,3 +1068,6 @@ OTEL_MAX_EXPORT_BATCH_SIZE=512
|
||||
OTEL_METRIC_EXPORT_INTERVAL=60000
|
||||
OTEL_BATCH_EXPORT_TIMEOUT=10000
|
||||
OTEL_METRIC_EXPORT_TIMEOUT=30000
|
||||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
|
||||
@@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.2.0
|
||||
image: langgenius/dify-api:1.3.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -31,7 +31,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.2.0
|
||||
image: langgenius/dify-api:1.3.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -57,7 +57,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.2.0
|
||||
image: langgenius/dify-web:1.3.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@@ -66,6 +66,7 @@ services:
|
||||
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
|
||||
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
||||
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
||||
ALLOW_EMBED: ${ALLOW_EMBED:-false}
|
||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai}
|
||||
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
|
||||
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
|
||||
@@ -130,6 +131,7 @@ services:
|
||||
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
||||
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
||||
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
||||
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||
volumes:
|
||||
- ./volumes/sandbox/dependencies:/dependencies
|
||||
- ./volumes/sandbox/conf:/conf
|
||||
@@ -140,7 +142,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.0.7-local
|
||||
image: langgenius/dify-plugin-daemon:0.0.8-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -551,7 +553,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/opengauss/data:/var/lib/opengauss/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "netstat -lntp | grep tcp6 > /dev/null 2>&1"]
|
||||
test: [ "CMD-SHELL", "netstat -lntp | grep tcp6 > /dev/null 2>&1" ]
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 10
|
||||
|
||||
@@ -60,6 +60,7 @@ services:
|
||||
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
||||
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
||||
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
||||
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||
volumes:
|
||||
- ./volumes/sandbox/dependencies:/dependencies
|
||||
- ./volumes/sandbox/conf:/conf
|
||||
@@ -70,7 +71,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.0.7-local
|
||||
image: langgenius/dify-plugin-daemon:0.0.8-local
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user