Compare commits

...

1 Commits

Author SHA1 Message Date
-LAN-
abbd8b3114 feat(api/workflow): Backend support for conversation variables.
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-08-12 18:44:59 +08:00
43 changed files with 812 additions and 414 deletions

View File

@@ -12,19 +12,14 @@ from configs.packaging import PackagingInfo
class DifyConfig(
# Packaging info
PackagingInfo,
# Deployment configs
DeploymentConfig,
# Feature configs
FeatureConfig,
# Middleware configs
MiddlewareConfig,
# Extra service configs
ExtraServiceConfig,
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
@@ -36,7 +31,6 @@ class DifyConfig(
env_file='.env',
env_file_encoding='utf-8',
frozen=True,
# ignore extra attributes
extra='ignore',
)
@@ -67,3 +61,5 @@ class DifyConfig(
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')

View File

@@ -17,6 +17,7 @@ from .app import (
audio,
completion,
conversation,
conversation_variables,
generator,
message,
model_config,

View File

@@ -0,0 +1,61 @@
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.conversation_variable_fields import paginated_conversation_variable_fields
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
class ConversationVariablesApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', type=str, location='args')
args = parser.parse_args()
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args['conversation_id']:
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
else:
raise ValueError('conversation_id is required')
# NOTE: This is a temporary solution to avoid performance issues.
page = 1
page_size = 100
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
with Session(db.engine) as session:
rows = session.scalars(stmt).all()
return {
'page': page,
'limit': page_size,
'total': len(rows),
'has_more': False,
'data': [
{
'created_at': row.created_at,
'updated_at': row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')

View File

@@ -74,6 +74,7 @@ class DraftWorkflowApi(Resource):
parser.add_argument('hash', type=str, required=False, location='json')
# TODO: set this to required=True after frontend is updated
parser.add_argument('environment_variables', type=list, required=False, location='json')
parser.add_argument('conversation_variables', type=list, required=False, location='json')
args = parser.parse_args()
elif 'text/plain' in content_type:
try:
@@ -88,7 +89,8 @@ class DraftWorkflowApi(Resource):
'graph': data.get('graph'),
'features': data.get('features'),
'hash': data.get('hash'),
'environment_variables': data.get('environment_variables')
'environment_variables': data.get('environment_variables'),
'conversation_variables': data.get('conversation_variables'),
}
except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400
@@ -100,6 +102,8 @@ class DraftWorkflowApi(Resource):
try:
environment_variables_list = args.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = args.get('conversation_variables') or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args['graph'],
@@ -107,6 +111,7 @@ class DraftWorkflowApi(Resource):
unique_hash=args.get('hash'),
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()

View File

@@ -3,8 +3,9 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.file.file_obj import FileExtraConfig
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
from models import AppMode
class ModelConfigEntity(BaseModel):
@@ -200,11 +201,6 @@ class TracingConfigEntity(BaseModel):
tracing_provider: str
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class AppAdditionalFeatures(BaseModel):

View File

@@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import FileExtraConfig
from core.file.file_obj import FileExtraConfig
class FileUploadConfigManager:

View File

@@ -113,7 +113,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
@@ -180,7 +179,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
@@ -189,12 +187,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream
)
def _generate(self, app_model: App,
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation = None,
conversation: Conversation | None = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False

View File

@@ -4,6 +4,9 @@ import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -17,11 +20,12 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueSto
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__)
@@ -31,10 +35,13 @@ class AdvancedChatAppRunner(AppRunner):
AdvancedChat Application Runner
"""
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@@ -48,11 +55,11 @@ class AdvancedChatAppRunner(AppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
query = application_generate_entity.query
@@ -68,35 +75,66 @@ class AdvancedChatAppRunner(AppRunner):
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
conversation_variables = [item.to_variable() for item in conversation_variables]
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@@ -106,43 +144,30 @@ class AdvancedChatAppRunner(AppRunner):
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@@ -150,22 +175,25 @@ class AdvancedChatAppRunner(AppRunner):
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def handle_input_moderation(
self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
@@ -192,17 +220,20 @@ class AdvancedChatAppRunner(AppRunner):
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
"""
Handle annotation reply
:param app_record: app record
@@ -217,29 +248,27 @@ class AdvancedChatAppRunner(AppRunner):
message=message,
query=query,
user_id=app_generate_entity.user_id,
invoke_from=app_generate_entity.invoke_from
invoke_from=app_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
)
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
)
return True
return False
def _stream_output(self, queue_manager: AppQueueManager,
text: str,
stream: bool,
stopped_by: QueueStopEvent.StopBy) -> None:
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
"""
Direct output
:param queue_manager: application queue manager
@@ -250,21 +279,10 @@ class AdvancedChatAppRunner(AppRunner):
if stream:
index = 0
for token in text:
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by),
PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)

View File

@@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@@ -26,8 +27,7 @@ class WorkflowAppRunner:
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@@ -47,25 +47,36 @@ class WorkflowAppRunner:
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@@ -75,44 +86,33 @@ class WorkflowAppRunner:
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
if not app_record.workflow_id:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@@ -120,11 +120,13 @@ class WorkflowAppRunner:
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow

View File

@@ -1,6 +1,7 @@
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
@@ -50,4 +51,5 @@ __all__ = [
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment',
]

View File

@@ -0,0 +1,2 @@
class VariableError(Exception):
pass

View File

@@ -1,8 +1,10 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.file.file_obj import FileVar
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
@@ -29,39 +31,43 @@ from .variables import (
)
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
if (value_type := m.get('value_type')) is None:
raise ValueError('missing value type')
if not m.get('name'):
raise ValueError('missing name')
if (value := m.get('value')) is None:
raise ValueError('missing value')
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None:
raise VariableError('missing value type')
if not mapping.get('name'):
raise VariableError('missing name')
if (value := mapping.get('value')) is None:
raise VariableError('missing value')
match value_type:
case SegmentType.STRING:
return StringVariable.model_validate(m)
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
return SecretVariable.model_validate(m)
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
return IntegerVariable.model_validate(m)
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
return FloatVariable.model_validate(m)
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise ValueError(f'invalid number value {value}')
raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
return FileVariable.model_validate(m)
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict):
return ObjectVariable.model_validate(
{**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}}
)
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
raise ValueError(f'not supported value type {value_type}')
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
return result
def build_segment(value: Any, /) -> Segment:
@@ -74,12 +80,9 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
return ObjectSegment(value=value)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_segment(v) for v in value]
return ArrayAnySegment(value=elements)
return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@@ -1,4 +1,5 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
@@ -37,6 +38,10 @@ class Segment(BaseModel):
def markdown(self) -> str:
return str(self.value)
@property
def size(self) -> int:
return sys.getsizeof(self.value)
def to_object(self) -> Any:
return self.value
@@ -105,28 +110,25 @@ class ArraySegment(Segment):
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Segment]
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[StringSegment]
value: Sequence[str]
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[FloatSegment | IntegerSegment]
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[ObjectSegment]
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):

View File

@@ -1,14 +1,19 @@
import enum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
from core.app.app_config.entities import FileExtraConfig
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db
from models.model import UploadFile
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
@@ -114,6 +119,7 @@ class FileVar(BaseModel):
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url

View File

@@ -5,8 +5,7 @@ from urllib.parse import parse_qs, urlparse
import requests
from core.app.app_config.entities import FileExtraConfig
from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, MessageFile, UploadFile

View File

@@ -2,7 +2,6 @@ import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
@@ -14,6 +13,7 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f'Tenant with id {tenant_id} not found')
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)

View File

@@ -23,10 +23,12 @@ class NodeType(Enum):
HTTP_REQUEST = 'http-request'
TOOL = 'tool'
VARIABLE_AGGREGATOR = 'variable-aggregator'
# TODO: merge this into VARIABLE_AGGREGATOR
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
@classmethod
def value_of(cls, value: str) -> 'NodeType':

View File

@@ -13,6 +13,7 @@ VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
class VariablePool:
@@ -21,6 +22,7 @@ class VariablePool:
system_variables: Mapping[SystemVariable, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
) -> None:
# system variables
# for example:
@@ -44,9 +46,13 @@ class VariablePool:
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in environment_variables or []:
for var in environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in conversation_variables or []:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.

View File

@@ -8,6 +8,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models import WorkflowNodeExecutionStatus
class UserFrom(Enum):
@@ -91,14 +92,19 @@ class BaseNode(ABC):
:param variable_pool: variable pool
:return:
"""
result = self._run(
variable_pool=variable_pool
)
try:
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
"""
Publish text chunk
:param text: chunk text

View File

@@ -0,0 +1,109 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
class VariableAssignerNodeError(Exception):
pass
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')

View File

@@ -4,12 +4,11 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -30,6 +29,7 @@ from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from extensions.ext_database import db
from models.workflow import (
Workflow,
@@ -51,7 +51,8 @@ node_classes: Mapping[NodeType, type[BaseNode]] = {
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
}
logger = logging.getLogger(__name__)
@@ -94,10 +95,9 @@ class WorkflowEngineManager:
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
user_inputs: Mapping[str, Any],
system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0
call_depth: int = 0,
variable_pool: VariablePool,
) -> None:
"""
:param workflow: Workflow instance
@@ -122,12 +122,6 @@ class WorkflowEngineManager:
if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
)
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
@@ -403,6 +397,7 @@ class WorkflowEngineManager:
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
conversation_variables=workflow.conversation_variables,
)
if node_cls is None:
@@ -468,6 +463,7 @@ class WorkflowEngineManager:
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
conversation_variables=workflow.conversation_variables,
)
# variable selector to variable mapping

