Compare commits

...

43 Commits

Author SHA1 Message Date
autofix-ci[bot]
a77b0c8216 [autofix.ci] apply automated fixes 2025-10-10 14:17:25 +00:00
lyzno1
542e904614 Update README.md 2025-10-10 22:15:30 +08:00
lyzno1
cc1f90103b apply suggestions 2025-10-10 22:11:20 +08:00
autofix-ci[bot]
909cbee550 [autofix.ci] apply automated fixes 2025-10-10 14:06:10 +00:00
lyzno1
090eb13430 refactor: simplify boolean expressions for stream parameter
Simplify verbose ternary expressions to direct boolean comparisons:
- Changed `True if response_mode == "streaming" else False`
- To: `response_mode == "streaming"`

This is more concise and follows Python idioms. Applied to 6 locations:
- async_client.py: create_completion_message, create_chat_message, run_specific_workflow
- client.py: create_completion_message, create_chat_message, run_specific_workflow

All tests pass (30/30).
2025-10-10 22:04:21 +08:00
lyzno1
1cd35c5289 perf: optimize file uploads and improve code consistency
Addresses final Gemini code review feedback (round 4):

1. Optimized async file uploads for better performance:
   - Removed unnecessary `await f.read()` that loaded entire files into memory
   - Pass aiofiles file object directly to httpx for streaming uploads
   - Significantly reduces memory usage for large file uploads
   - httpx handles the async file streaming efficiently
   - Fixed in 3 async methods:
     * create_document_by_file
     * update_document_by_file
     * upload_pipeline_file

2. Fixed potential information leak in sync client:
   - Changed from `files = {"file": f}` to `files = {"file": (basename, f)}`
   - Prevents full local file path from being sent in multipart requests
   - Only sends the filename (basename) instead of complete path
   - Fixed in 3 sync methods:
     * create_document_by_file
     * update_document_by_file
     * upload_pipeline_file

3. Improved code consistency in query parameter construction:
   - Simplified parameter building by leveraging httpx's automatic None filtering
   - Changed from multiple `if` statements to direct dict construction
   - More readable and consistent with existing patterns (e.g., get_conversations)
   - Applied to 4 methods:
     * get_conversation_messages
     * get_workflow_logs
     * list_documents
     * query_segments

4. Enhanced documentation:
   - Added comprehensive "Asynchronous Usage" section to README
   - Included examples for async chat, completion, workflow, and dataset operations
   - Documented benefits of async usage (performance, scalability, non-blocking I/O)
   - Listed all available async client classes

All tests pass (30/30).
2025-10-10 22:02:30 +08:00
lyzno1
62fba95b1e fix: improve API design and cross-platform compatibility
Addresses additional Gemini code review feedback (round 3):

1. Improved annotation_reply_action API design:
   - Changed parameters from Optional (float | None, str | None) to required (float, str)
   - Removed runtime ValueError check that's now unnecessary
   - Let Python's TypeError handle missing arguments more idiomatically
   - Applies to both sync (client.py) and async (async_client.py) versions

2. Fixed cross-platform file path handling:
   - Added `import os` to both client.py and async_client.py
   - Replaced file_path.split('/')[-1] with os.path.basename(file_path)
   - Ensures compatibility with Windows path separators (\)
   - Fixed in 3 async methods:
     * create_document_by_file
     * update_document_by_file
     * upload_pipeline_file

All tests pass (30/30).
2025-10-10 21:54:56 +08:00
lyzno1
055f510619 fix: address Gemini code review issues (round 2)
This commit resolves all critical and medium priority issues identified
in the second Gemini code review:

1. Python version compatibility:
   - Updated requires-python to >=3.10 in pyproject.toml
   - Allows use of union type syntax (dict | None) introduced in Python 3.10
   - Eliminates need to use typing.Optional for cleaner code

2. params conflict in query_segments:
   - Changed kwargs["params"] to kwargs.pop("params")
   - Prevents TypeError from duplicate 'params' argument
   - Applies to both sync and async clients

3. Type annotation inconsistency in annotation_reply_action:
   - Changed parameter types from float/str to float | None / str | None
   - Matches runtime behavior that checks for None values
   - Applies to both sync and async clients

4. URL encoding issue in list_datasets:
   - Changed from f-string URL construction to params dict
   - Lets httpx handle proper URL encoding of query parameters
   - Prevents potential injection and handles special characters correctly

5. Async file I/O blocking issue:
   - Added aiofiles>=23.0.0 dependency
   - Replaced synchronous open() with aiofiles.open() in async methods:
     * create_document_by_file
     * update_document_by_file
     * upload_pipeline_file
   - Prevents blocking the asyncio event loop during file operations

All tests pass (30/30 runnable tests, 12 failures unrelated to changes
due to SOCKS proxy configuration).
2025-10-10 21:39:41 +08:00
lyzno1
54ac80b2c6 fix: address Gemini code review issues
- Add **kwargs support to _send_request methods (sync and async)
- Fix file handle leaks in create/update_document_by_file methods
- Update README examples to use context managers for proper resource cleanup

Fixes identified by Gemini code review:
1. Critical: **kwargs not handled causing TypeError in KnowledgeBaseClient
2. Critical: File descriptors leaked in document file upload methods
3. Medium: README examples not demonstrating best practices
2025-10-10 21:23:59 +08:00
autofix-ci[bot]
eb275c0fca [autofix.ci] apply automated fixes 2025-10-10 13:10:54 +00:00
lyzno1
db4cdd5d60 rm test 2025-10-10 21:08:45 +08:00
lyzno1
1e9317910c test: add comprehensive unit tests for httpx migration and async support
- Add test_httpx_migration.py with 15 tests covering:
  - httpx.Client initialization and configuration
  - Context manager support (with statement)
  - Request parameter handling (json, params)
  - Response compatibility with requests.Response API
  - All client classes (DifyClient, ChatClient, CompletionClient, etc.)
  - Inheritance chain verification

- Add test_async_client.py with 15 tests covering:
  - API parity between sync and async clients (6 client classes)
  - httpx.AsyncClient initialization and configuration
  - Async context manager support (async with statement)
  - Async request handling with await
  - All async client classes functionality

- Update pyproject.toml:
  - Add dev dependencies: pytest>=7.0.0, pytest-asyncio>=0.21.0
  - Add pytest configuration for async mode

Test results: 30/30 tests passing
- 15 sync httpx migration tests 
- 15 async client tests 

