Compare commits

...

3 Commits

Author SHA1 Message Date
-LAN-
f042e1e545 Merge branch 'main' into feat/optimize-database-usage 2024-12-20 14:42:17 +08:00
-LAN-
7e3781e689 Merge branch 'main' into feat/optimize-database-usage 2024-12-20 14:28:21 +08:00
-LAN-
60bffd6d93 feat(app_runner): remove unnecessary app, workflow and user query
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-20 14:26:46 +08:00
9 changed files with 92 additions and 116 deletions

View File

@@ -314,12 +314,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
.first()
)
if not workflow:
raise ValueError("Workflow not initialized")
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
workflow=workflow,
dialogue_count=self._dialogue_count,
)

View File

@@ -21,8 +21,9 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message
from models.model import Conversation, Message
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
@@ -35,38 +36,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def __init__(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
dialogue_count: int,
workflow: Workflow,
) -> None:
super().__init__(queue_manager)
super().__init__(queue_manager=queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
self.workflow = workflow
def run(self) -> None:
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
app_config = cast(AdvancedChatAppConfig, self.application_generate_entity.app_config)
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
@@ -75,7 +62,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
workflow=self.workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
@@ -86,7 +73,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
@@ -96,7 +82,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
@@ -116,7 +101,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
for variable in self.workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
@@ -129,7 +114,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.USER_ID: self.application_generate_entity.user_id,
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
@@ -140,23 +125,23 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
environment_variables=self.workflow.environment_variables,
conversation_variables=conversation_variables,
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_graph(graph_config=self.workflow.graph_dict)
db.session.close()
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
tenant_id=self.workflow.tenant_id,
app_id=self.workflow.app_id,
workflow_id=self.workflow.id,
workflow_type=WorkflowType.value_of(self.workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
graph_config=self.workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
@@ -177,7 +162,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def handle_input_moderation(
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
@@ -186,7 +170,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
app_id=app_generate_entity.app_config.app_id,
tenant_id=app_generate_entity.app_config.tenant_id,
app_generate_entity=app_generate_entity,
inputs=inputs,
@@ -200,10 +184,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
return False
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
self,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
app_id=app_generate_entity.app_config.app_id,
tenant_id=app_generate_entity.app_config.tenant_id,
message=message,
query=query,
user_id=app_generate_entity.user_id,

View File

@@ -116,7 +116,8 @@ class AgentChatAppRunner(AppRunner):
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
app_id=app_record.id,
tenant_id=app_config.tenant_id,
message=message,
query=query,
user_id=application_generate_entity.user_id,

View File

@@ -409,7 +409,7 @@ class AppRunner:
)
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
self, app_id: str, tenant_id: str, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
@@ -422,5 +422,10 @@ class AppRunner:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
app_id=app_id,
tenant_id=tenant_id,
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from,
)

View File

@@ -111,7 +111,8 @@ class ChatAppRunner(AppRunner):
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
app_id=app_record.id,
tenant_id=app_config.tenant_id,
message=message,
query=query,
user_id=application_generate_entity.user_id,

View File

@@ -253,11 +253,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
var.set(val)
with flask_app.app_context():
try:
# workflow app
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
.first()
)
if not workflow:
raise ValueError("Workflow not initialized")
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
workflow=workflow,
)
runner.run()

View File

@@ -13,9 +13,8 @@ from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import Workflow
from models.enums import UserFrom
from models.model import App, EndUser
from models.workflow import WorkflowType
logger = logging.getLogger(__name__)
@@ -28,18 +27,17 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
def __init__(
self,
*,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(queue_manager=queue_manager)
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
self.workflow = workflow
def run(self) -> None:
"""
@@ -48,26 +46,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
:param queue_manager: application queue manager
:return:
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
app_config = cast(WorkflowAppConfig, self.application_generate_entity.app_config)
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
@@ -77,7 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
workflow=self.workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
@@ -88,7 +67,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.USER_ID: self.application_generate_entity.user_id,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
@@ -97,21 +76,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
environment_variables=self.workflow.environment_variables,
conversation_variables=[],
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_graph(graph_config=self.workflow.graph_dict)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
tenant_id=self.workflow.tenant_id,
app_id=self.workflow.app_id,
workflow_id=self.workflow.id,
workflow_type=WorkflowType.value_of(self.workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
graph_config=self.workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT

View File

@@ -1,10 +1,9 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -51,8 +50,6 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
@@ -452,22 +449,3 @@ class WorkflowBasedAppRunner(AppRunner):
start_index=event.start_index,
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from models.model import AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
self, app_id: str, tenant_id: str, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
@@ -26,7 +26,7 @@ class AnnotationReplyFeature:
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if not annotation_setting:
@@ -44,8 +44,8 @@ class AnnotationReplyFeature:
)
dataset = Dataset(
id=app_record.id,
tenant_id=app_record.tenant_id,
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
@@ -70,15 +70,15 @@ class AnnotationReplyFeature:
# insert annotation history
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.content,
query,
user_id,
message.id,
from_source,
score,
annotation_id=annotation.id,
app_id=app_id,
annotation_question=annotation.question,
annotation_content=annotation.content,
query=query,
user_id=user_id,
message_id=message.id,
from_source=from_source,
score=score,
)
return annotation