View File

@@ -0,0 +1,21 @@
from flask_restful import fields
from libs.helper import TimestampField
conversation_variable_fields = {
'id': fields.String,
'name': fields.String,
'value_type': fields.String(attribute='value_type.value'),
'value': fields.String,
'description': fields.String,
'created_at': TimestampField,
'updated_at': TimestampField,
}
paginated_conversation_variable_fields = {
'page': fields.Integer,
'limit': fields.Integer,
'total': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'),
}

View File

@@ -32,11 +32,12 @@ class EnvironmentVariableField(fields.Raw):
return value
environment_variable_fields = {
conversation_variable_fields = {
'id': fields.String,
'name': fields.String,
'value': fields.Raw,
'value_type': fields.String(attribute='value_type.value'),
'value': fields.Raw,
'description': fields.String,
}
workflow_fields = {
@@ -50,4 +51,5 @@ workflow_fields = {
'updated_at': TimestampField,
'tool_published': fields.Boolean,
'environment_variables': fields.List(EnvironmentVariableField()),
'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)),
}

View File

@@ -1,15 +1,19 @@
from enum import Enum
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
from .model import AppMode
from .types import StringUUID
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
class CreatedByRole(Enum):
"""
Enum class for createdByRole
"""
ACCOUNT = "account"
END_USER = "end_user"
ACCOUNT = 'account'
END_USER = 'end_user'
@classmethod
def value_of(cls, value: str) -> 'CreatedByRole':
@@ -23,49 +27,3 @@ class CreatedByRole(Enum):
if role.value == value:
return role
raise ValueError(f'invalid createdByRole value {value}')
class CreatedFrom(Enum):
"""
Enum class for createdFrom
"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
EXPLORE = "explore"
@classmethod
def value_of(cls, value: str) -> 'CreatedFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for role in cls:
if role.value == value:
return role
raise ValueError(f'invalid createdFrom value {value}')
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@@ -4,7 +4,8 @@ import json
from flask_login import UserMixin
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class AccountStatus(str, enum.Enum):

View File

@@ -1,7 +1,8 @@
import enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class APIBasedExtensionPoint(enum.Enum):

View File

@@ -16,9 +16,10 @@ from configs import dify_config
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import StringUUID
from models.account import Account
from models.model import App, Tag, TagBinding, UploadFile
from .account import Account
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
class Dataset(db.Model):

View File

@@ -14,8 +14,8 @@ from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db
from libs.helper import generate_string
from . import StringUUID
from .account import Account, Tenant
from .types import StringUUID
class DifySetup(db.Model):

View File

@@ -1,7 +1,8 @@
from enum import Enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class ProviderType(Enum):

View File

@@ -3,7 +3,8 @@ import json
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class DataSourceOauthBinding(db.Model):

View File

@@ -2,7 +2,8 @@ import json
from enum import Enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class ToolProviderName(Enum):

View File

@@ -6,8 +6,9 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from models import StringUUID
from models.model import Account, App, Tenant
from .model import Account, App, Tenant
from .types import StringUUID
class BuiltinToolProvider(db.Model):

26
api/models/types.py Normal file
View File

@@ -0,0 +1,26 @@
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@@ -1,7 +1,8 @@
from extensions.ext_database import db
from models import StringUUID
from models.model import Message
from .model import Message
from .types import StringUUID
class SavedMessage(db.Model):

View File

@@ -3,18 +3,18 @@ from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional, Union
from sqlalchemy import func
from sqlalchemy.orm import Mapped
import contexts
from constants import HIDDEN_VALUE
from core.app.segments import (
SecretVariable,
Variable,
factory,
)
from core.app.segments import SecretVariable, Variable, factory
from core.helper import encrypter
from extensions.ext_database import db
from libs import helper
from models import StringUUID
from models.account import Account
from .account import Account
from .types import StringUUID
class CreatedByRole(Enum):
@@ -122,6 +122,7 @@ class Workflow(db.Model):
updated_by = db.Column(StringUUID)
updated_at = db.Column(db.DateTime)
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
_conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
@property
def created_by_account(self):
@@ -249,9 +250,27 @@ class Workflow(db.Model):
'graph': self.graph_dict,
'features': self.features_dict,
'environment_variables': [var.model_dump(mode='json') for var in environment_variables],
'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables],
}
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = '{}'
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()]
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]) -> None:
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
)
class WorkflowRunTriggeredFrom(Enum):
"""
@@ -702,3 +721,34 @@ class WorkflowAppLog(db.Model):
created_by_role = CreatedByRole.value_of(self.created_by_role)
return db.session.get(EndUser, self.created_by) \
if created_by_role == CreatedByRole.END_USER else None
class ConversationVariable(db.Model):
__tablename__ = 'workflow__conversation_variables'
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
data = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
self.id = id
self.app_id = app_id
self.conversation_id = conversation_id
self.data = data
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable':
obj = cls(
id=variable.id,
app_id=app_id,
conversation_id=conversation_id,
data=variable.model_dump_json(),
)
return obj
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return factory.build_variable_from_mapping(mapping)

View File

@@ -238,6 +238,8 @@ class AppDslService:
# init draft workflow
environment_variables_list = workflow_data.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = workflow_data.get('conversation_variables') or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow_service = WorkflowService()
draft_workflow = workflow_service.sync_draft_workflow(
app_model=app,
@@ -246,6 +248,7 @@ class AppDslService:
unique_hash=None,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
workflow_service.publish_workflow(
app_model=app,

View File

@@ -6,7 +6,6 @@ from core.app.app_config.entities import (
DatasetRetrieveConfigEntity,
EasyUIBasedAppConfig,
ExternalDataVariableEntity,
FileExtraConfig,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
@@ -14,6 +13,7 @@ from core.app.app_config.entities import (
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.file.file_obj import FileExtraConfig
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder

View File

@@ -72,6 +72,7 @@ class WorkflowService:
unique_hash: Optional[str],
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
) -> Workflow:
"""
Sync draft workflow
@@ -99,7 +100,8 @@ class WorkflowService:
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account.id,
environment_variables=environment_variables
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
db.session.add(workflow)
# update draft workflow if found
@@ -109,6 +111,7 @@ class WorkflowService:
workflow.updated_by = account.id
workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
# commit db session changes
db.session.commit()
@@ -145,7 +148,8 @@ class WorkflowService:
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
)
# commit db session changes
@@ -336,8 +340,8 @@ class WorkflowService:
)
if not workflow_nodes:
return elapsed_time
for node in workflow_nodes:
elapsed_time += node.elapsed_time
return elapsed_time
return elapsed_time

View File

@@ -1,8 +1,10 @@
import logging
import time
from collections.abc import Callable
import click
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from extensions.ext_database import db
@@ -28,7 +30,7 @@ from models.model import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
@shared_task(queue='app_deletion', bind=True, max_retries=3)
@@ -54,6 +56,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_app_tag_bindings(tenant_id, app_id)
_delete_end_users(tenant_id, app_id)
_delete_trace_app_configs(tenant_id, app_id)
_delete_conversation_variables(app_id=app_id)
end_at = time.perf_counter()
logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green'))
@@ -225,6 +228,13 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
"conversation"
)
def _delete_conversation_variables(*, app_id: str):
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
with db.engine.connect() as conn:
conn.execute(stmt)
conn.commit()
logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green'))
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(message_id: str):
@@ -299,7 +309,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
)
def _delete_records(query_sql: str, params: dict, delete_func: callable, name: str) -> None:
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(db.text(query_sql), params)

View File

@@ -7,15 +7,16 @@ from core.app.segments import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileSegment,
FileVariable,
FloatVariable,
IntegerVariable,
NoneSegment,
ObjectSegment,
SecretVariable,
StringVariable,
factory,
)
from core.app.segments.exc import VariableError
def test_string_variable():
@@ -44,7 +45,7 @@ def test_secret_variable():
def test_invalid_value_type():
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
with pytest.raises(ValueError):
with pytest.raises(VariableError):
factory.build_variable_from_mapping(test_data)
@@ -77,26 +78,14 @@ def test_object_variable():
'name': 'test_object',
'description': 'Description of the variable.',
'value': {
'key1': {
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'key2': {
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
'key1': 'text',
'key2': 2,
},
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value['key1'], StringVariable)
assert isinstance(variable.value['key2'], IntegerVariable)
assert isinstance(variable.value['key1'], str)
assert isinstance(variable.value['key2'], int)
def test_array_string_variable():
@@ -106,26 +95,14 @@ def test_array_string_variable():
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'text',
'text',
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], StringVariable)
assert isinstance(variable.value[1], StringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
def test_array_number_variable():
@@ -135,26 +112,14 @@ def test_array_number_variable():
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
{
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 2.0,
'description': 'Description of the variable.',
},
1,
2.0,
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], IntegerVariable)
assert isinstance(variable.value[1], FloatVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
def test_array_object_variable():
@@ -165,59 +130,23 @@ def test_array_object_variable():
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'value_type': 'object',
'name': 'object',
'description': 'Description of the variable.',
'value': {
'key1': {
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'key2': {
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
},
'key1': 'text',
'key2': 1,
},
{
'id': str(uuid4()),
'value_type': 'object',
'name': 'object',
'description': 'Description of the variable.',
'value': {
'key1': {
'id': str(uuid4()),
'value_type': 'string',
'name': 'text',
'value': 'text',
'description': 'Description of the variable.',
},
'key2': {
'id': str(uuid4()),
'value_type': 'number',
'name': 'number',
'value': 1,
'description': 'Description of the variable.',
},
},
'key1': 'text',
'key2': 1,
},
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], ObjectSegment)
assert isinstance(variable.value[1], ObjectSegment)
assert isinstance(variable.value[0].value['key1'], StringVariable)
assert isinstance(variable.value[0].value['key2'], IntegerVariable)
assert isinstance(variable.value[1].value['key1'], StringVariable)
assert isinstance(variable.value[1].value['key2'], IntegerVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]['key1'], str)
assert isinstance(variable.value[0]['key2'], int)
assert isinstance(variable.value[1]['key1'], str)
assert isinstance(variable.value[1]['key2'], int)
def test_file_variable():
@@ -257,51 +186,53 @@ def test_array_file_variable():
'value': [
{
'id': str(uuid4()),
'name': 'file',
'value_type': 'file',
'value': {
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
{
'id': str(uuid4()),
'name': 'file',
'value_type': 'file',
'value': {
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayFileVariable)
assert isinstance(variable.value[0], FileVariable)
assert isinstance(variable.value[1], FileVariable)
assert isinstance(variable.value[0], FileSegment)
assert isinstance(variable.value[1], FileSegment)
def test_variable_cannot_large_than_5_kb():
with pytest.raises(VariableError):
factory.build_variable_from_mapping(
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'test_text',
'value': 'a' * 1024 * 6,
}
)

View File

@@ -2,8 +2,8 @@ from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.app.app_config.entities import ModelConfigEntity
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform

View File

@@ -0,0 +1,150 @@
from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
DEFAULT_NODE_ID = 'node_id'
def test_overwrite_string_variable():
conversation_variable = StringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value='the first value',
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.OVER_WRITE.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
assert got is not None
assert got.value == 'the second value'
assert got.to_object() == 'the second value'
def test_append_variable_to_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.APPEND.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
assert got is not None
assert got.to_object() == ['the first value', 'the second value']
def test_clear_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.CLEAR.value,
'input_variable_selector': [],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node.run(variable_pool)
got = variable_pool.get(['conversation', conversation_variable.name])
assert got is not None
assert got.to_object() == []

View File

@@ -0,0 +1,25 @@
from uuid import uuid4
from core.app.segments import SegmentType, factory
from models import ConversationVariable
def test_from_variable_and_to_variable():
variable = factory.build_variable_from_mapping(
{
'id': str(uuid4()),
'name': 'name',
'value_type': SegmentType.OBJECT,
'value': {
'key': {
'key': 'value',
}
},
}
)
conversation_variable = ConversationVariable.from_variable(
app_id='app_id', conversation_id='conversation_id', variable=variable
)
assert conversation_variable.to_variable() == variable