Compare commits

...

15 Commits

Author SHA1 Message Date
-LAN-
1d577fd64b Merge remote-tracking branch 'origin/main' into feat/environment-variables-in-workflow 2024-07-06 03:51:06 +08:00
-LAN-
202c2ff06f Merge remote-tracking branch 'origin/main' into feat/environment-variables-in-workflow 2024-07-05 14:34:58 +08:00
-LAN-
fc0e7f8382 chore(gitignore): avoid upload .vscode in api dir. 2024-07-05 14:33:58 +08:00
-LAN-
dfd244f6ba fix: Avoid modify user inputs in start node 2024-07-04 16:30:20 +08:00
-LAN-
e04e965361 fix: Fix or remove wrong type hints 2024-07-04 15:19:09 +08:00
-LAN-
c95bfe2495 refactor: Improve some type hints 2024-07-04 14:03:32 +08:00
-LAN-
97c5a800d0 refactor(api/core/app/app_config/base_app_config_manager.py): Remove unused code. 2024-07-04 13:59:20 +08:00
-LAN-
bc5853749c refactor: Rename BaseWorkflow and BaseWorkflowCallback 2024-07-04 13:53:01 +08:00
-LAN-
e7dc09f815 fix(api/core/app/app_config/features/text_to_speech/manager.py): returns type in tts config manager 2024-07-04 13:42:15 +08:00
-LAN-
a118b23a4d Merge remote-tracking branch 'origin/main' into feat/environment-variables-in-workflow 2024-07-03 21:25:48 +08:00
-LAN-
13304db406 Merge remote-tracking branch 'origin/main' into feat/environment-variables-in-workflow 2024-07-03 16:48:14 +08:00
-LAN-
694d29294d feat(api/fields): Update value field type to Raw in workflow_fields.py 2024-06-25 17:13:07 +08:00
-LAN-
4f7959b5c4 feat(api/workflow): Encrypt environment variables in workflow model. 2024-06-25 17:05:08 +08:00
-LAN-
211ae43d29 feat(api/workflow): Add environment_variables to workflow creation and update 2024-06-25 15:34:46 +08:00
-LAN-
fae197ce1a feat(api/workflow): Support environment variables in workflow.
1. Add environment_variables in workflow model(not migrate yet).
2. Accept environment_variables in controllers and runner.
3. Include environment_variables when import and export.
2024-06-25 15:06:06 +08:00
28 changed files with 264 additions and 149 deletions

1
.gitignore vendored
View File

@@ -174,5 +174,6 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json
pyrightconfig.json
api/.vscode
.idea/

View File