All tests use proper mocking to avoid network dependencies.
2025-10-10 21:04:29 +08:00
autofix-ci[bot]
f3011ce67a [autofix.ci] apply automated fixes 2025-10-10 12:57:29 +00:00
lyzno1
52dddf9cd6 feat: migrate Python SDK to httpx with async/await support
- Migrate from requests to httpx for better performance and modern features
- Add complete async/await support with AsyncDifyClient and all async client classes
- Maintain 100% backward compatibility with existing sync API
- Add context manager support for proper resource management
- Implement connection pooling for performance improvement

Breaking changes:
- Bump Python requirement from >=3.6 to >=3.8 (required by httpx)

New features:
- AsyncChatClient, AsyncCompletionClient, AsyncWorkflowClient, AsyncWorkspaceClient, AsyncKnowledgeBaseClient
- Full async/await support with httpx.AsyncClient
- Context manager support (with/async with statements)
- Fine-grained timeout control

All 64 public methods available in both sync and async versions.
2025-10-10 20:54:23 +08:00
lyzno1
5ce6776383 refactor: simplify update_dataset method with dict comprehension
Use more Pythonic dict comprehension to filter None values instead of
multiple if statements. This improves code readability and maintainability
while preserving the same functionality.

Changes:
- Replace multiple if-None checks with dict comprehension
- More concise implementation (from ~15 lines to ~8 lines)
- Easier to maintain when adding new parameters

Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
2025-10-10 19:36:06 +08:00
autofix-ci[bot]
dfcf337fe1 [autofix.ci] apply automated fixes 2025-10-10 11:28:43 +00:00
lyzno1
5dc70d5be0 ruff format 2025-10-10 19:26:26 +08:00
lyzno1
b351d24a3c ruff fix 2025-10-10 19:26:08 +08:00
autofix-ci[bot]
089c47cab1 [autofix.ci] apply automated fixes 2025-10-10 11:24:47 +00:00
lyzno1
6af6c7846b chore: migrate from setup.py to pyproject.toml
Modernize Python SDK package management following PEP 621 standards:

**Changes:**
- Add pyproject.toml with hatchling build backend
- Remove deprecated setup.py
- Update MANIFEST.in to include README and LICENSE
- Add uv.lock for reproducible builds

**Benefits:**
- Standard PEP 621 declarative configuration
- Better tooling support (uv, pip, poetry)
- Faster builds with uv
- Maintains backward compatibility (Python >=3.6)

All metadata preserved from original setup.py:
- Author: Dify <hello@dify.ai>
- Version: 0.1.12
- Dependencies: requests (no version constraint)
- Classifiers: unchanged
2025-10-10 19:21:56 +08:00
lyzno1
7a8e027a93 feat: add missing core Service APIs to Python SDK
Added 5 high-priority core APIs to improve SDK coverage from 87.7% to 96.5%:

**Dataset Management (KnowledgeBaseClient):**
- get_dataset(): Get detailed dataset information
- update_dataset(): Update dataset configuration (name, description, indexing, embeddings)
- batch_update_document_status(): Batch enable/disable/archive/unarchive documents

**Conversation Variables (ChatClient):**
- get_conversation_variables(): Get all variables for a conversation
- update_conversation_variable(): Update specific conversation variable value

**Testing:**
- Added 4 comprehensive test methods covering 12+ test scenarios
- All tests pass with 100% coverage of new APIs
- Follows existing test patterns with unittest.mock

