mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
136 Commits
refactor/i
...
feat/suppo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
573adb1deb | ||
|
|
2020a31785 | ||
|
|
2c04a16eaa | ||
|
|
6325129761 | ||
|
|
9a18a98b58 | ||
|
|
7b9e01aa07 | ||
|
|
2bb19f85c6 | ||
|
|
17fe62cf91 | ||
|
|
e99861d4fe | ||
|
|
a205ee16b9 | ||
|
|
9835730278 | ||
|
|
8331b63baa | ||
|
|
2df4699312 | ||
|
|
2eae7503e1 | ||
|
|
dbae5b0564 | ||
|
|
30cfc9c172 | ||
|
|
420a34a546 | ||
|
|
918bb9a2f7 | ||
|
|
99acdcdef7 | ||
|
|
59fdfc3728 | ||
|
|
614c5e087e | ||
|
|
83719cab73 | ||
|
|
9e73e8b9e8 | ||
|
|
d4be356ffb | ||
|
|
47e0f92c0f | ||
|
|
6d033d4064 | ||
|
|
ab290ed968 | ||
|
|
879f839d75 | ||
|
|
15800c6108 | ||
|
|
787a556bd7 | ||
|
|
ea44b895e2 | ||
|
|
37f26c412f | ||
|
|
1da8027445 | ||
|
|
ce3e2e5eb8 | ||
|
|
b69f952e3e | ||
|
|
0784c6295d | ||
|
|
45c89bd6de | ||
|
|
8ac3bd1768 | ||
|
|
945d1569ee | ||
|
|
61526c027d | ||
|
|
4689e8953e | ||
|
|
4e701b1efd | ||
|
|
21c2de2d7e | ||
|
|
72a6cde828 | ||
|
|
0476937f55 | ||
|
|
fc187d6998 | ||
|
|
0dcacdf83d | ||
|
|
7a2a8a2ffd | ||
|
|
2ddc3658a0 | ||
|
|
6a5236b200 | ||
|
|
6c0a91a64f | ||
|
|
df6451076b | ||
|
|
19339c3de1 | ||
|
|
c5acffb034 | ||
|
|
110bb5e5a5 | ||
|
|
d7663159e9 | ||
|
|
2440ac43b1 | ||
|
|
41cb39eb5d | ||
|
|
809a0ab6bf | ||
|
|
51b63b2398 | ||
|
|
59b89b9971 | ||
|
|
ecd8f32cce | ||
|
|
909259da37 | ||
|
|
366ddb05ae | ||
|
|
d587480a3e | ||
|
|
765189d4f5 | ||
|
|
f6aa2498a3 | ||
|
|
f6641c0f41 | ||
|
|
f4df759ba6 | ||
|
|
3a628bc671 | ||
|
|
175571e740 | ||
|
|
8cb3ed5cc2 | ||
|
|
c05e47ebc0 | ||
|
|
b2ac11bc47 | ||
|
|
af83120832 | ||
|
|
1e03c97663 | ||
|
|
41e3ecc837 | ||
|
|
acb2488fc8 | ||
|
|
d6d8cca053 | ||
|
|
f601093ccc | ||
|
|
0f3d4d0b6e | ||
|
|
60777bc610 | ||
|
|
21a50e22d2 | ||
|
|
fc6e2d14a5 | ||
|
|
c439e82038 | ||
|
|
a97ff587d2 | ||
|
|
91144207e0 | ||
|
|
0720bc7408 | ||
|
|
ab62a9662c | ||
|
|
d6a8af03b4 | ||
|
|
65c7c01d90 | ||
|
|
e6e76852d5 | ||
|
|
930c4cb609 | ||
|
|
0c8447fd0e | ||
|
|
37c3283450 | ||
|
|
723b69cf8d | ||
|
|
85859b6723 | ||
|
|
c1a13fa553 | ||
|
|
4f0c9fdf2b | ||
|
|
4271602cfc | ||
|
|
4f14d7c0ca | ||
|
|
38554c5f3e | ||
|
|
138ad6e8b3 | ||
|
|
f76f70f0b6 | ||
|
|
7094680e23 | ||
|
|
59dc7c880e | ||
|
|
3fb9b41fe5 | ||
|
|
0ccf8cb23e | ||
|
|
837f769960 | ||
|
|
3367d4258d | ||
|
|
d608be6e7f | ||
|
|
de9c7f2ea4 | ||
|
|
1fbbbb735d | ||
|
|
9915a70d7f | ||
|
|
822298f69d | ||
|
|
ad2f25875e | ||
|
|
ad8e79c440 | ||
|
|
f2dcfc976d | ||
|
|
5ccfb1f4ba | ||
|
|
92614765ff | ||
|
|
4f066454d0 | ||
|
|
7ae5819c67 | ||
|
|
2b0f3edcef | ||
|
|
244687c9a7 | ||
|
|
d22c351221 | ||
|
|
006496f24e | ||
|
|
01d500db14 | ||
|
|
4ac3600f81 | ||
|
|
6aba223383 | ||
|
|
f1c19cda74 | ||
|
|
275e86a26c | ||
|
|
077d627953 | ||
|
|
ca0b268ae5 | ||
|
|
25be7c1ad5 | ||
|
|
888cd86afd | ||
|
|
157d916154 |
2
.github/actions/setup-uv/action.yml
vendored
2
.github/actions/setup-uv/action.yml
vendored
@@ -8,7 +8,7 @@ inputs:
|
||||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '0.6.14'
|
||||
default: '~=0.7.11'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
|
||||
28
.github/workflows/deploy-rag-dev.yml
vendored
Normal file
28
.github/workflows/deploy-rag-dev.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Deploy RAG Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
1
.github/workflows/expose_service_ports.sh
vendored
1
.github/workflows/expose_service_ports.sh
vendored
@@ -10,6 +10,7 @@ yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-com
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
|
||||
12
.github/workflows/vdb-tests.yml
vendored
12
.github/workflows/vdb-tests.yml
vendored
@@ -31,6 +31,13 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@v2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
with:
|
||||
@@ -59,7 +66,7 @@ jobs:
|
||||
tidb
|
||||
tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
@@ -75,8 +82,9 @@ jobs:
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: Check TiDB Ready
|
||||
- name: Check VDB Ready (TiDB)
|
||||
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -179,6 +179,7 @@ docker/volumes/pgvecto_rs/data/*
|
||||
docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
docker/volumes/matrixone/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
@@ -192,12 +193,12 @@ sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
@@ -207,3 +208,9 @@ plugins.jsonl
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
|
||||
14
.vscode/README.md
vendored
Normal file
14
.vscode/README.md
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# Debugging with VS Code
|
||||
|
||||
This `launch.json.template` file provides various debug configurations for the Dify project within VS Code / Cursor. To use these configurations, you should copy the contents of this file into a new file named `launch.json` in the same `.vscode` directory.
|
||||
|
||||
## How to Use
|
||||
|
||||
1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory.
|
||||
2. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file.
|
||||
3. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D).
|
||||
4. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button.
|
||||
|
||||
## Tips
|
||||
|
||||
- If you need to debug with Edge browser instead of Chrome, modify the `serverReadyAction` configuration in the "Next.js: debug full stack" section, change `"debugWithChrome"` to `"debugWithEdge"` to use Microsoft Edge for debugging.
|
||||
68
.vscode/launch.json.template
vendored
Normal file
68
.vscode/launch.json.template
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask API",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--no-debugger",
|
||||
"--no-reload"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery Worker (Solo)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
"worker",
|
||||
"-P",
|
||||
"solo",
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Next.js: debug full stack",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/web/node_modules/next/dist/bin/next",
|
||||
"runtimeArgs": ["--inspect"],
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"serverReadyAction": {
|
||||
"action": "debugWithChrome",
|
||||
"killOnServerStop": true,
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -294,6 +294,13 @@ VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Matrixone configration
|
||||
MATRIXONE_HOST=127.0.0.1
|
||||
MATRIXONE_PORT=6001
|
||||
MATRIXONE_USER=dump
|
||||
MATRIXONE_PASSWORD=111
|
||||
MATRIXONE_DATABASE=dify
|
||||
|
||||
# Lindorm configuration
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
@@ -332,9 +339,11 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
# Mail configuration, support: resend, smtp, sendgrid
|
||||
MAIL_TYPE=
|
||||
# If using SendGrid, use the 'from' field for authentication if necessary.
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
# resend configuration
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
@@ -344,7 +353,22 @@ SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=true
|
||||
SMTP_OPPORTUNISTIC_TLS=false
|
||||
|
||||
# SMTP authentication type: 'basic' for traditional username/password authentication,
|
||||
# 'oauth2' for modern OAuth2 authentication with services like Microsoft 365/Outlook
|
||||
SMTP_AUTH_TYPE=basic
|
||||
# OAuth2 configuration for SMTP (required when SMTP_AUTH_TYPE=oauth2)
|
||||
# Client ID from your registered OAuth2 application in the provider's developer portal
|
||||
SMTP_CLIENT_ID=
|
||||
# Client secret from your registered OAuth2 application in the provider's developer portal
|
||||
SMTP_CLIENT_SECRET=
|
||||
# For Microsoft OAuth2 (Office 365/Outlook)
|
||||
# Tenant ID (Directory ID) from your Azure AD/Microsoft 365 account
|
||||
SMTP_TENANT_ID=
|
||||
# OAuth2 provider name - currently only 'microsoft' is supported
|
||||
# This identifies which OAuth2 implementation to use for authentication
|
||||
SMTP_OAUTH2_PROVIDER=microsoft
|
||||
# Sendgid configuration
|
||||
SENDGRID_API_KEY=
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
@@ -491,3 +515,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
|
||||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
|
||||
# Dataset queue monitor configuration
|
||||
QUEUE_MONITOR_THRESHOLD=200
|
||||
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
|
||||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
@@ -43,6 +43,7 @@ select = [
|
||||
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
]
|
||||
|
||||
ignore = [
|
||||
|
||||
@@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install uv
|
||||
ENV UV_VERSION=0.6.14
|
||||
ENV UV_VERSION=0.7.11
|
||||
|
||||
RUN pip install --no-cache-dir uv==${UV_VERSION}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
@@ -68,6 +68,7 @@ def reset_password(email, new_password, password_confirm):
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@@ -280,6 +281,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.MATRIXONE,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
|
||||
@@ -609,7 +609,7 @@ class MailConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Email service provider type ('smtp' or 'resend'), default to None.",
|
||||
description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -658,11 +658,41 @@ class MailConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
SMTP_AUTH_TYPE: str = Field(
|
||||
description="SMTP authentication type ('basic' or 'oauth2')",
|
||||
default="basic",
|
||||
)
|
||||
|
||||
SMTP_CLIENT_ID: Optional[str] = Field(
|
||||
description="OAuth2 client ID for SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_CLIENT_SECRET: Optional[str] = Field(
|
||||
description="OAuth2 client secret for SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_TENANT_ID: Optional[str] = Field(
|
||||
description="OAuth2 tenant ID for Microsoft SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_OAUTH2_PROVIDER: str = Field(
|
||||
description="OAuth2 provider for SMTP authentication (currently only 'microsoft' is supported)",
|
||||
default="microsoft",
|
||||
)
|
||||
|
||||
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
|
||||
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
|
||||
default=50,
|
||||
)
|
||||
|
||||
SENDGRID_API_KEY: Optional[str] = Field(
|
||||
description="API key for SendGrid service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
@@ -24,6 +24,7 @@ from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.matrixone_config import MatrixoneConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
from .vdb.myscale_config import MyScaleConfig
|
||||
from .vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
@@ -256,6 +257,25 @@ class InternalTestConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class DatasetQueueMonitorConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Dataset Queue Monitor
|
||||
"""
|
||||
|
||||
QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
|
||||
description="Threshold for dataset queue monitor",
|
||||
default=200,
|
||||
)
|
||||
QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
|
||||
description="Emails for dataset queue monitor alert, separated by commas",
|
||||
default=None,
|
||||
)
|
||||
QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
|
||||
description="Interval for dataset queue monitor in minutes",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class MiddlewareConfig(
|
||||
# place the configs in alphabet order
|
||||
CeleryConfig,
|
||||
@@ -303,5 +323,7 @@ class MiddlewareConfig(
|
||||
BaiduVectorDBConfig,
|
||||
OpenGaussConfig,
|
||||
TableStoreConfig,
|
||||
DatasetQueueMonitorConfig,
|
||||
MatrixoneConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
14
api/configs/middleware/vdb/matrixone_config.py
Normal file
14
api/configs/middleware/vdb/matrixone_config.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
"""Matrixone vector database configuration."""
|
||||
|
||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||
MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
|
||||
MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
|
||||
MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
|
||||
MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
|
||||
MATRIXONE_METRIC: str = Field(
|
||||
default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
|
||||
)
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.4.1",
|
||||
default="1.4.3",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
@@ -34,6 +34,20 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
@@ -57,6 +71,8 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
@@ -119,9 +119,6 @@ class ForgotPasswordResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
|
||||
@@ -686,6 +686,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@@ -733,6 +734,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
|
||||
@@ -59,7 +59,14 @@ class InstalledAppsListApi(Resource):
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
user_id = current_user.id
|
||||
res = []
|
||||
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
|
||||
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
|
||||
for installed_app in installed_app_list:
|
||||
webapp_setting = webapp_settings.get(installed_app["app"].id)
|
||||
if not webapp_setting:
|
||||
continue
|
||||
if webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
|
||||
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -44,6 +44,17 @@ def only_edition_cloud(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
||||
@@ -29,7 +29,7 @@ from core.plugin.entities.request import (
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import compact_generate_response
|
||||
from libs.helper import length_prefixed_response
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
@@ -44,7 +44,7 @@ class PluginInvokeLLMApi(Resource):
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@@ -101,7 +101,7 @@ class PluginInvokeTTSApi(Resource):
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@@ -162,7 +162,7 @@ class PluginInvokeToolApi(Resource):
|
||||
),
|
||||
)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@@ -228,7 +228,7 @@ class PluginInvokeAppApi(Resource):
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
|
||||
@@ -32,6 +32,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
|
||||
@@ -47,7 +47,13 @@ class AppInfoApi(Resource):
|
||||
def get(self, app_model: App):
|
||||
"""Get app information"""
|
||||
tags = [tag.name for tag in app_model.tags]
|
||||
return {"name": app_model.name, "description": app_model.description, "tags": tags, "mode": app_model.mode}
|
||||
return {
|
||||
"name": app_model.name,
|
||||
"description": app_model.description,
|
||||
"tags": tags,
|
||||
"mode": app_model.mode,
|
||||
"author_name": app_model.author_name,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, "/parameters")
|
||||
|
||||
@@ -135,6 +135,20 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("created_at__before", type=str, location="args")
|
||||
parser.add_argument("created_at__after", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
@@ -158,6 +172,8 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
@@ -5,7 +5,11 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
@@ -70,6 +74,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response, 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -193,6 +198,7 @@ class DatasetApi(DatasetApiResource):
|
||||
|
||||
return data, 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -293,6 +299,7 @@ class DatasetApi(DatasetApiResource):
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, _, dataset_id):
|
||||
"""
|
||||
Deletes a dataset given its ID.
|
||||
@@ -369,6 +376,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
)
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
|
||||
@@ -19,7 +19,11 @@ from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
)
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
@@ -35,6 +39,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -99,6 +104,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -158,6 +164,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
@@ -175,8 +182,11 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
@@ -206,12 +216,16 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
|
||||
raise ValueError("process_rule is required.")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
dataset_process_rule=dataset_process_rule,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
@@ -225,6 +239,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
@@ -295,6 +310,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DocumentDeleteApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
|
||||
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@@ -14,6 +14,7 @@ from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
@@ -39,6 +40,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
@@ -54,6 +56,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, metadata_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@@ -73,6 +76,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, action):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -88,6 +92,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@@ -8,6 +8,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
@@ -35,6 +36,7 @@ class SegmentApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
@@ -139,6 +141,7 @@ class SegmentApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DatasetSegmentApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -162,6 +165,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
return 204
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -236,6 +240,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
"""Create child chunk."""
|
||||
# check dataset
|
||||
@@ -332,6 +337,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
"""Resource for updating child chunks."""
|
||||
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
"""Delete child chunk."""
|
||||
# check dataset
|
||||
@@ -370,6 +376,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
"""Update child chunk."""
|
||||
# check dataset
|
||||
|
||||
@@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow
|
||||
from . import (
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
forgot_password,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,8 @@ from libs.passport import PassportService
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
@@ -46,10 +48,22 @@ class AppMeta(WebApiResource):
|
||||
class AppAccessMode(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
parser.add_argument("appId", type=str, required=False, location="args")
|
||||
parser.add_argument("appCode", type=str, required=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id = args["appId"]
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"accessMode": "public"}
|
||||
|
||||
app_id = args.get("appId")
|
||||
if args.get("appCode"):
|
||||
app_code = args["appCode"]
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
|
||||
if not app_id:
|
||||
raise ValueError("appId or appCode must be provided")
|
||||
|
||||
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
|
||||
return {"accessMode": res.access_mode}
|
||||
@@ -75,6 +89,10 @@ class AppWebAuthPermission(Resource):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"result": True}
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
@@ -82,7 +100,9 @@ class AppWebAuthPermission(Resource):
|
||||
app_id = args["appId"]
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
res = True
|
||||
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
return {"result": res}
|
||||
|
||||
|
||||
|
||||
147
api/controllers/web/forgot_password.py
Normal file
147
api/controllers/web/forgot_password.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
@@ -1,12 +1,11 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
@@ -16,6 +15,8 @@ from services.webapp_auth_service import WebAppAuthService
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -23,10 +24,6 @@ class LoginApi(Resource):
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError:
|
||||
@@ -36,12 +33,8 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
raise AccountNotFound()
|
||||
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
return {"result": "success", "token": token}
|
||||
token = WebAppAuthService.login(account=account)
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
|
||||
|
||||
# class LogoutApi(Resource):
|
||||
@@ -56,6 +49,7 @@ class LoginApi(Resource):
|
||||
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@@ -78,6 +72,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@@ -86,9 +81,6 @@ class EmailCodeLoginApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
@@ -105,16 +97,12 @@ class EmailCodeLoginApi(Resource):
|
||||
if not account:
|
||||
raise AccountNotFound()
|
||||
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "token": token}
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
|
||||
|
||||
# api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LoginApi, "/login")
|
||||
# api.add_resource(LogoutApi, "/logout")
|
||||
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
@@ -11,6 +13,7 @@ from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
|
||||
class PassportResource(Resource):
|
||||
@@ -20,10 +23,19 @@ class PassportResource(Resource):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
user_id = request.args.get("user_id")
|
||||
web_app_access_token = request.args.get("web_app_access_token")
|
||||
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
|
||||
# exchange token for enterprise logined web user
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
|
||||
if enterprise_user_decoded:
|
||||
# a web user has already logged in, exchange a token for this app without redirecting to the login page
|
||||
return exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
|
||||
)
|
||||
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not app_settings or not app_settings.access_mode == "public":
|
||||
@@ -84,6 +96,128 @@ class PassportResource(Resource):
|
||||
api.add_resource(PassportResource, "/passport")
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
user_id = enterprise_user_decoded.get("user_id")
|
||||
end_user_id = enterprise_user_decoded.get("end_user_id")
|
||||
session_id = enterprise_user_decoded.get("session_id")
|
||||
user_auth_type = enterprise_user_decoded.get("auth_type")
|
||||
if not user_auth_type:
|
||||
raise Unauthorized("Missing auth_type in the token.")
|
||||
|
||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||
if not site:
|
||||
raise NotFound()
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
|
||||
|
||||
if app_auth_type == WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
|
||||
end_user = None
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if session_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.filter(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not end_user:
|
||||
if not session_id:
|
||||
raise NotFound("Missing session_id for existing web user.")
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
exp = int(exp_dt.timestamp())
|
||||
payload = {
|
||||
"iss": site.id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user.id,
|
||||
"auth_type": user_auth_type,
|
||||
"granted_at": int(datetime.now(UTC).timestamp()),
|
||||
"token_source": "webapp",
|
||||
"exp": exp,
|
||||
}
|
||||
token: str = PassportService().issue(payload)
|
||||
return {
|
||||
"access_token": token,
|
||||
}
|
||||
|
||||
|
||||
def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
user_id = token_decoded.get("user_id")
|
||||
end_user = None
|
||||
if user_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
||||
)
|
||||
|
||||
if not end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
"access_token": tk,
|
||||
}
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
@@ -8,8 +9,9 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
def validate_jwt_token(view=None):
|
||||
@@ -45,7 +47,8 @@ def decode_jwt_token():
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first()
|
||||
app_id = decoded.get("app_id")
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
@@ -53,23 +56,30 @@ def decode_jwt_token():
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
webapp_settings = None
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||
)
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(
|
||||
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
|
||||
)
|
||||
|
||||
return app_model, end_user
|
||||
except Unauthorized as e:
|
||||
if system_features.webapp_auth.enabled:
|
||||
if not app_code:
|
||||
raise Unauthorized("Please re-login to access the web app.")
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
|
||||
)
|
||||
@@ -95,15 +105,41 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au
|
||||
raise Unauthorized("webapp token expired.")
|
||||
|
||||
|
||||
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||
def _validate_user_accessibility(
|
||||
decoded,
|
||||
app_code,
|
||||
app_web_auth_enabled: bool,
|
||||
system_webapp_auth_enabled: bool,
|
||||
webapp_settings: WebAppSettings | None,
|
||||
):
|
||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||
# Check if the user is allowed to access the web app
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
if not webapp_settings:
|
||||
raise WebAppAuthRequiredError("Web app settings not found.")
|
||||
|
||||
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
auth_type = decoded.get("auth_type")
|
||||
granted_at = decoded.get("granted_at")
|
||||
if not auth_type:
|
||||
raise WebAppAuthAccessDeniedError("Missing auth_type in the token.")
|
||||
if not granted_at:
|
||||
raise WebAppAuthAccessDeniedError("Missing granted_at in the token.")
|
||||
# check if sso has been updated
|
||||
if auth_type == "external":
|
||||
last_update_time = EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
|
||||
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
|
||||
elif auth_type == "internal":
|
||||
last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
|
||||
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
|
||||
|
||||
|
||||
class WebApiResource(Resource):
|
||||
|
||||
@@ -138,15 +138,12 @@ class DatasetConfigManager:
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -31,6 +31,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
@@ -366,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
@@ -399,20 +401,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -449,24 +448,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
@@ -23,6 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
@@ -182,20 +183,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=context,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -229,24 +227,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
@@ -33,6 +34,8 @@ from models.model import App, AppMode, Message, MessageAnnotation
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(
|
||||
@@ -298,7 +301,7 @@ class AppRunner:
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
|
||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
@@ -317,18 +320,28 @@ class AppRunner:
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
text += result.delta.message.content
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, str):
|
||||
# TODO(QuantumGhost): Add multimodal output support for easy ui.
|
||||
_logger.warning("received multimodal output, type=%s", type(content))
|
||||
text += content.data
|
||||
else:
|
||||
text += content # failback to str
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not usage:
|
||||
if usage is None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
llm_result = LLMResult(
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -29,6 +29,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
@@ -194,6 +195,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
@@ -209,19 +211,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
context=context,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -408,24 +407,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
||||
@@ -48,6 +48,7 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
@@ -309,6 +310,23 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
delta_text = ""
|
||||
for content in chunk.delta.message.content:
|
||||
logger.debug(
|
||||
"The content type %s in LLM chunk delta message content.: %r", type(content), content
|
||||
)
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
delta_text += content.data
|
||||
elif isinstance(content, str):
|
||||
delta_text += content # failback to str
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported content type %s in LLM chunk delta message content.: %r",
|
||||
type(content),
|
||||
content,
|
||||
)
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
@@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
Check model status and raise ValueError if not active.
|
||||
|
||||
:raises ValueError: When model status is not active, with a descriptive message
|
||||
"""
|
||||
if self.status == ModelStatus.ACTIVE:
|
||||
return
|
||||
|
||||
error_messages = {
|
||||
ModelStatus.NO_CONFIGURE: "Model is not configured",
|
||||
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
|
||||
ModelStatus.NO_PERMISSION: "No permission to use this model",
|
||||
ModelStatus.DISABLED: "Model is disabled",
|
||||
}
|
||||
|
||||
if self.status in error_messages:
|
||||
raise ValueError(error_messages[self.status])
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
|
||||
@@ -41,45 +41,53 @@ class Extensible:
|
||||
extensions = []
|
||||
position_map: dict[str, int] = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
|
||||
current_dir_path = os.path.dirname(current_path)
|
||||
# Get the package name from the module path
|
||||
package_name = ".".join(cls.__module__.split(".")[:-1])
|
||||
|
||||
# traverse subdirectories
|
||||
for subdir_name in os.listdir(current_dir_path):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
# Get package directory path
|
||||
package_spec = importlib.util.find_spec(package_name)
|
||||
if not package_spec or not package_spec.origin:
|
||||
raise ImportError(f"Could not find package {package_name}")
|
||||
|
||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||
extension_name = subdir_name
|
||||
if os.path.isdir(subdir_path):
|
||||
package_dir = os.path.dirname(package_spec.origin)
|
||||
|
||||
# Traverse subdirectories
|
||||
for subdir_name in os.listdir(package_dir):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
|
||||
subdir_path = os.path.join(package_dir, subdir_name)
|
||||
if not os.path.isdir(subdir_path):
|
||||
continue
|
||||
|
||||
extension_name = subdir_name
|
||||
file_names = os.listdir(subdir_path)
|
||||
|
||||
# is builtin extension, builtin extension
|
||||
# in the front-end page and business logic, there are special treatments.
|
||||
# Check for extension module file
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Check for builtin flag and position
|
||||
builtin = False
|
||||
# default position is 0 can not be None for sort_to_dict_by_position_map
|
||||
position = 0
|
||||
if "__builtin__" in file_names:
|
||||
builtin = True
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||
if os.path.exists(builtin_file_path):
|
||||
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + ".py")
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
# Import the extension module
|
||||
module_name = f"{package_name}.{extension_name}.{extension_name}"
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
||||
raise ImportError(f"Failed to load module {module_name}")
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# Find extension class
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
@@ -87,21 +95,21 @@ class Extensible:
|
||||
break
|
||||
|
||||
if not extension_class:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
|
||||
continue
|
||||
|
||||
# Load schema if not builtin
|
||||
json_data: dict[str, Any] = {}
|
||||
if not builtin:
|
||||
if "schema.json" not in file_names:
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
if not os.path.exists(json_path):
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
# Create extension
|
||||
extensions.append(
|
||||
ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
@@ -113,6 +121,11 @@ class Extensible:
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
# Sort extensions by position
|
||||
sorted_extensions = sort_to_dict_by_position_map(
|
||||
position_map=position_map, data=extensions, name_func=lambda x: x.name
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
@@ -64,7 +65,7 @@ class CodeExecutor:
|
||||
:param code: code
|
||||
:return:
|
||||
"""
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
|
||||
url = code_execution_endpoint_url / "v1" / "sandbox" / "run"
|
||||
|
||||
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
|
||||
|
||||
|
||||
@@ -7,29 +7,28 @@ from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str):
|
||||
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
|
||||
unique_identifier=plugin_unique_identifier
|
||||
)
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
|
||||
return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier))
|
||||
|
||||
|
||||
def download_plugin_pkg(plugin_unique_identifier: str):
|
||||
url = str(get_plugin_pkg_url(plugin_unique_identifier))
|
||||
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@@ -38,7 +38,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
if len(text_chunks) == 0:
|
||||
return True
|
||||
|
||||
text_chunk = random.choice(text_chunks)
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
|
||||
@@ -542,8 +542,6 @@ class LBModelManager:
|
||||
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
|
||||
"""
|
||||
Cooldown model load balancing config
|
||||
|
||||
@@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
|
||||
deprecated: bool = False
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def support_structure_output(self) -> bool:
|
||||
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig):
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
@@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig):
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def validate_host(cls, v, info: ValidationInfo):
|
||||
if v is not None and v != "":
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
return v
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@@ -81,7 +81,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
)
|
||||
decrypt_trace_config_key = str(decrypt_trace_config)
|
||||
decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True)
|
||||
tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
|
||||
if tracing_instance is None:
|
||||
# create new tracing_instance and update the cache if it absent
|
||||
|
||||
@@ -40,9 +40,14 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
self.host = weave_config.host
|
||||
|
||||
# Login with API key first, including host if provided
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
# Login with API key first
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if not login_status:
|
||||
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||
raise ValueError("Weave login failed")
|
||||
@@ -386,7 +391,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
|
||||
@@ -11,14 +11,12 @@ class BaseBackwardsInvocation:
|
||||
try:
|
||||
for chunk in response:
|
||||
if isinstance(chunk, BaseModel | dict):
|
||||
yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode() + b"\n\n"
|
||||
elif isinstance(chunk, str):
|
||||
yield f"event: {chunk}\n\n".encode()
|
||||
yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode()
|
||||
except Exception as e:
|
||||
error_message = BaseBackwardsInvocationResponse(error=str(e)).model_dump_json()
|
||||
yield f"{error_message}\n\n".encode()
|
||||
yield error_message.encode()
|
||||
else:
|
||||
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode() + b"\n\n"
|
||||
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.plugin.entities.request import (
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
LLMNode.deduct_llm_quota(
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
@@ -64,7 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
|
||||
@@ -156,9 +156,23 @@ class PluginInstallTaskStartResponse(BaseModel):
|
||||
task_id: str = Field(description="The ID of the install task.")
|
||||
|
||||
|
||||
class PluginUploadResponse(BaseModel):
|
||||
class PluginVerification(BaseModel):
|
||||
"""
|
||||
Verification of the plugin.
|
||||
"""
|
||||
|
||||
class AuthorizedCategory(StrEnum):
|
||||
Langgenius = "langgenius"
|
||||
Partner = "partner"
|
||||
Community = "community"
|
||||
|
||||
authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.")
|
||||
|
||||
|
||||
class PluginDecodeResponse(BaseModel):
|
||||
unique_identifier: str = Field(description="The unique identifier of the plugin.")
|
||||
manifest: PluginDeclaration
|
||||
verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information")
|
||||
|
||||
|
||||
class PluginOAuthAuthorizationUrlResponse(BaseModel):
|
||||
|
||||
@@ -31,8 +31,7 @@ from core.plugin.impl.exc import (
|
||||
PluginUniqueIdentifierError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
@@ -53,9 +52,9 @@ class BasePluginClient:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = plugin_daemon_inner_api_key
|
||||
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
headers["Accept-Encoding"] = "gzip, deflate, br"
|
||||
|
||||
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
|
||||
|
||||
@@ -10,10 +10,10 @@ from core.plugin.entities.plugin import (
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginDecodeResponse,
|
||||
PluginInstallTask,
|
||||
PluginInstallTaskStartResponse,
|
||||
PluginListResponse,
|
||||
PluginUploadResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
@@ -53,7 +53,7 @@ class PluginInstaller(BasePluginClient):
|
||||
tenant_id: str,
|
||||
pkg: bytes,
|
||||
verify_signature: bool = False,
|
||||
) -> PluginUploadResponse:
|
||||
) -> PluginDecodeResponse:
|
||||
"""
|
||||
Upload a plugin package and return the plugin unique identifier.
|
||||
"""
|
||||
@@ -68,7 +68,7 @@ class PluginInstaller(BasePluginClient):
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/upload/package",
|
||||
PluginUploadResponse,
|
||||
PluginDecodeResponse,
|
||||
files=body,
|
||||
data=data,
|
||||
)
|
||||
@@ -176,6 +176,18 @@ class PluginInstaller(BasePluginClient):
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
)
|
||||
|
||||
def decode_plugin_from_identifier(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDecodeResponse:
|
||||
"""
|
||||
Decode a plugin from an identifier.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/decode/from_identifier",
|
||||
PluginDecodeResponse,
|
||||
data={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_by_ids(
|
||||
self, tenant_id: str, plugin_ids: Sequence[str]
|
||||
) -> Sequence[PluginInstallation]:
|
||||
|
||||
@@ -3,7 +3,9 @@ from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
@@ -393,19 +395,13 @@ class ProviderManager:
|
||||
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
"""
|
||||
Get all provider records of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
# TODO: Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
# Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -416,17 +412,12 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider model records of the workspace
|
||||
provider_models = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
return provider_name_to_provider_model_records_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -437,17 +428,14 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
preferred_provider_types = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
return provider_name_to_preferred_provider_type_records_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -458,18 +446,14 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_model_settings = (
|
||||
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
|
||||
provider_model_setting
|
||||
)
|
||||
)
|
||||
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -492,15 +476,14 @@ class ProviderManager:
|
||||
if not model_load_balancing_enabled:
|
||||
return {}
|
||||
|
||||
provider_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@@ -626,10 +609,9 @@ class ProviderManager:
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if (
|
||||
custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")
|
||||
):
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
@@ -733,7 +715,7 @@ class ProviderManager:
|
||||
return SystemConfiguration(enabled=False)
|
||||
|
||||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict = {}
|
||||
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
@@ -758,6 +740,11 @@ class ProviderManager:
|
||||
else:
|
||||
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
|
||||
|
||||
if provider_record.quota_used is None:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
@@ -791,10 +778,9 @@ class ProviderManager:
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
provider_credentials: dict[str, Any] = {}
|
||||
if provider_records and provider_records[0].encrypted_config:
|
||||
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
|
||||
@@ -720,7 +720,7 @@ STOPWORDS = {
|
||||
"〉",
|
||||
"〈",
|
||||
"…",
|
||||
" ",
|
||||
" ",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
@@ -731,16 +731,6 @@ STOPWORDS = {
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"二",
|
||||
"三",
|
||||
"四",
|
||||
|
||||
0
api/core/rag/datasource/vdb/matrixone/__init__.py
Normal file
0
api/core/rag/datasource/vdb/matrixone/__init__.py
Normal file
233
api/core/rag/datasource/vdb/matrixone/matrixone_vector.py
Normal file
233
api/core/rag/datasource/vdb/matrixone/matrixone_vector.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Any, Optional
|
||||
|
||||
from mo_vector.client import MoVectorClient # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 6001
|
||||
user: str = "dump"
|
||||
password: str = "111"
|
||||
database: str = "dify"
|
||||
metric: str = "l2"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config host is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config port is required")
|
||||
if not values["user"]:
|
||||
raise ValueError("config user is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config password is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config database is required")
|
||||
return values
|
||||
|
||||
|
||||
def ensure_client(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(None, False)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MatrixoneVector(BaseVector):
|
||||
"""
|
||||
Matrixone vector storage implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, config: MatrixoneConfig):
|
||||
super().__init__(collection_name)
|
||||
self.config = config
|
||||
self.collection_name = collection_name.lower()
|
||||
self.client = None
|
||||
|
||||
@property
|
||||
def collection_name(self):
|
||||
return self._collection_name
|
||||
|
||||
@collection_name.setter
|
||||
def collection_name(self, value):
|
||||
self._collection_name = value
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MATRIXONE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(len(embeddings[0]), True)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient:
|
||||
"""
|
||||
Create a new client for the collection.
|
||||
|
||||
The collection will be created if it doesn't exist.
|
||||
"""
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
client = MoVectorClient(
|
||||
connection_string=f"mysql+pymysql://{self.config.user}:{self.config.password}@{self.config.host}:{self.config.port}/{self.config.database}",
|
||||
table_name=self.collection_name,
|
||||
vector_dimension=dimension,
|
||||
create_table=create_table,
|
||||
)
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return client
|
||||
try:
|
||||
client.create_full_text_index()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create full text index")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
return client
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(len(embeddings[0]), True)
|
||||
assert self.client is not None
|
||||
ids = []
|
||||
for _, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
ids.append(doc_id)
|
||||
self.client.insert(
|
||||
texts=[doc.page_content for doc in documents],
|
||||
embeddings=embeddings,
|
||||
metadatas=[doc.metadata for doc in documents],
|
||||
ids=ids,
|
||||
)
|
||||
return ids
|
||||
|
||||
@ensure_client
|
||||
def text_exists(self, id: str) -> bool:
|
||||
assert self.client is not None
|
||||
result = self.client.get(ids=[id])
|
||||
return len(result) > 0
|
||||
|
||||
@ensure_client
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
assert self.client is not None
|
||||
if not ids:
|
||||
return
|
||||
self.client.delete(ids=ids)
|
||||
|
||||
@ensure_client
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
assert self.client is not None
|
||||
results = self.client.query_by_metadata(filter={key: value})
|
||||
return [result.id for result in results]
|
||||
|
||||
@ensure_client
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
assert self.client is not None
|
||||
self.client.delete(filter={key: value})
|
||||
|
||||
@ensure_client
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
assert self.client is not None
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = {"document_id": {"$in": document_ids_filter}}
|
||||
|
||||
results = self.client.query(
|
||||
query_vector=query_vector,
|
||||
k=top_k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
docs = []
|
||||
# TODO: add the score threshold to the query
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=result.document,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
@ensure_client
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
assert self.client is not None
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = {"document_id": {"$in": document_ids_filter}}
|
||||
score_threshold = float(kwargs.get("score_threshold", 0.0))
|
||||
|
||||
results = self.client.full_text_query(
|
||||
keywords=[query],
|
||||
k=top_k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
if isinstance(metadata, str):
|
||||
import json
|
||||
|
||||
metadata = json.loads(metadata)
|
||||
score = 1 - result.distance
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=result.document,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
@ensure_client
|
||||
def delete(self) -> None:
|
||||
assert self.client is not None
|
||||
self.client.delete()
|
||||
|
||||
|
||||
class MatrixoneVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MATRIXONE, collection_name))
|
||||
|
||||
config = MatrixoneConfig(
|
||||
host=dify_config.MATRIXONE_HOST or "localhost",
|
||||
port=dify_config.MATRIXONE_PORT or 6001,
|
||||
user=dify_config.MATRIXONE_USER or "dump",
|
||||
password=dify_config.MATRIXONE_PASSWORD or "111",
|
||||
database=dify_config.MATRIXONE_DATABASE or "dify",
|
||||
metric=dify_config.MATRIXONE_METRIC or "l2",
|
||||
)
|
||||
return MatrixoneVector(collection_name=collection_name, config=config)
|
||||
@@ -80,6 +80,23 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
self.delete()
|
||||
|
||||
vals = []
|
||||
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
|
||||
for row in params:
|
||||
val = int(row[6])
|
||||
vals.append(val)
|
||||
if len(vals) == 0:
|
||||
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.")
|
||||
if any(val == 0 for val in vals):
|
||||
try:
|
||||
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Failed to set ob_vector_memory_limit_percentage. "
|
||||
+ "Maybe the database user has insufficient privilege.",
|
||||
e,
|
||||
)
|
||||
|
||||
cols = [
|
||||
Column("id", String(36), primary_key=True, autoincrement=False),
|
||||
Column("vector", VECTOR(self._vec_dim)),
|
||||
@@ -110,22 +127,6 @@ class OceanBaseVector(BaseVector):
|
||||
+ "to support fulltext index and vector index in the same table",
|
||||
e,
|
||||
)
|
||||
vals = []
|
||||
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
|
||||
for row in params:
|
||||
val = int(row[6])
|
||||
vals.append(val)
|
||||
if len(vals) == 0:
|
||||
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.")
|
||||
if any(val == 0 for val in vals):
|
||||
try:
|
||||
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Failed to set ob_vector_memory_limit_percentage. "
|
||||
+ "Maybe the database user has insufficient privilege.",
|
||||
e,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
|
||||
@@ -184,7 +184,16 @@ class OpenSearchVector(BaseVector):
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
query["query"] = {
|
||||
"script_score": {
|
||||
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
|
||||
"script": {
|
||||
"source": "knn_score",
|
||||
"lang": "knn",
|
||||
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
@@ -209,10 +218,10 @@ class OpenSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||
full_text_query["query"]["bool"]["filter"] = [{"terms": {"metadata.document_id": document_ids_filter}}]
|
||||
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
@@ -255,7 +264,8 @@ class OpenSearchVector(BaseVector):
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
|
||||
"document_id": {"type": "keyword"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
@@ -303,7 +303,6 @@ class OracleVector(BaseVector):
|
||||
return docs
|
||||
else:
|
||||
return [Document(page_content="", metadata={})]
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_connection() as conn:
|
||||
|
||||
@@ -164,6 +164,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
|
||||
|
||||
return HuaweiCloudVectorFactory
|
||||
case VectorType.MATRIXONE:
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@@ -29,3 +29,4 @@ class VectorType(StrEnum):
|
||||
OPENGAUSS = "opengauss"
|
||||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
|
||||
@@ -41,6 +41,13 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
# Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0,
|
||||
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
|
||||
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
|
||||
# which does not contain the deprecation check.
|
||||
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
|
||||
weaviate.connect.connection.PYPI_TIMEOUT = 0.001
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
|
||||
@@ -139,4 +139,4 @@ class CacheEmbedding(Embeddings):
|
||||
logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
|
||||
raise ex
|
||||
|
||||
return embedding_results
|
||||
return embedding_results # type: ignore
|
||||
|
||||
@@ -22,6 +22,7 @@ class FirecrawlApp:
|
||||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
"timeout": 30000,
|
||||
"integration": "dify",
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
@@ -39,7 +40,7 @@ class FirecrawlApp:
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
|
||||
headers = self._prepare_headers()
|
||||
json_data = {"url": url}
|
||||
json_data = {"url": url, "integration": "dify"}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
|
||||
@@ -49,7 +50,6 @@ class FirecrawlApp:
|
||||
return cast(str, job_id)
|
||||
else:
|
||||
self._handle_error(response, "start crawl job")
|
||||
# FIXME: unreachable code for mypy
|
||||
return "" # unreachable
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
@@ -82,7 +82,6 @@ class FirecrawlApp:
|
||||
)
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
# FIXME: unreachable code for mypy
|
||||
return {} # unreachable
|
||||
|
||||
def _format_crawl_status_response(
|
||||
@@ -126,4 +125,31 @@ class FirecrawlApp:
|
||||
|
||||
def _handle_error(self, response, action) -> None:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
||||
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
"query": query,
|
||||
"limit": 5,
|
||||
"lang": "en",
|
||||
"country": "us",
|
||||
"timeout": 60000,
|
||||
"ignoreInvalidURLs": False,
|
||||
"scrapeOptions": {},
|
||||
"integration": "dify",
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if not response_data.get("success"):
|
||||
raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
|
||||
return cast(dict[str, Any], response_data)
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "perform search")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
|
||||
|
||||
@@ -79,55 +79,71 @@ class NotionExtractor(BaseExtractor):
|
||||
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict,
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
|
||||
database_content = []
|
||||
if "results" not in data or data["results"] is None:
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
current_query = query_dict.copy()
|
||||
if next_cursor:
|
||||
current_query["start_cursor"] = next_cursor
|
||||
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=current_query,
|
||||
)
|
||||
|
||||
response_data = res.json()
|
||||
|
||||
if "results" not in response_data or response_data["results"] is None:
|
||||
break
|
||||
|
||||
for result in response_data["results"]:
|
||||
properties = result["properties"]
|
||||
data = {}
|
||||
value: Any
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value["type"]
|
||||
if type == "multi_select":
|
||||
value = []
|
||||
multi_select_list = property_value[type]
|
||||
for multi_select in multi_select_list:
|
||||
value.append(multi_select["name"])
|
||||
elif type in {"rich_text", "title"}:
|
||||
if len(property_value[type]) > 0:
|
||||
value = property_value[type][0]["plain_text"]
|
||||
else:
|
||||
value = ""
|
||||
elif type in {"select", "status"}:
|
||||
if property_value[type]:
|
||||
value = property_value[type]["name"]
|
||||
else:
|
||||
value = ""
|
||||
else:
|
||||
value = property_value[type]
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ""
|
||||
for key, value in row_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
|
||||
row_content = row_content + f"{key}:{value_content}\n"
|
||||
else:
|
||||
row_content = row_content + f"{key}:{value}\n"
|
||||
database_content.append(row_content)
|
||||
|
||||
has_more = response_data.get("has_more", False)
|
||||
next_cursor = response_data.get("next_cursor")
|
||||
|
||||
if not database_content:
|
||||
return []
|
||||
for result in data["results"]:
|
||||
properties = result["properties"]
|
||||
data = {}
|
||||
value: Any
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value["type"]
|
||||
if type == "multi_select":
|
||||
value = []
|
||||
multi_select_list = property_value[type]
|
||||
for multi_select in multi_select_list:
|
||||
value.append(multi_select["name"])
|
||||
elif type in {"rich_text", "title"}:
|
||||
if len(property_value[type]) > 0:
|
||||
value = property_value[type][0]["plain_text"]
|
||||
else:
|
||||
value = ""
|
||||
elif type in {"select", "status"}:
|
||||
if property_value[type]:
|
||||
value = property_value[type]["name"]
|
||||
else:
|
||||
value = ""
|
||||
else:
|
||||
value = property_value[type]
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ""
|
||||
for key, value in row_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
|
||||
row_content = row_content + f"{key}:{value_content}\n"
|
||||
else:
|
||||
row_content = row_content + f"{key}:{value}\n"
|
||||
database_content.append(row_content)
|
||||
|
||||
return [Document(page_content="\n".join(database_content))]
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
|
||||
@@ -496,6 +496,8 @@ class DatasetRetrieval:
|
||||
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
|
||||
elif index_type == "high_quality":
|
||||
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
|
||||
else:
|
||||
all_documents = all_documents[:top_k] if top_k else all_documents
|
||||
|
||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@@ -165,7 +165,7 @@ class ReactMultiDatasetRouter:
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -151,11 +151,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_))
|
||||
if not existing:
|
||||
# For new records, get the next sequence number
|
||||
stmt = select(WorkflowRun.sequence_number).where(
|
||||
stmt = select(func.max(WorkflowRun.sequence_number)).where(
|
||||
WorkflowRun.app_id == self._app_id,
|
||||
WorkflowRun.tenant_id == self._tenant_id,
|
||||
)
|
||||
max_sequence = session.scalar(stmt.order_by(WorkflowRun.sequence_number.desc()))
|
||||
max_sequence = session.scalar(stmt)
|
||||
db_model.sequence_number = (max_sequence or 0) + 1
|
||||
else:
|
||||
# For updates, keep the existing sequence number
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
- audio
|
||||
- code
|
||||
- time
|
||||
- qrcode
|
||||
- webscraper
|
||||
|
||||
@@ -153,8 +153,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
raise RuntimeError("not segments found")
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
|
||||
@@ -32,14 +32,14 @@ class ToolFileMessageTransformer:
|
||||
try:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_url(
|
||||
tool_file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
@@ -68,7 +68,7 @@ class ToolFileMessageTransformer:
|
||||
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -77,7 +77,7 @@ class ToolFileMessageTransformer:
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
|
||||
@@ -9,7 +9,7 @@ from copy import copy, deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app, has_request_context
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
@@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -537,24 +538,9 @@ class GraphEngine:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
q.put(
|
||||
ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
@@ -653,26 +639,19 @@ class GraphEngine:
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
# yield control to other threads
|
||||
time.sleep(0.001)
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
elif isinstance(item, BaseLoopEvent):
|
||||
# add parallel info to loop event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
yield item
|
||||
event_stream = node_instance.run()
|
||||
for event in event_stream:
|
||||
if isinstance(event, GraphEngineEvent):
|
||||
# add parallel info to iteration event
|
||||
if isinstance(event, BaseIterationEvent | BaseLoopEvent):
|
||||
event.parallel_id = parallel_id
|
||||
event.parallel_start_node_id = parallel_start_node_id
|
||||
event.parent_parallel_id = parent_parallel_id
|
||||
event.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
yield event
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
run_result = event.run_result
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
@@ -708,7 +687,7 @@ class GraphEngine:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
event.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
@@ -811,28 +790,28 @@ class GraphEngine:
|
||||
should_continue_retry = False
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
elif isinstance(event, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
chunk_content=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
elif isinstance(event, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
|
||||
@@ -214,7 +214,7 @@ class AgentNode(ToolNode):
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
|
||||
@@ -57,7 +57,6 @@ class StreamProcessor(ABC):
|
||||
|
||||
# The branch_identify parameter is added to ensure that
|
||||
# only nodes in the correct logical branch are included.
|
||||
reachable_node_ids.append(edge.target_node_id)
|
||||
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
|
||||
reachable_node_ids.extend(ids)
|
||||
else:
|
||||
@@ -74,6 +73,8 @@ class StreamProcessor(ABC):
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
|
||||
if node_id not in self.rest_node_ids:
|
||||
self.rest_node_ids.append(node_id)
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
|
||||
@@ -397,19 +397,44 @@ def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
# Create Markdown table
|
||||
markdown_table = "| " + " | ".join(rows[0]) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n"
|
||||
for row in rows[1:]:
|
||||
markdown_table += "| " + " | ".join(row) + " |\n"
|
||||
# Combine multi-line text in the header row
|
||||
header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]]
|
||||
|
||||
return markdown_table.strip()
|
||||
# Create Markdown table
|
||||
markdown_table = "| " + " | ".join(header_row) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n"
|
||||
|
||||
# Process each data row and combine multi-line text in each cell
|
||||
for row in rows[1:]:
|
||||
processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row]
|
||||
markdown_table += "| " + " | ".join(processed_row) + " |\n"
|
||||
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
"""Extract text from an Excel file using pandas."""
|
||||
|
||||
def _construct_markdown_table(df: pd.DataFrame) -> str:
|
||||
"""Manually construct a Markdown table from a DataFrame."""
|
||||
# Construct the header row
|
||||
header_row = "| " + " | ".join(df.columns) + " |"
|
||||
|
||||
# Construct the separator row
|
||||
separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |"
|
||||
|
||||
# Construct the data rows
|
||||
data_rows = []
|
||||
for _, row in df.iterrows():
|
||||
data_row = "| " + " | ".join(map(str, row)) + " |"
|
||||
data_rows.append(data_row)
|
||||
|
||||
# Combine all rows into a single string
|
||||
markdown_table = "\n".join([header_row, separator_row] + data_rows)
|
||||
return markdown_table
|
||||
|
||||
try:
|
||||
excel_file = pd.ExcelFile(io.BytesIO(file_content))
|
||||
markdown_table = ""
|
||||
@@ -417,8 +442,15 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
try:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
# Create Markdown table two times to separate tables with a newline
|
||||
markdown_table += df.to_markdown(index=False, floatfmt="") + "\n\n"
|
||||
|
||||
# Combine multi-line text in each cell into a single line
|
||||
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
|
||||
|
||||
# Combine multi-line text in column names into a single line
|
||||
df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns])
|
||||
|
||||
# Manually construct the Markdown table
|
||||
markdown_table += _construct_markdown_table(df) + "\n\n"
|
||||
except Exception as e:
|
||||
continue
|
||||
return markdown_table
|
||||
|
||||
@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
@@ -39,11 +38,3 @@ class RunRetryEvent(BaseModel):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class SingleStepRetryEvent(NodeRunResult):
|
||||
"""Single step retry event"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
|
||||
|
||||
elapsed_time: float = Field(..., description="elapsed time")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import base64
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from random import randint
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
@@ -434,4 +435,4 @@ def _generate_random_string(n: int) -> str:
|
||||
>>> _generate_random_string(5)
|
||||
'abcde'
|
||||
"""
|
||||
return "".join([chr(randint(97, 122)) for _ in range(n)])
|
||||
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n))
|
||||
|
||||
@@ -7,7 +7,7 @@ from datetime import UTC, datetime
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app, has_request_context
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
@@ -37,6 +37,7 @@ from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@@ -583,23 +584,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
run single iteration in parallel mode
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
parallel_mode_run_id = uuid.uuid4().hex
|
||||
graph_engine_copy = graph_engine.create_copy()
|
||||
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
||||
|
||||
@@ -132,3 +132,12 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
# NOTE(QuantumGhost): Temporary workaround for issue #20725
|
||||
# (https://github.com/langgenius/dify/issues/20725).
|
||||
#
|
||||
# The proper fix would be to make `KnowledgeRetrievalNode` inherit
|
||||
# from `BaseNode` instead of `LLMNode`.
|
||||
return False
|
||||
|
||||
@@ -86,31 +86,31 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
||||
)
|
||||
# TODO(-LAN-): Move this check outside.
|
||||
# check rate limit
|
||||
if self.tenant_id:
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with Session(db.engine) as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
session.commit()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with Session(db.engine) as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
session.commit()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
|
||||
@@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: dict | None = None
|
||||
structured_output_enabled: bool = False
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
@@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
return self.structured_output_switch_on and self.structured_output is not None
|
||||
|
||||
156
api/core/workflow/nodes/llm/llm_utils.py
Normal file
156
api/core/workflow/nodes/llm/llm_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from models import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
tenant_id: str, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
)
|
||||
|
||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||
|
||||
# check model
|
||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
# model config
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
return model, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=model.provider_model_bundle,
|
||||
credentials=model.credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
elif isinstance(variable, FileSegment):
|
||||
return [variable.value]
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
return variable.value
|
||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||
return []
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
@@ -3,18 +3,11 @@ import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@@ -42,12 +35,10 @@ from core.model_runtime.entities.model_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
@@ -74,14 +65,11 @@ from core.workflow.nodes.event import (
|
||||
from core.workflow.utils.structured_output.entities import (
|
||||
ResponseFormat,
|
||||
SpecialModelType,
|
||||
SupportStructuredOutputStatus,
|
||||
)
|
||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
@@ -91,7 +79,6 @@ from .entities import (
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
InvalidVariableTypeError,
|
||||
LLMModeRequiredError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
ModelNotExistError,
|
||||
@@ -163,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
@@ -181,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
# fetch files
|
||||
files = (
|
||||
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=self.node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if self.node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
@@ -203,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
query = None
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
if not query and (
|
||||
query_variable := self.graph_runtime_state.variable_pool.get(
|
||||
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
|
||||
)
|
||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
@@ -225,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
)
|
||||
|
||||
@@ -254,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
structured_output = process_structured_output(result_text)
|
||||
@@ -277,7 +271,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
except LLMNodeError as e:
|
||||
except ValueError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@@ -450,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
elif isinstance(variable, FileSegment):
|
||||
return [variable.value]
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
return variable.value
|
||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||
return []
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData):
|
||||
if not node_data.context.enabled:
|
||||
return
|
||||
@@ -527,91 +509,25 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
def _fetch_memory(
|
||||
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
if self.node_data.structured_output_enabled:
|
||||
if model_schema.support_structure_output:
|
||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
return model, model_config_with_cred
|
||||
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
@@ -786,13 +702,25 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
)
|
||||
model_schema = model.model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_config.model} not exist.")
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not model_schema.support_structure_output:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
return filtered_prompt_messages, model_config.stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
||||
structured_output: dict[str, Any] = {}
|
||||
@@ -813,55 +741,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
structured_output = parsed
|
||||
return structured_output
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@@ -903,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
@@ -1185,32 +1064,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
|
||||
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
||||
"""
|
||||
Check if the current model supports structured output.
|
||||
|
||||
Returns:
|
||||
SupportStructuredOutput: The support status of structured output
|
||||
"""
|
||||
# Early return if structured output is disabled
|
||||
if (
|
||||
not isinstance(self.node_data, LLMNodeData)
|
||||
or not self.node_data.structured_output_enabled
|
||||
or not self.node_data.structured_output
|
||||
):
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
# Get model schema and check if it exists
|
||||
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
||||
if not model_schema:
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
|
||||
# Check if model supports structured output feature
|
||||
return (
|
||||
SupportStructuredOutputStatus.SUPPORTED
|
||||
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||
)
|
||||
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
self,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
|
||||
@@ -28,8 +28,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm import LLMNode, ModelConfig
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.utils import variable_template_parser
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
@@ -83,7 +84,7 @@ def extract_json(text):
|
||||
return None
|
||||
|
||||
|
||||
class ParameterExtractorNode(LLMNode):
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
"""
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
@@ -116,8 +117,11 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
files = (
|
||||
self._fetch_files(
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if node_data.vision.enabled
|
||||
@@ -137,7 +141,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
@@ -279,7 +285,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
if text is None:
|
||||
text = ""
|
||||
@@ -794,7 +800,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
Fetch model config.
|
||||
"""
|
||||
if not self._model_instance or not self._model_config:
|
||||
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
|
||||
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||
)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
|
||||
@@ -19,3 +19,12 @@ class QuestionClassifierNodeData(BaseNodeData):
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
# NOTE(QuantumGhost): Temporary workaround for issue #20725
|
||||
# (https://github.com/langgenius/dify/issues/20725).
|
||||
#
|
||||
# The proper fix would be to make `QuestionClassifierNode` inherit
|
||||
# from `BaseNode` instead of `LLMNode`.
|
||||
return False
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.workflow.nodes.llm import (
|
||||
LLMNode,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
@@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode):
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
@@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||
|
||||
files = (
|
||||
self._fetch_files(
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if node_data.vision.enabled
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@@ -17,7 +18,7 @@ class AdvancedSettings(BaseModel):
|
||||
Group.
|
||||
"""
|
||||
|
||||
output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
output_type: SegmentType
|
||||
variables: list[list[str]]
|
||||
group_name: str
|
||||
|
||||
|
||||
@@ -8,4 +8,5 @@ EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_FILE: [],
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.variables import SegmentType
|
||||
|
||||
from .enums import Operation
|
||||
@@ -85,6 +86,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
case SegmentType.ARRAY_FILE if operation == Operation.APPEND:
|
||||
return isinstance(value, File)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
@@ -95,6 +98,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
case SegmentType.ARRAY_FILE if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, File) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
||||
@@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class SupportStructuredOutputStatus(StrEnum):
|
||||
"""Constants for structured output support status"""
|
||||
|
||||
SUPPORTED = "supported"
|
||||
UNSUPPORTED = "unsupported"
|
||||
DISABLED = "disabled"
|
||||
|
||||
@@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"schedule.update_tidb_serverless_status_task",
|
||||
"schedule.clean_messages",
|
||||
"schedule.mail_clean_document_notify_task",
|
||||
"schedule.queue_monitor_task",
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
beat_schedule = {
|
||||
@@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
|
||||
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
|
||||
},
|
||||
"datasets-queue-monitor": {
|
||||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||
"schedule": timedelta(
|
||||
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
|
||||
),
|
||||
},
|
||||
}
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
|
||||
@@ -57,6 +57,9 @@ def load_user_from_request(request_from_flask_login):
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
source = decoded.get("token_source")
|
||||
if source:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
if not user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
|
||||
|
||||
@@ -5,85 +5,54 @@ from flask import Flask
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from libs.mail import MailConfigError, MailMessage, MailSender, MailSenderFactory
|
||||
|
||||
|
||||
class Mail:
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
self._default_send_from = None
|
||||
self._sender: Optional[MailSender] = None
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return self._client is not None
|
||||
return self._sender is not None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
mail_type = dify_config.MAIL_TYPE
|
||||
if not mail_type:
|
||||
logging.warning("MAIL_TYPE is not set")
|
||||
return
|
||||
|
||||
if dify_config.MAIL_DEFAULT_SEND_FROM:
|
||||
self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM
|
||||
|
||||
match mail_type:
|
||||
case "resend":
|
||||
import resend
|
||||
|
||||
api_key = dify_config.RESEND_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("RESEND_API_KEY is not set")
|
||||
|
||||
api_url = dify_config.RESEND_API_URL
|
||||
if api_url:
|
||||
resend.api_url = api_url
|
||||
|
||||
resend.api_key = api_key
|
||||
self._client = resend.Emails
|
||||
case "smtp":
|
||||
from libs.smtp import SMTPClient
|
||||
|
||||
if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT:
|
||||
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
|
||||
if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS:
|
||||
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
|
||||
self._client = SMTPClient(
|
||||
server=dify_config.SMTP_SERVER,
|
||||
port=dify_config.SMTP_PORT,
|
||||
username=dify_config.SMTP_USERNAME or "",
|
||||
password=dify_config.SMTP_PASSWORD or "",
|
||||
_from=dify_config.MAIL_DEFAULT_SEND_FROM or "",
|
||||
use_tls=dify_config.SMTP_USE_TLS,
|
||||
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Unsupported mail type {}".format(mail_type))
|
||||
def init_app(self, app: Flask) -> None:
|
||||
"""Initialize mail sender using the new factory pattern."""
|
||||
try:
|
||||
self._sender = MailSenderFactory.create_from_dify_config(dify_config)
|
||||
if self._sender:
|
||||
logging.info("Mail sender initialized successfully")
|
||||
else:
|
||||
logging.warning("MAIL_TYPE is not set, mail functionality disabled")
|
||||
except MailConfigError as e:
|
||||
logging.exception("Failed to initialize mail sender")
|
||||
raise ValueError(f"Mail configuration error: {e}")
|
||||
except Exception as e:
|
||||
logging.exception("Unexpected error initializing mail sender")
|
||||
raise ValueError(f"Failed to initialize mail sender: {e}")
|
||||
|
||||
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
|
||||
if not self._client:
|
||||
raise ValueError("Mail client is not initialized")
|
||||
"""
|
||||
Send an email using the configured mail sender.
|
||||
|
||||
if not from_ and self._default_send_from:
|
||||
from_ = self._default_send_from
|
||||
Args:
|
||||
to: Recipient email address
|
||||
subject: Email subject
|
||||
html: Email HTML content
|
||||
from_: Sender email address (optional, uses default if not provided)
|
||||
"""
|
||||
if not self._sender:
|
||||
raise ValueError("Mail sender is not initialized")
|
||||
|
||||
if not from_:
|
||||
raise ValueError("mail from is not set")
|
||||
try:
|
||||
# Create mail message
|
||||
message = MailMessage(to=to, subject=subject, html=html, from_=from_)
|
||||
|
||||
if not to:
|
||||
raise ValueError("mail to is not set")
|
||||
# Send the message
|
||||
self._sender.send(message)
|
||||
|
||||
if not subject:
|
||||
raise ValueError("mail subject is not set")
|
||||
|
||||
if not html:
|
||||
raise ValueError("mail html is not set")
|
||||
|
||||
self._client.send(
|
||||
{
|
||||
"from": from_,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to send email to {to}")
|
||||
raise
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user