mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 19:42:42 +08:00
Compare commits
43 Commits
chore/work
...
feat/pytho
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a77b0c8216 | ||
|
|
542e904614 | ||
|
|
cc1f90103b | ||
|
|
909cbee550 | ||
|
|
090eb13430 | ||
|
|
1cd35c5289 | ||
|
|
62fba95b1e | ||
|
|
055f510619 | ||
|
|
54ac80b2c6 | ||
|
|
eb275c0fca | ||
|
|
db4cdd5d60 | ||
|
|
1e9317910c | ||
|
|
f3011ce67a | ||
|
|
52dddf9cd6 | ||
|
|
5ce6776383 | ||
|
|
dfcf337fe1 | ||
|
|
5dc70d5be0 | ||
|
|
b351d24a3c | ||
|
|
089c47cab1 | ||
|
|
6af6c7846b | ||
|
|
7a8e027a93 | ||
|
|
d0dd81cf84 | ||
|
|
65b832c46c | ||
|
|
a90b60c36f | ||
|
|
94a07706ec | ||
|
|
ab2eacb6c1 | ||
|
|
aead192743 | ||
|
|
c1e8584b97 | ||
|
|
8a2b208299 | ||
|
|
2b6882bd97 | ||
|
|
aa51662d98 | ||
|
|
3068526797 | ||
|
|
298d8c2d88 | ||
|
|
294e01a8c1 | ||
|
|
3a5aa4587c | ||
|
|
cf1778e696 | ||
|
|
54db4c176a | ||
|
|
5d3e8a31d0 | ||
|
|
885dff82e3 | ||
|
|
3c4aa24198 | ||
|
|
33b0814323 | ||
|
|
45ae511036 | ||
|
|
0fa063c640 |
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
||||
@@ -81,7 +81,6 @@ ignore = [
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
@@ -362,11 +362,11 @@ class HttpConfig(BaseSettings):
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
@@ -771,7 +771,7 @@ class MailConfig(BaseSettings):
|
||||
|
||||
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
||||
description="""
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Only available in sandbox mode.""",
|
||||
default=3,
|
||||
)
|
||||
|
||||
@@ -90,7 +90,7 @@ class ModelConfigResource(Resource):
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
@@ -124,7 +124,7 @@ class ModelConfigResource(Resource):
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
|
||||
# get tool
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
|
||||
@@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
@@ -257,13 +257,15 @@ class DataSourceNotionApi(Resource):
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
@@ -24,7 +24,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
@@ -513,13 +513,15 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@@ -528,14 +530,16 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
},
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
@@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
@@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource):
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
|
||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@@ -395,7 +395,7 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
@@ -547,13 +547,15 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@@ -562,14 +564,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
@@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource):
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
|
||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
@@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
@@ -22,6 +22,7 @@ from services.feature_service import FeatureService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps")
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource):
|
||||
return {"message": "App installed successfully"}
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
|
||||
class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
update and delete an installed app
|
||||
@@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "App info updated successfully"}
|
||||
|
||||
|
||||
api.add_resource(InstalledAppsListApi, "/installed-apps")
|
||||
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask_restx import marshal_with
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
@@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
class ExploreAppMetaApi(InstalledAppResource):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Get app meta"""
|
||||
@@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
if not app_model:
|
||||
raise ValueError("App not found")
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
api.add_resource(
|
||||
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
|
||||
)
|
||||
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import current_user, login_required
|
||||
@@ -35,6 +35,7 @@ recommended_app_list_fields = {
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource):
|
||||
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||
class RecommendedAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
return RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
|
||||
api.add_resource(RecommendedAppListApi, "/explore/apps")
|
||||
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")
|
||||
|
||||
@@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
@@ -25,6 +25,7 @@ message_fields = {
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
@@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
def delete(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
@@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource):
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
SavedMessageListApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages",
|
||||
endpoint="installed_app_saved_messages",
|
||||
)
|
||||
api.add_resource(
|
||||
SavedMessageApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
|
||||
endpoint="installed_app_saved_message",
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailChangeLimitError,
|
||||
@@ -45,6 +45,7 @@ from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -97,6 +98,7 @@ class AccountInitApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/profile")
|
||||
class AccountProfileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -109,6 +111,7 @@ class AccountProfileApi(Resource):
|
||||
return current_user
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
class AccountNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -130,6 +133,7 @@ class AccountNameApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -147,6 +151,7 @@ class AccountAvatarApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
class AccountInterfaceLanguageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
class AccountInterfaceThemeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
class AccountTimezoneApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
class AccountPasswordApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -227,6 +235,7 @@ class AccountPasswordApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/integrates")
|
||||
class AccountIntegrateApi(Resource):
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
@@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource):
|
||||
return {"data": integrate_data}
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/verify")
|
||||
class AccountDeleteVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -320,6 +331,7 @@ class AccountDeleteApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
@@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/education/verify")
|
||||
class EducationVerifyApi(Resource):
|
||||
verify_fields = {
|
||||
"token": fields.String,
|
||||
@@ -352,6 +365,7 @@ class EducationVerifyApi(Resource):
|
||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
"result": fields.Boolean,
|
||||
@@ -396,6 +410,7 @@ class EducationApi(Resource):
|
||||
return res
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
@@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource):
|
||||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
class ChangeEmailResetApi(Resource):
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
@@ -558,28 +577,3 @@ class CheckEmailUnique(Resource):
|
||||
if not AccountService.check_email_unique(args["email"]):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
# Register API resources
|
||||
api.add_resource(AccountInitApi, "/account/init")
|
||||
api.add_resource(AccountProfileApi, "/account/profile")
|
||||
api.add_resource(AccountNameApi, "/account/name")
|
||||
api.add_resource(AccountAvatarApi, "/account/avatar")
|
||||
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
|
||||
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
|
||||
api.add_resource(AccountTimezoneApi, "/account/timezone")
|
||||
api.add_resource(AccountPasswordApi, "/account/password")
|
||||
api.add_resource(AccountIntegrateApi, "/account/integrates")
|
||||
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
|
||||
api.add_resource(AccountDeleteApi, "/account/delete")
|
||||
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
|
||||
api.add_resource(EducationVerifyApi, "/account/education/verify")
|
||||
api.add_resource(EducationApi, "/account/education")
|
||||
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
|
||||
# Change email
|
||||
api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
|
||||
api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
|
||||
api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
|
||||
api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
|
||||
# api.add_resource(AccountEmailApi, '/account/email')
|
||||
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||
)
|
||||
class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||
)
|
||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
response["error"] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Load Balancing Config
|
||||
api.add_resource(
|
||||
LoadBalancingCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
LoadBalancingConfigCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
@@ -33,6 +33,7 @@ from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
class MemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
@@ -49,6 +50,7 @@ class MemberListApi(Resource):
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
class MemberInviteEmailApi(Resource):
|
||||
"""Invite a new member by email."""
|
||||
|
||||
@@ -111,6 +113,7 @@ class MemberInviteEmailApi(Resource):
|
||||
}, 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>")
|
||||
class MemberCancelInviteApi(Resource):
|
||||
"""Cancel an invitation by member id."""
|
||||
|
||||
@@ -143,6 +146,7 @@ class MemberCancelInviteApi(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
"""Update member role."""
|
||||
|
||||
@@ -177,6 +181,7 @@ class MemberUpdateRoleApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dataset-operators")
|
||||
class DatasetOperatorMemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
@@ -193,6 +198,7 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
@@ -233,6 +239,7 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -278,6 +285,7 @@ class OwnerTransferCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
class OwnerTransfer(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -339,14 +347,3 @@ class OwnerTransfer(Resource):
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(MemberListApi, "/workspaces/current/members")
|
||||
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
|
||||
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
|
||||
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
|
||||
# owner transfer
|
||||
api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
|
||||
api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -17,6 +17,7 @@ from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
class ModelProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -45,6 +46,7 @@ class ModelProviderListApi(Resource):
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>")
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
@@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource):
|
||||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/checkout-url")
|
||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
prefilled_email=current_user.email,
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(
|
||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||
)
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
)
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi,
|
||||
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
class DefaultModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -85,6 +86,7 @@ class DefaultModelApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
|
||||
class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
return jsonable_encoder({"data": parameter_rules})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/models/model-types/<string:model_type>")
|
||||
class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
|
||||
api.add_resource(
|
||||
ModelProviderModelEnableApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable",
|
||||
endpoint="model-provider-model-enable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelDisableApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable",
|
||||
endpoint="model-provider-model-disable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialSwitchApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
|
||||
)
|
||||
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
|
||||
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")
|
||||
|
||||
@@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -55,6 +57,7 @@ class PluginListApi(Resource):
|
||||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource):
|
||||
return jsonable_encoder({"versions": versions})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource):
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/icon")
|
||||
class PluginIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
@@ -108,6 +113,7 @@ class PluginIconApi(Resource):
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/pkg")
|
||||
class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/github")
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/bundle")
|
||||
class PluginUploadFromBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/pkg")
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/github")
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/marketplace")
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
|
||||
class PluginFetchMarketplacePkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks")
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
|
||||
class PluginFetchInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
class PluginDeleteInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
|
||||
class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/github")
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -453,6 +474,7 @@ class PluginUninstallApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource):
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/fetch")
|
||||
class PluginFetchPermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource):
|
||||
return jsonable_encoder({"success": True})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
|
||||
class PluginFetchPreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource):
|
||||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
args = req.parse_args()
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
|
||||
|
||||
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
|
||||
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
|
||||
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
|
||||
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
|
||||
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
|
||||
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
|
||||
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
|
||||
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
|
||||
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
|
||||
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
|
||||
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
|
||||
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
|
||||
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
|
||||
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
|
||||
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
||||
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
||||
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
|
||||
|
||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||
|
||||
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
|
||||
|
||||
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
|
||||
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
|
||||
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
|
||||
@@ -10,7 +10,7 @@ from flask_restx import (
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
@@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-providers")
|
||||
class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -71,6 +72,7 @@ class ToolProviderListApi(Resource):
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||
class ToolBuiltinProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credentials")
|
||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
@@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/add")
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/remote")
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/tools")
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/update")
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/delete")
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/get")
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>")
|
||||
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/schema")
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
|
||||
class ToolWorkflowProviderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
|
||||
class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
|
||||
class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
|
||||
class ToolWorkflowProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
return jsonable_encoder(tool)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
|
||||
class ToolWorkflowProviderListToolApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/builtin")
|
||||
class ToolBuiltinListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/api")
|
||||
class ToolApiListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -642,6 +666,7 @@ class ToolApiListApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/workflow")
|
||||
class ToolWorkflowListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-labels")
|
||||
class ToolLabelsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -672,6 +698,7 @@ class ToolLabelsApi(Resource):
|
||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/tool/authorization-url")
|
||||
class ToolPluginOAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
|
||||
class ToolOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
@@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema")
|
||||
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/info")
|
||||
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp")
|
||||
class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource):
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource):
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/mcp")
|
||||
class ToolMCPListAllApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource):
|
||||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
@console_ns.route("/mcp/oauth/callback")
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource):
|
||||
authorization_code = args["code"]
|
||||
handle_callback(state_key, authorization_code)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
# tool oauth
|
||||
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
|
||||
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
|
||||
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
|
||||
# builtin tool provider
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetOauthClientSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
|
||||
# api tool provider
|
||||
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
|
||||
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
|
||||
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
|
||||
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
|
||||
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
|
||||
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
|
||||
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
|
||||
|
||||
# workflow tool provider
|
||||
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
|
||||
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
|
||||
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
|
||||
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
|
||||
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
|
||||
|
||||
# mcp tool provider
|
||||
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
|
||||
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
|
||||
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
|
||||
|
||||
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
|
||||
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
|
||||
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
|
||||
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
|
||||
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")
|
||||
|
||||
@@ -14,7 +14,7 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
@@ -65,6 +65,7 @@ tenants_fields = {
|
||||
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces")
|
||||
class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -93,6 +94,7 @@ class TenantListApi(Resource):
|
||||
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/all-workspaces")
|
||||
class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
@@ -118,6 +120,8 @@ class WorkspaceListApi(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current", endpoint="workspaces_current")
|
||||
@console_ns.route("/info", endpoint="info") # Deprecated
|
||||
class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -143,11 +147,10 @@ class TenantApi(Resource):
|
||||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
if not tenant:
|
||||
raise ValueError("No tenant available")
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
class SwitchWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -172,6 +175,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/custom-config")
|
||||
class CustomConfigWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -202,6 +206,7 @@ class CustomConfigWorkspaceApi(Resource):
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/custom-config/webapp-logo/upload")
|
||||
class WebappLogoWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -242,6 +247,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
return {"id": upload_file.id}, 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/info")
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -261,13 +267,3 @@ class WorkspaceInfoApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
|
||||
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
|
||||
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
|
||||
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
|
||||
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
|
||||
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info
|
||||
|
||||
@@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type(**data)
|
||||
payload = payload_type.model_validate(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
embedding_model_provider=args["embedding_model_provider"],
|
||||
embedding_model_name=args["embedding_model"],
|
||||
retrieval_model=RetrievalModel(**args["retrieval_model"])
|
||||
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
|
||||
if args["retrieval_model"] is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
@@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
@@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
@@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
|
||||
@@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
args = metadata_create_parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
args = document_metadata_parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
parser.add_argument("is_published", type=bool, required=True, location="json")
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
@@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
args = segment_update_parser.parse_args()
|
||||
|
||||
updated_segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
|
||||
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
||||
end_user_id = enterprise_user_decoded.get("end_user_id")
|
||||
session_id = enterprise_user_decoded.get("session_id")
|
||||
user_auth_type = enterprise_user_decoded.get("auth_type")
|
||||
exchanged_token_expires_unix = enterprise_user_decoded.get("exp")
|
||||
|
||||
if not user_auth_type:
|
||||
raise Unauthorized("Missing auth_type in the token.")
|
||||
|
||||
@@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
exp = int(exp_dt.timestamp())
|
||||
|
||||
exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp())
|
||||
if exchanged_token_expires_unix:
|
||||
exp = int(exchanged_token_expires_unix)
|
||||
|
||||
payload = {
|
||||
"iss": site.id,
|
||||
"sub": "Web API Passport",
|
||||
|
||||
@@ -40,7 +40,7 @@ class AgentConfigManager:
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
||||
"react_router",
|
||||
|
||||
@@ -61,9 +61,6 @@ class AppRunner:
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
|
||||
@@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id
|
||||
in (self.application_generate_entity.start_node_id, "shared")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
@@ -11,11 +11,12 @@ class I18nObject(BaseModel):
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in self.provider.configurate_methods:
|
||||
@@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
|
||||
):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
return self
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
"""
|
||||
|
||||
@@ -131,7 +131,7 @@ class CodeExecutor:
|
||||
if (code := response_data.get("code")) != 0:
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
|
||||
|
||||
response_code = CodeExecutionResponse(**response_data)
|
||||
response_code = CodeExecutionResponse.model_validate(response_data)
|
||||
|
||||
if response_code.data.error:
|
||||
raise CodeExecutionError(response_code.data.error)
|
||||
|
||||
@@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
@@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration(**plugin))
|
||||
result.append(MarketplacePluginDeclaration.model_validate(plugin))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -357,14 +357,16 @@ class IndexingRunner:
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
},
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
@@ -378,14 +380,16 @@ class IndexingRunner:
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
|
||||
@@ -294,7 +294,7 @@ class ClientSession(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
argument=types.CompletionArgument.model_validate(argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
@@ -9,7 +9,8 @@ class I18nObject(BaseModel):
|
||||
zh_Hans: str | None = None
|
||||
en_US: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
return self
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@@ -46,10 +46,11 @@ class FormOption(BaseModel):
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.label:
|
||||
self.label = I18nObject(en_US=self.value)
|
||||
return self
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
|
||||
@@ -269,17 +269,17 @@ class ModelProviderFactory:
|
||||
}
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
return LargeLanguageModel(**init_params) # type: ignore
|
||||
return LargeLanguageModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel(**init_params) # type: ignore
|
||||
return TextEmbeddingModel.model_validate(init_params)
|
||||
elif model_type == ModelType.RERANK:
|
||||
return RerankModel(**init_params) # type: ignore
|
||||
return RerankModel.model_validate(init_params)
|
||||
elif model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel(**init_params) # type: ignore
|
||||
return Speech2TextModel.model_validate(init_params)
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return ModerationModel(**init_params) # type: ignore
|
||||
return ModerationModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel(**init_params) # type: ignore
|
||||
return TTSModel.model_validate(init_params)
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
|
||||
@@ -51,7 +51,7 @@ class ApiModeration(Moderation):
|
||||
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
||||
return ModerationInputsResult(**result)
|
||||
return ModerationInputsResult.model_validate(result)
|
||||
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
@@ -67,7 +67,7 @@ class ApiModeration(Moderation):
|
||||
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
||||
return ModerationOutputsResult(**result)
|
||||
return ModerationOutputsResult.model_validate(result)
|
||||
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
|
||||
@@ -84,15 +84,15 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
|
||||
for i in range(len(v)):
|
||||
if v[i]["role"] == PromptMessageRole.USER.value:
|
||||
v[i] = UserPromptMessage(**v[i])
|
||||
v[i] = UserPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
|
||||
v[i] = AssistantPromptMessage(**v[i])
|
||||
v[i] = AssistantPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
|
||||
v[i] = SystemPromptMessage(**v[i])
|
||||
v[i] = SystemPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL.value:
|
||||
v[i] = ToolPromptMessage(**v[i])
|
||||
v[i] = ToolPromptMessage.model_validate(v[i])
|
||||
else:
|
||||
v[i] = PromptMessage(**v[i])
|
||||
v[i] = PromptMessage.model_validate(v[i])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class BasePluginClient:
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
@@ -104,13 +104,13 @@ class BasePluginClient:
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type(**json.loads(line)) # type: ignore
|
||||
yield type_(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | None = None,
|
||||
params: dict | None = None,
|
||||
@@ -120,13 +120,13 @@ class BasePluginClient:
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type(**response.json()) # type: ignore
|
||||
return type_(**response.json()) # type: ignore
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
@@ -140,22 +140,22 @@ class BasePluginClient:
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
|
||||
logger.exception(msg)
|
||||
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
|
||||
raise e
|
||||
except Exception as e:
|
||||
msg = f"Failed to request plugin daemon, url: {path}"
|
||||
logger.exception(msg)
|
||||
logger.exception("Failed to request plugin daemon, url: %s", path)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
|
||||
except Exception:
|
||||
msg = (
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
|
||||
f" url: {path}"
|
||||
)
|
||||
logger.exception(msg)
|
||||
@@ -163,7 +163,7 @@ class BasePluginClient:
|
||||
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise ValueError(f"{rep.message}, code: {rep.code}")
|
||||
|
||||
@@ -178,7 +178,7 @@ class BasePluginClient:
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
@@ -189,7 +189,7 @@ class BasePluginClient:
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
try:
|
||||
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
# TODO modify this when line_data has code and message
|
||||
try:
|
||||
@@ -204,7 +204,7 @@ class BasePluginClient:
|
||||
if rep.code != 0:
|
||||
if rep.code == -500:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
|
||||
@@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
|
||||
self._get_local_file_datasource_provider()
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
@@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider_id == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type=LLMResultChunk,
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type=PluginLLMNumTokensResponse,
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type=TextEmbeddingResult,
|
||||
type_=TextEmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type=PluginTextEmbeddingNumTokensResponse,
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type=RerankResult,
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type=PluginVoicesResponse,
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type=PluginBasicBooleanResponse,
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Union, cast
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
@@ -87,7 +87,8 @@ def merge_blob_chunks(
|
||||
),
|
||||
meta=resp.meta,
|
||||
)
|
||||
yield cast(MessageType, merged_message)
|
||||
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
|
||||
yield merged_message # type: ignore
|
||||
# Clean up the buffer
|
||||
del files[chunk_id]
|
||||
else:
|
||||
|
||||
@@ -134,7 +134,7 @@ class RetrievalService:
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
|
||||
@@ -17,9 +17,6 @@ class NotionInfo(BaseModel):
|
||||
tenant_id: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
@@ -47,6 +44,3 @@ class ExtractSetting(BaseModel):
|
||||
website_info: WebsiteInfo | None = None
|
||||
document_model: str | None = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("No process rule found.")
|
||||
if process_rule.get("mode") == "automatic":
|
||||
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
|
||||
rules = Rule(**automatic_rule)
|
||||
rules = Rule.model_validate(automatic_rule)
|
||||
else:
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
rules = Rule.model_validate(process_rule.get("rules"))
|
||||
# Split the text documents into nodes.
|
||||
if not rules.segmentation:
|
||||
raise ValueError("No segmentation found in rules.")
|
||||
|
||||
@@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("No process rule found.")
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
rules = Rule.model_validate(process_rule.get("rules"))
|
||||
all_documents: list[Document] = []
|
||||
if rules.parent_mode == ParentMode.PARAGRAPH:
|
||||
# Split the text documents into nodes.
|
||||
@@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_documents = document.children
|
||||
if child_documents:
|
||||
formatted_child_documents = [
|
||||
Document(**child_document.model_dump()) for child_document in child_documents
|
||||
Document.model_validate(child_document.model_dump()) for child_document in child_documents
|
||||
]
|
||||
vector.create(formatted_child_documents)
|
||||
|
||||
@@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
return child_nodes
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
metadata = {
|
||||
@@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
vector.create(all_child_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
|
||||
@@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("No process rule found.")
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
rules = Rule.model_validate(process_rule.get("rules"))
|
||||
splitter = self._get_splitter(
|
||||
processing_rule_mode=process_rule.get("mode"),
|
||||
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
|
||||
@@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
qa_chunks = QAStructureChunk(**chunks)
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
metadata = {
|
||||
@@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
qa_chunks = QAStructureChunk(**chunks)
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = fixed_separator
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
@@ -90,16 +91,19 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if separator == " ":
|
||||
splits = text.split()
|
||||
splits = re.split(r" +", text)
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
else:
|
||||
splits = list(text)
|
||||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
if separator == "\n":
|
||||
splits = [s for s in splits if s != ""]
|
||||
else:
|
||||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
_separator = separator if self._keep_separator else ""
|
||||
s_lens = self._length_function(splits)
|
||||
if separator != "":
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
|
||||
@@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
tools.append(
|
||||
assistant_tool_class(
|
||||
provider=provider,
|
||||
entity=ToolEntity(**tool),
|
||||
entity=ToolEntity.model_validate(tool),
|
||||
runtime=ToolRuntime(tenant_id=""),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
@@ -11,11 +11,12 @@ class I18nObject(BaseModel):
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _populate_missing_locales(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self):
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
||||
@@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
|
||||
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
|
||||
user = db_provider.load_user()
|
||||
tools = [
|
||||
ToolEntity(
|
||||
|
||||
@@ -1008,7 +1008,7 @@ class ToolManager:
|
||||
config = tool_configurations.get(parameter.name, {})
|
||||
if not (config and isinstance(config, dict) and config.get("value") is not None):
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
|
||||
@@ -126,7 +126,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
score=document_score_list.get(segment.index_node_id),
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -105,10 +105,10 @@ class RedisChannel:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand(**data)
|
||||
return AbortCommand.model_validate(data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand(**data)
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@@ -16,7 +16,7 @@ class EndNode(Node):
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
self._node_data = EndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@@ -342,10 +342,13 @@ class IterationNode(Node):
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
@@ -357,13 +360,39 @@ class IterationNode(Node):
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": flattened_outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
|
||||
"""
|
||||
Flatten the outputs list if all elements are lists.
|
||||
This maintains backward compatibility with version 1.8.1 behavior.
|
||||
"""
|
||||
if not outputs:
|
||||
return outputs
|
||||
|
||||
# Check if all non-None outputs are lists
|
||||
non_none_outputs = [output for output in outputs if output is not None]
|
||||
if not non_none_outputs:
|
||||
return outputs
|
||||
|
||||
if all(isinstance(output, list) for output in non_none_outputs):
|
||||
# Flatten the list of lists
|
||||
flattened: list[Any] = []
|
||||
for output in outputs:
|
||||
if isinstance(output, list):
|
||||
flattened.extend(output)
|
||||
elif output is not None:
|
||||
# This shouldn't happen based on our check, but handle it gracefully
|
||||
flattened.append(output)
|
||||
return flattened
|
||||
|
||||
return outputs
|
||||
|
||||
def _handle_iteration_failure(
|
||||
self,
|
||||
started_at: datetime,
|
||||
@@ -373,10 +402,13 @@ class IterationNode(Node):
|
||||
iter_run_map: dict[str, float],
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
|
||||
@@ -18,7 +18,7 @@ class IterationStartNode(Node):
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
self._node_data = IterationStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
@@ -62,7 +62,7 @@ class KnowledgeIndexNode(Node):
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeIndexNodeData, self._node_data)
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
|
||||
@@ -41,7 +41,7 @@ class ListOperatorNode(Node):
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
self._node_data = ListOperatorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@@ -18,7 +18,7 @@ class LoopEndNode(Node):
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
self._node_data = LoopEndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@@ -18,7 +18,7 @@ class LoopStartNode(Node):
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
self._node_data = LoopStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@@ -16,7 +16,7 @@ class StartNode(Node):
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
self._node_data = StartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@@ -15,7 +15,7 @@ class VariableAggregatorNode(Node):
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@@ -14,7 +14,7 @@ def handle(sender, **kwargs):
|
||||
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
|
||||
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
|
||||
try:
|
||||
tool_entity = ToolEntity(**node_data["data"])
|
||||
tool_entity = ToolEntity.model_validate(node_data["data"])
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=tool_entity.provider_type,
|
||||
provider_id=tool_entity.provider_id,
|
||||
|
||||
@@ -61,7 +61,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
|
||||
|
||||
for node in knowledge_retrieval_nodes:
|
||||
try:
|
||||
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
|
||||
node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {}))
|
||||
dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import field
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
@@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum):
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class Account(UserMixin, Base):
|
||||
class Account(UserMixin, TypeBase):
|
||||
__tablename__ = "accounts"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
email: Mapped[str] = mapped_column(String(255))
|
||||
password: Mapped[str | None] = mapped_column(String(255))
|
||||
password_salt: Mapped[str | None] = mapped_column(String(255))
|
||||
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
interface_language: Mapped[str | None] = mapped_column(String(255))
|
||||
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
timezone: Mapped[str | None] = mapped_column(String(255))
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
|
||||
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
password: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
password_salt: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
interface_language: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
timezone: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
last_active_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), server_default=sa.text("'active'::character varying"), default="active"
|
||||
)
|
||||
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
|
||||
@reconstructor
|
||||
def init_on_load(self):
|
||||
self.role: TenantAccountRole | None = None
|
||||
self._current_tenant: Tenant | None = None
|
||||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
@@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum):
|
||||
ARCHIVE = "archive"
|
||||
|
||||
|
||||
class Tenant(Base):
|
||||
class Tenant(TypeBase):
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text)
|
||||
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
||||
custom_config: Mapped[str | None] = mapped_column(sa.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None)
|
||||
plan: Mapped[str] = mapped_column(
|
||||
String(255), server_default=sa.text("'basic'::character varying"), default="basic"
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(255), server_default=sa.text("'normal'::character varying"), default="normal"
|
||||
)
|
||||
custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return list(
|
||||
@@ -257,7 +270,7 @@ class Tenant(Base):
|
||||
self.custom_config = json.dumps(value)
|
||||
|
||||
|
||||
class TenantAccountJoin(Base):
|
||||
class TenantAccountJoin(TypeBase):
|
||||
__tablename__ = "tenant_account_joins"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
@@ -266,17 +279,21 @@ class TenantAccountJoin(Base):
|
||||
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
role: Mapped[str] = mapped_column(String(16), server_default="normal")
|
||||
invited_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
|
||||
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
|
||||
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
|
||||
|
||||
class AccountIntegrate(Base):
|
||||
class AccountIntegrate(TypeBase):
|
||||
__tablename__ = "account_integrates"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
@@ -284,16 +301,20 @@ class AccountIntegrate(Base):
|
||||
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
provider: Mapped[str] = mapped_column(String(16))
|
||||
open_id: Mapped[str] = mapped_column(String(255))
|
||||
encrypted_token: Mapped[str] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
|
||||
|
||||
class InvitationCode(Base):
|
||||
class InvitationCode(TypeBase):
|
||||
__tablename__ = "invitation_codes"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
@@ -301,18 +322,22 @@ class InvitationCode(Base):
|
||||
sa.Index("invitation_codes_code_idx", "code", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(sa.Integer)
|
||||
id: Mapped[int] = mapped_column(sa.Integer, init=False)
|
||||
batch: Mapped[str] = mapped_column(String(255))
|
||||
code: Mapped[str] = mapped_column(String(32))
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
|
||||
used_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), server_default=sa.text("'unused'::character varying"), default="unused"
|
||||
)
|
||||
used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None)
|
||||
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False
|
||||
)
|
||||
|
||||
|
||||
class TenantPluginPermission(Base):
|
||||
class TenantPluginPermission(TypeBase):
|
||||
class InstallPermission(enum.StrEnum):
|
||||
EVERYONE = "everyone"
|
||||
ADMINS = "admins"
|
||||
@@ -329,13 +354,17 @@ class TenantPluginPermission(Base):
|
||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
|
||||
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(
|
||||
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
|
||||
)
|
||||
debug_permission: Mapped[DebugPermission] = mapped_column(
|
||||
String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY
|
||||
)
|
||||
|
||||
|
||||
class TenantPluginAutoUpgradeStrategy(Base):
|
||||
class TenantPluginAutoUpgradeStrategy(TypeBase):
|
||||
class StrategySetting(enum.StrEnum):
|
||||
DISABLED = "disabled"
|
||||
FIX_ONLY = "fix_only"
|
||||
@@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base):
|
||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
|
||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(
|
||||
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
|
||||
)
|
||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(
|
||||
String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE
|
||||
)
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@@ -754,7 +754,7 @@ class DocumentSegment(Base):
|
||||
if process_rule and process_rule.mode == "hierarchical":
|
||||
rules_dict = process_rule.rules_dict
|
||||
if rules_dict:
|
||||
rules = Rule(**rules_dict)
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
@@ -772,7 +772,7 @@ class DocumentSegment(Base):
|
||||
if process_rule and process_rule.mode == "hierarchical":
|
||||
rules_dict = process_rule.rules_dict
|
||||
if rules_dict:
|
||||
rules = Rule(**rules_dict)
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
|
||||
@@ -152,7 +152,7 @@ class ApiToolProvider(Base):
|
||||
def tools(self) -> list["ApiToolBundle"]:
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
|
||||
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
|
||||
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
@@ -242,7 +242,10 @@ class WorkflowToolProvider(Base):
|
||||
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
|
||||
return [
|
||||
WorkflowToolParameterConfiguration.model_validate(config)
|
||||
for config in json.loads(self.parameter_configuration)
|
||||
]
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
@@ -312,7 +315,7 @@ class MCPToolProvider(Base):
|
||||
def mcp_tools(self) -> list["MCPTool"]:
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
|
||||
return [MCPTool(**tool) for tool in json.loads(self.tools)]
|
||||
return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)]
|
||||
|
||||
@property
|
||||
def provider_icon(self) -> Mapping[str, str] | str:
|
||||
@@ -552,4 +555,4 @@ class DeprecatedPublishedAppTool(Base):
|
||||
def description_i18n(self) -> "I18nObject":
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
return I18nObject(**json.loads(self.description))
|
||||
return I18nObject.model_validate(json.loads(self.description))
|
||||
|
||||
@@ -360,7 +360,9 @@ class Workflow(Base):
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# _environment_variables is guaranteed to be non-None due to server_default="{}"
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
@@ -444,7 +446,9 @@ class Workflow(Base):
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[Variable]:
|
||||
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
|
||||
@@ -110,7 +110,7 @@ dev = [
|
||||
"lxml-stubs~=0.5.1",
|
||||
"ty~=0.0.1a19",
|
||||
"basedpyright~=1.31.0",
|
||||
"ruff~=0.12.3",
|
||||
"ruff~=0.14.0",
|
||||
"pytest~=8.3.2",
|
||||
"pytest-benchmark~=4.0.0",
|
||||
"pytest-cov~=4.1.0",
|
||||
|
||||
@@ -24,9 +24,7 @@
|
||||
"reportUnknownLambdaType": "hint",
|
||||
"reportMissingParameterType": "hint",
|
||||
"reportMissingTypeArgument": "hint",
|
||||
"reportUnnecessaryContains": "hint",
|
||||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryCast": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
"reportUntypedFunctionDecorator": "hint",
|
||||
|
||||
|
||||
@@ -246,10 +246,8 @@ class AccountService:
|
||||
)
|
||||
)
|
||||
|
||||
account = Account()
|
||||
account.email = email
|
||||
account.name = name
|
||||
|
||||
password_to_set = None
|
||||
salt_to_set = None
|
||||
if password:
|
||||
valid_password(password)
|
||||
|
||||
@@ -261,14 +259,18 @@ class AccountService:
|
||||
password_hashed = hash_password(password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
password_to_set = base64_password_hashed
|
||||
salt_to_set = base64_salt
|
||||
|
||||
account.interface_language = interface_language
|
||||
account.interface_theme = interface_theme
|
||||
|
||||
# Set timezone based on language
|
||||
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
|
||||
account = Account(
|
||||
name=name,
|
||||
email=email,
|
||||
password=password_to_set,
|
||||
password_salt=salt_to_set,
|
||||
interface_language=interface_language,
|
||||
interface_theme=interface_theme,
|
||||
timezone=language_timezone_mapping.get(interface_language, "UTC"),
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
@@ -659,31 +659,31 @@ class AppDslService:
|
||||
typ = node.get("data", {}).get("type")
|
||||
match typ:
|
||||
case NodeType.TOOL.value:
|
||||
tool_entity = ToolNodeData(**node["data"])
|
||||
tool_entity = ToolNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
|
||||
)
|
||||
case NodeType.LLM.value:
|
||||
llm_entity = LLMNodeData(**node["data"])
|
||||
llm_entity = LLMNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
|
||||
)
|
||||
case NodeType.QUESTION_CLASSIFIER.value:
|
||||
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
|
||||
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
question_classifier_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.PARAMETER_EXTRACTOR.value:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
|
||||
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
parameter_extractor_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
|
||||
if knowledge_retrieval_entity.retrieval_mode == "multiple":
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config:
|
||||
if (
|
||||
@@ -773,7 +773,7 @@ class AppDslService:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
|
||||
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
return []
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class EnterpriseService:
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
return WebAppSettings.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
|
||||
@@ -100,7 +100,7 @@ class EnterpriseService:
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
return WebAppSettings.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import (
|
||||
@@ -71,7 +72,7 @@ class ProviderResponse(BaseModel):
|
||||
icon_large: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: list[ModelType]
|
||||
supported_model_types: Sequence[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
provider_credential_schema: ProviderCredentialSchema | None = None
|
||||
model_credential_schema: ModelCredentialSchema | None = None
|
||||
@@ -82,9 +83,8 @@ class ProviderResponse(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
@@ -97,6 +97,7 @@ class ProviderResponse(BaseModel):
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class ProviderWithModelsResponse(BaseModel):
|
||||
@@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ProviderModelWithStatusEntity]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
@@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
@@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
@@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class DefaultModelResponse(BaseModel):
|
||||
|
||||
@@ -46,7 +46,7 @@ class HitTestingService:
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
|
||||
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
|
||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=[dataset.id],
|
||||
|
||||
@@ -123,7 +123,7 @@ class OpsService:
|
||||
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
||||
other_keys: list[str] = provider_config["other_keys"]
|
||||
|
||||
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
|
||||
default_config_instance = config_class.model_validate(tracing_config)
|
||||
for key in other_keys:
|
||||
if key in tracing_config and tracing_config[key] == "":
|
||||
tracing_config[key] = getattr(default_config_instance, key, None)
|
||||
|
||||
@@ -269,7 +269,7 @@ class PluginMigration:
|
||||
for tool in agent_config["tools"]:
|
||||
if isinstance(tool, dict):
|
||||
try:
|
||||
tool_entity = AgentToolEntity(**tool)
|
||||
tool_entity = AgentToolEntity.model_validate(tool)
|
||||
if (
|
||||
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
|
||||
and tool_entity.provider_id not in excluded_providers
|
||||
|
||||
@@ -358,7 +358,7 @@ class RagPipelineService:
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration)
|
||||
|
||||
# update dataset
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
|
||||
@@ -288,7 +288,7 @@ class RagPipelineDslService:
|
||||
dataset_id = None
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if (
|
||||
dataset
|
||||
and pipeline.is_published
|
||||
@@ -426,7 +426,7 @@ class RagPipelineDslService:
|
||||
dataset_id = None
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if not dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=account.current_tenant_id,
|
||||
@@ -734,35 +734,35 @@ class RagPipelineDslService:
|
||||
typ = node.get("data", {}).get("type")
|
||||
match typ:
|
||||
case NodeType.TOOL.value:
|
||||
tool_entity = ToolNodeData(**node["data"])
|
||||
tool_entity = ToolNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
|
||||
)
|
||||
case NodeType.DATASOURCE.value:
|
||||
datasource_entity = DatasourceNodeData(**node["data"])
|
||||
datasource_entity = DatasourceNodeData.model_validate(node["data"])
|
||||
if datasource_entity.provider_type != "local_file":
|
||||
dependencies.append(datasource_entity.plugin_id)
|
||||
case NodeType.LLM.value:
|
||||
llm_entity = LLMNodeData(**node["data"])
|
||||
llm_entity = LLMNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
|
||||
)
|
||||
case NodeType.QUESTION_CLASSIFIER.value:
|
||||
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
|
||||
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
question_classifier_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.PARAMETER_EXTRACTOR.value:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
|
||||
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
parameter_extractor_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_INDEX.value:
|
||||
knowledge_index_entity = KnowledgeConfiguration(**node["data"])
|
||||
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
|
||||
if knowledge_index_entity.indexing_technique == "high_quality":
|
||||
if knowledge_index_entity.embedding_model_provider:
|
||||
dependencies.append(
|
||||
@@ -783,7 +783,7 @@ class RagPipelineDslService:
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
|
||||
if knowledge_retrieval_entity.retrieval_mode == "multiple":
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config:
|
||||
if (
|
||||
@@ -873,7 +873,7 @@ class RagPipelineDslService:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
|
||||
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
return []
|
||||
|
||||
|
||||
@@ -156,13 +156,13 @@ class RagPipelineTransformService:
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
|
||||
if retrieval_model:
|
||||
retrieval_setting = RetrievalSetting(**retrieval_model)
|
||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||
if indexing_technique == "economy":
|
||||
retrieval_setting.search_method = "keyword_search"
|
||||
knowledge_configuration.retrieval_model = retrieval_setting
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -55,7 +55,7 @@ class MCPToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return cast(dict[str, str], encrypter_instance.encrypt(headers))
|
||||
return encrypter_instance.encrypt(headers)
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
|
||||
@@ -242,7 +242,7 @@ class ToolTransformService:
|
||||
is_team_authorization=db_provider.authed,
|
||||
server_url=db_provider.masked_server_url,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)]
|
||||
),
|
||||
updated_at=int(db_provider.updated_at.timestamp()),
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
@@ -387,6 +387,7 @@ class ToolTransformService:
|
||||
labels=labels or [],
|
||||
)
|
||||
else:
|
||||
assert tool.operation_id
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
|
||||
@@ -36,7 +36,7 @@ def process_trace_tasks(file_info):
|
||||
if trace_info.get("workflow_data"):
|
||||
trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"])
|
||||
if trace_info.get("documents"):
|
||||
trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]]
|
||||
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
|
||||
|
||||
try:
|
||||
if trace_instance:
|
||||
|
||||
@@ -79,7 +79,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
@@ -112,7 +112,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
@@ -100,7 +100,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
@@ -133,7 +133,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
@@ -33,17 +33,19 @@ class TestChatMessageApiPermissions:
|
||||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
account.id = str(uuid.uuid4())
|
||||
|
||||
tenant = Tenant()
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
|
||||
@@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions:
|
||||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
account = Account()
|
||||
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
tenant = Tenant()
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def test_api_tool(setup_http_mock):
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")),
|
||||
),
|
||||
api_bundle=ApiToolBundle(**tool_bundle),
|
||||
api_bundle=ApiToolBundle.model_validate(tool_bundle),
|
||||
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
|
||||
provider_id="test_tool",
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from services.errors.account import (
|
||||
AccountPasswordError,
|
||||
AccountRegisterError,
|
||||
CurrentPasswordIncorrectError,
|
||||
TenantNotFoundError,
|
||||
)
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
|
||||
@@ -1414,7 +1415,7 @@ class TestTenantService:
|
||||
)
|
||||
|
||||
# Try to get current tenant (should fail)
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises((AttributeError, TenantNotFoundError)):
|
||||
TenantService.get_current_tenant_by_account(account)
|
||||
|
||||
def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
||||
@@ -44,27 +44,26 @@ class TestWorkflowService:
|
||||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
account.tenant_id = fake.uuid4()
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US" # Set interface language for Site creation
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
avatar=fake.url(),
|
||||
status="active",
|
||||
interface_language="en-US", # Set interface language for Site creation
|
||||
)
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.id = fake.uuid4()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
# Create a tenant for the account
|
||||
from models.account import Tenant
|
||||
|
||||
tenant = Tenant()
|
||||
tenant.id = account.tenant_id
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant = Tenant(
|
||||
name=f"Test Tenant {fake.company()}",
|
||||
plan="basic",
|
||||
status="active",
|
||||
)
|
||||
tenant.id = account.current_tenant_id
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
@@ -91,20 +90,21 @@ class TestWorkflowService:
|
||||
App: Created test app instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
app = App()
|
||||
app.id = fake.uuid4()
|
||||
app.tenant_id = fake.uuid4()
|
||||
app.name = fake.company()
|
||||
app.description = fake.text()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#FFEAD5"
|
||||
app.enable_site = True
|
||||
app.enable_api = True
|
||||
app.created_by = fake.uuid4()
|
||||
app = App(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=fake.uuid4(),
|
||||
name=fake.company(),
|
||||
description=fake.text(),
|
||||
mode=AppMode.WORKFLOW,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FFEAD5",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
created_by=fake.uuid4(),
|
||||
workflow_id=None, # Will be set when workflow is created
|
||||
)
|
||||
app.updated_by = app.created_by
|
||||
app.workflow_id = None # Will be set when workflow is created
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -126,19 +126,20 @@ class TestWorkflowService:
|
||||
Workflow: Created test workflow instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
workflow = Workflow()
|
||||
workflow.id = fake.uuid4()
|
||||
workflow.tenant_id = app.tenant_id
|
||||
workflow.app_id = app.id
|
||||
workflow.type = WorkflowType.WORKFLOW.value
|
||||
workflow.version = Workflow.VERSION_DRAFT
|
||||
workflow.graph = json.dumps({"nodes": [], "edges": []})
|
||||
workflow.features = json.dumps({"features": []})
|
||||
# unique_hash is a computed property based on graph and features
|
||||
workflow.created_by = account.id
|
||||
workflow.updated_by = account.id
|
||||
workflow.environment_variables = []
|
||||
workflow.conversation_variables = []
|
||||
workflow = Workflow(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
features=json.dumps({"features": []}),
|
||||
# unique_hash is a computed property based on graph and features
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask:
|
||||
Tenant: Created test tenant instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
tenant = Tenant()
|
||||
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
|
||||
tenant.id = fake.uuid4()
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
@@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask:
|
||||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account = Account(
|
||||
name=fake.name(),
|
||||
email=fake.email(),
|
||||
avatar=fake.url(),
|
||||
status="active",
|
||||
interface_language="en-US",
|
||||
)
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
account.tenant_id = tenant.id
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US"
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
|
||||
@@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask:
|
||||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
avatar=fake.url(),
|
||||
status="active",
|
||||
interface_language="en-US",
|
||||
)
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
# monkey-patch attributes for test setup
|
||||
account.tenant_id = fake.uuid4()
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US"
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
# Create a tenant for the account
|
||||
from models.account import Tenant
|
||||
|
||||
tenant = Tenant()
|
||||
tenant = Tenant(
|
||||
name=f"Test Tenant {fake.company()}",
|
||||
plan="basic",
|
||||
status="active",
|
||||
)
|
||||
tenant.id = account.tenant_id
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
@@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask:
|
||||
Dataset: Created test dataset instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
dataset = Dataset()
|
||||
dataset.id = fake.uuid4()
|
||||
dataset.tenant_id = account.tenant_id
|
||||
dataset.name = f"Test Dataset {fake.word()}"
|
||||
dataset.description = fake.text(max_nb_chars=200)
|
||||
dataset.provider = "vendor"
|
||||
dataset.permission = "only_me"
|
||||
dataset.data_source_type = "upload_file"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.built_in_field_enabled = False
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=account.tenant_id,
|
||||
name=f"Test Dataset {fake.word()}",
|
||||
description=fake.text(max_nb_chars=200),
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
built_in_field_enabled=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
document = DatasetDocument()
|
||||
|
||||
document.id = fake.uuid4()
|
||||
document.tenant_id = dataset.tenant_id
|
||||
document.dataset_id = dataset.id
|
||||
@@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask:
|
||||
document.archived = False
|
||||
document.doc_form = "text_model" # Use text_model form for testing
|
||||
document.doc_language = "en"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
|
||||
@@ -96,9 +96,9 @@ class TestMailInviteMemberTask:
|
||||
password=fake.password(),
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.ACTIVE.value,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
account.created_at = datetime.now(UTC)
|
||||
account.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(account)
|
||||
@@ -106,9 +106,9 @@ class TestMailInviteMemberTask:
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
tenant.created_at = datetime.now(UTC)
|
||||
tenant.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(tenant)
|
||||
@@ -118,8 +118,8 @@ class TestMailInviteMemberTask:
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
tenant_join.created_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@@ -164,9 +164,10 @@ class TestMailInviteMemberTask:
|
||||
password="",
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.PENDING.value,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
account.created_at = datetime.now(UTC)
|
||||
account.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(account)
|
||||
@@ -176,8 +177,8 @@ class TestMailInviteMemberTask:
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.NORMAL.value,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
tenant_join.created_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
@@ -15,13 +15,13 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
|
||||
# Set environment variables using monkeypatch
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "600")
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing
|
||||
|
||||
# load dotenv file with pydantic-settings
|
||||
config = DifyConfig()
|
||||
@@ -35,16 +35,36 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
|
||||
assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0
|
||||
assert config.TEMPLATE_TRANSFORM_MAX_LENGTH == 400_000
|
||||
|
||||
# annotated field with default value
|
||||
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 600
|
||||
# annotated field with custom configured value
|
||||
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 300
|
||||
|
||||
# annotated field with configured value
|
||||
# annotated field with custom configured value
|
||||
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
|
||||
|
||||
# values from pyproject.toml
|
||||
assert Version(config.project.version) >= Version("1.0.0")
|
||||
|
||||
|
||||
def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that HTTP timeout defaults are correctly set"""
|
||||
# clear system environment variables
|
||||
os.environ.clear()
|
||||
|
||||
# Set minimal required env vars
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
|
||||
config = DifyConfig()
|
||||
|
||||
# Verify default timeout values
|
||||
assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10
|
||||
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 600
|
||||
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 600
|
||||
|
||||
|
||||
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
|
||||
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
|
||||
def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -55,7 +75,6 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
# Set environment variables using monkeypatch
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
@@ -105,7 +124,6 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
|
||||
# Set environment variables using monkeypatch
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user