mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
64 Commits
mcp-condit
...
refactor/u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
227ca64e13 | ||
|
|
61a0fcc2ea | ||
|
|
f627348b11 | ||
|
|
87fb9a6b69 | ||
|
|
97a2e2ec2e | ||
|
|
68d357d7f6 | ||
|
|
a103ad3ee7 | ||
|
|
f65d5a9761 | ||
|
|
6e0a5f5bbd | ||
|
|
22f858152f | ||
|
|
775d2e14fc | ||
|
|
744b287e67 | ||
|
|
c0fc5d98f0 | ||
|
|
08ea79d730 | ||
|
|
f31b821cc0 | ||
|
|
34be16874f | ||
|
|
e9738b891f | ||
|
|
829796514a | ||
|
|
ef1db35f80 | ||
|
|
f9c67621ca | ||
|
|
e29e8e3180 | ||
|
|
7a81e720d4 | ||
|
|
55600c0eb1 | ||
|
|
35e41d7d68 | ||
|
|
b610cf9a11 | ||
|
|
c8e9edc024 | ||
|
|
471cd760d7 | ||
|
|
7f48c57edf | ||
|
|
6569801162 | ||
|
|
9dd83f50a7 | ||
|
|
59c56b1b0d | ||
|
|
94cd2de940 | ||
|
|
3c23375607 | ||
|
|
56047f638f | ||
|
|
9c01d3e775 | ||
|
|
c85c87f3da | ||
|
|
eaa02e3d55 | ||
|
|
0219222a60 | ||
|
|
dba659b220 | ||
|
|
ee6458768e | ||
|
|
ed3d02dc6d | ||
|
|
95471b1188 | ||
|
|
6190cfbfd8 | ||
|
|
11f2f95103 | ||
|
|
2abbc14703 | ||
|
|
b2b2816ade | ||
|
|
4461df1bd9 | ||
|
|
f7f6b4a8b0 | ||
|
|
41be581594 | ||
|
|
20ad5b7ac2 | ||
|
|
a1c0bd7a1c | ||
|
|
fd7c4e8a6d | ||
|
|
41e549af14 | ||
|
|
b7360140ee | ||
|
|
c71f7c7613 | ||
|
|
c905c47775 | ||
|
|
4ca7ba000c | ||
|
|
f260627660 | ||
|
|
1e9142c213 | ||
|
|
82890fe38e | ||
|
|
7dc7c8af98 | ||
|
|
addebc465a | ||
|
|
5ab315aeaf | ||
|
|
f092bc1912 |
@@ -6,7 +6,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
7
.vscode/launch.json.template
vendored
7
.vscode/launch.json.template
vendored
@@ -8,8 +8,7 @@
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
"FLASK_ENV": "development"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
@@ -28,9 +27,7 @@
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
|
||||
@@ -371,6 +371,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||
# Empty by default to allow all file types.
|
||||
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
@@ -608,3 +614,6 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
@@ -15,7 +15,11 @@ FROM base AS packages
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
@@ -49,7 +53,9 @@ RUN \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
curl nodejs \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
19
api/app.py
19
api/app.py
@@ -13,23 +13,12 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
@@ -1601,7 +1601,7 @@ def transform_datasource_credentials():
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
provider="jinareader",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
|
||||
@@ -331,6 +331,31 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||
"Empty by default to allow all file types."
|
||||
),
|
||||
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||
"""
|
||||
Parse and return the blacklist as a set of lowercase extensions.
|
||||
Returns an empty set if no blacklist is configured.
|
||||
"""
|
||||
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||
return set()
|
||||
return {
|
||||
ext.strip().lower().strip(".")
|
||||
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||
if ext.strip()
|
||||
}
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
@@ -920,6 +945,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
|
||||
@@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseHTTPException):
|
||||
error_code = "file_extension_blocked"
|
||||
description = "The file extension is blocked for security reasons."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
|
||||
@@ -16,6 +16,7 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@@ -175,8 +176,10 @@ class AnnotationApi(Resource):
|
||||
api.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
@@ -193,11 +196,14 @@ class AnnotationApi(Resource):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
@@ -19,7 +17,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
@@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
@@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
|
||||
@@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
@@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
@@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@marshal_with(annotation_fields)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
@@ -11,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode, Message
|
||||
@@ -56,26 +55,16 @@ WHERE
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -120,8 +109,11 @@ class DailyConversationStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
@@ -134,18 +126,10 @@ class DailyConversationStatistic(Resource):
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
if start_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
if end_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
@@ -198,26 +182,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -273,26 +248,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -357,26 +323,17 @@ FROM
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -446,26 +403,17 @@ WHERE
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -525,26 +473,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -602,26 +541,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
|
||||
@@ -102,7 +102,18 @@ class DraftWorkflowApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
||||
@api.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
api.model(
|
||||
"SyncDraftWorkflowResponse",
|
||||
{
|
||||
"result": fields.String,
|
||||
"hash": fields.String,
|
||||
"updated_at": fields.String,
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid workflow configuration")
|
||||
@api.response(403, "Permission denied")
|
||||
@edit_permission_required
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -9,6 +6,7 @@ from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@@ -43,23 +41,11 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@@ -100,23 +86,11 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@@ -157,23 +131,11 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@@ -214,23 +176,11 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
||||
@@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@@ -16,7 +17,13 @@ class Subscription(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@@ -8,6 +8,7 @@ import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
@@ -83,6 +84,8 @@ class FileApi(Resource):
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
||||
raise BlockedFileExtensionError(blocked_extension_error.description)
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -83,7 +84,7 @@ class TenantListApi(Resource):
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
@@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
abort(
|
||||
403,
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
|
||||
@@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@@ -2,6 +2,7 @@ from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
@@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
|
||||
# validate args
|
||||
args = segment_create_parser.parse_args()
|
||||
if args["segments"] is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(args["segments"]) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
|
||||
for args_item in args["segments"]:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -66,6 +67,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
@@ -74,7 +76,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
@@ -89,6 +90,28 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
# Set EndUser as current logged-in user for flask_login.current_user
|
||||
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
|
||||
else:
|
||||
# For service API without end-user context, ensure an Account is logged in
|
||||
# so services relying on current_account_with_tenant() work correctly.
|
||||
tenant_owner_info = (
|
||||
db.session.query(Tenant, Account)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.join(Account, TenantAccountJoin.account_id == Account.id)
|
||||
.where(
|
||||
Tenant.id == app_model.tenant_id,
|
||||
TenantAccountJoin.role == "owner",
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if tenant_owner_info:
|
||||
tenant_model, account = tenant_owner_info
|
||||
account.current_tenant = tenant_model
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account not found or tenant is not active.")
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
@@ -138,7 +161,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
app: App,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
@@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
@@ -60,6 +61,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
@@ -391,6 +393,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if should_direct_answer:
|
||||
return
|
||||
|
||||
current_time = time.perf_counter()
|
||||
if self._task_state.first_token_time is None and delta_text.strip():
|
||||
self._task_state.first_token_time = current_time
|
||||
self._task_state.is_streaming_response = True
|
||||
|
||||
if delta_text.strip():
|
||||
self._task_state.last_token_time = current_time
|
||||
|
||||
# Only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
@@ -772,7 +782,33 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
# Set usage first before dumping metadata
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
usage = LLMUsage.empty_usage()
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self._base_task_pipeline.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
@@ -790,20 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class AppRunner:
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: str | None = None,
|
||||
query: str = "",
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -105,7 +105,7 @@ class AppRunner:
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query or "",
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
|
||||
@@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
|
||||
@@ -40,6 +40,7 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
@@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
|
||||
@@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
@@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
app_id: str,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
self._queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
self._graph_engine_layers = graph_engine_layers
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
|
||||
@@ -129,7 +129,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: EasyUIBasedAppConfig = None # type: ignore
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
||||
query: str | None = None
|
||||
query: str = ""
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@@ -48,6 +48,9 @@ class WorkflowTaskState(TaskState):
|
||||
"""
|
||||
|
||||
answer: str = ""
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class StreamEvent(StrEnum):
|
||||
|
||||
71
api/core/app/layers/pause_state_persist_layer.py
Normal file
71
api/core/app/layers/pause_state_persist_layer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
|
||||
"""Create a PauseStatePersistenceLayer.
|
||||
|
||||
The `state_owner_user_id` is used when creating state file for pause.
|
||||
It generally should id of the creator of workflow.
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=self.graph_runtime_state.dumps(),
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
@@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
@@ -140,7 +140,27 @@ class MessageCycleManager:
|
||||
if not self._application_generate_entity.app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
|
||||
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
|
||||
|
||||
# Add new unique resources from the event
|
||||
for resource in event.retriever_resources or []:
|
||||
if not resource:
|
||||
continue
|
||||
|
||||
is_duplicate = (
|
||||
resource.dataset_id
|
||||
and resource.document_id
|
||||
and (resource.dataset_id, resource.document_id) in existing_ids
|
||||
)
|
||||
|
||||
if not is_duplicate:
|
||||
merged_resources.append(resource)
|
||||
|
||||
for i, resource in enumerate(merged_resources, 1):
|
||||
resource.position = i
|
||||
|
||||
self._task_state.metadata.retriever_resources = merged_resources
|
||||
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
@@ -272,6 +271,8 @@ class MCPProviderEntity(BaseModel):
|
||||
|
||||
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic method to decrypt dictionary fields"""
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
|
||||
@@ -74,6 +74,10 @@ class File(BaseModel):
|
||||
storage_key: str | None = None,
|
||||
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
|
||||
url: str | None = None,
|
||||
# Legacy compatibility fields - explicitly handle known extra fields
|
||||
tool_file_id: str | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
datasource_file_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
|
||||
@@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(
|
||||
f"""
|
||||
// declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
// decode and prepare input object
|
||||
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
|
||||
@@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
var output_json = JSON.stringify(output_obj)
|
||||
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
|
||||
console.log(result)
|
||||
"""
|
||||
)
|
||||
""")
|
||||
return runner_script
|
||||
|
||||
@@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
# declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
@@ -29,6 +29,18 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
|
||||
if not plugin_ids:
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("data", {}).get("plugins", [])
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
plugin_ids: list[str],
|
||||
) -> Sequence[MarketplacePluginDeclaration]:
|
||||
|
||||
@@ -109,12 +109,16 @@ class ClientSession(
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
# Only set capabilities if non-default callbacks are provided
|
||||
# This prevents servers from attempting callbacks when we don't actually support them
|
||||
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
|
||||
roots = (
|
||||
types.RootsCapability(
|
||||
# Only enable listChanged if we have a custom callback
|
||||
listChanged=True,
|
||||
)
|
||||
if self._list_roots_callback is not _default_list_roots_callback
|
||||
else None
|
||||
)
|
||||
|
||||
result = self.send_request(
|
||||
|
||||
@@ -38,6 +38,8 @@ class LLMUsageMetadata(TypedDict, total=False):
|
||||
prompt_price: Union[float, str]
|
||||
completion_price: Union[float, str]
|
||||
latency: float
|
||||
time_to_first_token: float
|
||||
time_to_generate: float
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
@@ -57,6 +59,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
time_to_first_token: float | None = None
|
||||
time_to_generate: float | None = None
|
||||
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
@@ -73,6 +77,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price=Decimal("0.0"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
time_to_first_token=None,
|
||||
time_to_generate=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -108,6 +114,8 @@ class LLMUsage(ModelUsage):
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
time_to_first_token=metadata.get("time_to_first_token"),
|
||||
time_to_generate=metadata.get("time_to_generate"),
|
||||
)
|
||||
|
||||
def plus(self, other: LLMUsage) -> LLMUsage:
|
||||
@@ -133,6 +141,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency,
|
||||
time_to_first_token=other.time_to_first_token,
|
||||
time_to_generate=other.time_to_generate,
|
||||
)
|
||||
|
||||
def __add__(self, other: LLMUsage) -> LLMUsage:
|
||||
|
||||
@@ -62,6 +62,9 @@ class MessageTraceInfo(BaseTraceInfo):
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
gen_ai_server_time_to_first_token: float | None = None
|
||||
llm_streaming_time_to_generate: float | None = None
|
||||
is_streaming_request: bool = False
|
||||
|
||||
|
||||
class ModerationTraceInfo(BaseTraceInfo):
|
||||
|
||||
@@ -14,7 +14,7 @@ from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
@@ -141,6 +141,8 @@ provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
class OpsTraceManager:
|
||||
ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
|
||||
decrypted_configs_cache: LRUCache = LRUCache(maxsize=128)
|
||||
_decryption_cache_lock = threading.RLock()
|
||||
|
||||
@classmethod
|
||||
def encrypt_tracing_config(
|
||||
@@ -161,7 +163,7 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
|
||||
new_config = {}
|
||||
new_config: dict[str, Any] = {}
|
||||
# Encrypt necessary keys
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
@@ -191,20 +193,41 @@ class OpsTraceManager:
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
config_json = json.dumps(tracing_config, sort_keys=True)
|
||||
decrypted_config_key = (
|
||||
tenant_id,
|
||||
tracing_provider,
|
||||
config_json,
|
||||
)
|
||||
new_config = {}
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
new_config[key] = decrypt_token(tenant_id, tracing_config[key])
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
# First check without lock for performance
|
||||
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
|
||||
if cached_config is not None:
|
||||
return dict(cached_config)
|
||||
|
||||
return config_class(**new_config).model_dump()
|
||||
with cls._decryption_cache_lock:
|
||||
# Second check (double-checked locking) to prevent race conditions
|
||||
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
|
||||
if cached_config is not None:
|
||||
return dict(cached_config)
|
||||
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config: dict[str, Any] = {}
|
||||
keys_to_decrypt = [key for key in secret_keys if key in tracing_config]
|
||||
if keys_to_decrypt:
|
||||
decrypted_values = batch_decrypt_token(tenant_id, [tracing_config[key] for key in keys_to_decrypt])
|
||||
new_config.update(zip(keys_to_decrypt, decrypted_values))
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
|
||||
decrypted_config = config_class(**new_config).model_dump()
|
||||
cls.decrypted_configs_cache[decrypted_config_key] = decrypted_config
|
||||
return dict(decrypted_config)
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
||||
@@ -219,7 +242,7 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config = {}
|
||||
new_config: dict[str, Any] = {}
|
||||
for key in secret_keys:
|
||||
if key in decrypt_tracing_config:
|
||||
new_config[key] = obfuscated_token(decrypt_tracing_config[key])
|
||||
@@ -596,6 +619,8 @@ class TraceTask:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
streaming_metrics = self._extract_streaming_metrics(message_data)
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -628,6 +653,9 @@ class TraceTask:
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
gen_ai_server_time_to_first_token=streaming_metrics.get("gen_ai_server_time_to_first_token"),
|
||||
llm_streaming_time_to_generate=streaming_metrics.get("llm_streaming_time_to_generate"),
|
||||
is_streaming_request=streaming_metrics.get("is_streaming_request", False),
|
||||
)
|
||||
|
||||
return message_trace_info
|
||||
@@ -853,6 +881,24 @@ class TraceTask:
|
||||
|
||||
return generate_name_trace_info
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
|
||||
try:
|
||||
metadata = json.loads(message_data.message_metadata)
|
||||
usage = metadata.get("usage", {})
|
||||
time_to_first_token = usage.get("time_to_first_token")
|
||||
time_to_generate = usage.get("time_to_generate")
|
||||
|
||||
return {
|
||||
"gen_ai_server_time_to_first_token": time_to_first_token,
|
||||
"llm_streaming_time_to_generate": time_to_generate,
|
||||
"is_streaming_request": time_to_first_token is not None,
|
||||
}
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
return {}
|
||||
|
||||
|
||||
trace_manager_timer: threading.Timer | None = None
|
||||
trace_manager_queue: queue.Queue = queue.Queue()
|
||||
|
||||
@@ -5,12 +5,18 @@ Tencent APM Trace Client - handles network operations, metrics, and API communic
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
except ImportError:
|
||||
from importlib_metadata import version # type: ignore[import-not-found]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Meter
|
||||
from opentelemetry.metrics._internal.instrument import Histogram
|
||||
@@ -27,12 +33,27 @@ from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from .entities.tencent_semconv import LLM_OPERATION_DURATION
|
||||
from .entities.semconv import (
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
GEN_AI_TOKEN_USAGE,
|
||||
GEN_AI_TRACE_DURATION,
|
||||
LLM_OPERATION_DURATION,
|
||||
)
|
||||
from .entities.tencent_trace_entity import SpanData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_opentelemetry_sdk_version() -> str:
|
||||
"""Get OpenTelemetry SDK version dynamically."""
|
||||
try:
|
||||
return version("opentelemetry-sdk")
|
||||
except Exception:
|
||||
logger.debug("Failed to get opentelemetry-sdk version, using default")
|
||||
return "1.27.0" # fallback version
|
||||
|
||||
|
||||
class TencentTraceClient:
|
||||
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
|
||||
|
||||
@@ -57,6 +78,9 @@ class TencentTraceClient:
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
|
||||
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
|
||||
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
|
||||
}
|
||||
)
|
||||
# Prepare gRPC endpoint/metadata
|
||||
@@ -80,18 +104,23 @@ class TencentTraceClient:
|
||||
)
|
||||
self.tracer_provider.add_span_processor(self.span_processor)
|
||||
|
||||
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
|
||||
# use dify api version as tracer version
|
||||
self.tracer = self.tracer_provider.get_tracer("dify-sdk", dify_config.project.version)
|
||||
|
||||
# Store span contexts for parent-child relationships
|
||||
self.span_contexts: dict[int, trace_api.SpanContext] = {}
|
||||
|
||||
self.meter: Meter | None = None
|
||||
self.meter_provider: MeterProvider | None = None
|
||||
self.hist_llm_duration: Histogram | None = None
|
||||
self.hist_token_usage: Histogram | None = None
|
||||
self.hist_time_to_first_token: Histogram | None = None
|
||||
self.hist_time_to_generate: Histogram | None = None
|
||||
self.hist_trace_duration: Histogram | None = None
|
||||
self.metric_reader: MetricReader | None = None
|
||||
|
||||
# Metrics exporter and instruments
|
||||
try:
|
||||
from opentelemetry import metrics
|
||||
from opentelemetry.sdk.metrics import Histogram, MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
|
||||
|
||||
@@ -99,7 +128,7 @@ class TencentTraceClient:
|
||||
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
|
||||
use_http_json = protocol in {"http/json", "http-json"}
|
||||
|
||||
# Set preferred temporality for histograms to DELTA
|
||||
# Tencent APM works best with delta aggregation temporality
|
||||
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
|
||||
|
||||
def _create_metric_exporter(exporter_cls, **kwargs):
|
||||
@@ -174,23 +203,66 @@ class TencentTraceClient:
|
||||
)
|
||||
|
||||
if metric_reader is not None:
|
||||
# Use instance-level MeterProvider instead of global to support config changes
|
||||
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
|
||||
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(provider)
|
||||
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
|
||||
self.meter_provider = provider
|
||||
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
|
||||
|
||||
# LLM operation duration histogram
|
||||
self.hist_llm_duration = self.meter.create_histogram(
|
||||
name=LLM_OPERATION_DURATION,
|
||||
unit="s",
|
||||
description="LLM operation duration (seconds)",
|
||||
)
|
||||
|
||||
# Token usage histogram with exponential buckets
|
||||
self.hist_token_usage = self.meter.create_histogram(
|
||||
name=GEN_AI_TOKEN_USAGE,
|
||||
unit="token",
|
||||
description="Number of tokens used in prompt and completions",
|
||||
)
|
||||
|
||||
# Time to first token histogram
|
||||
self.hist_time_to_first_token = self.meter.create_histogram(
|
||||
name=GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
unit="s",
|
||||
description="Time to first token for streaming LLM responses (seconds)",
|
||||
)
|
||||
|
||||
# Time to generate histogram
|
||||
self.hist_time_to_generate = self.meter.create_histogram(
|
||||
name=GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
unit="s",
|
||||
description="Total time to generate streaming LLM responses (seconds)",
|
||||
)
|
||||
|
||||
# Trace duration histogram
|
||||
self.hist_trace_duration = self.meter.create_histogram(
|
||||
name=GEN_AI_TRACE_DURATION,
|
||||
unit="s",
|
||||
description="End-to-end GenAI trace duration (seconds)",
|
||||
)
|
||||
|
||||
self.metric_reader = metric_reader
|
||||
else:
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
self.hist_time_to_generate = None
|
||||
self.hist_trace_duration = None
|
||||
self.metric_reader = None
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
self.hist_time_to_generate = None
|
||||
self.hist_trace_duration = None
|
||||
self.metric_reader = None
|
||||
|
||||
def add_span(self, span_data: SpanData) -> None:
|
||||
@@ -212,10 +284,158 @@ class TencentTraceClient:
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
LLM_OPERATION_DURATION,
|
||||
latency_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
|
||||
|
||||
def record_token_usage(
|
||||
self,
|
||||
token_count: int,
|
||||
token_type: str,
|
||||
operation_name: str,
|
||||
request_model: str,
|
||||
response_model: str,
|
||||
server_address: str,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""Record token usage histogram.
|
||||
|
||||
Args:
|
||||
token_count: Number of tokens used
|
||||
token_type: "input" or "output"
|
||||
operation_name: Operation name (e.g., "chat")
|
||||
request_model: Model used in request
|
||||
response_model: Model used in response
|
||||
server_address: Server address
|
||||
provider: Model provider name
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_token_usage") or self.hist_token_usage is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.request.model": request_model,
|
||||
"gen_ai.response.model": response_model,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.token.type": token_type,
|
||||
"server.address": server_address,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
|
||||
GEN_AI_TOKEN_USAGE,
|
||||
token_count,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
|
||||
|
||||
def record_time_to_first_token(
|
||||
self, ttft_seconds: float, provider: str, model: str, operation_name: str = "chat"
|
||||
) -> None:
|
||||
"""Record time to first token histogram.
|
||||
|
||||
Args:
|
||||
ttft_seconds: Time to first token in seconds
|
||||
provider: Model provider name
|
||||
model: Model name
|
||||
operation_name: Operation name (default: "chat")
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_time_to_first_token") or self.hist_time_to_first_token is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.request.model": model,
|
||||
"gen_ai.response.model": model,
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
ttft_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
|
||||
|
||||
def record_time_to_generate(
|
||||
self, ttg_seconds: float, provider: str, model: str, operation_name: str = "chat"
|
||||
) -> None:
|
||||
"""Record time to generate histogram.
|
||||
|
||||
Args:
|
||||
ttg_seconds: Time to generate in seconds
|
||||
provider: Model provider name
|
||||
model: Model name
|
||||
operation_name: Operation name (default: "chat")
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_time_to_generate") or self.hist_time_to_generate is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.request.model": model,
|
||||
"gen_ai.response.model": model,
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
ttg_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
|
||||
|
||||
def record_trace_duration(self, duration_seconds: float, attributes: dict[str, str] | None = None) -> None:
|
||||
"""Record end-to-end trace duration histogram in seconds.
|
||||
|
||||
Args:
|
||||
duration_seconds: Trace duration in seconds
|
||||
attributes: Optional attributes (e.g., conversation_mode, app_id)
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_trace_duration") or self.hist_trace_duration is None:
|
||||
return
|
||||
|
||||
attrs: dict[str, str] = {}
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_TRACE_DURATION,
|
||||
duration_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
|
||||
|
||||
def _create_and_export_span(self, span_data: SpanData) -> None:
|
||||
"""Create span using OpenTelemetry Tracer API"""
|
||||
try:
|
||||
@@ -296,11 +516,19 @@ class TencentTraceClient:
|
||||
|
||||
if self.tracer_provider:
|
||||
self.tracer_provider.shutdown()
|
||||
|
||||
# Shutdown instance-level meter provider
|
||||
if self.meter_provider is not None:
|
||||
try:
|
||||
self.meter_provider.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
|
||||
|
||||
if self.metric_reader is not None:
|
||||
try:
|
||||
self.metric_reader.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Error during client shutdown")
|
||||
|
||||
@@ -47,6 +47,9 @@ GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Streaming Span Attributes
|
||||
GEN_AI_IS_STREAMING_REQUEST = "llm.is_streaming" # Same as OpenLLMetry semconv
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
@@ -62,6 +65,19 @@ INSTRUMENTATION_LANGUAGE = "python"
|
||||
|
||||
# Metrics
|
||||
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
|
||||
GEN_AI_TOKEN_USAGE = "gen_ai.client.token.usage"
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN = "gen_ai.server.time_to_first_token"
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE = "gen_ai.streaming.time_to_generate"
|
||||
# The LLM trace duration which is exclusive to tencent apm
|
||||
GEN_AI_TRACE_DURATION = "gen_ai.trace.duration"
|
||||
|
||||
# Token Usage Attributes
|
||||
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
|
||||
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
|
||||
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
GEN_AI_TOKEN_TYPE = "gen_ai.token.type"
|
||||
SERVER_ADDRESS = "server.address"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
@@ -14,10 +14,11 @@ from core.ops.entities.trace_entity import (
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.tencent_semconv import (
|
||||
from core.ops.tencent_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_IS_ENTRY,
|
||||
GEN_AI_IS_STREAMING_REQUEST,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER,
|
||||
@@ -156,6 +157,25 @@ class TencentSpanBuilder:
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
}
|
||||
|
||||
if usage_data.get("time_to_first_token") is not None:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
@@ -163,21 +183,7 @@ class TencentSpanBuilder:
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
},
|
||||
attributes=attributes,
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@@ -191,6 +197,19 @@ class TencentSpanBuilder:
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: str(trace_info.outputs or ""),
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
@@ -198,15 +217,7 @@ class TencentSpanBuilder:
|
||||
name="message",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: str(trace_info.outputs or ""),
|
||||
},
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@@ -90,6 +90,9 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
|
||||
self._process_workflow_nodes(trace_info, trace_id)
|
||||
|
||||
# Record trace duration for entry span
|
||||
self._record_workflow_trace_duration(trace_info)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process workflow trace")
|
||||
|
||||
@@ -107,6 +110,11 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
self._record_message_llm_metrics(trace_info)
|
||||
|
||||
# Record trace duration for entry span
|
||||
self._record_message_trace_duration(trace_info)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process message trace")
|
||||
|
||||
@@ -290,24 +298,219 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
|
||||
"""Record LLM performance metrics"""
|
||||
try:
|
||||
if not hasattr(self.trace_client, "record_llm_duration"):
|
||||
return
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
usage = process_data.get("usage", {})
|
||||
latency_s = float(usage.get("latency", 0.0))
|
||||
outputs = node_execution.outputs or {}
|
||||
usage = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
if latency_s > 0:
|
||||
attributes = {
|
||||
"provider": process_data.get("model_provider", ""),
|
||||
"model": process_data.get("model_name", ""),
|
||||
"span_kind": "GENERATION",
|
||||
}
|
||||
self.trace_client.record_llm_duration(latency_s, attributes)
|
||||
model_provider = process_data.get("model_provider", "unknown")
|
||||
model_name = process_data.get("model_name", "unknown")
|
||||
model_mode = process_data.get("model_mode", "chat")
|
||||
|
||||
# Record LLM duration
|
||||
if hasattr(self.trace_client, "record_llm_duration"):
|
||||
latency_s = float(usage.get("latency", 0.0))
|
||||
|
||||
if latency_s > 0:
|
||||
# Determine if streaming from usage metrics
|
||||
is_streaming = usage.get("time_to_first_token") is not None
|
||||
|
||||
attributes = {
|
||||
"gen_ai.system": model_provider,
|
||||
"gen_ai.response.model": model_name,
|
||||
"gen_ai.operation.name": model_mode,
|
||||
"stream": "true" if is_streaming else "false",
|
||||
}
|
||||
self.trace_client.record_llm_duration(latency_s, attributes)
|
||||
|
||||
# Record streaming metrics from usage
|
||||
time_to_first_token = usage.get("time_to_first_token")
|
||||
if time_to_first_token is not None and hasattr(self.trace_client, "record_time_to_first_token"):
|
||||
ttft_seconds = float(time_to_first_token)
|
||||
if ttft_seconds > 0:
|
||||
self.trace_client.record_time_to_first_token(
|
||||
ttft_seconds=ttft_seconds, provider=model_provider, model=model_name, operation_name=model_mode
|
||||
)
|
||||
|
||||
time_to_generate = usage.get("time_to_generate")
|
||||
if time_to_generate is not None and hasattr(self.trace_client, "record_time_to_generate"):
|
||||
ttg_seconds = float(time_to_generate)
|
||||
if ttg_seconds > 0:
|
||||
self.trace_client.record_time_to_generate(
|
||||
ttg_seconds=ttg_seconds, provider=model_provider, model=model_name, operation_name=model_mode
|
||||
)
|
||||
|
||||
# Record token usage
|
||||
if hasattr(self.trace_client, "record_token_usage"):
|
||||
# Extract token counts
|
||||
input_tokens = int(usage.get("prompt_tokens", 0))
|
||||
output_tokens = int(usage.get("completion_tokens", 0))
|
||||
|
||||
if input_tokens > 0 or output_tokens > 0:
|
||||
server_address = f"{model_provider}"
|
||||
|
||||
# Record input tokens
|
||||
if input_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=input_tokens,
|
||||
token_type="input",
|
||||
operation_name=model_mode,
|
||||
request_model=model_name,
|
||||
response_model=model_name,
|
||||
server_address=server_address,
|
||||
provider=model_provider,
|
||||
)
|
||||
|
||||
# Record output tokens
|
||||
if output_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=output_tokens,
|
||||
token_type="output",
|
||||
operation_name=model_mode,
|
||||
request_model=model_name,
|
||||
response_model=model_name,
|
||||
server_address=server_address,
|
||||
provider=model_provider,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM metrics")
|
||||
|
||||
def _record_message_llm_metrics(self, trace_info: MessageTraceInfo) -> None:
|
||||
"""Record LLM metrics for message traces"""
|
||||
try:
|
||||
trace_metadata = trace_info.metadata or {}
|
||||
message_data = trace_info.message_data or {}
|
||||
provider_latency = 0.0
|
||||
if isinstance(message_data, dict):
|
||||
provider_latency = float(message_data.get("provider_response_latency", 0.0) or 0.0)
|
||||
else:
|
||||
provider_latency = float(getattr(message_data, "provider_response_latency", 0.0) or 0.0)
|
||||
|
||||
model_provider = trace_metadata.get("ls_provider") or (
|
||||
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
model_name = trace_metadata.get("ls_model_name") or (
|
||||
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
|
||||
# Record LLM duration
|
||||
if provider_latency > 0 and hasattr(self.trace_client, "record_llm_duration"):
|
||||
is_streaming = trace_info.is_streaming_request
|
||||
|
||||
duration_attributes = {
|
||||
"gen_ai.system": model_provider,
|
||||
"gen_ai.response.model": model_name,
|
||||
"gen_ai.operation.name": "chat", # Message traces are always chat
|
||||
"stream": "true" if is_streaming else "false",
|
||||
}
|
||||
self.trace_client.record_llm_duration(provider_latency, duration_attributes)
|
||||
|
||||
# Record streaming metrics for message traces
|
||||
if trace_info.is_streaming_request:
|
||||
# Record time to first token
|
||||
if trace_info.gen_ai_server_time_to_first_token is not None and hasattr(
|
||||
self.trace_client, "record_time_to_first_token"
|
||||
):
|
||||
ttft_seconds = float(trace_info.gen_ai_server_time_to_first_token)
|
||||
if ttft_seconds > 0:
|
||||
self.trace_client.record_time_to_first_token(
|
||||
ttft_seconds=ttft_seconds, provider=str(model_provider or ""), model=str(model_name or "")
|
||||
)
|
||||
|
||||
# Record time to generate
|
||||
if trace_info.llm_streaming_time_to_generate is not None and hasattr(
|
||||
self.trace_client, "record_time_to_generate"
|
||||
):
|
||||
ttg_seconds = float(trace_info.llm_streaming_time_to_generate)
|
||||
if ttg_seconds > 0:
|
||||
self.trace_client.record_time_to_generate(
|
||||
ttg_seconds=ttg_seconds, provider=str(model_provider or ""), model=str(model_name or "")
|
||||
)
|
||||
|
||||
# Record token usage
|
||||
if hasattr(self.trace_client, "record_token_usage"):
|
||||
input_tokens = int(trace_info.message_tokens or 0)
|
||||
output_tokens = int(trace_info.answer_tokens or 0)
|
||||
|
||||
if input_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=input_tokens,
|
||||
token_type="input",
|
||||
operation_name="chat",
|
||||
request_model=str(model_name or ""),
|
||||
response_model=str(model_name or ""),
|
||||
server_address=str(model_provider or ""),
|
||||
provider=str(model_provider or ""),
|
||||
)
|
||||
|
||||
if output_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=output_tokens,
|
||||
token_type="output",
|
||||
operation_name="chat",
|
||||
request_model=str(model_name or ""),
|
||||
response_model=str(model_name or ""),
|
||||
server_address=str(model_provider or ""),
|
||||
provider=str(model_provider or ""),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record message LLM metrics")
|
||||
|
||||
def _record_workflow_trace_duration(self, trace_info: WorkflowTraceInfo) -> None:
|
||||
"""Record end-to-end workflow trace duration."""
|
||||
try:
|
||||
if not hasattr(self.trace_client, "record_trace_duration"):
|
||||
return
|
||||
|
||||
# Calculate duration from start_time and end_time to match span duration
|
||||
if trace_info.start_time and trace_info.end_time:
|
||||
duration_s = (trace_info.end_time - trace_info.start_time).total_seconds()
|
||||
else:
|
||||
# Fallback to workflow_run_elapsed_time if timestamps not available
|
||||
duration_s = float(trace_info.workflow_run_elapsed_time)
|
||||
|
||||
if duration_s > 0:
|
||||
attributes = {
|
||||
"conversation_mode": "workflow",
|
||||
"workflow_status": trace_info.workflow_run_status,
|
||||
}
|
||||
|
||||
# Add conversation_id if available
|
||||
if trace_info.conversation_id:
|
||||
attributes["has_conversation"] = "true"
|
||||
else:
|
||||
attributes["has_conversation"] = "false"
|
||||
|
||||
self.trace_client.record_trace_duration(duration_s, attributes)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record workflow trace duration")
|
||||
|
||||
def _record_message_trace_duration(self, trace_info: MessageTraceInfo) -> None:
|
||||
"""Record end-to-end message trace duration."""
|
||||
try:
|
||||
if not hasattr(self.trace_client, "record_trace_duration"):
|
||||
return
|
||||
|
||||
# Calculate duration from start_time and end_time
|
||||
if trace_info.start_time and trace_info.end_time:
|
||||
duration = (trace_info.end_time - trace_info.start_time).total_seconds()
|
||||
|
||||
if duration > 0:
|
||||
attributes = {
|
||||
"conversation_mode": trace_info.conversation_mode,
|
||||
}
|
||||
|
||||
# Add streaming flag if available
|
||||
if hasattr(trace_info, "is_streaming_request"):
|
||||
attributes["stream"] = "true" if trace_info.is_streaming_request else "false"
|
||||
|
||||
self.trace_client.record_trace_duration(duration, attributes)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record message trace duration")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure proper cleanup on garbage collection."""
|
||||
try:
|
||||
|
||||
@@ -39,11 +39,13 @@ class WeaviateConfig(BaseModel):
|
||||
|
||||
Attributes:
|
||||
endpoint: Weaviate server endpoint URL
|
||||
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
|
||||
api_key: Optional API key for authentication
|
||||
batch_size: Number of objects to batch per insert operation
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
grpc_endpoint: str | None = None
|
||||
api_key: str | None = None
|
||||
batch_size: int = 100
|
||||
|
||||
@@ -88,9 +90,22 @@ class WeaviateVector(BaseVector):
|
||||
http_secure = p.scheme == "https"
|
||||
http_port = p.port or (443 if http_secure else 80)
|
||||
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python verions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
)
|
||||
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
||||
grpc_host = grpc_p.hostname or "localhost"
|
||||
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
||||
grpc_secure = grpc_p.scheme == "grpcs"
|
||||
else:
|
||||
# Infer from HTTP endpoint as fallback
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
|
||||
client = weaviate.connect_to_custom(
|
||||
http_host=host,
|
||||
@@ -100,6 +115,7 @@ class WeaviateVector(BaseVector):
|
||||
grpc_port=grpc_port,
|
||||
grpc_secure=grpc_secure,
|
||||
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
||||
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
|
||||
)
|
||||
|
||||
if not client.is_ready():
|
||||
@@ -431,6 +447,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
|
||||
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
|
||||
),
|
||||
|
||||
@@ -210,12 +210,13 @@ class Tool(ABC):
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||
type=ToolInvokeMessage.MessageType.JSON,
|
||||
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
|
||||
)
|
||||
|
||||
def create_variable_message(
|
||||
|
||||
@@ -129,6 +129,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
@@ -228,29 +228,41 @@ class ToolEngine:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ""
|
||||
parts: list[str] = []
|
||||
json_parts: list[str] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += cast(ToolInvokeMessage.TextMessage, response.message).text
|
||||
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += (
|
||||
parts.append(
|
||||
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
result += (
|
||||
parts.append(
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result += json.dumps(
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
|
||||
if json_message.suppress_output:
|
||||
continue
|
||||
json_parts.append(
|
||||
json.dumps(
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
parts.append(str(response.message))
|
||||
|
||||
return result
|
||||
# Add JSON parts, avoiding duplicates from text parts.
|
||||
if json_parts:
|
||||
existing_parts = set(parts)
|
||||
parts.extend(p for p in json_parts if p not in existing_parts)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary_and_text(
|
||||
|
||||
@@ -117,7 +117,7 @@ class WorkflowTool(Tool):
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs)
|
||||
yield self.create_json_message(outputs, suppress_output=True)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
|
||||
@@ -4,6 +4,7 @@ from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
@@ -12,4 +13,5 @@ __all__ = [
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
|
||||
49
api/core/workflow/entities/pause_reason.py
Normal file
49
api/core/workflow/entities/pause_reason.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
61
api/core/workflow/entities/workflow_pause.py
Normal file
61
api/core/workflow/entities/workflow_pause.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Domain entities for workflow pause management.
|
||||
|
||||
This module contains the domain model for workflow pause, which is used
|
||||
by the core workflow module. These models are independent of the storage mechanism
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
Abstract base class for workflow pause entities.
|
||||
|
||||
This domain model represents a paused workflow execution state,
|
||||
without implementation details like tenant_id, app_id, etc.
|
||||
It provides the interface for managing workflow pause/resume operations
|
||||
and state persistence through file storage.
|
||||
|
||||
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
|
||||
it will generate multiple `WorkflowPauseEntity` records.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> str:
|
||||
"""The identifier of current WorkflowPauseEntity"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def workflow_execution_id(self) -> str:
|
||||
"""The identifier of the workflow execution record the pause associated with.
|
||||
Correspond to `WorkflowExecution.id`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
This method should load and return the workflow execution state
|
||||
that was saved when the workflow was paused. The state contains
|
||||
all necessary information to resume the workflow execution.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized workflow state containing
|
||||
execution context, variable values, node states, etc.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def resumed_at(self) -> datetime | None:
|
||||
"""`resumed_at` return the resumption time of the current pause, or `None` if
|
||||
the pause is not resumed yet.
|
||||
"""
|
||||
pass
|
||||
@@ -92,13 +92,111 @@ class WorkflowType(StrEnum):
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
# State diagram for the workflw status:
|
||||
# (@) means start, (*) means end
|
||||
#
|
||||
# ┌------------------>------------------------->------------------->--------------┐
|
||||
# | |
|
||||
# | ┌-----------------------<--------------------┐ |
|
||||
# ^ | | |
|
||||
# | | ^ |
|
||||
# | V | |
|
||||
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
|
||||
# | Scheduled |------->| Running |---------------------->| paused | |
|
||||
# └-----------┘ └-----------------------┘ └-----------┘ |
|
||||
# | | | | | | |
|
||||
# | | | | | | |
|
||||
# ^ | | | V V |
|
||||
# | | | | | ┌---------┐ |
|
||||
# (@) | | | └------------------------>| Stopped |<----┘
|
||||
# | | | └---------┘
|
||||
# | | | |
|
||||
# | | V V
|
||||
# | | ┌-----------┐ |
|
||||
# | | | Succeeded |------------->--------------┤
|
||||
# | | └-----------┘ |
|
||||
# | V V
|
||||
# | +--------┐ |
|
||||
# | | Failed |---------------------->----------------┤
|
||||
# | └--------┘ |
|
||||
# V V
|
||||
# ┌---------------------┐ |
|
||||
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
|
||||
# └---------------------┘
|
||||
#
|
||||
# Mermaid diagram:
|
||||
#
|
||||
# ---
|
||||
# title: State diagram for Workflow run state
|
||||
# ---
|
||||
# stateDiagram-v2
|
||||
# scheduled: Scheduled
|
||||
# running: Running
|
||||
# succeeded: Succeeded
|
||||
# failed: Failed
|
||||
# partial_succeeded: Partial Succeeded
|
||||
# paused: Paused
|
||||
# stopped: Stopped
|
||||
#
|
||||
# [*] --> scheduled:
|
||||
# scheduled --> running: Start Execution
|
||||
# running --> paused: Human input required
|
||||
# paused --> running: human input added
|
||||
# paused --> stopped: User stops execution
|
||||
# running --> succeeded: Execution finishes without any error
|
||||
# running --> failed: Execution finishes with errors
|
||||
# running --> stopped: User stops execution
|
||||
# running --> partial_succeeded: some execution occurred and handled during execution
|
||||
#
|
||||
# scheduled --> stopped: User stops execution
|
||||
#
|
||||
# succeeded --> [*]
|
||||
# failed --> [*]
|
||||
# partial_succeeded --> [*]
|
||||
# stopped --> [*]
|
||||
|
||||
# `SCHEDULED` means that the workflow is scheduled to run, but has not
|
||||
# started running yet. (maybe due to possible worker saturation.)
|
||||
#
|
||||
# This enum value is currently unused.
|
||||
SCHEDULED = "scheduled"
|
||||
|
||||
# `RUNNING` means the workflow is exeuting.
|
||||
RUNNING = "running"
|
||||
|
||||
# `SUCCEEDED` means the execution of workflow succeed without any error.
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
# `FAILED` means the execution of workflow failed without some errors.
|
||||
FAILED = "failed"
|
||||
|
||||
# `STOPPED` means the execution of workflow was stopped, either manually
|
||||
# by the user, or automatically by the Dify application (E.G. the moderation
|
||||
# mechanism.)
|
||||
STOPPED = "stopped"
|
||||
|
||||
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
|
||||
# execution, but they were successfully handled (e.g., by using an error
|
||||
# strategy such as "fail branch" or "default value").
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
# `PAUSED` indicates that the workflow execution is temporarily paused
|
||||
# (e.g., awaiting human input) and is expected to resume later.
|
||||
PAUSED = "paused"
|
||||
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from .command_processor import CommandHandler
|
||||
@@ -25,4 +27,7 @@ class PauseCommandHandler(CommandHandler):
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.pause(command.reason)
|
||||
# Convert string reason to PauseReason if needed
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
@@ -41,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
@@ -106,7 +107,7 @@ class GraphExecution:
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: str | None = None
|
||||
pause_reason: PauseReason | None = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
@@ -130,7 +131,7 @@ class GraphExecution:
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: str | None = None) -> None:
|
||||
def pause(self, reason: PauseReason) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
|
||||
@@ -36,4 +36,4 @@ class PauseCommand(GraphEngineCommand):
|
||||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for pause")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
||||
@@ -210,7 +210,7 @@ class EventHandler:
|
||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason or "Awaiting human input"
|
||||
pause_reason = event.reason
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
|
||||
@@ -247,8 +247,11 @@ class GraphEngine:
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=self._graph_execution.pause_reason,
|
||||
reason=pause_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
|
||||
@@ -216,7 +216,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.PAUSED
|
||||
execution.error_message = event.reason or "Workflow execution paused"
|
||||
execution.outputs = event.outputs
|
||||
self._populate_completion_statistics(execution, update_finished=False)
|
||||
|
||||
@@ -296,7 +295,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error=event.reason,
|
||||
error="",
|
||||
update_outputs=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
@@ -44,7 +45,8 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for pause")
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
@@ -54,4 +55,4 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
from .base import NodeEventBase
|
||||
@@ -43,4 +44,4 @@ class StreamCompletedEvent(NodeEventBase):
|
||||
|
||||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
@@ -104,7 +104,7 @@ class HttpRequestNode(Node):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.text if not files else "",
|
||||
"body": response.text if not files.value else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
@@ -64,7 +65,7 @@ class HumanInputNode(Node):
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
@@ -3,6 +3,7 @@ import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
@@ -384,6 +385,8 @@ class LLMNode(Node):
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
@@ -396,6 +399,8 @@ class LLMNode(Node):
|
||||
user=user_id,
|
||||
)
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
@@ -411,6 +416,7 @@ class LLMNode(Node):
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
reasoning_format=reasoning_format,
|
||||
request_start_time=request_start_time,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -422,14 +428,20 @@ class LLMNode(Node):
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
request_start_time: float | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
duration = None
|
||||
if request_start_time is not None:
|
||||
duration = time.perf_counter() - request_start_time
|
||||
invoke_result.usage.latency = round(duration, 3)
|
||||
event = LLMNode.handle_blocking_result(
|
||||
invoke_result=invoke_result,
|
||||
saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
reasoning_format=reasoning_format,
|
||||
request_latency=duration,
|
||||
)
|
||||
yield event
|
||||
return
|
||||
@@ -441,6 +453,12 @@ class LLMNode(Node):
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
|
||||
# Initialize streaming metrics tracking
|
||||
start_time = request_start_time if request_start_time is not None else time.perf_counter()
|
||||
first_token_time = None
|
||||
has_content = False
|
||||
|
||||
collected_structured_output = None # Collect structured_output from streaming chunks
|
||||
# Consume the invoke result and handle generator exception
|
||||
try:
|
||||
@@ -457,6 +475,11 @@ class LLMNode(Node):
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
):
|
||||
# Detect first token for TTFT calculation
|
||||
if text_part and not has_content:
|
||||
first_token_time = time.perf_counter()
|
||||
has_content = True
|
||||
|
||||
full_text_buffer.write(text_part)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
@@ -489,6 +512,16 @@ class LLMNode(Node):
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
# Calculate streaming metrics
|
||||
end_time = time.perf_counter()
|
||||
total_duration = end_time - start_time
|
||||
usage.latency = round(total_duration, 3)
|
||||
if has_content and first_token_time:
|
||||
gen_ai_server_time_to_first_token = first_token_time - start_time
|
||||
llm_streaming_time_to_generate = end_time - first_token_time
|
||||
usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3)
|
||||
usage.time_to_generate = round(llm_streaming_time_to_generate, 3)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
@@ -1068,6 +1101,7 @@ class LLMNode(Node):
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
request_latency: float | None = None,
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
@@ -1088,7 +1122,7 @@ class LLMNode(Node):
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
event = ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=invoke_result.usage,
|
||||
@@ -1098,6 +1132,9 @@ class LLMNode(Node):
|
||||
# Pass structured output if enabled
|
||||
structured_output=getattr(invoke_result, "structured_output", None),
|
||||
)
|
||||
if request_latency is not None:
|
||||
event.usage.latency = round(request_latency, 3)
|
||||
return event
|
||||
|
||||
@staticmethod
|
||||
def save_multimodal_image_output(
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
@@ -30,6 +31,9 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
||||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView: ...
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .variable_pool import VariablePool
|
||||
@@ -42,6 +43,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView:
|
||||
return self._state.variable_pool.system_variables.as_view()
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
@@ -108,3 +109,102 @@ class SystemVariable(BaseModel):
|
||||
if self.invoke_from is not None:
|
||||
d[SystemVariableKey.INVOKE_FROM] = self.invoke_from
|
||||
return d
|
||||
|
||||
def as_view(self) -> "SystemVariableReadOnlyView":
|
||||
return SystemVariableReadOnlyView(self)
|
||||
|
||||
|
||||
class SystemVariableReadOnlyView:
|
||||
"""
|
||||
A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol.
|
||||
|
||||
This class wraps a SystemVariable instance and provides read-only access to all its fields.
|
||||
It always reads the latest data from the wrapped instance and prevents any write operations.
|
||||
"""
|
||||
|
||||
def __init__(self, system_variable: SystemVariable) -> None:
|
||||
"""
|
||||
Initialize the read-only view with a SystemVariable instance.
|
||||
|
||||
Args:
|
||||
system_variable: The SystemVariable instance to wrap
|
||||
"""
|
||||
self._system_variable = system_variable
|
||||
|
||||
@property
|
||||
def user_id(self) -> str | None:
|
||||
return self._system_variable.user_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str | None:
|
||||
return self._system_variable.app_id
|
||||
|
||||
@property
|
||||
def workflow_id(self) -> str | None:
|
||||
return self._system_variable.workflow_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._system_variable.workflow_execution_id
|
||||
|
||||
@property
|
||||
def query(self) -> str | None:
|
||||
return self._system_variable.query
|
||||
|
||||
@property
|
||||
def conversation_id(self) -> str | None:
|
||||
return self._system_variable.conversation_id
|
||||
|
||||
@property
|
||||
def dialogue_count(self) -> int | None:
|
||||
return self._system_variable.dialogue_count
|
||||
|
||||
@property
|
||||
def document_id(self) -> str | None:
|
||||
return self._system_variable.document_id
|
||||
|
||||
@property
|
||||
def original_document_id(self) -> str | None:
|
||||
return self._system_variable.original_document_id
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str | None:
|
||||
return self._system_variable.dataset_id
|
||||
|
||||
@property
|
||||
def batch(self) -> str | None:
|
||||
return self._system_variable.batch
|
||||
|
||||
@property
|
||||
def datasource_type(self) -> str | None:
|
||||
return self._system_variable.datasource_type
|
||||
|
||||
@property
|
||||
def invoke_from(self) -> str | None:
|
||||
return self._system_variable.invoke_from
|
||||
|
||||
@property
|
||||
def files(self) -> Sequence[File]:
|
||||
"""
|
||||
Get a copy of the files from the wrapped SystemVariable.
|
||||
|
||||
Returns:
|
||||
A defensive copy of the files sequence to prevent modification
|
||||
"""
|
||||
return tuple(self._system_variable.files) # Convert to immutable tuple
|
||||
|
||||
@property
|
||||
def datasource_info(self) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get a copy of the datasource info from the wrapped SystemVariable.
|
||||
|
||||
Returns:
|
||||
A view of the datasource info mapping to prevent modification
|
||||
"""
|
||||
if self._system_variable.datasource_info is None:
|
||||
return None
|
||||
return MappingProxyType(self._system_variable.datasource_info)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the read-only view."""
|
||||
return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})"
|
||||
|
||||
0
api/enums/__init__.py
Normal file
0
api/enums/__init__.py
Normal file
15
api/enums/cloud_plan.py
Normal file
15
api/enums/cloud_plan.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class CloudPlan(StrEnum):
|
||||
"""
|
||||
Enum representing user plan types in the cloud platform.
|
||||
|
||||
SANDBOX: Free/default plan with limited features
|
||||
PROFESSIONAL: Professional paid plan
|
||||
TEAM: Team collaboration paid plan
|
||||
"""
|
||||
|
||||
SANDBOX = auto()
|
||||
PROFESSIONAL = auto()
|
||||
TEAM = auto()
|
||||
@@ -85,7 +85,7 @@ class Storage:
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
def save(self, filename, data):
|
||||
def save(self, filename: str, data: bytes):
|
||||
self.storage_runner.save(filename, data)
|
||||
|
||||
@overload
|
||||
|
||||
@@ -8,7 +8,7 @@ class BaseStorage(ABC):
|
||||
"""Interface for file storage."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, filename, data):
|
||||
def save(self, filename: str, data: bytes):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,6 +2,19 @@ import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from gevent import events as gevent_events
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# WARNING: This module is loaded very early in the Gunicorn worker lifecycle,
|
||||
# before gevent's monkey-patching is applied. Importing modules at the top level here can
|
||||
# interfere with gevent's ability to properly patch the standard library,
|
||||
# potentially causing subtle and difficult-to-diagnose bugs.
|
||||
#
|
||||
# To ensure correct behavior, defer any initialization or imports that depend on monkey-patching
|
||||
# to the `post_patch` hook below, or use a gevent_events subscriber as shown.
|
||||
#
|
||||
# For further context, see: https://github.com/langgenius/dify/issues/26689
|
||||
#
|
||||
# Note: The `post_fork` hook is also executed before monkey-patching,
|
||||
# so moving imports there does not resolve this issue.
|
||||
|
||||
# NOTE(QuantumGhost): here we cannot use post_fork to patch gRPC, as
|
||||
# grpc_gevent.init_gevent must be called after patching stdlib.
|
||||
# Gunicorn calls `post_init` before applying monkey patch.
|
||||
@@ -11,7 +24,7 @@ from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
# ref:
|
||||
# - https://github.com/grpc/grpc/blob/62533ea13879d6ee95c6fda11ec0826ca822c9dd/src/python/grpcio/grpc/experimental/gevent.py
|
||||
# - https://github.com/gevent/gevent/issues/2060#issuecomment-3016768668
|
||||
# - https://github.com/benoitc/gunicorn/blob/master/gunicorn/arbiter.py#L607-L613
|
||||
# - https://github.com/benoitc/gunicorn/blob/23.0.0/gunicorn/arbiter.py#L605-L609
|
||||
|
||||
|
||||
def post_patch(event):
|
||||
|
||||
@@ -2,6 +2,8 @@ import abc
|
||||
import datetime
|
||||
from typing import Protocol
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
class _NowFunction(Protocol):
|
||||
@abc.abstractmethod
|
||||
@@ -20,3 +22,51 @@ def naive_utc_now() -> datetime.datetime:
|
||||
representing current UTC time.
|
||||
"""
|
||||
return _now_func(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
def parse_time_range(
|
||||
start: str | None, end: str | None, tzname: str
|
||||
) -> tuple[datetime.datetime | None, datetime.datetime | None]:
|
||||
"""
|
||||
Parse time range strings and convert to UTC datetime objects.
|
||||
Handles DST ambiguity and non-existent times gracefully.
|
||||
|
||||
Args:
|
||||
start: Start time string (YYYY-MM-DD HH:MM)
|
||||
end: End time string (YYYY-MM-DD HH:MM)
|
||||
tzname: Timezone name
|
||||
|
||||
Returns:
|
||||
tuple: (start_datetime_utc, end_datetime_utc)
|
||||
|
||||
Raises:
|
||||
ValueError: When time range is invalid or start > end
|
||||
"""
|
||||
tz = pytz.timezone(tzname)
|
||||
utc = pytz.utc
|
||||
|
||||
def _parse(time_str: str | None, label: str) -> datetime.datetime | None:
|
||||
if not time_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
dt = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M").replace(second=0)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid {label} time format: {e}")
|
||||
|
||||
try:
|
||||
return tz.localize(dt, is_dst=None).astimezone(utc)
|
||||
except pytz.AmbiguousTimeError:
|
||||
return tz.localize(dt, is_dst=False).astimezone(utc)
|
||||
except pytz.NonExistentTimeError:
|
||||
dt += datetime.timedelta(hours=1)
|
||||
return tz.localize(dt, is_dst=None).astimezone(utc)
|
||||
|
||||
start_dt = _parse(start, "start")
|
||||
end_dt = _parse(end, "end")
|
||||
|
||||
# Range validation
|
||||
if start_dt and end_dt and start_dt > end_dt:
|
||||
raise ValueError("start must be earlier than or equal to end")
|
||||
|
||||
return start_dt, end_dt
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add WorkflowPause model
|
||||
|
||||
Revision ID: 03f8dcbc611e
|
||||
Revises: ae662b25d9bc
|
||||
Create Date: 2025-10-22 16:11:31.805407
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03f8dcbc611e"
|
||||
down_revision = "ae662b25d9bc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_pauses",
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("resumed_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("state_object_key", sa.String(length=255), nullable=False),
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pauses_pkey")),
|
||||
sa.UniqueConstraint("workflow_run_id", name=op.f("workflow_pauses_workflow_run_id_key")),
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("workflow_pauses")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,60 @@
|
||||
"""make message annotation question not nullable
|
||||
|
||||
Revision ID: 9e6fa5cbcd80
|
||||
Revises: 03f8dcbc611e
|
||||
Create Date: 2025-11-06 16:03:54.549378
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '9e6fa5cbcd80'
|
||||
down_revision = '03f8dcbc611e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
bind = op.get_bind()
|
||||
message_annotations = sa.table(
|
||||
"message_annotations",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("message_id", sa.String),
|
||||
sa.column("question", sa.Text),
|
||||
)
|
||||
messages = sa.table(
|
||||
"messages",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("query", sa.Text),
|
||||
)
|
||||
update_question_from_message = (
|
||||
sa.update(message_annotations)
|
||||
.where(
|
||||
sa.and_(
|
||||
message_annotations.c.question.is_(None),
|
||||
message_annotations.c.message_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.values(
|
||||
question=sa.select(sa.func.coalesce(messages.c.query, ""))
|
||||
.where(messages.c.id == message_annotations.c.message_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
bind.execute(update_question_from_message)
|
||||
|
||||
fill_remaining_questions = (
|
||||
sa.update(message_annotations)
|
||||
.where(message_annotations.c.question.is_(None))
|
||||
.values(question="")
|
||||
)
|
||||
bind.execute(fill_remaining_questions)
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True)
|
||||
@@ -88,6 +88,7 @@ from .workflow import (
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowPause,
|
||||
WorkflowRun,
|
||||
WorkflowType,
|
||||
)
|
||||
@@ -177,6 +178,7 @@ __all__ = [
|
||||
"WorkflowNodeExecutionModel",
|
||||
"WorkflowNodeExecutionOffload",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
"WorkflowPause",
|
||||
"WorkflowRun",
|
||||
"WorkflowRunTriggeredFrom",
|
||||
"WorkflowToolProvider",
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func, text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.engine import metadata
|
||||
from models.types import StringUUID
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@@ -13,3 +19,34 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
|
||||
"""
|
||||
|
||||
metadata = metadata
|
||||
|
||||
|
||||
class DefaultFieldsMixin:
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
# NOTE: The default and server_default serve as fallback mechanisms.
|
||||
# The application can generate the `id` before saving to optimize
|
||||
# the insertion process (especially for interdependent models)
|
||||
# and reduce database roundtrips.
|
||||
default=uuidv7,
|
||||
server_default=text("uuidv7()"),
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
__name_pos=DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
@@ -914,34 +915,40 @@ class Message(Base):
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
model_provider = mapped_column(String(255), nullable=True)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
override_model_configs = mapped_column(sa.Text)
|
||||
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
override_model_configs: Mapped[str | None] = mapped_column(sa.Text)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
|
||||
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
|
||||
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
message = mapped_column(sa.JSON, nullable=False)
|
||||
message: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
message_unit_price: Mapped[Decimal] = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
message_price_unit: Mapped[Decimal] = mapped_column(
|
||||
sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")
|
||||
)
|
||||
answer: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
parent_message_id = mapped_column(StringUUID, nullable=True)
|
||||
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price = mapped_column(sa.Numeric(10, 7))
|
||||
answer_unit_price: Mapped[Decimal] = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit: Mapped[Decimal] = mapped_column(
|
||||
sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")
|
||||
)
|
||||
parent_message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
error = mapped_column(sa.Text)
|
||||
message_metadata = mapped_column(sa.Text)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'normal'::character varying")
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(sa.Text)
|
||||
message_metadata: Mapped[str | None] = mapped_column(sa.Text)
|
||||
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
@@ -1212,9 +1219,13 @@ class Message(Base):
|
||||
@property
|
||||
def workflow_run(self):
|
||||
if self.workflow_run_id:
|
||||
from .workflow import WorkflowRun
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
return db.session.query(WorkflowRun).where(WorkflowRun.id == self.workflow_run_id).first()
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return repo.get_workflow_run_by_id_without_tenant(run_id=self.workflow_run_id)
|
||||
|
||||
return None
|
||||
|
||||
@@ -1275,20 +1286,20 @@ class MessageFeedback(Base):
|
||||
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id = mapped_column(StringUUID, nullable=False)
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content = mapped_column(sa.Text)
|
||||
content: Mapped[str | None] = mapped_column(sa.Text)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def from_account(self):
|
||||
def from_account(self) -> Account | None:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
@@ -1362,7 +1373,7 @@ class MessageAnnotation(Base):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
question = mapped_column(sa.Text, nullable=True)
|
||||
question: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
@@ -13,8 +13,11 @@ from core.file.constants import maybe_file_object
|
||||
from core.file.models import File
|
||||
from core.variables import utils as variable_utils
|
||||
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -35,7 +38,7 @@ from factories import variable_factory
|
||||
from libs import helper
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .base import Base, DefaultFieldsMixin
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
|
||||
from .types import EnumText, StringUUID
|
||||
@@ -247,7 +250,9 @@ class Workflow(Base):
|
||||
return node_type
|
||||
|
||||
@staticmethod
|
||||
def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None:
|
||||
def get_enclosing_node_type_and_id(
|
||||
node_config: Mapping[str, Any],
|
||||
) -> tuple[NodeType, str] | None:
|
||||
in_loop = node_config.get("isInLoop", False)
|
||||
in_iteration = node_config.get("isInIteration", False)
|
||||
if in_loop:
|
||||
@@ -306,7 +311,10 @@ class Workflow(Base):
|
||||
if "nodes" not in graph_dict:
|
||||
return []
|
||||
|
||||
start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None)
|
||||
start_node = next(
|
||||
(node for node in graph_dict["nodes"] if node["data"]["type"] == "start"),
|
||||
None,
|
||||
)
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
@@ -359,7 +367,9 @@ class Workflow(Base):
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
@@ -376,7 +386,9 @@ class Workflow(Base):
|
||||
]
|
||||
|
||||
# decrypt secret variables value
|
||||
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
def decrypt_func(
|
||||
var: Variable,
|
||||
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
||||
@@ -537,7 +549,10 @@ class WorkflowRun(Base):
|
||||
version: Mapped[str] = mapped_column(String(255))
|
||||
graph: Mapped[str | None] = mapped_column(sa.Text)
|
||||
inputs: Mapped[str | None] = mapped_column(sa.Text)
|
||||
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
status: Mapped[str] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255),
|
||||
nullable=False,
|
||||
)
|
||||
outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}")
|
||||
error: Mapped[str | None] = mapped_column(sa.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
@@ -549,6 +564,15 @@ class WorkflowRun(Base):
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
||||
"WorkflowPause",
|
||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||
uselist=False,
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
back_populates="workflow_run",
|
||||
)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
@@ -1034,7 +1058,16 @@ class WorkflowAppLog(Base):
|
||||
|
||||
@property
|
||||
def workflow_run(self):
|
||||
return db.session.get(WorkflowRun, self.workflow_run_id)
|
||||
if self.workflow_run_id:
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return repo.get_workflow_run_by_id_without_tenant(run_id=self.workflow_run_id)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@@ -1073,7 +1106,10 @@ class ConversationVariable(Base):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str):
|
||||
@@ -1101,10 +1137,6 @@ class ConversationVariable(Base):
|
||||
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
|
||||
|
||||
|
||||
def _naive_utc_datetime():
|
||||
return naive_utc_now()
|
||||
|
||||
|
||||
class WorkflowDraftVariable(Base):
|
||||
"""`WorkflowDraftVariable` record variables and outputs generated during
|
||||
debugging workflow or chatflow.
|
||||
@@ -1138,14 +1170,14 @@ class WorkflowDraftVariable(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
@@ -1412,8 +1444,8 @@ class WorkflowDraftVariable(Base):
|
||||
file_id: str | None = None,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.created_at = _naive_utc_datetime()
|
||||
variable.updated_at = _naive_utc_datetime()
|
||||
variable.created_at = naive_utc_now()
|
||||
variable.updated_at = naive_utc_now()
|
||||
variable.description = description
|
||||
variable.app_id = app_id
|
||||
variable.node_id = node_id
|
||||
@@ -1518,7 +1550,7 @@ class WorkflowDraftVariableFile(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
@@ -1583,3 +1615,68 @@ class WorkflowDraftVariableFile(Base):
|
||||
|
||||
def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
||||
|
||||
class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
"""
|
||||
WorkflowPause records the paused state and related metadata for a specific workflow run.
|
||||
|
||||
Each `WorkflowRun` can have zero or one associated `WorkflowPause`, depending on its execution status.
|
||||
If a `WorkflowRun` is in the `PAUSED` state, there must be a corresponding `WorkflowPause`
|
||||
that has not yet been resumed.
|
||||
Otherwise, there should be no active (non-resumed) `WorkflowPause` linked to that run.
|
||||
|
||||
This model captures the execution context required to resume workflow processing at a later time.
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_pauses"
|
||||
__table_args__ = (
|
||||
# Design Note:
|
||||
# Instead of adding a `pause_id` field to the `WorkflowRun` model—which would require a migration
|
||||
# on a potentially large table—we reference `WorkflowRun` from `WorkflowPause` and enforce a unique
|
||||
# constraint on `workflow_run_id` to guarantee a one-to-one relationship.
|
||||
UniqueConstraint("workflow_run_id"),
|
||||
)
|
||||
|
||||
# `workflow_id` represents the unique identifier of the workflow associated with this pause.
|
||||
# It corresponds to the `id` field in the `Workflow` model.
|
||||
#
|
||||
# Since an application can have multiple versions of a workflow, each with its own unique ID,
|
||||
# the `app_id` alone is insufficient to determine which workflow version should be loaded
|
||||
# when resuming a suspended workflow.
|
||||
workflow_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# `workflow_run_id` represents the identifier of the execution of workflow,
|
||||
# correspond to the `id` field of `WorkflowRun`.
|
||||
workflow_run_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# `resumed_at` records the timestamp when the suspended workflow was resumed.
|
||||
# It is set to `NULL` if the workflow has not been resumed.
|
||||
#
|
||||
# NOTE: Resuming a suspended WorkflowPause does not delete the record immediately.
|
||||
# It only set `resumed_at` to a non-null value.
|
||||
resumed_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# state_object_key stores the object key referencing the serialized runtime state
|
||||
# of the `GraphEngine`. This object captures the complete execution context of the
|
||||
# workflow at the moment it was paused, enabling accurate resumption.
|
||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# Relationship to WorkflowRun
|
||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
||||
foreign_keys=[workflow_run_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
|
||||
back_populates="pause",
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ dependencies = [
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.8.72",
|
||||
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.27.0",
|
||||
"opentelemetry-distro==0.48b0",
|
||||
"opentelemetry-exporter-otlp==1.27.0",
|
||||
@@ -74,7 +75,7 @@ dependencies = [
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.47.2",
|
||||
"starlette==0.49.1",
|
||||
"tiktoken~=0.9.0",
|
||||
"transformers~=4.56.1",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||
@@ -209,9 +210,9 @@ vdb = [
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
"pymochow==2.2.9",
|
||||
"pyobvector~=0.2.15",
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.9.0",
|
||||
"tablestore==6.2.0",
|
||||
"tablestore==6.3.7",
|
||||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
|
||||
@@ -38,6 +38,7 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@@ -251,6 +252,116 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
||||
Creates a pause state for a workflow run, storing the current execution
|
||||
state and marking the workflow as paused. This is used when a workflow
|
||||
needs to be suspended and later resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to pause
|
||||
state_owner_user_id: User ID who owns the pause state for file storage
|
||||
state: Serialized workflow execution state (JSON string)
|
||||
|
||||
Returns:
|
||||
WorkflowPauseEntity representing the created pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||
RuntimeError: If workflow is already paused or in invalid state
|
||||
"""
|
||||
# NOTE: we may get rid of the `state_owner_user_id` in parameter list.
|
||||
# However, removing it would require an extra for `Workflow` model
|
||||
# while creating pause.
|
||||
...
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Resume a paused workflow.
|
||||
|
||||
Marks a paused workflow as resumed, set the `resumed_at` field of WorkflowPauseEntity
|
||||
and returning the workflow to running status. Returns the pause entity
|
||||
that was resumed.
|
||||
|
||||
The returned `WorkflowPauseEntity` model has `resumed_at` set.
|
||||
|
||||
NOTE: this method does not delete the correspond `WorkflowPauseEntity` record and associated states.
|
||||
It's the callers responsibility to clear the correspond state with `delete_workflow_pause`.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to resume
|
||||
pause_entity: The pause entity to resume
|
||||
|
||||
Returns:
|
||||
WorkflowPauseEntity representing the resumed pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
RuntimeError: If workflow is not paused or already resumed
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a workflow pause state.
|
||||
|
||||
Permanently removes the pause state for a workflow run, including
|
||||
the stored state file. Used for cleanup operations when a paused
|
||||
workflow is no longer needed.
|
||||
|
||||
Args:
|
||||
pause_entity: The pause entity to delete
|
||||
|
||||
Raises:
|
||||
ValueError: If pause_entity is invalid
|
||||
RuntimeError: If workflow is not paused
|
||||
|
||||
Note:
|
||||
This operation is irreversible. The stored workflow state will be
|
||||
permanently deleted along with the pause record.
|
||||
"""
|
||||
...
|
||||
|
||||
def prune_pauses(
|
||||
self,
|
||||
expiration: datetime,
|
||||
resumption_expiration: datetime,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Clean up expired and old pause states.
|
||||
|
||||
Removes pause states that have expired (created before expiration time)
|
||||
and pause states that were resumed more than resumption_duration ago.
|
||||
This is used for maintenance and cleanup operations.
|
||||
|
||||
Args:
|
||||
expiration: Remove pause states created before this time
|
||||
resumption_expiration: Remove pause states resumed before this time
|
||||
limit: maximum number of records deleted in one call
|
||||
|
||||
Returns:
|
||||
a list of ids for pause records that were pruned
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
...
|
||||
|
||||
def get_daily_runs_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -20,19 +20,26 @@ Implementation Notes:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import and_, delete, func, null, or_, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.types import (
|
||||
@@ -45,6 +52,10 @@ from repositories.types import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WorkflowRunError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||
@@ -301,6 +312,281 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||
return total_deleted
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
||||
Creates a pause state for a workflow run, storing the current execution
|
||||
state and marking the workflow as paused. This is used when a workflow
|
||||
needs to be suspended and later resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to pause
|
||||
state_owner_user_id: User ID who owns the pause state for file storage
|
||||
state: Serialized workflow execution state (JSON string)
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity representing the created pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||
RuntimeError: If workflow is already paused or in invalid state
|
||||
"""
|
||||
previous_pause_model_query = select(WorkflowPauseModel).where(
|
||||
WorkflowPauseModel.workflow_run_id == workflow_run_id
|
||||
)
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the workflow run
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
# Check if workflow is in RUNNING status
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
raise _WorkflowRunError(
|
||||
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||
)
|
||||
#
|
||||
previous_pause = session.scalars(previous_pause_model_query).first()
|
||||
if previous_pause:
|
||||
self._delete_pause_model(session, previous_pause)
|
||||
# we need to flush here to ensure that the old one is actually deleted.
|
||||
session.flush()
|
||||
|
||||
state_obj_key = f"workflow-state-{uuid.uuid4()}.json"
|
||||
storage.save(state_obj_key, state.encode())
|
||||
# Upload the state file
|
||||
|
||||
# Create the pause record
|
||||
pause_model = WorkflowPauseModel()
|
||||
pause_model.id = str(uuidv7())
|
||||
pause_model.workflow_id = workflow_run.workflow_id
|
||||
pause_model.workflow_run_id = workflow_run.id
|
||||
pause_model.state_object_key = state_obj_key
|
||||
pause_model.created_at = naive_utc_now()
|
||||
|
||||
# Update workflow run status
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Save everything in a transaction
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
|
||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> WorkflowPauseEntity | None:
|
||||
"""
|
||||
Get an existing workflow pause state.
|
||||
|
||||
Retrieves the pause state for a specific workflow run if it exists.
|
||||
Used to check if a workflow is paused and to retrieve its saved state.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to get pause state for
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity if pause state exists, None otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
# Query workflow run with pause and state file
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
return None
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Resume a paused workflow.
|
||||
|
||||
Marks a paused workflow as resumed, clearing the pause state and
|
||||
returning the workflow to running status. Returns the pause entity
|
||||
that was resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to resume
|
||||
pause_entity: The pause entity to resume
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity representing the resumed pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
RuntimeError: If workflow is not paused or already resumed
|
||||
"""
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the workflow run with pause
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
if workflow_run.status != WorkflowExecutionStatus.PAUSED:
|
||||
raise _WorkflowRunError(
|
||||
f"WorkflowRun is not in PAUSED status, workflow_run_id={workflow_run_id}, "
|
||||
f"current_status={workflow_run.status}"
|
||||
)
|
||||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
raise _WorkflowRunError(f"No pause state found for workflow run: {workflow_run_id}")
|
||||
|
||||
if pause_model.id != pause_entity.id:
|
||||
raise _WorkflowRunError(
|
||||
"different id in WorkflowPause and WorkflowPauseEntity, "
|
||||
f"WorkflowPause.id={pause_model.id}, "
|
||||
f"WorkflowPauseEntity.id={pause_entity.id}"
|
||||
)
|
||||
|
||||
if pause_model.resumed_at is not None:
|
||||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
|
||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a workflow pause state.
|
||||
|
||||
Permanently removes the pause state for a workflow run, including
|
||||
the stored state file. Used for cleanup operations when a paused
|
||||
workflow is no longer needed.
|
||||
|
||||
Args:
|
||||
pause_entity: The pause entity to delete
|
||||
|
||||
Raises:
|
||||
ValueError: If pause_entity is invalid
|
||||
_WorkflowRunError: If workflow is not paused
|
||||
|
||||
Note:
|
||||
This operation is irreversible. The stored workflow state will be
|
||||
permanently deleted along with the pause record.
|
||||
"""
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the pause model by ID
|
||||
pause_model = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
if pause_model is None:
|
||||
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
|
||||
self._delete_pause_model(session, pause_model)
|
||||
|
||||
@staticmethod
|
||||
def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
|
||||
storage.delete(pause_model.state_object_key)
|
||||
|
||||
# Delete the pause record
|
||||
session.delete(pause_model)
|
||||
|
||||
logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id)
|
||||
|
||||
def prune_pauses(
|
||||
self,
|
||||
expiration: datetime,
|
||||
resumption_expiration: datetime,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Clean up expired and old pause states.
|
||||
|
||||
Removes pause states that have expired (created before expiration time)
|
||||
and pause states that were resumed more than resumption_duration ago.
|
||||
This is used for maintenance and cleanup operations.
|
||||
|
||||
Args:
|
||||
expiration: Remove pause states created before this time
|
||||
resumption_expiration: Remove pause states resumed before this time
|
||||
limit: maximum number of records deleted in one call
|
||||
|
||||
Returns:
|
||||
a list of ids for pause records that were pruned
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
_limit: int = limit or 1000
|
||||
pruned_record_ids: list[str] = []
|
||||
cond = or_(
|
||||
WorkflowPauseModel.created_at < expiration,
|
||||
and_(
|
||||
WorkflowPauseModel.resumed_at.is_not(null()),
|
||||
WorkflowPauseModel.resumed_at < resumption_expiration,
|
||||
),
|
||||
)
|
||||
# First, collect pause records to delete with their state files
|
||||
# Expired pauses (created before expiration time)
|
||||
stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
# Old resumed pauses (resumed more than resumption_duration ago)
|
||||
|
||||
# Get all records to delete
|
||||
pauses_to_delete = session.scalars(stmt).all()
|
||||
|
||||
# Delete state files from storage
|
||||
for pause in pauses_to_delete:
|
||||
with self._session_maker(expire_on_commit=False) as session, session.begin():
|
||||
# todo: this issues a separate query for each WorkflowPauseModel record.
|
||||
# consider batching this lookup.
|
||||
try:
|
||||
storage.delete(pause.state_object_key)
|
||||
logger.info(
|
||||
"Deleted state object for pause, pause_id=%s, object_key=%s",
|
||||
pause.id,
|
||||
pause.state_object_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete state file for pause, pause_id=%s, object_key=%s",
|
||||
pause.id,
|
||||
pause.state_object_key,
|
||||
)
|
||||
continue
|
||||
session.delete(pause)
|
||||
pruned_record_ids.append(pause.id)
|
||||
logger.info(
|
||||
"workflow pause records deleted, id=%s, resumed_at=%s",
|
||||
pause.id,
|
||||
pause.resumed_at,
|
||||
)
|
||||
|
||||
return pruned_record_ids
|
||||
|
||||
def get_daily_runs_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -510,3 +796,69 @@ GROUP BY
|
||||
)
|
||||
|
||||
return cast(list[AverageInteractionStats], response_data)
|
||||
|
||||
|
||||
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
"""
|
||||
Private implementation of WorkflowPauseEntity for SQLAlchemy repository.
|
||||
|
||||
This implementation is internal to the repository layer and provides
|
||||
the concrete implementation of the WorkflowPauseEntity interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pause_model: WorkflowPauseModel,
|
||||
) -> None:
|
||||
self._pause_model = pause_model
|
||||
self._cached_state: bytes | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
|
||||
"""
|
||||
Create a _PrivateWorkflowPauseEntity from database models.
|
||||
|
||||
Args:
|
||||
workflow_pause_model: The WorkflowPause database model
|
||||
upload_file_model: The UploadFile database model
|
||||
|
||||
Returns:
|
||||
_PrivateWorkflowPauseEntity: The constructed entity
|
||||
|
||||
Raises:
|
||||
ValueError: If required model attributes are missing
|
||||
"""
|
||||
return cls(pause_model=workflow_pause_model)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._pause_model.id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self._pause_model.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: The workflow state as a dictionary
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the state file cannot be found
|
||||
IOError: If there are issues reading the state file
|
||||
_Workflow: If the state cannot be deserialized properly
|
||||
"""
|
||||
if self._cached_state is not None:
|
||||
return self._cached_state
|
||||
|
||||
# Load the state from storage
|
||||
state_data = storage.load(self._pause_model.state_object_key)
|
||||
self._cached_state = state_data
|
||||
return state_data
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return self._pause_model.resumed_at
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import (
|
||||
@@ -63,7 +64,7 @@ def clean_messages():
|
||||
plan = features.billing.subscription.plan
|
||||
else:
|
||||
plan = plan_cache.decode()
|
||||
if plan == "sandbox":
|
||||
if plan == CloudPlan.SANDBOX:
|
||||
# clean related message
|
||||
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
|
||||
synchronize_session=False
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
import app
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
|
||||
@@ -35,7 +36,7 @@ def clean_unused_datasets_task():
|
||||
},
|
||||
{
|
||||
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_PRO_CLEAN_DAY_SETTING),
|
||||
"plan_filter": "sandbox",
|
||||
"plan_filter": CloudPlan.SANDBOX,
|
||||
"add_logs": False,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
@@ -45,7 +46,7 @@ def mail_clean_document_notify_task():
|
||||
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
plan = features.billing.subscription.plan
|
||||
if plan != "sandbox":
|
||||
if plan != CloudPlan.SANDBOX:
|
||||
knowledge_details = []
|
||||
# check tenant
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
|
||||
@@ -32,41 +32,48 @@ class AppAnnotationService:
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
answer = args.get("answer") or args.get("content")
|
||||
if answer is None:
|
||||
raise ValueError("Either 'answer' or 'content' must be provided")
|
||||
|
||||
if args.get("message_id"):
|
||||
message_id = str(args["message_id"])
|
||||
# get message info
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
question = args.get("question") or message.query or ""
|
||||
|
||||
annotation: MessageAnnotation | None = message.annotation
|
||||
# save the message annotation
|
||||
if annotation:
|
||||
annotation.content = args["answer"]
|
||||
annotation.question = args["question"]
|
||||
annotation.content = answer
|
||||
annotation.question = question
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args["answer"],
|
||||
question=args["question"],
|
||||
content=answer,
|
||||
question=question,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
|
||||
)
|
||||
question = args.get("question")
|
||||
if not question:
|
||||
raise ValueError("'question' is required when 'message_id' is not provided")
|
||||
|
||||
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
assert current_tenant_id is not None
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
annotation.question,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
@@ -179,8 +186,12 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
question = args.get("question")
|
||||
if question is None:
|
||||
raise ValueError("'question' is required")
|
||||
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
|
||||
app_id=app.id, content=args["answer"], question=question, account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
@@ -189,7 +200,7 @@ class AppAnnotationService:
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
question,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
@@ -214,8 +225,12 @@ class AppAnnotationService:
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
question = args.get("question")
|
||||
if question is None:
|
||||
raise ValueError("'question' is required")
|
||||
|
||||
annotation.content = args["answer"]
|
||||
annotation.question = args["question"]
|
||||
annotation.question = question
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting import RateLimit
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.helper import RateLimiter
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
@@ -44,7 +45,7 @@ class AppGenerateService:
|
||||
if dify_config.BILLING_ENABLED:
|
||||
# check if it's free plan
|
||||
limit_info = BillingService.get_info(app_model.tenant_id)
|
||||
if limit_info["subscription"]["plan"] == "sandbox":
|
||||
if limit_info["subscription"]["plan"] == CloudPlan.SANDBOX:
|
||||
if cls.system_rate_limiter.is_rate_limited(app_model.tenant_id):
|
||||
raise InvokeRateLimitError(
|
||||
"Rate limit exceeded, please upgrade your plan "
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Literal
|
||||
import httpx
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import RateLimiter
|
||||
@@ -31,7 +32,7 @@ class BillingService:
|
||||
|
||||
return {
|
||||
"limit": knowledge_rate_limit.get("limit", 10),
|
||||
"subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"),
|
||||
"subscription_plan": knowledge_rate_limit.get("subscription_plan", CloudPlan.SANDBOX),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
@@ -358,7 +359,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
try:
|
||||
if (
|
||||
not dify_config.BILLING_ENABLED
|
||||
or BillingService.get_info(tenant_id)["subscription"]["plan"] == "sandbox"
|
||||
or BillingService.get_info(tenant_id)["subscription"]["plan"] == CloudPlan.SANDBOX
|
||||
):
|
||||
# only process sandbox tenant
|
||||
cls.process_tenant(flask_app, tenant_id, days, batch)
|
||||
|
||||
@@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
from events.document_event import document_was_deleted
|
||||
from extensions.ext_database import db
|
||||
@@ -1042,7 +1043,7 @@ class DatasetService:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||
if not features.billing.enabled or features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
return {
|
||||
"document_ids": [],
|
||||
"count": 0,
|
||||
@@ -1416,8 +1417,6 @@ class DocumentService:
|
||||
# check document limit
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
assert knowledge_config.data_source
|
||||
assert knowledge_config.data_source.info_list
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
@@ -1440,7 +1439,7 @@ class DocumentService:
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -1448,7 +1447,7 @@ class DocumentService:
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
if not dataset.data_source_type and knowledge_config.data_source:
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
@@ -1494,6 +1493,10 @@ class DocumentService:
|
||||
documents.append(document)
|
||||
batch = document.batch
|
||||
else:
|
||||
# When creating new documents, data_source must be provided
|
||||
if not knowledge_config.data_source:
|
||||
raise ValueError("Data source is required when creating new documents")
|
||||
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000))
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
@@ -1725,7 +1728,7 @@ class DocumentService:
|
||||
# count = len(website_info.urls) # type: ignore
|
||||
# batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
|
||||
# if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
# if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
# raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
# if count > batch_upload_limit:
|
||||
# raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -2194,7 +2197,7 @@ class DocumentService:
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if website_info:
|
||||
count = len(website_info.urls)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
|
||||
@@ -92,16 +92,6 @@ class EnterpriseService:
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
|
||||
if not app_code:
|
||||
raise ValueError("app_code must be provided.")
|
||||
params = {"appCode": app_code}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||
if not app_id:
|
||||
|
||||
@@ -11,3 +11,7 @@ class FileTooLargeError(BaseServiceError):
|
||||
|
||||
class UnsupportedFileTypeError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseServiceError):
|
||||
description = "File extension '{extension}' is not allowed for security reasons"
|
||||
|
||||
@@ -3,12 +3,13 @@ from enum import StrEnum
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class SubscriptionModel(BaseModel):
|
||||
plan: str = "sandbox"
|
||||
plan: str = CloudPlan.SANDBOX
|
||||
interval: str = ""
|
||||
|
||||
|
||||
@@ -186,7 +187,7 @@ class FeatureService:
|
||||
knowledge_rate_limit.enabled = True
|
||||
limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
|
||||
knowledge_rate_limit.limit = limit_info.get("limit", 10)
|
||||
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
|
||||
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX)
|
||||
return knowledge_rate_limit
|
||||
|
||||
@classmethod
|
||||
@@ -240,7 +241,7 @@ class FeatureService:
|
||||
features.billing.subscription.interval = billing_info["subscription"]["interval"]
|
||||
features.education.activated = billing_info["subscription"].get("education", False)
|
||||
|
||||
if features.billing.subscription.plan != "sandbox":
|
||||
if features.billing.subscription.plan != CloudPlan.SANDBOX:
|
||||
features.webapp_copyright_enabled = True
|
||||
else:
|
||||
features.is_allow_transfer_workspace = False
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user