**Documentation:**
- Updated README with usage examples for all new APIs
- Consistent formatting with existing documentation
2025-10-10 19:11:39 +08:00
Bowen Liang
d0dd81cf84 chore: bump ruff to 0.14 (#26063) 2025-10-10 18:10:23 +08:00
znn
65b832c46c pan and zoom during workflow execution (#24254) 2025-10-10 17:07:25 +08:00
znn
a90b60c36f removing horus eye and adding mcp icon (#25323)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2025-10-10 17:00:03 +08:00
诗浓
94a07706ec fix: restore None guards for _environment_variables/_conversation_variables getters (#25633) 2025-10-10 16:32:09 +08:00
Asuka Minato
ab2eacb6c1 use model_validate (#26182)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-10 17:30:13 +09:00
Xiyuan Chen
aead192743 Fix/token exp when exchange main (#26708) 2025-10-10 01:24:36 -07:00
Asuka Minato
c1e8584b97 feat: Refactor api.add_resource to @console_ns.route decorator (#26386)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-10 16:23:39 +08:00
Asuka Minato
8a2b208299 Refactor account models to use SQLAlchemy 2.0 dataclass mapping (#26415)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-10 17:12:12 +09:00
znn
2b6882bd97 fix chunks 2 (#26623)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-10 16:01:33 +08:00
Guangdong Liu
aa51662d98 refactor(api): add new endpoints for workspace management and update routing (#26465)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-10 15:59:14 +08:00
github-actions[bot]
3068526797 chore: translate i18n files and update type definitions (#26709)
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
2025-10-10 15:55:24 +08:00
Jyong
298d8c2d88 Update deploy-dev.yml (#26712) 2025-10-10 15:54:33 +08:00
fenglin
294e01a8c1 Fix/tool provider tag internationalization (#26710)
Co-authored-by: qiaofenglin <qiaofenglin@baidu.com>
2025-10-10 15:52:09 +08:00
Coding On Star
3a5aa4587c feat(billing): add tax information tooltips in pricing footer (#26705)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-10-10 15:34:56 +08:00
crazywoola
cf1778e696 fix: issue w/ timepicker (#26696)
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
2025-10-10 13:17:33 +08:00
yihong
54db4c176a fix: drop useless logic (#26678)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-10-10 12:59:28 +08:00
Novice
5d3e8a31d0 fix: restore array flattening behavior in iteration node (#26695) 2025-10-10 10:54:32 +08:00
Nan LI
885dff82e3 feat: update HTTP timeout configurations and enhance timeout input handling in UI (#26685) 2025-10-10 09:00:06 +08:00
Asuka Minato
3c4aa24198 Refactor: Remove unnecessary casts and tighten type checking (#26625)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-09 22:11:14 +08:00
GuanMu
33b0814323 refactor(types): remove any usages and strengthen typings across web and base (#26677) 2025-10-09 21:36:42 +08:00
Tianyi Jing
45ae511036 fix: add missing toType to toolCredentialToFormSchemas (#26681)
Signed-off-by: jingfelix <jingfelix@outlook.com>
2025-10-09 21:23:15 +08:00
Asuka Minato
0fa063c640 Refactor: Remove reportUnnecessaryContains from pyrightconfig.json (#26626)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
2025-10-09 10:22:41 +08:00
166 changed files with 3541 additions and 1333 deletions

View File

@@ -18,7 +18,7 @@ jobs:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -81,7 +81,6 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]
[lint.per-file-ignores]

View File

@@ -362,11 +362,11 @@ class HttpConfig(BaseSettings):
)
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
)
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
@@ -771,7 +771,7 @@ class MailConfig(BaseSettings):
MAIL_TEMPLATING_TIMEOUT: int = Field(
description="""
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Only available in sandbox mode.""",
default=3,
)

View File

@@ -90,7 +90,7 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
@@ -124,7 +124,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"

View File

@@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
@@ -257,13 +257,15 @@ class DataSourceNotionApi(Resource):
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)

View File

@@ -24,7 +24,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
@@ -513,13 +513,15 @@ class DatasetIndexingEstimateApi(Resource):
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
@@ -528,14 +530,16 @@ class DatasetIndexingEstimateApi(Resource):
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)

View File

@@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db
from fields.document_fields import (
dataset_and_document_fields,
@@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@@ -395,7 +395,7 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@@ -547,13 +547,15 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
@@ -562,14 +564,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)

View File

@@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource):
args = parser.parse_args()
try:
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))

View File

@@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200

View File

@@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
@@ -22,6 +22,7 @@ from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource):
return {"message": "App installed successfully"}
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
class InstalledAppApi(InstalledAppResource):
"""
update and delete an installed app
@@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource):
db.session.commit()
return {"result": "success", "message": "App info updated successfully"}
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

View File

@@ -1,7 +1,7 @@
from flask_restx import marshal_with
from controllers.common import fields
from controllers.console import api
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
from services.app_service import AppService
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp):
"""Get app meta"""
@@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model)
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
@@ -35,6 +35,7 @@ recommended_app_list_fields = {
}
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@login_required
@account_initialization_required
@@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource):
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
@console_ns.route("/explore/apps/<uuid:app_id>")
class RecommendedAppApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
app_id = str(app_id)
return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

View File

@@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
@@ -25,6 +25,7 @@ message_fields = {
}
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
@@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
)
class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id):
app_model = installed_app.app
@@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailChangeLimitError,
@@ -45,6 +45,7 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@setup_required
@login_required
@@ -97,6 +98,7 @@ class AccountInitApi(Resource):
return {"result": "success"}
@console_ns.route("/account/profile")
class AccountProfileApi(Resource):
@setup_required
@login_required
@@ -109,6 +111,7 @@ class AccountProfileApi(Resource):
return current_user
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@setup_required
@login_required
@@ -130,6 +133,7 @@ class AccountNameApi(Resource):
return updated_account
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@setup_required
@login_required
@@ -147,6 +151,7 @@ class AccountAvatarApi(Resource):
return updated_account
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource):
return updated_account
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource):
return updated_account
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@setup_required
@login_required
@@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource):
return updated_account
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@setup_required
@login_required
@@ -227,6 +235,7 @@ class AccountPasswordApi(Resource):
return {"result": "success"}
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource):
integrate_fields = {
"provider": fields.String,
@@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource):
return {"data": integrate_data}
@console_ns.route("/account/delete/verify")
class AccountDeleteVerifyApi(Resource):
@setup_required
@login_required
@@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@setup_required
@login_required
@@ -320,6 +331,7 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required
def post(self):
@@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
@@ -352,6 +365,7 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
@@ -396,6 +410,7 @@ class EducationApi(Resource):
return res
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
@@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource):
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@enable_change_email
@setup_required
@@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@enable_change_email
@setup_required
@@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@enable_change_email
@setup_required
@@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource):
return updated_account
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@setup_required
def post(self):
@@ -558,28 +577,3 @@ class CheckEmailUnique(Resource):
if not AccountService.check_email_unique(args["email"]):
raise EmailAlreadyInUseError()
return {"result": "success"}
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
api.add_resource(AccountNameApi, "/account/name")
api.add_resource(AccountAvatarApi, "/account/avatar")
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# Change email
api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@@ -1,7 +1,7 @@
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
return response
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
response["error"] = error
return response
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@@ -6,7 +6,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
@@ -33,6 +33,7 @@ from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
"""List all members of current tenant."""
@@ -49,6 +50,7 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@@ -111,6 +113,7 @@ class MemberInviteEmailApi(Resource):
}, 201
@console_ns.route("/workspaces/current/members/<uuid:member_id>")
class MemberCancelInviteApi(Resource):
"""Cancel an invitation by member id."""
@@ -143,6 +146,7 @@ class MemberCancelInviteApi(Resource):
}, 200
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@@ -177,6 +181,7 @@ class MemberUpdateRoleApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/dataset-operators")
class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant."""
@@ -193,6 +198,7 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@@ -233,6 +239,7 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@setup_required
@login_required
@@ -278,6 +285,7 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@setup_required
@login_required
@@ -339,14 +347,3 @@ class OwnerTransfer(Resource):
raise ValueError(str(e))
return {"result": "success"}
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
# owner transfer
api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer")

View File

@@ -5,7 +5,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -17,6 +17,7 @@ from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@setup_required
@login_required
@@ -45,6 +46,7 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list})
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@setup_required
@login_required
@@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource):
return response
@console_ns.route("/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>")
class ModelProviderIconApi(Resource):
"""
Get model provider icon
@@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/checkout-url")
class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
prefilled_email=current_user.email,
)
return data
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
)
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
)

View File

@@ -4,7 +4,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@setup_required
@login_required
@@ -85,6 +86,7 @@ class DefaultModelApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@login_required
@@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@setup_required
@login_required
@@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource):
return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource):
return response
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource):
return jsonable_encoder({"data": parameter_rules})
@console_ns.route("/workspaces/current/models/model-types/<string:model_type>")
class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource):
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelCredentialSwitchApi,
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@setup_required
@login_required
@@ -55,6 +57,7 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@setup_required
@login_required
@@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions})
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins})
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@setup_required
def get(self):
@@ -108,6 +113,7 @@ class PluginIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/plugin/upload/pkg")
class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/bundle")
class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@setup_required
@login_required
@@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@setup_required
@login_required
@@ -453,6 +474,7 @@ class PluginUninstallApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource):
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource):
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@setup_required
@login_required
@@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource):
return jsonable_encoder({"success": True})
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

View File

@@ -10,7 +10,7 @@ from flask_restx import (
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
@@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool:
return False
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource):
@setup_required
@login_required
@@ -71,6 +72,7 @@ class ToolProviderListApi(Resource):
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/info")
class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource):
@setup_required
@login_required
@@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource):
@setup_required
@login_required
@@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
return result
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credentials")
class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
@@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource):
@setup_required
@login_required
@@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource):
@setup_required
@login_required
@@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource):
@setup_required
@login_required
@@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource):
@setup_required
@login_required
@@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource):
@setup_required
@login_required
@@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource):
@setup_required
@login_required
@@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>")
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource):
@setup_required
@login_required
@@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource):
@setup_required
@login_required
@@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource):
@setup_required
@login_required
@@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource):
@setup_required
@login_required
@@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource):
@setup_required
@login_required
@@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource):
@setup_required
@login_required
@@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource):
@setup_required
@login_required
@@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource):
)
@console_ns.route("/workspaces/current/tools/builtin")
class ToolBuiltinListApi(Resource):
@setup_required
@login_required
@@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource):
)
@console_ns.route("/workspaces/current/tools/api")
class ToolApiListApi(Resource):
@setup_required
@login_required
@@ -642,6 +666,7 @@ class ToolApiListApi(Resource):
)
@console_ns.route("/workspaces/current/tools/workflow")
class ToolWorkflowListApi(Resource):
@setup_required
@login_required
@@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource):
)
@console_ns.route("/workspaces/current/tool-labels")
class ToolLabelsApi(Resource):
@setup_required
@login_required
@@ -672,6 +698,7 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
@console_ns.route("/oauth/plugin/<path:provider>/tool/authorization-url")
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource):
return response
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
@@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema")
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/info")
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource):
@setup_required
@login_required
@@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource):
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource):
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
class ToolMCPListAllApi(Resource):
@setup_required
@login_required
@@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource):
return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
def get(self):
parser = reqparse.RequestParser()
@@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource):
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
# mcp tool provider
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

View File

@@ -14,7 +14,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import api
from controllers.console import console_ns
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
@@ -65,6 +65,7 @@ tenants_fields = {
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
@console_ns.route("/workspaces")
class TenantListApi(Resource):
@setup_required
@login_required
@@ -93,6 +94,7 @@ class TenantListApi(Resource):
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource):
@setup_required
@admin_required
@@ -118,6 +120,8 @@ class WorkspaceListApi(Resource):
}, 200
@console_ns.route("/workspaces/current", endpoint="workspaces_current")
@console_ns.route("/info", endpoint="info") # Deprecated
class TenantApi(Resource):
@setup_required
@login_required
@@ -143,11 +147,10 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@setup_required
@login_required
@@ -172,6 +175,7 @@ class SwitchWorkspaceApi(Resource):
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource):
@setup_required
@login_required
@@ -202,6 +206,7 @@ class CustomConfigWorkspaceApi(Resource):
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config/webapp-logo/upload")
class WebappLogoWorkspaceApi(Resource):
@setup_required
@login_required
@@ -242,6 +247,7 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@setup_required
@login_required
@@ -261,13 +267,3 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

View File

@@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
raise ValueError("invalid json")
try:
payload = payload_type(**data)
payload = payload_type.model_validate(data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")

View File

@@ -280,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
)

View File

@@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
# validate args
DocumentService.document_create_args_validate(knowledge_config)
@@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
@@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
@@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:

View File

@@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)

View File

@@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type")
exchanged_token_expires_unix = enterprise_user_decoded.get("exp")
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
@@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
)
db.session.add(end_user)
db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp())
exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp())
if exchanged_token_expires_unix:
exp = int(exchanged_token_expires_unix)
payload = {
"iss": site.id,
"sub": "Web API Passport",

View File

@@ -40,7 +40,7 @@ class AgentConfigManager:
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",

View File

@@ -61,9 +61,6 @@ class AppRunner:
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:

View File

@@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
@@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
@@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
return self
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""

View File

@@ -131,7 +131,7 @@ class CodeExecutor:
if (code := response_data.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
response_code = CodeExecutionResponse(**response_data)
response_code = CodeExecutionResponse.model_validate(response_data)
if response_code.data.error:
raise CodeExecutionError(response_code.data.error)

View File

@@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error(
@@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
result.append(MarketplacePluginDeclaration.model_validate(plugin))
except Exception:
pass

View File

@@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -357,14 +357,16 @@ class IndexingRunner:
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
@@ -378,14 +380,16 @@ class IndexingRunner:
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])

View File

@@ -294,7 +294,7 @@ class ClientSession(
method="completion/complete",
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
argument=types.CompletionArgument.model_validate(argument),
),
)
),

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class I18nObject(BaseModel):
@@ -9,7 +9,8 @@ class I18nObject(BaseModel):
zh_Hans: str | None = None
en_US: str
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.zh_Hans:
self.zh_Hans = self.en_US
return self

View File

@@ -1,7 +1,7 @@
from collections.abc import Sequence
from enum import Enum, StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@@ -46,10 +46,11 @@ class FormOption(BaseModel):
value: str
show_on: list[FormShowOnObject] = []
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.label:
self.label = I18nObject(en_US=self.value)
return self
class CredentialFormSchema(BaseModel):

View File

@@ -269,17 +269,17 @@ class ModelProviderFactory:
}
if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params) # type: ignore
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params) # type: ignore
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel(**init_params) # type: ignore
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params) # type: ignore
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params) # type: ignore
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel(**init_params) # type: ignore
return TTSModel.model_validate(init_params)
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""

View File

@@ -51,7 +51,7 @@ class ApiModeration(Moderation):
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
return ModerationInputsResult(**result)
return ModerationInputsResult.model_validate(result)
return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
@@ -67,7 +67,7 @@ class ApiModeration(Moderation):
params = ModerationOutputParams(app_id=self.app_id, text=text)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
return ModerationOutputsResult(**result)
return ModerationOutputsResult.model_validate(result)
return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response

View File

@@ -84,15 +84,15 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
for i in range(len(v)):
if v[i]["role"] == PromptMessageRole.USER.value:
v[i] = UserPromptMessage(**v[i])
v[i] = UserPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
v[i] = AssistantPromptMessage(**v[i])
v[i] = AssistantPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
v[i] = SystemPromptMessage(**v[i])
v[i] = SystemPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.TOOL.value:
v[i] = ToolPromptMessage(**v[i])
v[i] = ToolPromptMessage.model_validate(v[i])
else:
v[i] = PromptMessage(**v[i])
v[i] = PromptMessage.model_validate(v[i])
return v

View File

@@ -94,7 +94,7 @@ class BasePluginClient:
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@@ -104,13 +104,13 @@ class BasePluginClient:
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) # type: ignore
yield type_(**json.loads(line)) # type: ignore
def _request_with_model(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | None = None,
params: dict | None = None,
@@ -120,13 +120,13 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore
def _request_with_plugin_daemon_response(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@@ -140,22 +140,22 @@ class BasePluginClient:
response = self._request(method, path, headers, data, params, files)
response.raise_for_status()
except HTTPError as e:
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
logger.exception(msg)
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
raise e
except Exception as e:
msg = f"Failed to request plugin daemon, url: {path}"
logger.exception(msg)
logger.exception("Failed to request plugin daemon, url: %s", path)
raise ValueError(msg) from e
try:
json_response = response.json()
if transformer:
json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
except Exception:
msg = (
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
f" url: {path}"
)
logger.exception(msg)
@@ -163,7 +163,7 @@ class BasePluginClient:
if rep.code != 0:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise ValueError(f"{rep.message}, code: {rep.code}")
@@ -178,7 +178,7 @@ class BasePluginClient:
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@@ -189,7 +189,7 @@ class BasePluginClient:
"""
for line in self._stream_request(method, path, params, headers, data, files):
try:
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
except (ValueError, TypeError):
# TODO modify this when line_data has code and message
try:
@@ -204,7 +204,7 @@ class BasePluginClient:
if rep.code != 0:
if rep.code == -500:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)

View File

@@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
params={"page": 1, "page_size": 256},
transformer=transformer,
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
self._get_local_file_datasource_provider()
)
for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
@@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource provider for the given tenant and plugin.
"""
if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
tool_provider_id = DatasourceProviderID(provider_id)

View File

@@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type=LLMResultChunk,
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type=PluginLLMNumTokensResponse,
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type=TextEmbeddingResult,
type_=TextEmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type=PluginTextEmbeddingNumTokensResponse,
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type=RerankResult,
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type=PluginVoicesResponse,
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type=PluginBasicBooleanResponse,
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TypeVar, Union, cast
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
@@ -87,7 +87,8 @@ def merge_blob_chunks(
),
meta=resp.meta,
)
yield cast(MessageType, merged_message)
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
yield merged_message # type: ignore
# Clean up the buffer
del files[chunk_id]
else:

View File

@@ -134,7 +134,7 @@ class RetrievalService:
if not dataset:
return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,

View File

@@ -17,9 +17,6 @@ class NotionInfo(BaseModel):
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)
class WebsiteInfo(BaseModel):
"""
@@ -47,6 +44,3 @@ class ExtractSetting(BaseModel):
website_info: WebsiteInfo | None = None
document_model: str | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)

View File

@@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
rules = Rule.model_validate(automatic_rule)
else:
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
# Split the text documents into nodes.
if not rules.segmentation:
raise ValueError("No segmentation found in rules.")

View File

@@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
all_documents: list[Document] = []
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
@@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_documents = document.children
if child_documents:
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
@@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return child_nodes
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
parent_childs = ParentChildStructureChunk(**chunks)
parent_childs = ParentChildStructureChunk.model_validate(chunks)
documents = []
for parent_child in parent_childs.parent_child_chunks:
metadata = {
@@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.create(all_child_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk(**chunks)
parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})

View File

@@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
@@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
qa_chunks = QAStructureChunk(**chunks)
qa_chunks = QAStructureChunk.model_validate(chunks)
documents = []
for qa_chunk in qa_chunks.qa_chunks:
metadata = {
@@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
qa_chunks = QAStructureChunk(**chunks)
qa_chunks = QAStructureChunk.model_validate(chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import re
from typing import Any
from core.model_manager import ModelInstance
@@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._separators = separators or ["\n\n", "\n", " ", ""]
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
def split_text(self, text: str) -> list[str]:
"""Split incoming text and return chunks."""
@@ -90,16 +91,19 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
# Now that we have the separator, split the text
if separator:
if separator == " ":
splits = text.split()
splits = re.split(r" +", text)
else:
splits = text.split(separator)
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
else:
splits = list(text)
splits = [s for s in splits if (s not in {"", "\n"})]
if separator == "\n":
splits = [s for s in splits if s != ""]
else:
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
_separator = separator if self._keep_separator else ""
s_lens = self._length_function(splits)
if separator != "":
for s, s_len in zip(splits, s_lens):

View File

@@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController):
tools.append(
assistant_tool_class(
provider=provider,
entity=ToolEntity(**tool),
entity=ToolEntity.model_validate(tool),
runtime=ToolRuntime(tenant_id=""),
)
)

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
@@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _populate_missing_locales(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self):
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController):
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
user = db_provider.load_user()
tools = [
ToolEntity(

View File

@@ -1008,7 +1008,7 @@ class ToolManager:
config = tool_configurations.get(parameter.name, {})
if not (config and isinstance(config, dict) and config.get("value") is not None):
continue
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:

View File

@@ -126,7 +126,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=document_score_list.get(segment.index_node_id, None),
score=document_score_list.get(segment.index_node_id),
doc_metadata=document.doc_metadata,
)

View File

@@ -105,10 +105,10 @@ class RedisChannel:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand(**data)
return AbortCommand.model_validate(data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None

View File

@@ -16,7 +16,7 @@ class EndNode(Node):
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
self._node_data = EndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@@ -342,10 +342,13 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
@@ -357,13 +360,39 @@ class IterationNode(Node):
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
)
)
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
"""
Flatten the outputs list if all elements are lists.
This maintains backward compatibility with version 1.8.1 behavior.
"""
if not outputs:
return outputs
# Check if all non-None outputs are lists
non_none_outputs = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs
if all(isinstance(output, list) for output in non_none_outputs):
# Flatten the list of lists
flattened: list[Any] = []
for output in outputs:
if isinstance(output, list):
flattened.extend(output)
elif output is not None:
# This shouldn't happen based on our check, but handle it gracefully
flattened.append(output)
return flattened
return outputs
def _handle_iteration_failure(
self,
started_at: datetime,
@@ -373,10 +402,13 @@ class IterationNode(Node):
iter_run_map: dict[str, float],
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,

View File

@@ -18,7 +18,7 @@ class IterationStartNode(Node):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
self._node_data = IterationStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@@ -2,7 +2,7 @@ import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, cast
from typing import Any
from sqlalchemy import func, select
@@ -62,7 +62,7 @@ class KnowledgeIndexNode(Node):
return self._node_data
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:

View File

@@ -41,7 +41,7 @@ class ListOperatorNode(Node):
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
self._node_data = ListOperatorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@@ -18,7 +18,7 @@ class LoopEndNode(Node):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
self._node_data = LoopEndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@@ -18,7 +18,7 @@ class LoopStartNode(Node):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
self._node_data = LoopStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@@ -16,7 +16,7 @@ class StartNode(Node):
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
self._node_data = StartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@@ -15,7 +15,7 @@ class VariableAggregatorNode(Node):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@@ -14,7 +14,7 @@ def handle(sender, **kwargs):
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try:
tool_entity = ToolEntity(**node_data["data"])
tool_entity = ToolEntity.model_validate(node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,

View File

@@ -61,7 +61,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {}))
dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
except Exception:
continue

View File

@@ -1,15 +1,16 @@
import enum
import json
from dataclasses import field
from datetime import datetime
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated
from models.base import Base
from models.base import TypeBase
from .engine import db
from .types import StringUUID
@@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum):
CLOSED = "closed"
class Account(UserMixin, Base):
class Account(UserMixin, TypeBase):
__tablename__ = "accounts"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[str | None] = mapped_column(String(255))
password_salt: Mapped[str | None] = mapped_column(String(255))
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True)
interface_language: Mapped[str | None] = mapped_column(String(255))
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True)
timezone: Mapped[str | None] = mapped_column(String(255))
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True)
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
password: Mapped[str | None] = mapped_column(String(255), default=None)
password_salt: Mapped[str | None] = mapped_column(String(255), default=None)
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
interface_language: Mapped[str | None] = mapped_column(String(255), default=None)
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
timezone: Mapped[str | None] = mapped_column(String(255), default=None)
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
last_active_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
status: Mapped[str] = mapped_column(
String(16), server_default=sa.text("'active'::character varying"), default="active"
)
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
@reconstructor
def init_on_load(self):
self.role: TenantAccountRole | None = None
self._current_tenant: Tenant | None = None
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
@property
def is_password_set(self):
@@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum):
ARCHIVE = "archive"
class Tenant(Base):
class Tenant(TypeBase):
__tablename__ = "tenants"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
custom_config: Mapped[str | None] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None)
plan: Mapped[str] = mapped_column(
String(255), server_default=sa.text("'basic'::character varying"), default="basic"
)
status: Mapped[str] = mapped_column(
String(255), server_default=sa.text("'normal'::character varying"), default="normal"
)
custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
def get_accounts(self) -> list[Account]:
return list(
@@ -257,7 +270,7 @@ class Tenant(Base):
self.custom_config = json.dumps(value)
class TenantAccountJoin(Base):
class TenantAccountJoin(TypeBase):
__tablename__ = "tenant_account_joins"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@@ -266,17 +279,21 @@ class TenantAccountJoin(Base):
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
role: Mapped[str] = mapped_column(String(16), server_default="normal")
invited_by: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
class AccountIntegrate(Base):
class AccountIntegrate(TypeBase):
__tablename__ = "account_integrates"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@@ -284,16 +301,20 @@ class AccountIntegrate(Base):
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255))
encrypted_token: Mapped[str] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
class InvitationCode(Base):
class InvitationCode(TypeBase):
__tablename__ = "invitation_codes"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
@@ -301,18 +322,22 @@ class InvitationCode(Base):
sa.Index("invitation_codes_code_idx", "code", "status"),
)
id: Mapped[int] = mapped_column(sa.Integer)
id: Mapped[int] = mapped_column(sa.Integer, init=False)
batch: Mapped[str] = mapped_column(String(255))
code: Mapped[str] = mapped_column(String(32))
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
used_at: Mapped[datetime | None] = mapped_column(DateTime)
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID)
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID)
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
status: Mapped[str] = mapped_column(
String(16), server_default=sa.text("'unused'::character varying"), default="unused"
)
used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None)
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False
)
class TenantPluginPermission(Base):
class TenantPluginPermission(TypeBase):
class InstallPermission(enum.StrEnum):
EVERYONE = "everyone"
ADMINS = "admins"
@@ -329,13 +354,17 @@ class TenantPluginPermission(Base):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
install_permission: Mapped[InstallPermission] = mapped_column(
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
)
debug_permission: Mapped[DebugPermission] = mapped_column(
String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY
)
class TenantPluginAutoUpgradeStrategy(Base):
class TenantPluginAutoUpgradeStrategy(TypeBase):
class StrategySetting(enum.StrEnum):
DISABLED = "disabled"
FIX_ONLY = "fix_only"
@@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
strategy_setting: Mapped[StrategySetting] = mapped_column(
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
)
upgrade_mode: Mapped[UpgradeMode] = mapped_column(
String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE
)
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@@ -754,7 +754,7 @@ class DocumentSegment(Base):
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
@@ -772,7 +772,7 @@ class DocumentSegment(Base):
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)

