mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
15 Commits
feat/e2e-t
...
feat/envir
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d577fd64b | ||
|
|
202c2ff06f | ||
|
|
fc0e7f8382 | ||
|
|
dfd244f6ba | ||
|
|
e04e965361 | ||
|
|
c95bfe2495 | ||
|
|
97c5a800d0 | ||
|
|
bc5853749c | ||
|
|
e7dc09f815 | ||
|
|
a118b23a4d | ||
|
|
13304db406 | ||
|
|
694d29294d | ||
|
|
4f7959b5c4 | ||
|
|
211ae43d29 | ||
|
|
fae197ce1a |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -174,5 +174,6 @@ sdks/python-client/dify_client.egg-info
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
|
||||
@@ -19,6 +19,7 @@ from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import EnvironmentVariable
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
@@ -39,7 +40,7 @@ class DraftWorkflowApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
@@ -62,13 +63,15 @@ class DraftWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get('Content-Type')
|
||||
content_type = request.headers.get('Content-Type', '')
|
||||
|
||||
if 'application/json' in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
# TODO: set this to required=True after frontend is updated
|
||||
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
elif 'text/plain' in content_type:
|
||||
try:
|
||||
@@ -82,7 +85,8 @@ class DraftWorkflowApi(Resource):
|
||||
args = {
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash')
|
||||
'hash': data.get('hash'),
|
||||
'environment_variables': data.get('environment_variables')
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
@@ -92,12 +96,15 @@ class DraftWorkflowApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get('environment_variables') or []
|
||||
environment_variables = [EnvironmentVariable(**obj) for obj in environment_variables_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args.get('graph'),
|
||||
features=args.get('features'),
|
||||
graph=args['graph'],
|
||||
features=args['features'],
|
||||
unique_hash=args.get('hash'),
|
||||
account=current_user
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_config.entities import AppAdditionalFeatures
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||
@@ -10,37 +11,19 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class BaseAppConfigManager:
|
||||
|
||||
@classmethod
|
||||
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom,
|
||||
app_model_config: Union[AppModelConfig, dict],
|
||||
config_dict: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert app model config to config dict
|
||||
:param config_from: app model config from
|
||||
:param app_model_config: app model config
|
||||
:param config_dict: app model config dict
|
||||
:return:
|
||||
"""
|
||||
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
def convert_features(cls, config_dict: dict, app_mode: AppMode) -> AppAdditionalFeatures:
|
||||
def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
|
||||
"""
|
||||
Convert app config to app model config
|
||||
|
||||
:param config_dict: app config
|
||||
:param app_mode: app mode
|
||||
"""
|
||||
config_dict = config_dict.copy()
|
||||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict, is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = False
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get('enabled'):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
@@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
@@ -87,7 +89,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
@@ -161,7 +163,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
) -> bool:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@@ -18,12 +20,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
@@ -39,7 +42,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -53,8 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -83,8 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def process(self):
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@@ -137,8 +137,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> ChatbotAppBlockingResponse:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@@ -168,8 +167,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
||||
@@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@@ -15,44 +15,41 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, None, None]
|
||||
], invoke_from: InvokeFrom) -> Union[
|
||||
dict,
|
||||
Generator[str, None, None]
|
||||
]:
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom):
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, cls._blocking_response_type):
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
def _generate():
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, cls._blocking_response_type):
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
def _generate():
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
return _generate_simple_response()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -68,7 +65,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _get_simple_metadata(cls, metadata: dict) -> dict:
|
||||
def _get_simple_metadata(cls, metadata: dict[str, Any]):
|
||||
"""
|
||||
Get simple metadata.
|
||||
:param metadata: metadata
|
||||
|
||||
@@ -38,7 +38,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -150,8 +150,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
stream: bool = True):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
@@ -57,7 +58,7 @@ class WorkflowAppRunner:
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
|
||||
@@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
@@ -15,7 +15,7 @@ _TEXT_COLOR_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(BaseWorkflowCallback):
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -76,7 +77,7 @@ class AppGenerateEntity(BaseModel):
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: dict[str, Any]
|
||||
inputs: Mapping[str, Any]
|
||||
files: list[FileVar] = []
|
||||
user_id: str
|
||||
|
||||
@@ -140,7 +141,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
query: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
@@ -16,7 +17,7 @@ class MessageFileParser:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig,
|
||||
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
@@ -6,7 +6,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class BaseWorkflowCallback(ABC):
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
@@ -78,7 +78,7 @@ class BaseWorkflowCallback(ABC):
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -82,9 +83,9 @@ class NodeRunResult(BaseModel):
|
||||
"""
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[dict] = None # node inputs
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[dict] = None # node outputs
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from models.workflow import EnvironmentVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
@@ -21,10 +23,17 @@ class ValueType(Enum):
|
||||
FILE = "file"
|
||||
|
||||
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
|
||||
class VariablePool:
|
||||
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any],
|
||||
user_inputs: dict) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariable, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
# TODO: remove Optional
|
||||
environment_variables: Optional[Sequence[EnvironmentVariable]] = None,
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
@@ -36,6 +45,9 @@ class VariablePool:
|
||||
self.system_variables = system_variables
|
||||
for system_variable, value in system_variables.items():
|
||||
self.append_variable('sys', [system_variable.value], value)
|
||||
self.environment_variables = environment_variables or []
|
||||
for var in self.environment_variables:
|
||||
self.append_variable(ENVIRONMENT_VARIABLE_NODE_ID, [var.name], var.value)
|
||||
|
||||
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@@ -46,7 +47,7 @@ class BaseNode(ABC):
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
callbacks: Sequence[WorkflowCallback]
|
||||
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
@@ -54,8 +55,8 @@ class BaseNode(ABC):
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
config: Mapping[str, Any],
|
||||
callbacks: Sequence[WorkflowCallback] | None = None,
|
||||
workflow_call_depth: int = 0) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
@@ -65,7 +66,8 @@ class BaseNode(ABC):
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
|
||||
self.node_id = config.get("id")
|
||||
# TODO: May need to check if key exists.
|
||||
self.node_id = config["id"]
|
||||
if not self.node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class StartNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = variable_pool.user_inputs
|
||||
cleaned_inputs = dict(variable_pool.user_inputs)
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@@ -8,8 +9,8 @@ from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -36,7 +37,7 @@ from models.workflow import (
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
node_classes: Mapping[NodeType, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
@@ -87,14 +88,14 @@ class WorkflowEngineManager:
|
||||
|
||||
return default_config
|
||||
|
||||
def run_workflow(self, workflow: Workflow,
|
||||
def run_workflow(self, *, workflow: Workflow,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
call_depth: Optional[int] = 0,
|
||||
user_inputs: Mapping[str, Any],
|
||||
system_inputs: Mapping[SystemVariable, Any],
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
call_depth: int = 0,
|
||||
variable_pool: Optional[VariablePool] = None) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
@@ -123,7 +124,8 @@ class WorkflowEngineManager:
|
||||
if not variable_pool:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
|
||||
@@ -155,7 +157,7 @@ class WorkflowEngineManager:
|
||||
|
||||
def _run_workflow(self, workflow: Workflow,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> None:
|
||||
"""
|
||||
@@ -174,8 +176,8 @@ class WorkflowEngineManager:
|
||||
graph = workflow.graph_dict
|
||||
|
||||
try:
|
||||
predecessor_node: BaseNode = None
|
||||
current_iteration_node: BaseIterationNode = None
|
||||
predecessor_node: BaseNode | None = None
|
||||
current_iteration_node: BaseIterationNode | None = None
|
||||
has_entry_node = False
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||
@@ -236,7 +238,7 @@ class WorkflowEngineManager:
|
||||
# move to next iteration
|
||||
next_node_id = next_iteration
|
||||
# get next id
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
|
||||
|
||||
if not next_node:
|
||||
break
|
||||
@@ -296,7 +298,7 @@ class WorkflowEngineManager:
|
||||
workflow_run_state.current_iteration_state = None
|
||||
continue
|
||||
else:
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
|
||||
|
||||
# run workflow, run multiple target nodes in the future
|
||||
self._run_workflow_node(
|
||||
@@ -420,7 +422,7 @@ class WorkflowEngineManager:
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run workflow node
|
||||
@@ -447,7 +449,8 @@ class WorkflowEngineManager:
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={}
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# variable selector to variable mapping
|
||||
@@ -536,7 +539,7 @@ class WorkflowEngineManager:
|
||||
end_at=end_node_id
|
||||
)
|
||||
|
||||
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow run success
|
||||
:param callbacks: workflow callbacks
|
||||
@@ -548,7 +551,7 @@ class WorkflowEngineManager:
|
||||
callback.on_workflow_run_succeeded()
|
||||
|
||||
def _workflow_run_failed(self, error: str,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param error: error message
|
||||
@@ -561,11 +564,11 @@ class WorkflowEngineManager:
|
||||
error=error
|
||||
)
|
||||
|
||||
def _workflow_iteration_started(self, graph: dict,
|
||||
def _workflow_iteration_started(self, *, graph: Mapping[str, Any],
|
||||
current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow iteration started
|
||||
:param current_iteration_node: current iteration node
|
||||
@@ -598,10 +601,10 @@ class WorkflowEngineManager:
|
||||
# add steps
|
||||
workflow_run_state.workflow_node_steps += 1
|
||||
|
||||
def _workflow_iteration_next(self, graph: dict,
|
||||
def _workflow_iteration_next(self, *, graph: Mapping[str, Any],
|
||||
current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow iteration next
|
||||
:param workflow_run_state: workflow run state
|
||||
@@ -630,9 +633,9 @@ class WorkflowEngineManager:
|
||||
for node in nodes:
|
||||
workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
|
||||
|
||||
def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
|
||||
def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
if callbacks:
|
||||
if isinstance(workflow_run_state.current_iteration_state, IterationState):
|
||||
for callback in callbacks:
|
||||
@@ -645,10 +648,10 @@ class WorkflowEngineManager:
|
||||
}
|
||||
)
|
||||
|
||||
def _get_next_overall_node(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState,
|
||||
graph: Mapping[str, Any],
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> Optional[BaseNode]:
|
||||
"""
|
||||
@@ -740,9 +743,9 @@ class WorkflowEngineManager:
|
||||
)
|
||||
|
||||
def _get_node(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
graph: Mapping[str, Any],
|
||||
node_id: str,
|
||||
callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
|
||||
callbacks: Sequence[WorkflowCallback]):
|
||||
"""
|
||||
Get node from graph by node id
|
||||
"""
|
||||
@@ -753,7 +756,7 @@ class WorkflowEngineManager:
|
||||
for node_config in nodes:
|
||||
if node_config.get('id') == node_id:
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = node_classes[node_type]
|
||||
return node_cls(
|
||||
tenant_id=workflow_run_state.tenant_id,
|
||||
app_id=workflow_run_state.app_id,
|
||||
@@ -766,8 +769,6 @@ class WorkflowEngineManager:
|
||||
workflow_call_depth=workflow_run_state.workflow_call_depth
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
@@ -786,10 +787,10 @@ class WorkflowEngineManager:
|
||||
if node_and_result.node_id == node_id
|
||||
])
|
||||
|
||||
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
|
||||
def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState,
|
||||
node: BaseNode,
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_node_execute_started(
|
||||
@@ -940,12 +941,14 @@ class WorkflowEngineManager:
|
||||
|
||||
return new_value
|
||||
|
||||
def _mapping_user_inputs_to_variable_pool(self,
|
||||
variable_mapping: dict,
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_instance: BaseNode):
|
||||
def _mapping_user_inputs_to_variable_pool(
|
||||
self,
|
||||
variable_mapping: dict,
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_instance: BaseNode
|
||||
):
|
||||
for variable_key, variable_selector in variable_mapping.items():
|
||||
if variable_key not in user_inputs:
|
||||
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
|
||||
|
||||
@@ -3,6 +3,13 @@ from flask_restful import fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
environment_variable_fields = {
|
||||
'name': fields.String,
|
||||
'value': fields.Raw,
|
||||
'value_type': fields.String(attribute='value_type.value'),
|
||||
'exportable': fields.Boolean,
|
||||
}
|
||||
|
||||
workflow_fields = {
|
||||
'id': fields.String,
|
||||
'graph': fields.Raw(attribute='graph_dict'),
|
||||
@@ -13,4 +20,5 @@ workflow_fields = {
|
||||
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
|
||||
'updated_at': TimestampField,
|
||||
'tool_published': fields.Boolean,
|
||||
'environment_variables': fields.List(fields.Nested(environment_variable_fields)),
|
||||
}
|
||||
|
||||
9
api/models/helpers.py
Normal file
9
api/models/helpers.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.workflow import EnvironmentVariable
|
||||
|
||||
|
||||
def encrypt_environment_variable(var: "EnvironmentVariable", *, encrypt_func) -> "EnvironmentVariable":
|
||||
var.value = encrypt_func(var.value)
|
||||
return var
|
||||
@@ -1,11 +1,18 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask_login import current_user
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
from models.helpers import encrypt_environment_variable
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
@@ -62,6 +69,30 @@ class WorkflowType(Enum):
|
||||
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
|
||||
|
||||
|
||||
class EnvironmentType(str, Enum):
|
||||
STRING = 'string'
|
||||
NUMBER = 'number'
|
||||
SECRET = 'secret'
|
||||
|
||||
class EnvironmentVariable(BaseModel):
|
||||
name: str
|
||||
value: Any
|
||||
value_type: EnvironmentType
|
||||
exportable: bool
|
||||
|
||||
def export(self):
|
||||
if not self.exportable:
|
||||
raise ValueError(f'environment variable {self.name} is not exportable')
|
||||
if self.value_type == EnvironmentType.SECRET:
|
||||
cp = self.model_copy()
|
||||
cp.value = None
|
||||
return cp.model_dump(mode='json')
|
||||
return self.model_dump(mode='json')
|
||||
|
||||
class DBEnvironmentVariable(BaseModel):
|
||||
data: Sequence[EnvironmentVariable]
|
||||
|
||||
|
||||
class Workflow(db.Model):
|
||||
"""
|
||||
Workflow, for `Workflow App` and `Chat App workflow mode`.
|
||||
@@ -112,6 +143,19 @@ class Workflow(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_by = db.Column(StringUUID)
|
||||
updated_at = db.Column(db.DateTime)
|
||||
# TODO: update this field to sqlalchemy column after frontend update.
|
||||
# JSON example:
|
||||
# {
|
||||
# "data": [
|
||||
# {
|
||||
# "name": "ENV_VAR_NAME",
|
||||
# "value": "ENV_VAR_VALUE",
|
||||
# "value_type": "string",
|
||||
# "exportable": true
|
||||
# },
|
||||
# ]
|
||||
# }
|
||||
_environment_variables = '{"data": [{"name": "TEST_ENV_NAME", "value": "TEST_ENV_VALUE", "value_type": "string", "exportable": true}, {"name": "TEST_ENV_NAME_2", "value":2, "value_type": "number", "exportable": true}]}'
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@@ -122,11 +166,11 @@ class Workflow(db.Model):
|
||||
return Account.query.get(self.updated_by) if self.updated_by else None
|
||||
|
||||
@property
|
||||
def graph_dict(self):
|
||||
return json.loads(self.graph) if self.graph else None
|
||||
def graph_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.graph) if self.graph else {}
|
||||
|
||||
@property
|
||||
def features_dict(self):
|
||||
def features_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
|
||||
def user_input_form(self, to_old_structure: bool = False) -> list:
|
||||
@@ -177,6 +221,34 @@ class Workflow(db.Model):
|
||||
WorkflowToolProvider.app_id == self.app_id
|
||||
).first() is not None
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[EnvironmentVariable]:
|
||||
return DBEnvironmentVariable.model_validate_json(self._environment_variables).data
|
||||
|
||||
@environment_variables.setter
|
||||
def environment_variables(self, vars: Sequence[EnvironmentVariable]):
|
||||
# get current user from flask context, may not a good way.
|
||||
user = current_user
|
||||
if not isinstance(user, Account | EndUser):
|
||||
raise ValueError('current user is not account or end user')
|
||||
|
||||
|
||||
previous_vars = {var.name: var for var in self.environment_variables}
|
||||
|
||||
new_vars = []
|
||||
for var in vars:
|
||||
if var.name in previous_vars and var != previous_vars[var.name]:
|
||||
new_vars.append(var)
|
||||
elif var.name in previous_vars and var == previous_vars[var.name]:
|
||||
new_vars.append(previous_vars[var.name])
|
||||
elif var.value_type == EnvironmentType.SECRET:
|
||||
new_vars.append(encrypt_environment_variable(var, encrypt_func=lambda t: encrypter.encrypt_token(user.current_tenant_id, t)))
|
||||
else:
|
||||
new_vars.append(var)
|
||||
|
||||
self._environment_variables = DBEnvironmentVariable(data=new_vars).model_dump_json()
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(Enum):
|
||||
"""
|
||||
Workflow Run Triggered From Enum
|
||||
|
||||
@@ -21,6 +21,7 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
from models.workflow import EnvironmentVariable
|
||||
from services.tag_service import TagService
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
|
||||
@@ -193,12 +194,16 @@ class AppService:
|
||||
if workflow:
|
||||
# init draft workflow
|
||||
workflow_service = WorkflowService()
|
||||
# parse environment variables.
|
||||
environment_variables_list = workflow.get('environment_variables') or []
|
||||
environment_variables = [EnvironmentVariable(**obj) for obj in environment_variables_list]
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow.get('graph'),
|
||||
features=workflow.get('features'),
|
||||
unique_hash=None,
|
||||
account=account
|
||||
account=account,
|
||||
environment_variables=environment_variables
|
||||
)
|
||||
workflow_service.publish_workflow(
|
||||
app_model=app,
|
||||
@@ -244,9 +249,12 @@ class AppService:
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("Draft workflow not found")
|
||||
export_data['workflow'] = {
|
||||
"graph": workflow.graph_dict,
|
||||
"features": workflow.features_dict
|
||||
"features": workflow.features_dict,
|
||||
"environment_variables": [var.export() for var in workflow.environment_variables if var.exportable]
|
||||
}
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
|
||||
@@ -199,7 +199,8 @@ class WorkflowConverter:
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account_id
|
||||
created_by=account_id,
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
db.session.add(workflow)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,6 +18,7 @@ from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import (
|
||||
CreatedByRole,
|
||||
EnvironmentVariable,
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
@@ -63,11 +65,15 @@ class WorkflowService:
|
||||
|
||||
return workflow
|
||||
|
||||
def sync_draft_workflow(self, app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
account: Account) -> Workflow:
|
||||
def sync_draft_workflow(
|
||||
self,
|
||||
app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
account: Account,
|
||||
environment_variables: Sequence[EnvironmentVariable],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
:raises WorkflowHashNotEqualError
|
||||
@@ -75,10 +81,8 @@ class WorkflowService:
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if workflow:
|
||||
# validate unique hash
|
||||
if workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
if workflow and workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(
|
||||
@@ -95,7 +99,8 @@ class WorkflowService:
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
@@ -104,6 +109,7 @@ class WorkflowService:
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
@@ -189,7 +195,8 @@ class WorkflowService:
|
||||
version=str(datetime.now(timezone.utc).replace(tzinfo=None)),
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
|
||||
Reference in New Issue
Block a user