mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 11:29:30 +08:00
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
This commit is contained in:
@@ -117,7 +117,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@@ -99,12 +99,13 @@ class MessageCycleManager:
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, query, conversation_id, conversation.app_id
|
||||
)
|
||||
conversation.name = name
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
|
||||
@@ -7,7 +7,7 @@ This module provides the interface for invoking and authenticating various model
|
||||
|
||||
## Features
|
||||
|
||||
- Supports capability invocation for 5 types of models
|
||||
- Supports capability invocation for 6 types of models
|
||||
|
||||
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
## 功能介绍
|
||||
|
||||
- 支持 5 种模型类型的能力调用
|
||||
- 支持 6 种模型类型的能力调用
|
||||
|
||||
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
||||
- `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力
|
||||
|
||||
@@ -150,6 +150,9 @@ class ProviderManager:
|
||||
tenant_id
|
||||
)
|
||||
|
||||
# Get All provider model credentials
|
||||
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id)
|
||||
|
||||
provider_configurations = ProviderConfigurations(tenant_id=tenant_id)
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
@@ -171,10 +174,18 @@ class ProviderManager:
|
||||
provider_model_records.extend(
|
||||
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
provider_model_credentials = provider_name_to_provider_model_credentials_dict.get(
|
||||
provider_entity.provider, []
|
||||
)
|
||||
provider_id_entity = ModelProviderID(provider_name)
|
||||
if provider_id_entity.is_langgenius():
|
||||
provider_model_credentials.extend(
|
||||
provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
tenant_id, provider_entity, provider_records, provider_model_records
|
||||
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
|
||||
)
|
||||
|
||||
# Convert to system configuration
|
||||
@@ -453,6 +464,24 @@ class ProviderManager:
|
||||
)
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]:
|
||||
"""
|
||||
Get All provider model credentials of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_credentials_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
|
||||
provider_model_credentials = session.scalars(stmt)
|
||||
for provider_model_credential in provider_model_credentials:
|
||||
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
|
||||
provider_model_credential
|
||||
)
|
||||
return provider_name_to_provider_model_credentials_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
|
||||
"""
|
||||
@@ -539,23 +568,6 @@ class ProviderManager:
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
|
||||
"""
|
||||
Get all the credentials records from ProviderModelCredential by provider_name
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
|
||||
)
|
||||
|
||||
all_credentials = session.scalars(stmt).all()
|
||||
return all_credentials
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
@@ -632,6 +644,7 @@ class ProviderManager:
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
provider_model_records: list[ProviderModel],
|
||||
provider_model_credentials: list[ProviderModelCredential],
|
||||
) -> CustomConfiguration:
|
||||
"""
|
||||
Convert to custom configuration.
|
||||
@@ -647,15 +660,12 @@ class ProviderManager:
|
||||
tenant_id, provider_entity, provider_records
|
||||
)
|
||||
|
||||
# Get all model credentials once
|
||||
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
|
||||
|
||||
# Get custom models which have not been added to the model list yet
|
||||
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
|
||||
unadded_models = self._get_can_added_models(provider_model_records, provider_model_credentials)
|
||||
|
||||
# Get custom model configurations
|
||||
custom_model_configurations = self._get_custom_model_configurations(
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, provider_model_credentials
|
||||
)
|
||||
|
||||
can_added_models = [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
import tablestore # type: ignore
|
||||
@@ -102,9 +103,12 @@ class TableStoreVector(BaseVector):
|
||||
return uuids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
_, return_row, _ = self._tablestore_client.get_row(
|
||||
result = self._tablestore_client.get_row(
|
||||
table_name=self._table_name, primary_key=[("id", id)], columns_to_get=["id"]
|
||||
)
|
||||
assert isinstance(result, tuple | list)
|
||||
# Unpack the tuple result
|
||||
_, return_row, _ = result
|
||||
|
||||
return return_row is not None
|
||||
|
||||
@@ -169,6 +173,7 @@ class TableStoreVector(BaseVector):
|
||||
|
||||
def _create_search_index_if_not_exist(self, dimension: int) -> None:
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
if self._index_name in [t[1] for t in search_index_list]:
|
||||
logger.info("Tablestore system index[%s] already exists", self._index_name)
|
||||
return None
|
||||
@@ -212,6 +217,7 @@ class TableStoreVector(BaseVector):
|
||||
|
||||
def _delete_table_if_exist(self):
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
for resp_tuple in search_index_list:
|
||||
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
@@ -269,7 +275,7 @@ class TableStoreVector(BaseVector):
|
||||
)
|
||||
|
||||
if search_response is not None:
|
||||
rows.extend([row[0][0][1] for row in search_response.rows])
|
||||
rows.extend([row[0][0][1] for row in list(search_response.rows)])
|
||||
|
||||
if search_response is None or search_response.next_token == b"":
|
||||
break
|
||||
|
||||
@@ -41,13 +41,6 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
|
||||
|
||||
# Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0,
|
||||
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
|
||||
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
|
||||
# which does not contain the deprecation check.
|
||||
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"): # ty: ignore [unresolved-attribute]
|
||||
weaviate.connect.connection.PYPI_TIMEOUT = 0.001 # ty: ignore [unresolved-attribute]
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""ClickZetta Volume file lifecycle management
|
||||
|
||||
This module provides file lifecycle management features including version control,
|
||||
automatic cleanup, backup and restore. Supports complete lifecycle management for
|
||||
knowledge base files.
|
||||
automatic cleanup, backup and restore.
|
||||
Supports complete lifecycle management for knowledge base files.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
@@ -215,7 +215,7 @@ vdb = [
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
"volcengine-compat~=1.0.0",
|
||||
"weaviate-client~=3.24.0",
|
||||
"weaviate-client~=3.26.7",
|
||||
"xinference-client~=1.2.2",
|
||||
"mo-vector~=0.1.13",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All",
|
||||
"reportMissingTypeStubs": false,
|
||||
"reportGeneralTypeIssues": "none",
|
||||
"reportOptionalMemberAccess": "none",
|
||||
"reportOptionalIterable": "none",
|
||||
"reportOptionalOperand": "none",
|
||||
|
||||
@@ -1093,7 +1093,7 @@ class DocumentService:
|
||||
account: Account | Any,
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
) -> tuple[list[Document], str]:
|
||||
# check doc_form
|
||||
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
|
||||
# check document limit
|
||||
|
||||
@@ -15,7 +15,7 @@ class RecommendedAppService:
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result = retrieval_instance.get_recommended_apps_and_categories(language)
|
||||
if not result.get("recommended_apps") and language != "en-US":
|
||||
if not result.get("recommended_apps"):
|
||||
result = (
|
||||
RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin(
|
||||
"en-US"
|
||||
|
||||
@@ -0,0 +1,553 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from models.account import Account, Tenant
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.workflow import Workflow
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
|
||||
class TestWorkflowConverter:
|
||||
"""Integration tests for WorkflowConverter using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.workflow.workflow_converter.encrypter") as mock_encrypter,
|
||||
patch("services.workflow.workflow_converter.SimplePromptTransform") as mock_prompt_transform,
|
||||
patch("services.workflow.workflow_converter.AgentChatAppConfigManager") as mock_agent_chat_config_manager,
|
||||
patch("services.workflow.workflow_converter.ChatAppConfigManager") as mock_chat_config_manager,
|
||||
patch("services.workflow.workflow_converter.CompletionAppConfigManager") as mock_completion_config_manager,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
|
||||
mock_prompt_transform.return_value.get_prompt_template.return_value = {
|
||||
"prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(),
|
||||
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
}
|
||||
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||
mock_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||
mock_completion_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||
|
||||
yield {
|
||||
"encrypter": mock_encrypter,
|
||||
"prompt_transform": mock_prompt_transform,
|
||||
"agent_chat_config_manager": mock_agent_chat_config_manager,
|
||||
"chat_config_manager": mock_chat_config_manager,
|
||||
"completion_config_manager": mock_completion_config_manager,
|
||||
}
|
||||
|
||||
def _create_mock_app_config(self):
|
||||
"""Helper method to create a mock app config."""
|
||||
mock_config = type("obj", (object,), {})()
|
||||
mock_config.variables = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="Text Input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
)
|
||||
]
|
||||
mock_config.model = ModelConfigEntity(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
mode=LLMMode.CHAT.value,
|
||||
parameters={},
|
||||
stop=[],
|
||||
)
|
||||
mock_config.prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}",
|
||||
)
|
||||
mock_config.dataset = None
|
||||
mock_config.external_data_variables = []
|
||||
mock_config.additional_features = type("obj", (object,), {"file_upload": None})()
|
||||
mock_config.app_model_config_dict = {}
|
||||
return mock_config
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account):
|
||||
"""
|
||||
Helper method to create a test app for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
tenant: Tenant instance
|
||||
account: Account instance
|
||||
|
||||
Returns:
|
||||
App: Created app instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create app
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
mode=AppMode.CHAT.value,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=10,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
configs={},
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
# Link app model config to app
|
||||
app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion of app to workflow.
|
||||
|
||||
This test verifies:
|
||||
- Proper app to workflow conversion
|
||||
- Correct database state after conversion
|
||||
- Proper relationship establishment
|
||||
- Workflow creation with correct configuration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
new_app = workflow_converter.convert_to_workflow(
|
||||
app_model=app,
|
||||
account=account,
|
||||
name="Test Workflow App",
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4CAF50",
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert new_app is not None
|
||||
assert new_app.name == "Test Workflow App"
|
||||
assert new_app.mode == AppMode.ADVANCED_CHAT.value
|
||||
assert new_app.icon_type == "emoji"
|
||||
assert new_app.icon == "🚀"
|
||||
assert new_app.icon_background == "#4CAF50"
|
||||
assert new_app.tenant_id == app.tenant_id
|
||||
assert new_app.created_by == account.id
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(new_app)
|
||||
assert new_app.id is not None
|
||||
|
||||
# Verify workflow was created
|
||||
workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first()
|
||||
assert workflow is not None
|
||||
assert workflow.tenant_id == app.tenant_id
|
||||
assert workflow.type == "chat"
|
||||
|
||||
def test_convert_to_workflow_without_app_model_config_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when app model config is missing.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing app model config
|
||||
- Correct exception type and message
|
||||
- Database state remains unchanged
|
||||
"""
|
||||
# Arrange: Create test data without app model config
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
mode=AppMode.CHAT.value,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=10,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
# Check initial state
|
||||
initial_workflow_count = db.session.query(Workflow).count()
|
||||
|
||||
with pytest.raises(ValueError, match="App model config is required"):
|
||||
workflow_converter.convert_to_workflow(
|
||||
app_model=app,
|
||||
account=account,
|
||||
name="Test Workflow App",
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4CAF50",
|
||||
)
|
||||
|
||||
# Verify database state remains unchanged
|
||||
# The workflow creation happens in convert_app_model_config_to_workflow
|
||||
# which is called before the app_model_config check, so we need to clean up
|
||||
db.session.rollback()
|
||||
final_workflow_count = db.session.query(Workflow).count()
|
||||
assert final_workflow_count == initial_workflow_count
|
||||
|
||||
def test_convert_app_model_config_to_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of app model config to workflow.
|
||||
|
||||
This test verifies:
|
||||
- Proper app model config to workflow conversion
|
||||
- Correct workflow graph structure
|
||||
- Proper node creation and configuration
|
||||
- Database state management
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow = workflow_converter.convert_app_model_config_to_workflow(
|
||||
app_model=app,
|
||||
app_model_config=app.app_model_config,
|
||||
account_id=account.id,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert workflow is not None
|
||||
assert workflow.tenant_id == app.tenant_id
|
||||
assert workflow.app_id == app.id
|
||||
assert workflow.type == "chat"
|
||||
assert workflow.version == Workflow.VERSION_DRAFT
|
||||
assert workflow.created_by == account.id
|
||||
|
||||
# Verify workflow graph structure
|
||||
graph = json.loads(workflow.graph)
|
||||
assert "nodes" in graph
|
||||
assert "edges" in graph
|
||||
assert len(graph["nodes"]) > 0
|
||||
assert len(graph["edges"]) > 0
|
||||
|
||||
# Verify start node exists
|
||||
start_node = next((node for node in graph["nodes"] if node["data"]["type"] == "start"), None)
|
||||
assert start_node is not None
|
||||
assert start_node["id"] == "start"
|
||||
|
||||
# Verify LLM node exists
|
||||
llm_node = next((node for node in graph["nodes"] if node["data"]["type"] == "llm"), None)
|
||||
assert llm_node is not None
|
||||
assert llm_node["id"] == "llm"
|
||||
|
||||
# Verify answer node exists for chat mode
|
||||
answer_node = next((node for node in graph["nodes"] if node["data"]["type"] == "answer"), None)
|
||||
assert answer_node is not None
|
||||
assert answer_node["id"] == "answer"
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(workflow)
|
||||
assert workflow.id is not None
|
||||
|
||||
# Verify features were set
|
||||
features = json.loads(workflow._features) if workflow._features else {}
|
||||
assert isinstance(features, dict)
|
||||
|
||||
def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion to start node.
|
||||
|
||||
This test verifies:
|
||||
- Proper start node creation with variables
|
||||
- Correct node structure and data
|
||||
- Variable encoding and formatting
|
||||
"""
|
||||
# Arrange: Create test variables
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="Text Input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="number_input",
|
||||
label="Number Input",
|
||||
type=VariableEntityType.NUMBER,
|
||||
),
|
||||
]
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(variables=variables)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert start_node is not None
|
||||
assert start_node["id"] == "start"
|
||||
assert start_node["data"]["title"] == "START"
|
||||
assert start_node["data"]["type"] == "start"
|
||||
assert len(start_node["data"]["variables"]) == 2
|
||||
|
||||
# Verify variable encoding
|
||||
first_variable = start_node["data"]["variables"][0]
|
||||
assert first_variable["variable"] == "text_input"
|
||||
assert first_variable["label"] == "Text Input"
|
||||
assert first_variable["type"] == "text-input"
|
||||
|
||||
second_variable = start_node["data"]["variables"][1]
|
||||
assert second_variable["variable"] == "number_input"
|
||||
assert second_variable["label"] == "Number Input"
|
||||
assert second_variable["type"] == "number"
|
||||
|
||||
def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion to HTTP request node.
|
||||
|
||||
This test verifies:
|
||||
- Proper HTTP request node creation
|
||||
- Correct API configuration and authorization
|
||||
- Code node creation for response parsing
|
||||
- External data variable mapping
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
|
||||
|
||||
# Create API based extension
|
||||
api_based_extension = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Test API Extension",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://api.example.com/test",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(api_based_extension)
|
||||
db.session.commit()
|
||||
|
||||
# Mock encrypter
|
||||
mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key"
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="user_input",
|
||||
label="User Input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
)
|
||||
]
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_data", type="api", config={"api_based_extension_id": api_based_extension.id}
|
||||
)
|
||||
]
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
nodes, external_data_variable_node_mapping = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app,
|
||||
variables=variables,
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert len(nodes) == 2 # HTTP request node + code node
|
||||
assert len(external_data_variable_node_mapping) == 1
|
||||
|
||||
# Verify HTTP request node
|
||||
http_request_node = nodes[0]
|
||||
assert http_request_node["data"]["type"] == "http-request"
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"]["type"] == "bearer"
|
||||
assert http_request_node["data"]["authorization"]["config"]["api_key"] == "decrypted_api_key"
|
||||
|
||||
# Verify code node
|
||||
code_node = nodes[1]
|
||||
assert code_node["data"]["type"] == "code"
|
||||
assert code_node["data"]["code_language"] == "python3"
|
||||
assert "response_json" in code_node["data"]["variables"][0]["variable"]
|
||||
|
||||
# Verify mapping
|
||||
assert external_data_variable_node_mapping["external_data"] == code_node["id"]
|
||||
|
||||
def test_convert_to_knowledge_retrieval_node_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion to knowledge retrieval node.
|
||||
|
||||
This test verifies:
|
||||
- Proper knowledge retrieval node creation
|
||||
- Correct dataset configuration
|
||||
- Model configuration integration
|
||||
- Query variable selector setup
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create dataset config
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_1", "dataset_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=10,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"provider": "cohere", "model": "rerank-v2"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
mode=LLMMode.CHAT.value,
|
||||
parameters={"temperature": 0.7},
|
||||
stop=[],
|
||||
)
|
||||
|
||||
# Act: Execute the conversion for advanced chat mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
node = workflow_converter._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert node is not None
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["title"] == "KNOWLEDGE RETRIEVAL"
|
||||
assert node["data"]["dataset_ids"] == ["dataset_1", "dataset_2"]
|
||||
assert node["data"]["retrieval_mode"] == "multiple"
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
|
||||
# Verify multiple retrieval config
|
||||
multiple_config = node["data"]["multiple_retrieval_config"]
|
||||
assert multiple_config["top_k"] == 10
|
||||
assert multiple_config["score_threshold"] == 0.8
|
||||
assert multiple_config["reranking_model"]["provider"] == "cohere"
|
||||
assert multiple_config["reranking_model"]["model"] == "rerank-v2"
|
||||
|
||||
# Verify single retrieval config is None for multiple strategy
|
||||
assert node["data"]["single_retrieval_config"] is None
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@@ -1639,7 +1639,7 @@ vdb = [
|
||||
{ name = "tidb-vector", specifier = "==0.0.9" },
|
||||
{ name = "upstash-vector", specifier = "==0.6.0" },
|
||||
{ name = "volcengine-compat", specifier = "~=1.0.0" },
|
||||
{ name = "weaviate-client", specifier = "~=3.24.0" },
|
||||
{ name = "weaviate-client", specifier = "~=3.26.7" },
|
||||
{ name = "xinference-client", specifier = "~=1.2.2" },
|
||||
]
|
||||
|
||||
@@ -6708,16 +6708,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "weaviate-client"
|
||||
version = "3.24.2"
|
||||
version = "3.26.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "authlib" },
|
||||
{ name = "requests" },
|
||||
{ name = "validators" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1f/c1/3285a21d8885f2b09aabb65edb9a8e062a35c2d7175e1bb024fa096582ab/weaviate-client-3.24.2.tar.gz", hash = "sha256:6914c48c9a7e5ad0be9399271f9cb85d6f59ab77476c6d4e56a3925bf149edaa", size = 199332, upload-time = "2023-10-04T08:37:54.26Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f8/2e/9588bae34c1d67d05ccc07d74a4f5d73cce342b916f79ab3a9114c6607bb/weaviate_client-3.26.7.tar.gz", hash = "sha256:ea538437800abc6edba21acf213accaf8a82065584ee8b914bae4a4ad4ef6b70", size = 210480, upload-time = "2024-08-15T13:27:02.431Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/98/3136d05f93e30cf29e1db280eaadf766df18d812dfe7994bcced653b2340/weaviate_client-3.24.2-py3-none-any.whl", hash = "sha256:bc50ca5fcebcd48de0d00f66700b0cf7c31a97c4cd3d29b4036d77c5d1d9479b", size = 107968, upload-time = "2023-10-04T08:37:52.511Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/95/fb326052bc1d73cb3c19fcfaf6ebb477f896af68de07eaa1337e27ee57fa/weaviate_client-3.26.7-py3-none-any.whl", hash = "sha256:48b8d4b71df881b4e5e15964d7ac339434338ccee73779e3af7eab698a92083b", size = 120051, upload-time = "2024-08-15T13:27:00.212Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -5,5 +5,12 @@ set -x
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/.."
|
||||
|
||||
# Get the path argument if provided
|
||||
PATH_TO_CHECK="$1"
|
||||
|
||||
# run basedpyright checks
|
||||
uv run --directory api --dev basedpyright
|
||||
if [ -n "$PATH_TO_CHECK" ]; then
|
||||
uv run --directory api --dev basedpyright "$PATH_TO_CHECK"
|
||||
else
|
||||
uv run --directory api --dev basedpyright
|
||||
fi
|
||||
|
||||
@@ -54,7 +54,7 @@ const TypeSelector: FC<Props> = ({
|
||||
<InputVarTypeIcon type={selectedItem?.value as InputVarType} className='size-4 shrink-0 text-text-secondary' />
|
||||
<span
|
||||
className={`
|
||||
ml-1.5 ${!selectedItem?.name && 'text-components-input-text-placeholder'}
|
||||
ml-1.5 text-components-input-text-filled ${!selectedItem?.name && 'text-components-input-text-placeholder'}
|
||||
`}
|
||||
>
|
||||
{selectedItem?.name}
|
||||
|
||||
@@ -11,7 +11,7 @@ import { useRouter } from 'next/navigation'
|
||||
import { debounce } from 'lodash-es'
|
||||
import cn from '@/utils/classnames'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication'
|
||||
import { AiText, BubbleTextMod, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication'
|
||||
import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
@@ -90,7 +90,7 @@ const NavSelector = ({ curNav, navs, createText, isApp, onCreate, onLoadmore }:
|
||||
'absolute -bottom-0.5 -right-0.5 h-3.5 w-3.5 rounded border-[0.5px] border-[rgba(0,0,0,0.02)] bg-white p-0.5 shadow-sm',
|
||||
)}>
|
||||
{nav.mode === 'advanced-chat' && (
|
||||
<ChatBot className='h-2.5 w-2.5 text-[#1570EF]' />
|
||||
<BubbleTextMod className='h-2.5 w-2.5 text-[#1570EF]' />
|
||||
)}
|
||||
{nav.mode === 'agent-chat' && (
|
||||
<CuteRobot className='h-2.5 w-2.5 text-indigo-600' />
|
||||
|
||||
@@ -97,7 +97,8 @@ const ExportImage: FC = () => {
|
||||
style: {
|
||||
width: `${contentWidth}px`,
|
||||
height: `${contentHeight}px`,
|
||||
transform: `translate(${padding - nodesBounds.x}px, ${padding - nodesBounds.y}px) scale(${zoom})`,
|
||||
transform: `translate(${padding - nodesBounds.x}px, ${padding - nodesBounds.y}px)`,
|
||||
transformOrigin: 'top left',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user