View File

@@ -152,7 +152,7 @@ class ApiToolProvider(Base):
def tools(self) -> list["ApiToolBundle"]:
from core.tools.entities.tool_bundle import ApiToolBundle
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
@property
def credentials(self) -> dict[str, Any]:
@@ -242,7 +242,10 @@ class WorkflowToolProvider(Base):
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
return [
WorkflowToolParameterConfiguration.model_validate(config)
for config in json.loads(self.parameter_configuration)
]
@property
def app(self) -> App | None:
@@ -312,7 +315,7 @@ class MCPToolProvider(Base):
def mcp_tools(self) -> list["MCPTool"]:
from core.mcp.types import Tool as MCPTool
return [MCPTool(**tool) for tool in json.loads(self.tools)]
return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> Mapping[str, str] | str:
@@ -552,4 +555,4 @@ class DeprecatedPublishedAppTool(Base):
def description_i18n(self) -> "I18nObject":
from core.tools.entities.common_entities import I18nObject
return I18nObject(**json.loads(self.description))
return I18nObject.model_validate(json.loads(self.description))

View File

@@ -360,7 +360,9 @@ class Workflow(Base):
@property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# _environment_variables is guaranteed to be non-None due to server_default="{}"
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
@@ -444,7 +446,9 @@ class Workflow(Base):
@property
def conversation_variables(self) -> Sequence[Variable]:
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]

