mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 19:42:42 +08:00
Compare commits
59 Commits
fix/file-s
...
feat/add-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2debc3b33a | ||
|
|
d5ac2e76fd | ||
|
|
f224dd5602 | ||
|
|
b7f29e8db6 | ||
|
|
05c122d085 | ||
|
|
1f211aa2e9 | ||
|
|
4b6a4fadb5 | ||
|
|
51a54c912e | ||
|
|
f1fef5a826 | ||
|
|
2422735689 | ||
|
|
216442ddc1 | ||
|
|
dd3ac7a2c9 | ||
|
|
11447324ff | ||
|
|
f8210b353e | ||
|
|
2b66c1358b | ||
|
|
102d86d4b6 | ||
|
|
227f49a0cc | ||
|
|
a17f169e01 | ||
|
|
72ea3d6b98 | ||
|
|
17cacf258e | ||
|
|
f7aacefcd6 | ||
|
|
ace7ffab5f | ||
|
|
eec63b112f | ||
|
|
caf7bc8569 | ||
|
|
fd437ff4c5 | ||
|
|
fb218f8b10 | ||
|
|
4693080ce0 | ||
|
|
60ddcdf960 | ||
|
|
303bafb3ac | ||
|
|
7a0d0d9b96 | ||
|
|
84a9d2d072 | ||
|
|
1b5adf40da | ||
|
|
59a32aaae6 | ||
|
|
18106a4fc6 | ||
|
|
fc2297a2ca | ||
|
|
5b7b765090 | ||
|
|
90769ac709 | ||
|
|
ac9f1e9de5 | ||
|
|
5bf31e7a86 | ||
|
|
dd17506078 | ||
|
|
5d1424f67c | ||
|
|
2346b0ab99 | ||
|
|
88dec6ef2b | ||
|
|
e71f494839 | ||
|
|
22bb0414a1 | ||
|
|
6477bb8d77 | ||
|
|
70ddc0ce43 | ||
|
|
9986e4c6d0 | ||
|
|
e2710161f6 | ||
|
|
5f11fe521d | ||
|
|
d018b32d0b | ||
|
|
e54b7cda3d | ||
|
|
fc63841169 | ||
|
|
b674c598f9 | ||
|
|
710230a294 | ||
|
|
169f7440ac | ||
|
|
57ec12eb6b | ||
|
|
2c26f77a25 | ||
|
|
95dc90e6b2 |
@@ -1,5 +1,9 @@
|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
|
||||
@@ -154,7 +154,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。
|
||||
|
||||
- **自托管 Dify 社区版</br>**
|
||||
使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。
|
||||
使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。
|
||||
使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。
|
||||
|
||||
- **面向企业/组织的 Dify</br>**
|
||||
|
||||
@@ -31,8 +31,17 @@ REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
REDIS_DB=0
|
||||
|
||||
# redis Sentinel configuration.
|
||||
REDIS_USE_SENTINEL=false
|
||||
REDIS_SENTINELS=
|
||||
REDIS_SENTINEL_SERVICE_NAME=
|
||||
REDIS_SENTINEL_USERNAME=
|
||||
REDIS_SENTINEL_PASSWORD=
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
|
||||
@@ -55,7 +55,9 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.1-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -571,6 +571,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
TIDB_SERVERLESS_NUMBER: PositiveInt = Field(
|
||||
description="number of tidb serverless cluster",
|
||||
default=500,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@@ -27,6 +27,7 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
|
||||
from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
@@ -54,6 +55,11 @@ class VectorStoreConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field(
|
||||
description="Enable whitelist for vector store.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
@@ -248,5 +254,6 @@ class MiddlewareConfig(
|
||||
InternalTestConfig,
|
||||
VikingDBConfig,
|
||||
UpstashConfig,
|
||||
TidbOnQdrantConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
65
api/configs/middleware/vdb/tidb_on_qdrant_config.py
Normal file
65
api/configs/middleware/vdb/tidb_on_qdrant_config.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TidbOnQdrantConfig(BaseSettings):
|
||||
"""
|
||||
Tidb on Qdrant configs
|
||||
"""
|
||||
|
||||
TIDB_ON_QDRANT_URL: Optional[str] = Field(
|
||||
description="Tidb on Qdrant url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_API_KEY: Optional[str] = Field(
|
||||
description="Tidb on Qdrant api key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Tidb on Qdrant client timeout in seconds",
|
||||
default=20,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field(
|
||||
description="whether enable grpc support for Tidb on Qdrant connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field(
|
||||
description="Tidb on Qdrant grpc port",
|
||||
default=6334,
|
||||
)
|
||||
|
||||
TIDB_PUBLIC_KEY: Optional[str] = Field(
|
||||
description="Tidb account public key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_PRIVATE_KEY: Optional[str] = Field(
|
||||
description="Tidb account private key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_API_URL: Optional[str] = Field(
|
||||
description="Tidb API url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_IAM_API_URL: Optional[str] = Field(
|
||||
description="Tidb IAM API url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_REGION: Optional[str] = Field(
|
||||
description="Tidb serverless region",
|
||||
default="regions/aws-us-east-1",
|
||||
)
|
||||
|
||||
TIDB_PROJECT_ID: Optional[str] = Field(
|
||||
description="Tidb project id",
|
||||
default=None,
|
||||
)
|
||||
6
api/controllers/common/errors.py
Normal file
6
api/controllers/common/errors.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
|
||||
class FilenameNotExistsError(HTTPException):
|
||||
code = 400
|
||||
description = "The specified filename does not exist."
|
||||
58
api/controllers/common/helpers.py
Normal file
58
api/controllers/common/helpers.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
filename: str
|
||||
extension: str
|
||||
mimetype: str
|
||||
size: int
|
||||
|
||||
|
||||
def guess_file_info_from_response(response: httpx.Response):
|
||||
url = str(response.url)
|
||||
# Try to extract filename from URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
|
||||
# If filename couldn't be extracted, use Content-Disposition header
|
||||
if not filename:
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
filename_match = re.search(r'filename="?(.+)"?', content_disposition)
|
||||
if filename_match:
|
||||
filename = filename_match.group(1)
|
||||
|
||||
# If still no filename, generate a unique one
|
||||
if not filename:
|
||||
unique_name = str(uuid4())
|
||||
filename = f"{unique_name}"
|
||||
|
||||
# Guess MIME type from filename first, then URL
|
||||
mimetype, _ = mimetypes.guess_type(filename)
|
||||
if mimetype is None:
|
||||
mimetype, _ = mimetypes.guess_type(url)
|
||||
if mimetype is None:
|
||||
# If guessing fails, use Content-Type from response headers
|
||||
mimetype = response.headers.get("Content-Type", "application/octet-stream")
|
||||
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
# Ensure filename has an extension
|
||||
if not extension:
|
||||
extension = mimetypes.guess_extension(mimetype) or ".bin"
|
||||
filename = f"{filename}{extension}"
|
||||
|
||||
return FileInfo(
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mimetype=mimetype,
|
||||
size=int(response.headers.get("Content-Length", -1)),
|
||||
)
|
||||
@@ -2,9 +2,21 @@ from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# File
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
|
||||
@@ -43,7 +55,6 @@ from .datasets import (
|
||||
datasets_document,
|
||||
datasets_segments,
|
||||
external,
|
||||
file,
|
||||
hit_testing,
|
||||
website,
|
||||
)
|
||||
|
||||
@@ -10,8 +10,7 @@ from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
from . import api
|
||||
from .setup import setup_required
|
||||
from .wraps import account_initialization_required
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
@@ -6,8 +6,11 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
|
||||
@@ -6,8 +6,11 @@ from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
|
||||
@@ -18,8 +18,7 @@ from controllers.console.app.error import (
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -15,8 +15,7 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
@@ -10,8 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
|
||||
@@ -4,8 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import paginated_conversation_variable_fields
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -10,8 +10,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
@@ -14,8 +14,11 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
@@ -105,6 +108,8 @@ class ChatMessageListApi(Resource):
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
|
||||
@@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
@@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -9,8 +9,7 @@ import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from factories import variable_factory
|
||||
|
||||
@@ -3,8 +3,7 @@ from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
|
||||
@@ -3,8 +3,7 @@ from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_detail_fields,
|
||||
|
||||
@@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -7,8 +7,7 @@ from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
|
||||
@@ -11,8 +11,7 @@ from controllers.console import api
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
|
||||
@@ -13,7 +13,7 @@ from controllers.console.auth.error import (
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
|
||||
@@ -20,7 +20,7 @@ from controllers.console.error import (
|
||||
NotAllowedCreateWorkspace,
|
||||
NotAllowedRegister,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
|
||||
@@ -2,8 +2,7 @@ from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@@ -7,8 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
|
||||
@@ -10,8 +10,7 @@ from controllers.console import api
|
||||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -102,6 +101,13 @@ class DatasetListApi(Resource):
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
@@ -140,6 +146,7 @@ class DatasetListApi(Resource):
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
@@ -631,6 +638,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -24,8 +24,11 @@ from controllers.console.datasets.error import (
|
||||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
||||
@@ -11,11 +11,11 @@ import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
@@ -6,8 +6,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@@ -2,8 +2,7 @@ from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
@@ -21,7 +21,12 @@ class AppParameterApi(InstalledAppResource):
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
@@ -82,7 +87,12 @@ class AppParameterApi(InstalledAppResource):
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
@@ -5,8 +5,7 @@ from libs.login import login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
from .setup import setup_required
|
||||
from .wraps import account_initialization_required, cloud_utm_record
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
|
||||
class FeatureApi(Resource):
|
||||
|
||||
@@ -1,25 +1,26 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import (
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
from .errors import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
@@ -44,21 +45,29 @@ class FileApi(Resource):
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
source = request.form.get("source")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("source", type=str, required=False, location="args")
|
||||
source = parser.parse_args().get("source")
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file=file, user=current_user, source=source)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source=source,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
@@ -83,23 +92,3 @@ class FileSupportTypeApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
25
api/controllers/console/files/errors.py
Normal file
25
api/controllers/console/files/errors.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
70
api/controllers/console/remote_files.py
Normal file
70
api/controllers/console/remote_files.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
from controllers.common import helpers
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
|
||||
response = ssrf_proxy.head(url)
|
||||
response.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(response)
|
||||
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
return {"error": "File size exceeded"}, 400
|
||||
|
||||
response = ssrf_proxy.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
@@ -1,5 +1,3 @@
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
@@ -10,7 +8,7 @@ from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
from .error import AlreadySetupError, NotInitValidateError, NotSetupError
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
@@ -52,26 +50,10 @@ class SetupApi(Resource):
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
elif not get_setup_status():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def get_setup_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return DifySetup.query.first()
|
||||
else:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(SetupApi, "/setup")
|
||||
|
||||
@@ -4,8 +4,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Tag
|
||||
|
||||
@@ -8,14 +8,13 @@ from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import (
|
||||
AccountAlreadyInitedError,
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidInvitationCodeError,
|
||||
RepeatPasswordNotMatchError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.helper import TimestampField, timezone
|
||||
|
||||
@@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_user, login_required
|
||||
|
||||
@@ -4,8 +4,11 @@ from flask_restful import Resource, abort, marshal_with, reqparse
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -6,8 +6,7 @@ from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
@@ -5,8 +5,7 @@ from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
@@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -6,6 +6,7 @@ from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqpa
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.datasets.error import (
|
||||
@@ -15,8 +16,11 @@ from controllers.console.datasets.error import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
@@ -193,12 +197,20 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
extension = file.filename.split(".")[-1]
|
||||
if extension.lower() not in {"svg", "png"}:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from flask import abort, request
|
||||
@@ -6,9 +7,12 @@ from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService
|
||||
from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError
|
||||
|
||||
|
||||
def account_initialization_required(view):
|
||||
@wraps(view)
|
||||
@@ -124,3 +128,17 @@ def cloud_utm_record(view):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first():
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
|
||||
@@ -21,7 +21,12 @@ class AppParameterApi(Resource):
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
@@ -81,7 +86,12 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from flask import request
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
@@ -31,8 +32,17 @@ class FileApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
||||
@@ -66,6 +66,13 @@ class DatasetListApi(DatasetApiResource):
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
@@ -108,6 +115,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=args["permission"],
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import (
|
||||
@@ -55,7 +56,12 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
@@ -104,7 +110,11 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
raise ValueError("Dataset is not exist.")
|
||||
|
||||
if args["text"]:
|
||||
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
@@ -163,7 +173,16 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
@@ -212,7 +231,16 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
|
||||
@@ -2,8 +2,17 @@ from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .files import FileApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("web", __name__, url_prefix="/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Files
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
|
||||
from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow
|
||||
|
||||
@@ -21,7 +21,12 @@ class AppParameterApi(WebApiResource):
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
@@ -80,7 +85,12 @@ class AppParameterApi(WebApiResource):
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("source", type=str, required=False, location="args")
|
||||
source = parser.parse_args().get("source")
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file=file, user=end_user, source=source)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
43
api/controllers/web/files.py
Normal file
43
api/controllers/web/files.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
file = request.files["file"]
|
||||
source = request.form.get("source")
|
||||
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
source=source,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
68
api/controllers/web/remote_files.py
Normal file
68
api/controllers/web/remote_files.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
|
||||
from controllers.common import helpers
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
class RemoteFileUploadApi(WebApiResource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
|
||||
response = ssrf_proxy.head(url)
|
||||
response.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(response)
|
||||
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
return {"error": "File size exceeded"}, 400
|
||||
|
||||
response = ssrf_proxy.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
@@ -105,6 +105,7 @@ class LLMResult(BaseModel):
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
message: AssistantPromptMessage
|
||||
|
||||
@@ -44,6 +44,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@@ -397,16 +397,21 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
@@ -450,7 +455,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
@@ -462,20 +467,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
@@ -524,6 +535,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
@@ -536,6 +548,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
@@ -545,17 +558,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
|
||||
id=message_id,
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
response_json = response.json()
|
||||
response_json: dict = response.json()
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
output = response_json["choices"][0]
|
||||
message_id = response_json.get("id")
|
||||
|
||||
response_content = ""
|
||||
tool_calls = None
|
||||
@@ -593,6 +611,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
id=message_id,
|
||||
model=response_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
|
||||
17
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
Normal file
17
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ClusterEntity(BaseModel):
|
||||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
|
||||
name: str
|
||||
cluster_id: str
|
||||
displayName: str
|
||||
region: str
|
||||
spendingLimit: Optional[int] = 1000
|
||||
version: str
|
||||
createdBy: str
|
||||
@@ -0,0 +1,526 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import qdrant_client
|
||||
import requests
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.models import (
|
||||
FilterSelector,
|
||||
HnswConfigDiff,
|
||||
PayloadSchemaType,
|
||||
TextIndexParams,
|
||||
TextIndexType,
|
||||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from requests.auth import HTTPDigestAuth
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, TidbAuthBinding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions import common_types
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||
|
||||
|
||||
class TidbOnQdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str] = None
|
||||
timeout: float = 20
|
||||
root_path: Optional[str] = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {"path": path}
|
||||
else:
|
||||
return {
|
||||
"url": self.endpoint,
|
||||
"api_key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
"verify": False,
|
||||
"grpc_port": self.grpc_port,
|
||||
"prefer_grpc": self.prefer_grpc,
|
||||
}
|
||||
|
||||
|
||||
class TidbConfig(BaseModel):
|
||||
api_url: str
|
||||
public_key: str
|
||||
private_key: str
|
||||
|
||||
|
||||
class TidbOnQdrantVector(BaseVector):
|
||||
def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
self._distance_func = distance_func.upper()
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TIDB_ON_QDRANT
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
# get embedding vector size
|
||||
vector_size = len(embeddings[0])
|
||||
# get collection name
|
||||
collection_name = self._collection_name
|
||||
# create collection
|
||||
self.create_collection(collection_name, vector_size)
|
||||
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str, vector_size: int):
|
||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
all_collection_name = []
|
||||
collections_response = self._client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[self._distance_func],
|
||||
)
|
||||
hnsw_config = HnswConfigDiff(
|
||||
m=0,
|
||||
payload_m=16,
|
||||
ef_construct=100,
|
||||
full_scan_threshold=10000,
|
||||
max_indexing_threads=0,
|
||||
on_disk=False,
|
||||
)
|
||||
self._client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
hnsw_config=hnsw_config,
|
||||
timeout=int(self._client_config.timeout),
|
||||
)
|
||||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
texts_iterator = iter(texts)
|
||||
embeddings_iterator = iter(embeddings)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata and id for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
batch_ids = list(islice(ids_iterator, batch_size))
|
||||
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = list(islice(embeddings_iterator, batch_size))
|
||||
|
||||
points = [
|
||||
rest.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload=payload,
|
||||
)
|
||||
for point_id, vector, payload in zip(
|
||||
batch_ids,
|
||||
batch_embeddings,
|
||||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
|
||||
@classmethod
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str,
|
||||
) -> list[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
if text is None:
|
||||
raise ValueError(
|
||||
"At least one of the texts is None. Please remove it before "
|
||||
"calling .from_texts or .add_texts on Qdrant instance."
|
||||
)
|
||||
metadata = metadatas[i] if metadatas is not None else None
|
||||
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
|
||||
|
||||
return payloads
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
self._client.delete_collection(collection_name=self._collection_name)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
collections_response = self._client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if self._collection_name not in all_collection_name:
|
||||
return False
|
||||
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=kwargs.get("score_threshold", 0.0),
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs most similar by bm25.
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
from qdrant_client.http import models
|
||||
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
]
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=kwargs.get("top_k", 2),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
document.metadata["vector"] = result.vector
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID,
|
||||
dify_config.TIDB_API_URL,
|
||||
dify_config.TIDB_IAM_API_URL,
|
||||
dify_config.TIDB_PUBLIC_KEY,
|
||||
dify_config.TIDB_PRIVATE_KEY,
|
||||
dify_config.TIDB_REGION,
|
||||
)
|
||||
new_tidb_auth_binding = TidbAuthBinding(
|
||||
cluster_id=new_cluster["cluster_id"],
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status="ACTIVE",
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
|
||||
return TidbOnQdrantVector(
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=TidbOnQdrantConfig(
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL,
|
||||
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||
root_path=config.root_path,
|
||||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
|
||||
),
|
||||
)
|
||||
|
||||
def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
|
||||
"""
|
||||
Creates a new TiDB Serverless cluster.
|
||||
:param tidb_config: The configuration for the TiDB Cloud API.
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
region_object = {
|
||||
"name": region,
|
||||
}
|
||||
|
||||
labels = {
|
||||
"tidb.cloud/project": "1372813089454548012",
|
||||
}
|
||||
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
|
||||
|
||||
response = requests.post(
|
||||
f"{tidb_config.api_url}/clusters",
|
||||
json=cluster_data,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
|
||||
"""
|
||||
Changes the root password of a specific TiDB Serverless cluster.
|
||||
|
||||
:param tidb_config: The configuration for the TiDB Cloud API.
|
||||
:param cluster_id: The ID of the cluster for which the password is to be changed (required).
|
||||
:param new_password: The new password for the root user (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
body = {"password": new_password}
|
||||
|
||||
response = requests.put(
|
||||
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
250
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
Normal file
250
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
|
||||
|
||||
class TidbService:
|
||||
@staticmethod
|
||||
def create_tidb_serverless_cluster(
|
||||
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||
):
|
||||
"""
|
||||
Creates a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
region_object = {
|
||||
"name": region,
|
||||
}
|
||||
|
||||
labels = {
|
||||
"tidb.cloud/project": project_id,
|
||||
}
|
||||
|
||||
spending_limit = {
|
||||
"monthly": 100,
|
||||
}
|
||||
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
display_name = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
cluster_data = {
|
||||
"displayName": display_name,
|
||||
"region": region_object,
|
||||
"labels": labels,
|
||||
"spendingLimit": spending_limit,
|
||||
"rootPassword": password,
|
||||
}
|
||||
|
||||
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_id = response_data["clusterId"]
|
||||
retry_count = 0
|
||||
max_retries = 30
|
||||
while retry_count < max_retries:
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if cluster_response["state"] == "ACTIVE":
|
||||
user_prefix = cluster_response["userPrefix"]
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"cluster_name": display_name,
|
||||
"account": f"{user_prefix}.root",
|
||||
"password": password,
|
||||
}
|
||||
time.sleep(30) # wait 30 seconds before retrying
|
||||
retry_count += 1
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
|
||||
"""
|
||||
Deletes a specific TiDB Serverless cluster.
|
||||
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param cluster_id: The ID of the cluster to be deleted (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
|
||||
"""
|
||||
Deletes a specific TiDB Serverless cluster.
|
||||
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param cluster_id: The ID of the cluster to be deleted (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def change_tidb_serverless_root_password(
|
||||
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
|
||||
):
|
||||
"""
|
||||
Changes the root password of a specific TiDB Serverless cluster.
|
||||
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
|
||||
:param account: The account for which the password is to be changed (required).
|
||||
:param new_password: The new password for the root user (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
|
||||
|
||||
response = requests.patch(
|
||||
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(public_key, private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def batch_update_tidb_serverless_cluster_status(
|
||||
tidb_serverless_list: list[TidbAuthBinding],
|
||||
project_id: str,
|
||||
api_url: str,
|
||||
iam_url: str,
|
||||
public_key: str,
|
||||
private_key: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Update the status of a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
clusters = []
|
||||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "FULL"}
|
||||
response = requests.get(
|
||||
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
state = item["state"]
|
||||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = "ACTIVE"
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def batch_create_tidb_serverless_cluster(
|
||||
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Creates a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
clusters = []
|
||||
for _ in range(batch_size):
|
||||
region_object = {
|
||||
"name": region,
|
||||
}
|
||||
|
||||
labels = {
|
||||
"tidb.cloud/project": project_id,
|
||||
}
|
||||
|
||||
spending_limit = {
|
||||
"monthly": 10,
|
||||
}
|
||||
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
display_name = str(uuid.uuid4()).replace("-", "")
|
||||
cluster_data = {
|
||||
"cluster": {
|
||||
"displayName": display_name,
|
||||
"region": region_object,
|
||||
"labels": labels,
|
||||
"spendingLimit": spending_limit,
|
||||
"rootPassword": password,
|
||||
}
|
||||
}
|
||||
cache_key = f"tidb_serverless_cluster_password:{display_name}"
|
||||
redis_client.setex(cache_key, 3600, password)
|
||||
clusters.append(cluster_data)
|
||||
|
||||
request_body = {"requests": clusters}
|
||||
response = requests.post(
|
||||
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||
password = redis_client.get(cache_key)
|
||||
if not password:
|
||||
continue
|
||||
cluster_info = {
|
||||
"cluster_id": item["clusterId"],
|
||||
"cluster_name": item["displayName"],
|
||||
"account": "root",
|
||||
"password": password.decode("utf-8"),
|
||||
}
|
||||
cluster_infos.append(cluster_info)
|
||||
return cluster_infos
|
||||
else:
|
||||
response.raise_for_status()
|
||||
@@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, Whitelist
|
||||
|
||||
|
||||
class AbstractVectorFactory(ABC):
|
||||
@@ -35,8 +36,18 @@ class Vector:
|
||||
|
||||
def _init_vector(self) -> BaseVector:
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict["type"]
|
||||
else:
|
||||
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||
whitelist = (
|
||||
db.session.query(Whitelist)
|
||||
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.one_or_none()
|
||||
)
|
||||
if whitelist:
|
||||
vector_type = VectorType.TIDB_ON_QDRANT
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
@@ -115,6 +126,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
||||
|
||||
return UpstashVectorFactory
|
||||
case VectorType.TIDB_ON_QDRANT:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@@ -19,3 +19,4 @@ class VectorType(str, Enum):
|
||||
BAIDU = "baidu"
|
||||
VIKINGDB = "vikingdb"
|
||||
UPSTASH = "upstash"
|
||||
TIDB_ON_QDRANT = "tidb_on_qdrant"
|
||||
|
||||
@@ -21,7 +21,6 @@ from core.rag.extractor.unstructured.unstructured_eml_extractor import Unstructu
|
||||
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pdf_extractor import UnstructuredPDFExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
|
||||
@@ -103,7 +102,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = UnstructuredPDFExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in {".md", ".markdown"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@@ -122,6 +121,8 @@ class ExtractProcessor:
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".ppt":
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
# You must first specify the API key
|
||||
# because unstructured_api_key is necessary to parse .ppt documents
|
||||
elif file_extension == ".pptx":
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".xml":
|
||||
|
||||
@@ -234,7 +234,7 @@ class WordExtractor(BaseExtractor):
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"):
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
drawing_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
语雀客户端
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-06-01 09:45:20"
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
@@ -29,14 +22,13 @@ class AliYuqueTool:
|
||||
session = requests.Session()
|
||||
session.headers.update({"accept": "application/json", "X-Auth-Token": token})
|
||||
new_params = {**tool_parameters}
|
||||
# 找出需要替换的变量
|
||||
|
||||
replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path}
|
||||
|
||||
# 替换 path 中的变量
|
||||
for key, value in replacements.items():
|
||||
path = path.replace(f"{{{key}}}", str(value))
|
||||
del new_params[key] # 从 kwargs 中删除已经替换的变量
|
||||
# 请求接口
|
||||
del new_params[key]
|
||||
|
||||
if method.upper() in {"POST", "PUT"}:
|
||||
session.headers.update(
|
||||
{
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
创建文档
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-06-01 10:45:20"
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -13,7 +13,7 @@ description:
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: number
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
删除文档
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-09-17 22:04"
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -13,7 +13,7 @@ description:
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: number
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
获取知识库首页
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-06-01 22:57:14"
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
获取知识库目录
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-09-17 15:17:11"
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -13,7 +13,7 @@ description:
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: number
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
获取文档
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-06-02 07:11:45"
|
||||
|
||||
import json
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
@@ -37,7 +30,6 @@ class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool):
|
||||
book_slug = path_parts[-2]
|
||||
group_id = path_parts[-3]
|
||||
|
||||
# 1. 请求首页信息,获取book_id
|
||||
new_params["group_login"] = group_id
|
||||
new_params["book_slug"] = book_slug
|
||||
index_page = json.loads(
|
||||
@@ -46,7 +38,7 @@ class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool):
|
||||
book_id = index_page.get("data", {}).get("book", {}).get("id")
|
||||
if not book_id:
|
||||
raise Exception(f"can not parse book_id from {index_page}")
|
||||
# 2. 获取文档内容
|
||||
|
||||
new_params["book_id"] = book_id
|
||||
new_params["id"] = doc_id
|
||||
data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
获取文档
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-06-01 10:45:20"
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -14,7 +14,7 @@ description:
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: number
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
获取知识库目录
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-09-17 15:17:11"
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -13,7 +13,7 @@ description:
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: number
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
"""
|
||||
更新文档
|
||||
"""
|
||||
|
||||
__author__ = "佐井"
|
||||
__created__ = "2024-06-19 16:50:07"
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -12,7 +12,7 @@ description:
|
||||
llm: Update doc in a knowledge base via ID/path.
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: number
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
|
||||
@@ -1,77 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
from fontTools.ttLib import TTFont
|
||||
from matplotlib.font_manager import findSystemFonts
|
||||
from matplotlib.font_manager import FontProperties
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
def set_chinese_font():
|
||||
font_list = [
|
||||
"PingFang SC",
|
||||
"SimHei",
|
||||
"Microsoft YaHei",
|
||||
"STSong",
|
||||
"SimSun",
|
||||
"Arial Unicode MS",
|
||||
"Noto Sans CJK SC",
|
||||
"Noto Sans CJK JP",
|
||||
]
|
||||
|
||||
for font in font_list:
|
||||
chinese_font = FontProperties(font)
|
||||
if chinese_font.get_name() == font:
|
||||
return chinese_font
|
||||
|
||||
return FontProperties()
|
||||
|
||||
|
||||
# use a business theme
|
||||
plt.style.use("seaborn-v0_8-darkgrid")
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
|
||||
def init_fonts():
|
||||
fonts = findSystemFonts()
|
||||
|
||||
popular_unicode_fonts = [
|
||||
"Arial Unicode MS",
|
||||
"DejaVu Sans",
|
||||
"DejaVu Sans Mono",
|
||||
"DejaVu Serif",
|
||||
"FreeMono",
|
||||
"FreeSans",
|
||||
"FreeSerif",
|
||||
"Liberation Mono",
|
||||
"Liberation Sans",
|
||||
"Liberation Serif",
|
||||
"Noto Mono",
|
||||
"Noto Sans",
|
||||
"Noto Serif",
|
||||
"Open Sans",
|
||||
"Roboto",
|
||||
"Source Code Pro",
|
||||
"Source Sans Pro",
|
||||
"Source Serif Pro",
|
||||
"Ubuntu",
|
||||
"Ubuntu Mono",
|
||||
]
|
||||
|
||||
supported_fonts = []
|
||||
|
||||
for font_path in fonts:
|
||||
try:
|
||||
font = TTFont(font_path)
|
||||
# get family name
|
||||
family_name = font["name"].getName(1, 3, 1).toUnicode()
|
||||
if family_name in popular_unicode_fonts:
|
||||
supported_fonts.append(family_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
plt.rcParams["font.family"] = "sans-serif"
|
||||
# sort by order of popular_unicode_fonts
|
||||
for font in popular_unicode_fonts:
|
||||
if font in supported_fonts:
|
||||
plt.rcParams["font.sans-serif"] = font
|
||||
break
|
||||
|
||||
|
||||
init_fonts()
|
||||
font_properties = set_chinese_font()
|
||||
plt.rcParams["font.family"] = font_properties.get_name()
|
||||
|
||||
|
||||
class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
LinearChartTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"data": "1,3,5,7,9,2,4,6,8,10",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
pass
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import uuid
|
||||
@@ -6,45 +8,48 @@ import httpx
|
||||
from websocket import WebSocket
|
||||
from yarl import URL
|
||||
|
||||
from core.file.file_manager import _get_encoded_string
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ComfyUiClient:
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = URL(base_url)
|
||||
|
||||
def get_history(self, prompt_id: str):
|
||||
def get_history(self, prompt_id: str) -> dict:
|
||||
res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
|
||||
history = res.json()[prompt_id]
|
||||
return history
|
||||
|
||||
def get_image(self, filename: str, subfolder: str, folder_type: str):
|
||||
def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
||||
response = httpx.get(
|
||||
str(self.base_url / "view"),
|
||||
params={"filename": filename, "subfolder": subfolder, "type": folder_type},
|
||||
)
|
||||
return response.content
|
||||
|
||||
def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False):
|
||||
# plan to support img2img in dify 0.10.0
|
||||
with open(input_path, "rb") as file:
|
||||
files = {"image": (name, file, "image/png")}
|
||||
data = {"type": image_type, "overwrite": str(overwrite).lower()}
|
||||
def upload_image(self, image_file: File) -> dict:
|
||||
image_content = base64.b64decode(_get_encoded_string(image_file))
|
||||
file = io.BytesIO(image_content)
|
||||
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
|
||||
res = httpx.post(str(self.base_url / "upload/image"), files=files)
|
||||
return res.json()
|
||||
|
||||
res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files)
|
||||
return res
|
||||
|
||||
def queue_prompt(self, client_id: str, prompt: dict):
|
||||
def queue_prompt(self, client_id: str, prompt: dict) -> str:
|
||||
res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
|
||||
prompt_id = res.json()["prompt_id"]
|
||||
return prompt_id
|
||||
|
||||
def open_websocket_connection(self):
|
||||
def open_websocket_connection(self) -> tuple[WebSocket, str]:
|
||||
client_id = str(uuid.uuid4())
|
||||
ws = WebSocket()
|
||||
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
|
||||
ws.connect(ws_address)
|
||||
return ws, client_id
|
||||
|
||||
def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""):
|
||||
def set_prompt(
|
||||
self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
find the first KSampler, then can find the prompt node through it.
|
||||
"""
|
||||
@@ -58,6 +63,10 @@ class ComfyUiClient:
|
||||
if negative_prompt != "":
|
||||
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
||||
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
||||
|
||||
if image_name != "":
|
||||
image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
|
||||
prompt.get(image_loader)["inputs"]["image"] = image_name
|
||||
return prompt
|
||||
|
||||
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
||||
@@ -89,7 +98,7 @@ class ComfyUiClient:
|
||||
else:
|
||||
continue
|
||||
|
||||
def generate_image_by_prompt(self, prompt: dict):
|
||||
def generate_image_by_prompt(self, prompt: dict) -> list[bytes]:
|
||||
try:
|
||||
ws, client_id = self.open_websocket_connection()
|
||||
prompt_id = self.queue_prompt(client_id, prompt)
|
||||
|
||||
@@ -2,10 +2,9 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from .comfyui_client import ComfyUiClient
|
||||
|
||||
|
||||
class ComfyUIWorkflowTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
@@ -14,13 +13,16 @@ class ComfyUIWorkflowTool(BuiltinTool):
|
||||
positive_prompt = tool_parameters.get("positive_prompt")
|
||||
negative_prompt = tool_parameters.get("negative_prompt")
|
||||
workflow = tool_parameters.get("workflow_json")
|
||||
image_name = ""
|
||||
if image := tool_parameters.get("image"):
|
||||
image_name = comfyui.upload_image(image).get("name")
|
||||
|
||||
try:
|
||||
origin_prompt = json.loads(workflow)
|
||||
except:
|
||||
return self.create_text_message("the Workflow JSON is not correct")
|
||||
|
||||
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt)
|
||||
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
|
||||
images = comfyui.generate_image_by_prompt(prompt)
|
||||
result = []
|
||||
for img in images:
|
||||
|
||||
@@ -24,6 +24,13 @@ parameters:
|
||||
zh_Hans: 负面提示词
|
||||
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||
form: llm
|
||||
- name: image
|
||||
type: file
|
||||
label:
|
||||
en_US: Input Image
|
||||
zh_Hans: 输入的图片
|
||||
llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
|
||||
form: llm
|
||||
- name: workflow_json
|
||||
type: string
|
||||
required: true
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Any
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -20,11 +19,9 @@ class DuckDuckGoImageSearchTool(BuiltinTool):
|
||||
"max_results": tool_parameters.get("max_results"),
|
||||
}
|
||||
response = DDGS().images(**query_dict)
|
||||
result = []
|
||||
markdown_result = "\n\n"
|
||||
json_result = []
|
||||
for res in response:
|
||||
res["transfer_method"] = FileTransferMethod.REMOTE_URL
|
||||
msg = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res
|
||||
)
|
||||
result.append(msg)
|
||||
return result
|
||||
markdown_result += f" or ''})"
|
||||
json_result.append(self.create_json_message(res))
|
||||
return [self.create_text_message(markdown_result)] + json_result
|
||||
|
||||
@@ -130,15 +130,14 @@ class GraphEngine:
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
stream_processor_cls = AnswerStreamProcessor
|
||||
stream_processor = AnswerStreamProcessor(
|
||||
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
else:
|
||||
stream_processor_cls = EndStreamProcessor
|
||||
|
||||
stream_processor = stream_processor_cls(
|
||||
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
stream_processor = EndStreamProcessor(
|
||||
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
|
||||
|
||||
@@ -149,10 +149,10 @@ class AnswerStreamGeneratorRouter:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.ANSWER,
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
}:
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user