@@ -19,6 +19,7 @@ from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.model import App, AppMode
from models.workflow import EnvironmentVariable
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService
@@ -39,7 +40,7 @@ class DraftWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
@@ -62,13 +63,15 @@ class DraftWorkflowApi(Resource):
if not current_user.is_editor:
raise Forbidden()
content_type = request.headers.get('Content-Type')
content_type = request.headers.get('Content-Type', '')
if 'application/json' in content_type:
parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
parser.add_argument('hash', type=str, required=False, location='json')
# TODO: set this to required=True after frontend is updated
parser.add_argument('environment_variables', type=list, required=False, location='json')
args = parser.parse_args()
elif 'text/plain' in content_type:
try:
@@ -82,7 +85,8 @@ class DraftWorkflowApi(Resource):
args = {
'graph': data.get('graph'),
'features': data.get('features'),
'hash': data.get('hash')
'hash': data.get('hash'),
'environment_variables': data.get('environment_variables')
}
except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400
@@ -92,12 +96,15 @@ class DraftWorkflowApi(Resource):
workflow_service = WorkflowService()
try:
environment_variables_list = args.get('environment_variables') or []
environment_variables = [EnvironmentVariable(**obj) for obj in environment_variables_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args.get('graph'),
features=args.get('features'),
graph=args['graph'],
features=args['features'],
unique_hash=args.get('hash'),
account=current_user
account=current_user,
environment_variables=environment_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()

View File

@@ -1,6 +1,7 @@
from typing import Optional, Union
from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom
from core.app.app_config.entities import AppAdditionalFeatures
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
@@ -10,37 +11,19 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import AppMode, AppModelConfig
from models.model import AppMode
class BaseAppConfigManager:
@classmethod
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: Union[AppModelConfig, dict],
config_dict: Optional[dict] = None) -> dict:
"""
Convert app model config to config dict
:param config_from: app model config from
:param app_model_config: app model config
:param config_dict: app model config dict
:return:
"""
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
return config_dict
@classmethod
def convert_features(cls, config_dict: dict, app_mode: AppMode) -> AppAdditionalFeatures:
def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
"""
Convert app config to app model config
:param config_dict: app config
:param app_mode: app mode
"""
config_dict = config_dict.copy()
config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(

View File

@@ -1,11 +1,12 @@
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import FileExtraConfig
class FileUploadConfigManager:
@classmethod
def convert(cls, config: dict, is_vision: bool = True) -> Optional[FileExtraConfig]:
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
"""
Convert model config to model config

View File

@@ -3,13 +3,13 @@ from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict):
"""
Convert model config to model config
:param config: model config args
"""
text_to_speech = False
text_to_speech = None
text_to_speech_dict = config.get('text_to_speech')
if text_to_speech_dict:
if text_to_speech_dict.get('enabled'):

View File

@@ -1,7 +1,8 @@
import logging
import os
import time
from typing import Optional, cast
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
@@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
@@ -87,7 +89,7 @@ class AdvancedChatAppRunner(AppRunner):
db.session.close()
workflow_callbacks = [WorkflowEventTriggerCallback(
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
@@ -161,7 +163,7 @@ class AdvancedChatAppRunner(AppRunner):
self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: dict,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:

View File

@@ -1,9 +1,11 @@
import json
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@@ -18,12 +20,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
'event': 'message',
'task_id': blocking_response.task_id,
@@ -39,7 +42,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -53,8 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -83,8 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def process(self):
"""
Process generate task pipeline.
:return:
@@ -137,8 +137,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
-> ChatbotAppBlockingResponse:
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
"""
Process blocking response.
:return:
@@ -168,8 +167,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[ChatbotAppStreamResponse, None, None]:
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:

View File

@@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager

View File

@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Union
from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@@ -15,44 +15,41 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, None, None]
], invoke_from: InvokeFrom) -> Union[
dict,
Generator[str, None, None]
]:
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom):
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, cls._blocking_response_type):
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate():
def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response):
if chunk == 'ping':
yield f'event: {chunk}\n\n'
else:
yield f'data: {chunk}\n\n'
return _generate()
return _generate_full_response()
else:
if isinstance(response, cls._blocking_response_type):
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
def _generate():
def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response):
if chunk == 'ping':
yield f'event: {chunk}\n\n'
else:
yield f'data: {chunk}\n\n'
return _generate()
return _generate_simple_response()
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@@ -68,7 +65,7 @@ class AppGenerateResponseConverter(ABC):
raise NotImplementedError
@classmethod
def _get_simple_metadata(cls, metadata: dict) -> dict:
def _get_simple_metadata(cls, metadata: dict[str, Any]):
"""
Get simple metadata.
:param metadata: metadata

View File

@@ -38,7 +38,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
) -> Union[dict, Generator[dict, None, None]]:
):
"""
Generate App response.
@@ -150,8 +150,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
"""
Generate App response.

View File