View File

@@ -110,7 +110,7 @@ dev = [
"lxml-stubs~=0.5.1",
"ty~=0.0.1a19",
"basedpyright~=1.31.0",
"ruff~=0.12.3",
"ruff~=0.14.0",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
"pytest-cov~=4.1.0",

View File

@@ -24,9 +24,7 @@
"reportUnknownLambdaType": "hint",
"reportMissingParameterType": "hint",
"reportMissingTypeArgument": "hint",
"reportUnnecessaryContains": "hint",
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryCast": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",

View File

@@ -246,10 +246,8 @@ class AccountService:
)
)
account = Account()
account.email = email
account.name = name
password_to_set = None
salt_to_set = None
if password:
valid_password(password)
@@ -261,14 +259,18 @@ class AccountService:
password_hashed = hash_password(password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
password_to_set = base64_password_hashed
salt_to_set = base64_salt
account.interface_language = interface_language
account.interface_theme = interface_theme
# Set timezone based on language
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
account = Account(
name=name,
email=email,
password=password_to_set,
password_salt=salt_to_set,
interface_language=interface_language,
interface_theme=interface_theme,
timezone=language_timezone_mapping.get(interface_language, "UTC"),
)
db.session.add(account)
db.session.commit()

View File

@@ -659,31 +659,31 @@ class AppDslService:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
@@ -773,7 +773,7 @@ class AppDslService:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
if not dependencies:
return []

View File

@@ -70,7 +70,7 @@ class EnterpriseService:
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
return WebAppSettings.model_validate(data)
@classmethod
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
@@ -100,7 +100,7 @@ class EnterpriseService:
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
return WebAppSettings.model_validate(data)
@classmethod
def update_app_access_mode(cls, app_id: str, access_mode: str):

View File

@@ -1,6 +1,7 @@
from collections.abc import Sequence
from enum import Enum
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from configs import dify_config
from core.entities.model_entities import (
@@ -71,7 +72,7 @@ class ProviderResponse(BaseModel):
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: list[ModelType]
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
@@ -82,9 +83,8 @@ class ProviderResponse(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@@ -97,6 +97,7 @@ class ProviderResponse(BaseModel):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class ProviderWithModelsResponse(BaseModel):
@@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel):
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class SimpleProviderEntityResponse(SimpleProviderEntity):
@@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
tenant_id: str
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class DefaultModelResponse(BaseModel):

View File

@@ -46,7 +46,7 @@ class HitTestingService:
from core.app.app_config.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],

View File

@@ -123,7 +123,7 @@ class OpsService:
config_class: type[BaseTracingConfig] = provider_config["config_class"]
other_keys: list[str] = provider_config["other_keys"]
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
default_config_instance = config_class.model_validate(tracing_config)
for key in other_keys:
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)

