Compare commits

...

1 Commits

Author SHA1 Message Date
-LAN-
62369a9ee8 fix(workflow_tool): loss current user 2025-10-04 02:04:18 +08:00
2 changed files with 85 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
@@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models.model import App
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -79,11 +81,13 @@ class WorkflowTool(Tool):
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
assert current_user is not None
user = self._resolve_user(user_id)
if user is None:
raise ToolInvokeError("workflow tool invoke missing user context")
result = generator.generate(
app_model=app,
workflow=workflow,
user=current_user,
user=user,
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,
@@ -227,3 +231,26 @@ class WorkflowTool(Tool):
elif transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_dict.get("related_id")
return file_dict
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
runtime = self.runtime
try:
user_candidate = current_user
except RuntimeError:
user_candidate = None
if user_candidate is not None and getattr(user_candidate, "is_authenticated", False):
return user_candidate
if not user_id or runtime is None:
return None
invoke_from = runtime.invoke_from
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.PUBLISHED}:
end_user = (
db.session.query(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == runtime.tenant_id).first()
)
if end_user:
return end_user
return db.session.query(Account).where(Account.id == user_id).first()

View File

@@ -40,9 +40,64 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
lambda *args, **kwargs: {"data": {"error": "oops"}},
)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
monkeypatch.setattr(
WorkflowTool,
"_resolve_user",
lambda self, _user_id: type("DummyUser", (), {"id": _user_id, "is_authenticated": True})(),
raising=False,
)
with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to
# actually `run` the tool.
list(tool.invoke("test_user", {}))
assert exc_info.value.args == ("oops",)
def test_workflow_tool_falls_back_to_user_resolver_when_no_current_user(monkeypatch: pytest.MonkeyPatch):
entity = ToolEntity(
identity=ToolIdentity(author="tester", name="work", label=I18nObject(en_US="work"), provider="prv"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="tenant-id", invoke_from=InvokeFrom.SERVICE_API)
tool = WorkflowTool(
workflow_app_id="app-id",
workflow_as_tool_id="tool-id",
version="1",
workflow_entities={},
workflow_call_depth=0,
entity=entity,
runtime=runtime,
)
# keep tool internals simple for the test
monkeypatch.setattr(tool, "_get_app", lambda *_args, **_kwargs: object())
monkeypatch.setattr(tool, "_get_workflow", lambda *_args, **_kwargs: object())
monkeypatch.setattr(tool, "_transform_args", lambda tool_parameters, **_: (tool_parameters, []))
captured: dict[str, str] = {}
class DummyUser:
id = "dummy-user"
is_authenticated = True
dummy_user = DummyUser()
def fake_resolver(self, user_id: str):
captured["user_id"] = user_id
return dummy_user
def fake_generate(self, *, user, **_kwargs):
assert user is dummy_user
return {"data": {"outputs": {}}}
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", fake_generate)
monkeypatch.setattr("core.tools.workflow_as_tool.tool.current_user", None)
monkeypatch.setattr(WorkflowTool, "_resolve_user", fake_resolver, raising=False)
result = list(tool.invoke("user-123", {}))
assert captured["user_id"] == "user-123"
assert len(result) == 2 # text + json outputs