@@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
@@ -57,7 +58,7 @@ class WorkflowAppRunner:
db.session.close()
workflow_callbacks = [WorkflowEventTriggerCallback(
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]

View File

@@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager

View File

@@ -2,7 +2,7 @@ from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
@@ -15,7 +15,7 @@ _TEXT_COLOR_MAPPING = {
}
class WorkflowLoggingCallback(BaseWorkflowCallback):
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None

View File

@@ -1,3 +1,4 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
@@ -76,7 +77,7 @@ class AppGenerateEntity(BaseModel):
# app config
app_config: AppConfig
inputs: dict[str, Any]
inputs: Mapping[str, Any]
files: list[FileVar] = []
user_id: str
@@ -140,7 +141,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
app_config: WorkflowUIBasedAppConfig
conversation_id: Optional[str] = None
query: Optional[str] = None
query: str
class SingleIterationRunEntity(BaseModel):
"""

View File

@@ -1,4 +1,5 @@
from typing import Union
from collections.abc import Mapping, Sequence
from typing import Any, Union
import requests
@@ -16,7 +17,7 @@ class MessageFileParser:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig,
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
user: Union[Account, EndUser]) -> list[FileVar]:
"""
validate and transform files arg

View File

@@ -6,7 +6,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class BaseWorkflowCallback(ABC):
class WorkflowCallback(ABC):
@abstractmethod
def on_workflow_run_started(self) -> None:
"""
@@ -78,7 +78,7 @@ class BaseWorkflowCallback(ABC):
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""

View File

@@ -1,3 +1,4 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
@@ -82,9 +83,9 @@ class NodeRunResult(BaseModel):
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

View File

@@ -1,8 +1,10 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional, Union
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
from models.workflow import EnvironmentVariable
VariableValue = Union[str, int, float, dict, list, FileVar]
@@ -21,10 +23,17 @@ class ValueType(Enum):
FILE = "file"
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
class VariablePool:
def __init__(self, system_variables: dict[SystemVariable, Any],
user_inputs: dict) -> None:
def __init__(
self,
system_variables: Mapping[SystemVariable, Any],
user_inputs: Mapping[str, Any],
# TODO: remove Optional
environment_variables: Optional[Sequence[EnvironmentVariable]] = None,
) -> None:
# system variables
# for example:
# {
@@ -36,6 +45,9 @@ class VariablePool:
self.system_variables = system_variables
for system_variable, value in system_variables.items():
self.append_variable('sys', [system_variable.value], value)
self.environment_variables = environment_variables or []
for var in self.environment_variables:
self.append_variable(ENVIRONMENT_VARIABLE_NODE_ID, [var.name], var.value)
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
"""

View File

@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Optional
from typing import Any, Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
@@ -46,7 +47,7 @@ class BaseNode(ABC):
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback]
callbacks: Sequence[WorkflowCallback]
def __init__(self, tenant_id: str,
app_id: str,
@@ -54,8 +55,8 @@ class BaseNode(ABC):
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
config: dict,
callbacks: list[BaseWorkflowCallback] = None,
config: Mapping[str, Any],
callbacks: Sequence[WorkflowCallback] | None = None,
workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
@@ -65,7 +66,8 @@ class BaseNode(ABC):
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
self.node_id = config.get("id")
# TODO: May need to check if key exists.
self.node_id = config["id"]
if not self.node_id:
raise ValueError("Node ID is required.")

View File

@@ -18,7 +18,7 @@ class StartNode(BaseNode):
:return:
"""
# Get cleaned inputs
cleaned_inputs = variable_pool.user_inputs
cleaned_inputs = dict(variable_pool.user_inputs)
for var in variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]

View File

@@ -1,6 +1,7 @@
import logging
import time
from typing import Optional, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from flask import current_app
@@ -8,8 +9,8 @@ from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -36,7 +37,7 @@ from models.workflow import (
WorkflowNodeExecutionStatus,
)
node_classes = {
node_classes: Mapping[NodeType, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
@@ -87,14 +88,14 @@ class WorkflowEngineManager:
return default_config
def run_workflow(self, workflow: Workflow,
def run_workflow(self, *, workflow: Workflow,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
user_inputs: dict,
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None,
call_depth: Optional[int] = 0,
user_inputs: Mapping[str, Any],
system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0,
variable_pool: Optional[VariablePool] = None) -> None:
"""
:param workflow: Workflow instance
@@ -123,7 +124,8 @@ class WorkflowEngineManager:
if not variable_pool:
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
)
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
@@ -155,7 +157,7 @@ class WorkflowEngineManager:
def _run_workflow(self, workflow: Workflow,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None,
callbacks: Sequence[WorkflowCallback],
start_at: Optional[str] = None,
end_at: Optional[str] = None) -> None:
"""
@@ -174,8 +176,8 @@ class WorkflowEngineManager:
graph = workflow.graph_dict
try:
predecessor_node: BaseNode = None
current_iteration_node: BaseIterationNode = None
predecessor_node: BaseNode | None = None
current_iteration_node: BaseIterationNode | None = None
has_entry_node = False
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
@@ -236,7 +238,7 @@ class WorkflowEngineManager:
# move to next iteration
next_node_id = next_iteration
# get next id
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
if not next_node:
break
@@ -296,7 +298,7 @@ class WorkflowEngineManager:
workflow_run_state.current_iteration_state = None
continue
else:
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
# run workflow, run multiple target nodes in the future
self._run_workflow_node(
@@ -420,7 +422,7 @@ class WorkflowEngineManager:
node_id: str,
user_id: str,
user_inputs: dict,
callbacks: list[BaseWorkflowCallback] = None,
callbacks: Sequence[WorkflowCallback],
) -> None:
"""
Single iteration run workflow node
@@ -447,7 +449,8 @@ class WorkflowEngineManager:
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={}
user_inputs={},
environment_variables=workflow.environment_variables,
)
# variable selector to variable mapping
@@ -536,7 +539,7 @@ class WorkflowEngineManager:
end_at=end_node_id
)
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None:
"""
Workflow run success
:param callbacks: workflow callbacks
@@ -548,7 +551,7 @@ class WorkflowEngineManager:
callback.on_workflow_run_succeeded()
def _workflow_run_failed(self, error: str,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
"""
Workflow run failed
:param error: error message
@@ -561,11 +564,11 @@ class WorkflowEngineManager:
error=error
)
def _workflow_iteration_started(self, graph: dict,
def _workflow_iteration_started(self, *, graph: Mapping[str, Any],
current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
predecessor_node_id: Optional[str] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
"""
Workflow iteration started
:param current_iteration_node: current iteration node
@@ -598,10 +601,10 @@ class WorkflowEngineManager:
# add steps
workflow_run_state.workflow_node_steps += 1
def _workflow_iteration_next(self, graph: dict,
def _workflow_iteration_next(self, *, graph: Mapping[str, Any],
current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
"""
Workflow iteration next
:param workflow_run_state: workflow run state
@@ -630,9 +633,9 @@ class WorkflowEngineManager:
for node in nodes:
workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
if callbacks:
if isinstance(workflow_run_state.current_iteration_state, IterationState):
for callback in callbacks:
@@ -645,10 +648,10 @@ class WorkflowEngineManager:
}
)
def _get_next_overall_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState,
graph: Mapping[str, Any],
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None,
callbacks: Sequence[WorkflowCallback],
start_at: Optional[str] = None,
end_at: Optional[str] = None) -> Optional[BaseNode]:
"""
@@ -740,9 +743,9 @@ class WorkflowEngineManager:
)
def _get_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
graph: Mapping[str, Any],
node_id: str,
callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
callbacks: Sequence[WorkflowCallback]):
"""
Get node from graph by node id
"""
@@ -753,7 +756,7 @@ class WorkflowEngineManager:
for node_config in nodes:
if node_config.get('id') == node_id:
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
node_cls = node_classes[node_type]
return node_cls(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
@@ -766,8 +769,6 @@ class WorkflowEngineManager:
workflow_call_depth=workflow_run_state.workflow_call_depth
)
return None
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
@@ -786,10 +787,10 @@ class WorkflowEngineManager:
if node_and_result.node_id == node_id
])
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_started(
@@ -940,12 +941,14 @@ class WorkflowEngineManager:
return new_value
def _mapping_user_inputs_to_variable_pool(self,
variable_mapping: dict,
user_inputs: dict,
variable_pool: VariablePool,
tenant_id: str,
node_instance: BaseNode):
def _mapping_user_inputs_to_variable_pool(
self,
variable_mapping: dict,
user_inputs: dict,
variable_pool: VariablePool,
tenant_id: str,
node_instance: BaseNode
):
for variable_key, variable_selector in variable_mapping.items():
if variable_key not in user_inputs:
raise ValueError(f'Variable key {variable_key} not found in user inputs.')

View File

@@ -3,6 +3,13 @@ from flask_restful import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
environment_variable_fields = {
'name': fields.String,
'value': fields.Raw,
'value_type': fields.String(attribute='value_type.value'),
'exportable': fields.Boolean,
}
workflow_fields = {
'id': fields.String,
'graph': fields.Raw(attribute='graph_dict'),
@@ -13,4 +20,5 @@ workflow_fields = {
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
'updated_at': TimestampField,
'tool_published': fields.Boolean,
'environment_variables': fields.List(fields.Nested(environment_variable_fields)),
}

9
api/models/helpers.py Normal file
View File

@@ -0,0 +1,9 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from models.workflow import EnvironmentVariable
def encrypt_environment_variable(var: "EnvironmentVariable", *, encrypt_func) -> "EnvironmentVariable":
var.value = encrypt_func(var.value)
return var

View File

@@ -1,11 +1,18 @@
import json
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Optional, Union
from typing import Any, Optional, Union
from flask_login import current_user
from pydantic import BaseModel
from core.helper import encrypter
from extensions.ext_database import db
from libs import helper
from models import StringUUID
from models.account import Account
from models.helpers import encrypt_environment_variable
from models.model import EndUser
class CreatedByRole(Enum):
@@ -62,6 +69,30 @@ class WorkflowType(Enum):
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
class EnvironmentType(str, Enum):
STRING = 'string'
NUMBER = 'number'
SECRET = 'secret'
class EnvironmentVariable(BaseModel):
name: str
value: Any
value_type: EnvironmentType
exportable: bool
def export(self):
if not self.exportable:
raise ValueError(f'environment variable {self.name} is not exportable')
if self.value_type == EnvironmentType.SECRET:
cp = self.model_copy()
cp.value = None
return cp.model_dump(mode='json')
return self.model_dump(mode='json')
class DBEnvironmentVariable(BaseModel):
data: Sequence[EnvironmentVariable]
class Workflow(db.Model):
"""
Workflow, for `Workflow App` and `Chat App workflow mode`.
@@ -112,6 +143,19 @@ class Workflow(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(StringUUID)
updated_at = db.Column(db.DateTime)
# TODO: update this field to sqlalchemy column after frontend update.
# JSON example:
# {
# "data": [
# {
# "name": "ENV_VAR_NAME",
# "value": "ENV_VAR_VALUE",
# "value_type": "string",
# "exportable": true
# },
# ]
# }
_environment_variables = '{"data": [{"name": "TEST_ENV_NAME", "value": "TEST_ENV_VALUE", "value_type": "string", "exportable": true}, {"name": "TEST_ENV_NAME_2", "value":2, "value_type": "number", "exportable": true}]}'
@property
def created_by_account(self):
@@ -122,11 +166,11 @@ class Workflow(db.Model):
return Account.query.get(self.updated_by) if self.updated_by else None
@property
def graph_dict(self):
return json.loads(self.graph) if self.graph else None
def graph_dict(self) -> Mapping[str, Any]:
return json.loads(self.graph) if self.graph else {}
@property
def features_dict(self):
def features_dict(self) -> Mapping[str, Any]:
return json.loads(self.features) if self.features else {}
def user_input_form(self, to_old_structure: bool = False) -> list:
@@ -177,6 +221,34 @@ class Workflow(db.Model):
WorkflowToolProvider.app_id == self.app_id
).first() is not None
@property
def environment_variables(self) -> Sequence[EnvironmentVariable]:
return DBEnvironmentVariable.model_validate_json(self._environment_variables).data
@environment_variables.setter
def environment_variables(self, vars: Sequence[EnvironmentVariable]):
# get current user from flask context, may not a good way.
user = current_user
if not isinstance(user, Account | EndUser):
raise ValueError('current user is not account or end user')
previous_vars = {var.name: var for var in self.environment_variables}
new_vars = []
for var in vars:
if var.name in previous_vars and var != previous_vars[var.name]:
new_vars.append(var)
elif var.name in previous_vars and var == previous_vars[var.name]:
new_vars.append(previous_vars[var.name])
elif var.value_type == EnvironmentType.SECRET:
new_vars.append(encrypt_environment_variable(var, encrypt_func=lambda t: encrypter.encrypt_token(user.current_tenant_id, t)))
else:
new_vars.append(var)
self._environment_variables = DBEnvironmentVariable(data=new_vars).model_dump_json()
class WorkflowRunTriggeredFrom(Enum):
"""
Workflow Run Triggered From Enum

View File

@@ -21,6 +21,7 @@ from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode, AppModelConfig
from models.tools import ApiToolProvider
from models.workflow import EnvironmentVariable
from services.tag_service import TagService
from services.workflow_service import WorkflowService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
@@ -193,12 +194,16 @@ class AppService:
if workflow:
# init draft workflow
workflow_service = WorkflowService()
# parse environment variables.
environment_variables_list = workflow.get('environment_variables') or []
environment_variables = [EnvironmentVariable(**obj) for obj in environment_variables_list]
draft_workflow = workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow.get('graph'),
features=workflow.get('features'),
unique_hash=None,
account=account
account=account,
environment_variables=environment_variables
)
workflow_service.publish_workflow(
app_model=app,
@@ -244,9 +249,12 @@ class AppService:
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app)
if not workflow:
raise ValueError("Draft workflow not found")
export_data['workflow'] = {
"graph": workflow.graph_dict,
"features": workflow.features_dict
"features": workflow.features_dict,
"environment_variables": [var.export() for var in workflow.environment_variables if var.exportable]
}
else:
app_model_config = app.app_model_config

View File

@@ -199,7 +199,8 @@ class WorkflowConverter:
version='draft',
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account_id
created_by=account_id,
environment_variables=[],
)
db.session.add(workflow)

View File

@@ -1,5 +1,6 @@
import json
import time
from collections.abc import Sequence
from datetime import datetime, timezone
from typing import Optional
@@ -17,6 +18,7 @@ from models.account import Account
from models.model import App, AppMode
from models.workflow import (
CreatedByRole,
EnvironmentVariable,
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@@ -63,11 +65,15 @@ class WorkflowService:
return workflow
def sync_draft_workflow(self, app_model: App,
graph: dict,
features: dict,
unique_hash: Optional[str],
account: Account) -> Workflow:
def sync_draft_workflow(
self,
app_model: App,
graph: dict,
features: dict,
unique_hash: Optional[str],
account: Account,
environment_variables: Sequence[EnvironmentVariable],
) -> Workflow:
"""
Sync draft workflow
:raises WorkflowHashNotEqualError
@@ -75,10 +81,8 @@ class WorkflowService:
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if workflow:
# validate unique hash
if workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# validate features structure
self.validate_features_structure(
@@ -95,7 +99,8 @@ class WorkflowService:
version='draft',
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account.id
created_by=account.id,
environment_variables=environment_variables
)
db.session.add(workflow)
# update draft workflow if found
@@ -104,6 +109,7 @@ class WorkflowService:
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow.environment_variables = environment_variables
# commit db session changes
db.session.commit()
@@ -189,7 +195,8 @@ class WorkflowService:
version=str(datetime.now(timezone.utc).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id
created_by=account.id,
environment_variables=draft_workflow.environment_variables
)
# commit db session changes