View File

@@ -269,7 +269,7 @@ class PluginMigration:
for tool in agent_config["tools"]:
if isinstance(tool, dict):
try:
tool_entity = AgentToolEntity(**tool)
tool_entity = AgentToolEntity.model_validate(tool)
if (
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
and tool_entity.provider_id not in excluded_providers

View File

@@ -358,7 +358,7 @@ class RagPipelineService:
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration)
# update dataset
dataset = pipeline.retrieve_dataset(session=session)

View File

@@ -288,7 +288,7 @@ class RagPipelineDslService:
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if (
dataset
and pipeline.is_published
@@ -426,7 +426,7 @@ class RagPipelineDslService:
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if not dataset:
dataset = Dataset(
tenant_id=account.current_tenant_id,
@@ -734,35 +734,35 @@ class RagPipelineDslService:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.DATASOURCE.value:
datasource_entity = DatasourceNodeData(**node["data"])
datasource_entity = DatasourceNodeData.model_validate(node["data"])
if datasource_entity.provider_type != "local_file":
dependencies.append(datasource_entity.plugin_id)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_INDEX.value:
knowledge_index_entity = KnowledgeConfiguration(**node["data"])
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.embedding_model_provider:
dependencies.append(
@@ -783,7 +783,7 @@ class RagPipelineDslService:
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
@@ -873,7 +873,7 @@ class RagPipelineDslService:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
if not dependencies:
return []

View File

@@ -156,13 +156,13 @@ class RagPipelineTransformService:
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
retrieval_setting = RetrievalSetting(**retrieval_model)
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
if indexing_technique == "economy":
retrieval_setting.search_method = "keyword_search"
knowledge_configuration.retrieval_model = retrieval_setting

View File

@@ -1,7 +1,7 @@
import hashlib
import json
from datetime import datetime
from typing import Any, cast
from typing import Any
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
@@ -55,7 +55,7 @@ class MCPToolManageService:
cache=NoOpProviderCredentialCache(),
)
return cast(dict[str, str], encrypter_instance.encrypt(headers))
return encrypter_instance.encrypt(headers)
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:

View File

@@ -242,7 +242,7 @@ class ToolTransformService:
is_team_authorization=db_provider.authed,
server_url=db_provider.masked_server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)]
),
updated_at=int(db_provider.updated_at.timestamp()),
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
@@ -387,6 +387,7 @@ class ToolTransformService:
labels=labels or [],
)
else:
assert tool.operation_id
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",

