mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 19:42:42 +08:00
Compare commits
1 Commits
fix/valid-
...
feat/conve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abbd8b3114 |
@@ -12,19 +12,14 @@ from configs.packaging import PackagingInfo
|
||||
class DifyConfig(
|
||||
# Packaging info
|
||||
PackagingInfo,
|
||||
|
||||
# Deployment configs
|
||||
DeploymentConfig,
|
||||
|
||||
# Feature configs
|
||||
FeatureConfig,
|
||||
|
||||
# Middleware configs
|
||||
MiddlewareConfig,
|
||||
|
||||
# Extra service configs
|
||||
ExtraServiceConfig,
|
||||
|
||||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
@@ -36,7 +31,6 @@ class DifyConfig(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
frozen=True,
|
||||
|
||||
# ignore extra attributes
|
||||
extra='ignore',
|
||||
)
|
||||
@@ -67,3 +61,5 @@ class DifyConfig(
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
|
||||
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
|
||||
|
||||
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')
|
||||
|
||||
@@ -17,6 +17,7 @@ from .app import (
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
conversation_variables,
|
||||
generator,
|
||||
message,
|
||||
model_config,
|
||||
|
||||
61
api/controllers/console/app/conversation_variables.py
Normal file
61
api/controllers/console/app/conversation_variables.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import paginated_conversation_variable_fields
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class ConversationVariablesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_fields)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
if args['conversation_id']:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
|
||||
else:
|
||||
raise ValueError('conversation_id is required')
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
page_size = 100
|
||||
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
'page': page,
|
||||
'limit': page_size,
|
||||
'total': len(rows),
|
||||
'has_more': False,
|
||||
'data': [
|
||||
{
|
||||
'created_at': row.created_at,
|
||||
'updated_at': row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')
|
||||
@@ -74,6 +74,7 @@ class DraftWorkflowApi(Resource):
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
# TODO: set this to required=True after frontend is updated
|
||||
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_variables', type=list, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
elif 'text/plain' in content_type:
|
||||
try:
|
||||
@@ -88,7 +89,8 @@ class DraftWorkflowApi(Resource):
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash'),
|
||||
'environment_variables': data.get('environment_variables')
|
||||
'environment_variables': data.get('environment_variables'),
|
||||
'conversation_variables': data.get('conversation_variables'),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
@@ -100,6 +102,8 @@ class DraftWorkflowApi(Resource):
|
||||
try:
|
||||
environment_variables_list = args.get('environment_variables') or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = args.get('conversation_variables') or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args['graph'],
|
||||
@@ -107,6 +111,7 @@ class DraftWorkflowApi(Resource):
|
||||
unique_hash=args.get('hash'),
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
@@ -3,8 +3,9 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
from models import AppMode
|
||||
|
||||
|
||||
class ModelConfigEntity(BaseModel):
|
||||
@@ -200,11 +201,6 @@ class TracingConfigEntity(BaseModel):
|
||||
tracing_provider: str
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
image_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
|
||||
@@ -113,7 +113,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
@@ -180,7 +179,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
@@ -189,12 +187,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate(self, app_model: App,
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation = None,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
is_first_conversation = False
|
||||
|
||||
@@ -4,6 +4,9 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -17,11 +20,12 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueSto
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,10 +35,13 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
AdvancedChat Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -48,11 +55,11 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
@@ -68,35 +75,66 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
):
|
||||
return
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.commit()
|
||||
# Convert database entities to variables
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
@@ -106,43 +144,30 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
def single_iteration_run(self, app_id: str, workflow_id: str,
|
||||
queue_manager: AppQueueManager,
|
||||
inputs: dict, node_id: str, user_id: str) -> None:
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id=user_id,
|
||||
user_inputs=inputs,
|
||||
callbacks=workflow_callbacks
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
@@ -150,22 +175,25 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == workflow_id
|
||||
).first()
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def handle_input_moderation(
|
||||
self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
@@ -192,17 +220,20 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
queue_manager=queue_manager,
|
||||
text=str(e),
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||
def handle_annotation_reply(
|
||||
self,
|
||||
app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
@@ -217,29 +248,27 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=app_generate_entity.user_id,
|
||||
invoke_from=app_generate_entity.invoke_from
|
||||
invoke_from=app_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=annotation_reply.content,
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(self, queue_manager: AppQueueManager,
|
||||
text: str,
|
||||
stream: bool,
|
||||
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
def _stream_output(
|
||||
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
|
||||
) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
@@ -250,21 +279,10 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=stopped_by),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
@@ -26,8 +27,7 @@ class WorkflowAppRunner:
|
||||
Workflow Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -47,25 +47,36 @@ class WorkflowAppRunner:
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
files = application_generate_entity.files
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
@@ -75,44 +86,33 @@ class WorkflowAppRunner:
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
def single_iteration_run(self, app_id: str, workflow_id: str,
|
||||
queue_manager: AppQueueManager,
|
||||
inputs: dict, node_id: str, user_id: str) -> None:
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
raise ValueError('App not found')
|
||||
|
||||
if not app_record.workflow_id:
|
||||
raise ValueError("Workflow not initialized")
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id=user_id,
|
||||
user_inputs=inputs,
|
||||
callbacks=workflow_callbacks
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
@@ -120,11 +120,13 @@ class WorkflowAppRunner:
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == workflow_id
|
||||
).first()
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
@@ -50,4 +51,5 @@ __all__ = [
|
||||
'ArrayNumberVariable',
|
||||
'ArrayObjectVariable',
|
||||
'ArrayFileVariable',
|
||||
'ArraySegment',
|
||||
]
|
||||
|
||||
2
api/core/app/segments/exc.py
Normal file
2
api/core/app/segments/exc.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class VariableError(Exception):
|
||||
pass
|
||||
@@ -1,8 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .exc import VariableError
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
FileSegment,
|
||||
@@ -29,39 +31,43 @@ from .variables import (
|
||||
)
|
||||
|
||||
|
||||
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
|
||||
if (value_type := m.get('value_type')) is None:
|
||||
raise ValueError('missing value type')
|
||||
if not m.get('name'):
|
||||
raise ValueError('missing name')
|
||||
if (value := m.get('value')) is None:
|
||||
raise ValueError('missing value')
|
||||
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if (value_type := mapping.get('value_type')) is None:
|
||||
raise VariableError('missing value type')
|
||||
if not mapping.get('name'):
|
||||
raise VariableError('missing name')
|
||||
if (value := mapping.get('value')) is None:
|
||||
raise VariableError('missing value')
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
return StringVariable.model_validate(m)
|
||||
result = StringVariable.model_validate(mapping)
|
||||
case SegmentType.SECRET:
|
||||
return SecretVariable.model_validate(m)
|
||||
result = SecretVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if isinstance(value, int):
|
||||
return IntegerVariable.model_validate(m)
|
||||
result = IntegerVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if isinstance(value, float):
|
||||
return FloatVariable.model_validate(m)
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise ValueError(f'invalid number value {value}')
|
||||
raise VariableError(f'invalid number value {value}')
|
||||
case SegmentType.FILE:
|
||||
return FileVariable.model_validate(m)
|
||||
result = FileVariable.model_validate(mapping)
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
return ObjectVariable.model_validate(
|
||||
{**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}}
|
||||
)
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
|
||||
result = ArrayStringVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
|
||||
return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
|
||||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_FILE if isinstance(value, list):
|
||||
return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
|
||||
raise ValueError(f'not supported value type {value_type}')
|
||||
mapping = dict(mapping)
|
||||
mapping['value'] = [{'value': v} for v in value]
|
||||
result = ArrayFileVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f'not supported value type {value_type}')
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
|
||||
return result
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
@@ -74,12 +80,9 @@ def build_segment(value: Any, /) -> Segment:
|
||||
if isinstance(value, float):
|
||||
return FloatSegment(value=value)
|
||||
if isinstance(value, dict):
|
||||
# TODO: Limit the depth of the object
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
# TODO: Limit the depth of the array
|
||||
elements = [build_segment(v) for v in value]
|
||||
return ArrayAnySegment(value=elements)
|
||||
return ArrayAnySegment(value=value)
|
||||
if isinstance(value, FileVar):
|
||||
return FileSegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
@@ -37,6 +38,10 @@ class Segment(BaseModel):
|
||||
def markdown(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return sys.getsizeof(self.value)
|
||||
|
||||
def to_object(self) -> Any:
|
||||
return self.value
|
||||
|
||||
@@ -105,28 +110,25 @@ class ArraySegment(Segment):
|
||||
def markdown(self) -> str:
|
||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||
|
||||
def to_object(self):
|
||||
return [v.to_object() for v in self.value]
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||
value: Sequence[Segment]
|
||||
value: Sequence[Any]
|
||||
|
||||
|
||||
class ArrayStringSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||
value: Sequence[StringSegment]
|
||||
value: Sequence[str]
|
||||
|
||||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||
value: Sequence[FloatSegment | IntegerSegment]
|
||||
value: Sequence[float | int]
|
||||
|
||||
|
||||
class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[ObjectSegment]
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
import enum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
image_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
@@ -114,6 +119,7 @@ class FileVar(BaseModel):
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
from models.model import UploadFile
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
|
||||
@@ -5,8 +5,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar
|
||||
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, MessageFile, UploadFile
|
||||
|
||||
@@ -2,7 +2,6 @@ import base64
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
@@ -14,6 +13,7 @@ def obfuscated_token(token: str):
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from models.account import Tenant
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f'Tenant with id {tenant_id} not found')
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
|
||||
@@ -23,10 +23,12 @@ class NodeType(Enum):
|
||||
HTTP_REQUEST = 'http-request'
|
||||
TOOL = 'tool'
|
||||
VARIABLE_AGGREGATOR = 'variable-aggregator'
|
||||
# TODO: merge this into VARIABLE_AGGREGATOR
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'NodeType':
|
||||
|
||||
@@ -13,6 +13,7 @@ VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
|
||||
|
||||
|
||||
class VariablePool:
|
||||
@@ -21,6 +22,7 @@ class VariablePool:
|
||||
system_variables: Mapping[SystemVariable, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
@@ -44,9 +46,13 @@ class VariablePool:
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
# Add environment variables to the variable pool
|
||||
for var in environment_variables or []:
|
||||
for var in environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
# Add conversation variables to the variable pool
|
||||
for var in conversation_variables or []:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
@@ -91,14 +92,19 @@ class BaseNode(ABC):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
try:
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
self.node_run_result = result
|
||||
return result
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
self.node_run_result = result
|
||||
return result
|
||||
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
:param text: chunk text
|
||||
|
||||
109
api/core/workflow/nodes/variable_assigner/__init__.py
Normal file
109
api/core/workflow/nodes/variable_assigner/__init__.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import SegmentType, Variable, factory
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
OVER_WRITE = 'over-write'
|
||||
APPEND = 'append'
|
||||
CLEAR = 'clear'
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = 'Variable Assigner'
|
||||
desc: Optional[str] = 'Assign a value to a variable'
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
data = cast(VariableAssignerData, self.node_data)
|
||||
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = variable_pool.get(data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError('assigned variable not found')
|
||||
|
||||
match data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||
|
||||
# Over write the variable.
|
||||
variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# Update conversation variable.
|
||||
# TODO: Find a better way to use the database.
|
||||
conversation_id = variable_pool.get(['sys', 'conversation_id'])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError('conversation_id not found')
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'value': income_value.to_object(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError('conversation variable not found in the database')
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return factory.build_segment([])
|
||||
case SegmentType.OBJECT:
|
||||
return factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return factory.build_segment('')
|
||||
case SegmentType.NUMBER:
|
||||
return factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
||||
@@ -4,12 +4,11 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -30,6 +29,7 @@ from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
@@ -51,7 +51,8 @@ node_classes: Mapping[NodeType, type[BaseNode]] = {
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -94,10 +95,9 @@ class WorkflowEngineManager:
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
user_inputs: Mapping[str, Any],
|
||||
system_inputs: Mapping[SystemVariable, Any],
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
call_depth: int = 0
|
||||
call_depth: int = 0,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
@@ -122,12 +122,6 @@ class WorkflowEngineManager:
|
||||
if not isinstance(graph.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
if call_depth > workflow_call_max_depth:
|
||||
@@ -403,6 +397,7 @@ class WorkflowEngineManager:
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=workflow.conversation_variables,
|
||||
)
|
||||
|
||||
if node_cls is None:
|
||||
@@ -468,6 +463,7 @@ class WorkflowEngineManager:
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=workflow.conversation_variables,
|
||||
)
|
||||
|
||||
# variable selector to variable mapping
|
||||
|
||||
21
api/fields/conversation_variable_fields.py
Normal file
21
api/fields/conversation_variable_fields.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
conversation_variable_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'value_type': fields.String(attribute='value_type.value'),
|
||||
'value': fields.String,
|
||||
'description': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'updated_at': TimestampField,
|
||||
}
|
||||
|
||||
paginated_conversation_variable_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer,
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'),
|
||||
}
|
||||
@@ -32,11 +32,12 @@ class EnvironmentVariableField(fields.Raw):
|
||||
return value
|
||||
|
||||
|
||||
environment_variable_fields = {
|
||||
conversation_variable_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'value': fields.Raw,
|
||||
'value_type': fields.String(attribute='value_type.value'),
|
||||
'value': fields.Raw,
|
||||
'description': fields.String,
|
||||
}
|
||||
|
||||
workflow_fields = {
|
||||
@@ -50,4 +51,5 @@ workflow_fields = {
|
||||
'updated_at': TimestampField,
|
||||
'tool_published': fields.Boolean,
|
||||
'environment_variables': fields.List(EnvironmentVariableField()),
|
||||
'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)),
|
||||
}
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from .model import AppMode
|
||||
from .types import StringUUID
|
||||
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
||||
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
"""
|
||||
Enum class for createdByRole
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
ACCOUNT = 'account'
|
||||
END_USER = 'end_user'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'CreatedByRole':
|
||||
@@ -23,49 +27,3 @@ class CreatedByRole(Enum):
|
||||
if role.value == value:
|
||||
return role
|
||||
raise ValueError(f'invalid createdByRole value {value}')
|
||||
|
||||
|
||||
class CreatedFrom(Enum):
|
||||
"""
|
||||
Enum class for createdFrom
|
||||
"""
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
EXPLORE = "explore"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'CreatedFrom':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for role in cls:
|
||||
if role.value == value:
|
||||
return role
|
||||
raise ValueError(f'invalid createdFrom value {value}')
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == 'postgresql':
|
||||
return str(value)
|
||||
else:
|
||||
return value.hex
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@@ -4,7 +4,8 @@ import json
|
||||
from flask_login import UserMixin
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class AccountStatus(str, enum.Enum):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import enum
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class APIBasedExtensionPoint(enum.Enum):
|
||||
|
||||
@@ -16,9 +16,10 @@ from configs import dify_config
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
from models.model import App, Tag, TagBinding, UploadFile
|
||||
|
||||
from .account import Account
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class Dataset(db.Model):
|
||||
|
||||
@@ -14,8 +14,8 @@ from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import generate_string
|
||||
|
||||
from . import StringUUID
|
||||
from .account import Account, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DifySetup(db.Model):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(db.Model):
|
||||
|
||||
@@ -2,7 +2,8 @@ import json
|
||||
from enum import Enum
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class ToolProviderName(Enum):
|
||||
|
||||
@@ -6,8 +6,9 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
from models.model import Account, App, Tenant
|
||||
|
||||
from .model import Account, App, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class BuiltinToolProvider(db.Model):
|
||||
|
||||
26
api/models/types.py
Normal file
26
api/models/types.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == 'postgresql':
|
||||
return str(value)
|
||||
else:
|
||||
return value.hex
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
@@ -1,7 +1,8 @@
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
from models.model import Message
|
||||
|
||||
from .model import Message
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class SavedMessage(db.Model):
|
||||
|
||||
@@ -3,18 +3,18 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
import contexts
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.app.segments import (
|
||||
SecretVariable,
|
||||
Variable,
|
||||
factory,
|
||||
)
|
||||
from core.app.segments import SecretVariable, Variable, factory
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
|
||||
from .account import Account
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
@@ -122,6 +122,7 @@ class Workflow(db.Model):
|
||||
updated_by = db.Column(StringUUID)
|
||||
updated_at = db.Column(db.DateTime)
|
||||
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
|
||||
_conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@@ -249,9 +250,27 @@ class Workflow(db.Model):
|
||||
'graph': self.graph_dict,
|
||||
'features': self.features_dict,
|
||||
'environment_variables': [var.model_dump(mode='json') for var in environment_variables],
|
||||
'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables],
|
||||
}
|
||||
return result
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[Variable]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = '{}'
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@conversation_variables.setter
|
||||
def conversation_variables(self, value: Sequence[Variable]) -> None:
|
||||
self._conversation_variables = json.dumps(
|
||||
{var.name: var.model_dump() for var in value},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(Enum):
|
||||
"""
|
||||
@@ -702,3 +721,34 @@ class WorkflowAppLog(db.Model):
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) \
|
||||
if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
|
||||
class ConversationVariable(db.Model):
|
||||
__tablename__ = 'workflow__conversation_variables'
|
||||
|
||||
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
|
||||
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
|
||||
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
|
||||
data = db.Column(db.Text, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
|
||||
|
||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
|
||||
self.id = id
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.data = data
|
||||
|
||||
@classmethod
|
||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable':
|
||||
obj = cls(
|
||||
id=variable.id,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
data=variable.model_dump_json(),
|
||||
)
|
||||
return obj
|
||||
|
||||
def to_variable(self) -> Variable:
|
||||
mapping = json.loads(self.data)
|
||||
return factory.build_variable_from_mapping(mapping)
|
||||
|
||||
@@ -238,6 +238,8 @@ class AppDslService:
|
||||
# init draft workflow
|
||||
environment_variables_list = workflow_data.get('environment_variables') or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = workflow_data.get('conversation_variables') or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
@@ -246,6 +248,7 @@ class AppDslService:
|
||||
unique_hash=None,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
workflow_service.publish_workflow(
|
||||
app_model=app,
|
||||
|
||||
@@ -6,7 +6,6 @@ from core.app.app_config.entities import (
|
||||
DatasetRetrieveConfigEntity,
|
||||
EasyUIBasedAppConfig,
|
||||
ExternalDataVariableEntity,
|
||||
FileExtraConfig,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
@@ -14,6 +13,7 @@ from core.app.app_config.entities import (
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
@@ -72,6 +72,7 @@ class WorkflowService:
|
||||
unique_hash: Optional[str],
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
@@ -99,7 +100,8 @@ class WorkflowService:
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
@@ -109,6 +111,7 @@ class WorkflowService:
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
@@ -145,7 +148,8 @@ class WorkflowService:
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
@@ -336,8 +340,8 @@ class WorkflowService:
|
||||
)
|
||||
if not workflow_nodes:
|
||||
return elapsed_time
|
||||
|
||||
|
||||
for node in workflow_nodes:
|
||||
elapsed_time += node.elapsed_time
|
||||
|
||||
return elapsed_time
|
||||
return elapsed_time
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from extensions.ext_database import db
|
||||
@@ -28,7 +30,7 @@ from models.model import (
|
||||
)
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.web import PinnedConversation, SavedMessage
|
||||
from models.workflow import Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
||||
@shared_task(queue='app_deletion', bind=True, max_retries=3)
|
||||
@@ -54,6 +56,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
||||
_delete_app_tag_bindings(tenant_id, app_id)
|
||||
_delete_end_users(tenant_id, app_id)
|
||||
_delete_trace_app_configs(tenant_id, app_id)
|
||||
_delete_conversation_variables(app_id=app_id)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green'))
|
||||
@@ -225,6 +228,13 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||
"conversation"
|
||||
)
|
||||
|
||||
def _delete_conversation_variables(*, app_id: str):
|
||||
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
|
||||
with db.engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
conn.commit()
|
||||
logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green'))
|
||||
|
||||
|
||||
def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
def del_message(message_id: str):
|
||||
@@ -299,7 +309,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
||||
)
|
||||
|
||||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: callable, name: str) -> None:
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query_sql), params)
|
||||
|
||||
@@ -7,15 +7,16 @@ from core.app.segments import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileSegment,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
factory,
|
||||
)
|
||||
from core.app.segments.exc import VariableError
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
@@ -44,7 +45,7 @@ def test_secret_variable():
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(test_data)
|
||||
|
||||
|
||||
@@ -77,26 +78,14 @@ def test_object_variable():
|
||||
'name': 'test_object',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'key1': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key2': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key1': 'text',
|
||||
'key2': 2,
|
||||
},
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ObjectSegment)
|
||||
assert isinstance(variable.value['key1'], StringVariable)
|
||||
assert isinstance(variable.value['key2'], IntegerVariable)
|
||||
assert isinstance(variable.value['key1'], str)
|
||||
assert isinstance(variable.value['key2'], int)
|
||||
|
||||
|
||||
def test_array_string_variable():
|
||||
@@ -106,26 +95,14 @@ def test_array_string_variable():
|
||||
'name': 'test_array',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'text',
|
||||
'text',
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayStringVariable)
|
||||
assert isinstance(variable.value[0], StringVariable)
|
||||
assert isinstance(variable.value[1], StringVariable)
|
||||
assert isinstance(variable.value[0], str)
|
||||
assert isinstance(variable.value[1], str)
|
||||
|
||||
|
||||
def test_array_number_variable():
|
||||
@@ -135,26 +112,14 @@ def test_array_number_variable():
|
||||
'name': 'test_array',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 2.0,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
1,
|
||||
2.0,
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayNumberVariable)
|
||||
assert isinstance(variable.value[0], IntegerVariable)
|
||||
assert isinstance(variable.value[1], FloatVariable)
|
||||
assert isinstance(variable.value[0], int)
|
||||
assert isinstance(variable.value[1], float)
|
||||
|
||||
|
||||
def test_array_object_variable():
|
||||
@@ -165,59 +130,23 @@ def test_array_object_variable():
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'object',
|
||||
'name': 'object',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'key1': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key2': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
},
|
||||
'key1': 'text',
|
||||
'key2': 1,
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'object',
|
||||
'name': 'object',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'key1': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key2': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
},
|
||||
'key1': 'text',
|
||||
'key2': 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayObjectVariable)
|
||||
assert isinstance(variable.value[0], ObjectSegment)
|
||||
assert isinstance(variable.value[1], ObjectSegment)
|
||||
assert isinstance(variable.value[0].value['key1'], StringVariable)
|
||||
assert isinstance(variable.value[0].value['key2'], IntegerVariable)
|
||||
assert isinstance(variable.value[1].value['key1'], StringVariable)
|
||||
assert isinstance(variable.value[1].value['key2'], IntegerVariable)
|
||||
assert isinstance(variable.value[0], dict)
|
||||
assert isinstance(variable.value[1], dict)
|
||||
assert isinstance(variable.value[0]['key1'], str)
|
||||
assert isinstance(variable.value[0]['key2'], int)
|
||||
assert isinstance(variable.value[1]['key1'], str)
|
||||
assert isinstance(variable.value[1]['key2'], int)
|
||||
|
||||
|
||||
def test_file_variable():
|
||||
@@ -257,51 +186,53 @@ def test_array_file_variable():
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'name': 'file',
|
||||
'value_type': 'file',
|
||||
'value': {
|
||||
'id': str(uuid4()),
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'name': 'file',
|
||||
'value_type': 'file',
|
||||
'value': {
|
||||
'id': str(uuid4()),
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayFileVariable)
|
||||
assert isinstance(variable.value[0], FileVariable)
|
||||
assert isinstance(variable.value[1], FileVariable)
|
||||
assert isinstance(variable.value[0], FileSegment)
|
||||
assert isinstance(variable.value[1], FileSegment)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_5_kb():
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'test_text',
|
||||
'value': 'a' * 1024 * 6,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
|
||||
DEFAULT_NODE_ID = 'node_id'
|
||||
|
||||
|
||||
def test_overwrite_string_variable():
|
||||
conversation_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value='the first value',
|
||||
)
|
||||
|
||||
input_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_string_variable',
|
||||
value='the second value',
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.OVER_WRITE.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
variable_pool.add(
|
||||
[DEFAULT_NODE_ID, input_variable.name],
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.value == 'the second value'
|
||||
assert got.to_object() == 'the second value'
|
||||
|
||||
|
||||
def test_append_variable_to_array():
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value=['the first value'],
|
||||
)
|
||||
|
||||
input_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_string_variable',
|
||||
value='the second value',
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.APPEND.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
variable_pool.add(
|
||||
[DEFAULT_NODE_ID, input_variable.name],
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == ['the first value', 'the second value']
|
||||
|
||||
|
||||
def test_clear_array():
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value=['the first value'],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.CLEAR.value,
|
||||
'input_variable_selector': [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node.run(variable_pool)
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
25
api/tests/unit_tests/models/test_conversation_variable.py
Normal file
25
api/tests/unit_tests/models/test_conversation_variable.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.segments import SegmentType, factory
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
def test_from_variable_and_to_variable():
|
||||
variable = factory.build_variable_from_mapping(
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'name': 'name',
|
||||
'value_type': SegmentType.OBJECT,
|
||||
'value': {
|
||||
'key': {
|
||||
'key': 'value',
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
conversation_variable = ConversationVariable.from_variable(
|
||||
app_id='app_id', conversation_id='conversation_id', variable=variable
|
||||
)
|
||||
|
||||
assert conversation_variable.to_variable() == variable
|
||||
Reference in New Issue
Block a user