mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 19:42:42 +08:00
Compare commits
3 Commits
feat/make-
...
refactor/t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8cb76bbf4d | ||
|
|
381634021d | ||
|
|
c6590aef1e |
@@ -103,7 +103,7 @@ class GraphEngine:
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
@@ -140,7 +140,7 @@ class GraphEngine:
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
self.max_execution_steps = max_execution_steps
|
||||
self.max_execution_time = max_execution_time
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
@@ -133,8 +134,11 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
@@ -146,7 +150,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -101,8 +102,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
@@ -114,7 +118,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=loop_graph,
|
||||
graph_config=self.graph_config,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
|
||||
@@ -69,6 +69,7 @@ class WorkflowEntry:
|
||||
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
|
||||
|
||||
# init workflow run state
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
@@ -80,7 +81,7 @@ class WorkflowEntry:
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=thread_pool_id,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
@@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
|
||||
user_inputs={"uid": "takato"},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@@ -11,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
@@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
|
||||
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
|
||||
"""Helper method to create a graph engine instance for testing"""
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
variable_pool = {
|
||||
"system_variables": {
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "clear",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
"user_inputs": user_inputs or {"uid": "takato"},
|
||||
}
|
||||
user_inputs=user_inputs or {"uid": "takato"},
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
return GraphEngine(
|
||||
tenant_id="111",
|
||||
@@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user