View File

@@ -36,7 +36,7 @@ def process_trace_tasks(file_info):
if trace_info.get("workflow_data"):
trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"])
if trace_info.get("documents"):
trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]]
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
if trace_instance:

View File

@@ -79,7 +79,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
@@ -112,7 +112,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)

View File

@@ -100,7 +100,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
@@ -133,7 +133,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)

View File

@@ -33,17 +33,19 @@ class TestChatMessageApiPermissions:
@pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account = Account(
name="Test User",
email="test@example.com",
)
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
account.id = str(uuid.uuid4())
tenant = Tenant()
# Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
mock_session_instance = mock.Mock()

View File

@@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions:
@pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account()
account = Account(name="Test User", email="test@example.com")
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
tenant = Tenant()
# Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
mock_session_instance = mock.Mock()

View File

@@ -36,7 +36,7 @@ def test_api_tool(setup_http_mock):
entity=ToolEntity(
identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")),
),
api_bundle=ApiToolBundle(**tool_bundle),
api_bundle=ApiToolBundle.model_validate(tool_bundle),
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
provider_id="test_tool",
)

View File

@@ -16,6 +16,7 @@ from services.errors.account import (
AccountPasswordError,
AccountRegisterError,
CurrentPasswordIncorrectError,
TenantNotFoundError,
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
@@ -1414,7 +1415,7 @@ class TestTenantService:
)
# Try to get current tenant (should fail)
with pytest.raises(AttributeError):
with pytest.raises((AttributeError, TenantNotFoundError)):
TenantService.get_current_tenant_by_account(account)
def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):

View File

@@ -44,27 +44,26 @@ class TestWorkflowService:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = fake.uuid4()
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US" # Set interface language for Site creation
account = Account(
email=fake.email(),
name=fake.name(),
avatar=fake.url(),
status="active",
interface_language="en-US", # Set interface language for Site creation
)
account.created_at = fake.date_time_this_year()
account.id = fake.uuid4()
account.updated_at = account.created_at
# Create a tenant for the account
from models.account import Tenant
tenant = Tenant()
tenant.id = account.tenant_id
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="active",
)
tenant.id = account.current_tenant_id
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
@@ -91,20 +90,21 @@ class TestWorkflowService:
App: Created test app instance
"""
fake = fake or Faker()
app = App()
app.id = fake.uuid4()
app.tenant_id = fake.uuid4()
app.name = fake.company()
app.description = fake.text()
app.mode = AppMode.WORKFLOW
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#FFEAD5"
app.enable_site = True
app.enable_api = True
app.created_by = fake.uuid4()
app = App(
id=fake.uuid4(),
tenant_id=fake.uuid4(),
name=fake.company(),
description=fake.text(),
mode=AppMode.WORKFLOW,
icon_type="emoji",
icon="🤖",
icon_background="#FFEAD5",
enable_site=True,
enable_api=True,
created_by=fake.uuid4(),
workflow_id=None, # Will be set when workflow is created
)
app.updated_by = app.created_by
app.workflow_id = None # Will be set when workflow is created
from extensions.ext_database import db
@@ -126,19 +126,20 @@ class TestWorkflowService:
Workflow: Created test workflow instance
"""
fake = fake or Faker()
workflow = Workflow()
workflow.id = fake.uuid4()
workflow.tenant_id = app.tenant_id
workflow.app_id = app.id
workflow.type = WorkflowType.WORKFLOW.value
workflow.version = Workflow.VERSION_DRAFT
workflow.graph = json.dumps({"nodes": [], "edges": []})
workflow.features = json.dumps({"features": []})
# unique_hash is a computed property based on graph and features
workflow.created_by = account.id
workflow.updated_by = account.id
workflow.environment_variables = []
workflow.conversation_variables = []
workflow = Workflow(
id=fake.uuid4(),
tenant_id=app.tenant_id,
app_id=app.id,
type=WorkflowType.WORKFLOW.value,
version=Workflow.VERSION_DRAFT,
graph=json.dumps({"nodes": [], "edges": []}),
features=json.dumps({"features": []}),
# unique_hash is a computed property based on graph and features
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
)
from extensions.ext_database import db

View File

@@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask:
Tenant: Created test tenant instance
"""
fake = fake or Faker()
tenant = Tenant()
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
tenant.id = fake.uuid4()
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
@@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account = Account(
name=fake.name(),
email=fake.email(),
avatar=fake.url(),
status="active",
interface_language="en-US",
)
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = tenant.id
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at

View File

@@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account = Account(
email=fake.email(),
name=fake.name(),
avatar=fake.url(),
status="active",
interface_language="en-US",
)
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
# monkey-patch attributes for test setup
account.tenant_id = fake.uuid4()
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at
# Create a tenant for the account
from models.account import Tenant
tenant = Tenant()
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="active",
)
tenant.id = account.tenant_id
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
@@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask:
Dataset: Created test dataset instance
"""
fake = fake or Faker()
dataset = Dataset()
dataset.id = fake.uuid4()
dataset.tenant_id = account.tenant_id
dataset.name = f"Test Dataset {fake.word()}"
dataset.description = fake.text(max_nb_chars=200)
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality"
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.embedding_model = "text-embedding-ada-002"
dataset.embedding_model_provider = "openai"
dataset.built_in_field_enabled = False
dataset = Dataset(
id=fake.uuid4(),
tenant_id=account.tenant_id,
name=f"Test Dataset {fake.word()}",
description=fake.text(max_nb_chars=200),
provider="vendor",
permission="only_me",
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
updated_by=account.id,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
built_in_field_enabled=False,
)
from extensions.ext_database import db
@@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask:
"""
fake = fake or Faker()
document = DatasetDocument()
document.id = fake.uuid4()
document.tenant_id = dataset.tenant_id
document.dataset_id = dataset.id
@@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask:
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en"
from extensions.ext_database import db
db.session.add(document)

View File

@@ -96,9 +96,9 @@ class TestMailInviteMemberTask:
password=fake.password(),
interface_language="en-US",
status=AccountStatus.ACTIVE.value,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
account.created_at = datetime.now(UTC)
account.updated_at = datetime.now(UTC)
db_session_with_containers.add(account)
db_session_with_containers.commit()
db_session_with_containers.refresh(account)
@@ -106,9 +106,9 @@ class TestMailInviteMemberTask:
# Create tenant
tenant = Tenant(
name=fake.company(),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
tenant.created_at = datetime.now(UTC)
tenant.updated_at = datetime.now(UTC)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
db_session_with_containers.refresh(tenant)
@@ -118,8 +118,8 @@ class TestMailInviteMemberTask:
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
created_at=datetime.now(UTC),
)
tenant_join.created_at = datetime.now(UTC)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
@@ -164,9 +164,10 @@ class TestMailInviteMemberTask:
password="",
interface_language="en-US",
status=AccountStatus.PENDING.value,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
account.created_at = datetime.now(UTC)
account.updated_at = datetime.now(UTC)
db_session_with_containers.add(account)
db_session_with_containers.commit()
db_session_with_containers.refresh(account)
@@ -176,8 +177,8 @@ class TestMailInviteMemberTask:
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.NORMAL.value,
created_at=datetime.now(UTC),
)
tenant_join.created_at = datetime.now(UTC)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()

View File

@@ -15,13 +15,13 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "600")
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing
# load dotenv file with pydantic-settings
config = DifyConfig()
@@ -35,16 +35,36 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0
assert config.TEMPLATE_TRANSFORM_MAX_LENGTH == 400_000
# annotated field with default value
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 600
# annotated field with custom configured value
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 300
# annotated field with configured value
# annotated field with custom configured value
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
# values from pyproject.toml
assert Version(config.project.version) >= Version("1.0.0")
def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
"""Test that HTTP timeout defaults are correctly set"""
# clear system environment variables
os.environ.clear()
# Set minimal required env vars
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
config = DifyConfig()
# Verify default timeout values
assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 600
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 600
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
@@ -55,7 +75,6 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
@@ -105,7 +124,6 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")

Some files were not shown because too many files have changed in this diff Show More