ruff check preview (#25653)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-16 13:58:12 +09:00
committed by GitHub
parent a0c7713494
commit bdd85b36a4
42 changed files with 224 additions and 342 deletions

View File

@@ -22,7 +22,7 @@ jobs:
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format .
uv run ruff format ..
- name: ast-grep
run: |

View File

@@ -5,7 +5,7 @@ line-length = 120
quote-style = "double"
[lint]
preview = false
preview = true
select = [
"B", # flake8-bugbear rules
"C4", # flake8-comprehensions
@@ -65,6 +65,7 @@ ignore = [
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B901", # allow return in yield
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict

View File

@@ -1,6 +1,7 @@
import base64
import json
import logging
import operator
import secrets
from typing import Any
@@ -953,7 +954,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn:
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
)
@@ -1307,7 +1308,7 @@ def cleanup_orphaned_draft_variables(
if dry_run:
logger.info("DRY RUN: Would delete the following:")
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[
:10
]: # Show top 10
logger.info(" App %s: %s variables", app_id, count)

View File

@@ -355,8 +355,8 @@ class AliyunDataTrace(BaseTraceInstance):
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),

View File

@@ -144,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
@@ -163,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance):
"status": status,
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
model_provider = process_data.get("model_provider", None)
model_name = process_data.get("model_name", None)
if model_provider is not None and model_name is not None:

View File

@@ -167,13 +167,13 @@ class LangSmithDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update(
@@ -188,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance):
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm

View File

@@ -182,13 +182,13 @@ class OpikDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
@@ -202,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance):
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
provider = None
model = None

View File

@@ -1,3 +1,4 @@
import collections
import json
import logging
import os
@@ -40,7 +41,7 @@ from tasks.ops_trace_task import process_trace_tasks
logger = logging.getLogger(__name__)
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
@@ -121,7 +122,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
provider_config_map = OpsTraceProviderConfigMap()
class OpsTraceManager:

View File

@@ -169,13 +169,13 @@ class WeaveDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update(
@@ -190,7 +190,7 @@ class WeaveDataTrace(BaseTraceInstance):
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat":
attributes.update(
{

View File

@@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector):
for doc, embedding in zip(batch_docs, batch_embeddings):
# Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata if doc.metadata else {}
metadata = doc.metadata or {}
if not isinstance(metadata, dict):
metadata = {}

View File

@@ -103,7 +103,7 @@ class MatrixoneVector(BaseVector):
self.client = self._get_client(len(embeddings[0]), True)
assert self.client is not None
ids = []
for _, doc in enumerate(documents):
for doc in documents:
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
ids.append(doc_id)

View File

@@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector):
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
if self._client_config.aws_service != "aoss":
action["_id"] = uuid4().hex
actions.append(action)

View File

@@ -159,7 +159,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
else None
)
db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.error = domain_model.error_message or None
db_model.total_tokens = domain_model.total_tokens
db_model.total_steps = domain_model.total_steps
db_model.exceptions_count = domain_model.exceptions_count

View File

@@ -320,7 +320,7 @@ class AgentNode(BaseNode):
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages

View File

@@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery:
imports.append("schedule.queue_monitor_task")
beat_schedule["datasets-queue-monitor"] = {
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task")

View File

@@ -7,6 +7,7 @@ Supports complete lifecycle management for knowledge base files.
import json
import logging
import operator
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import StrEnum, auto
@@ -356,7 +357,7 @@ class FileLifecycleManager:
# Cleanup old versions for each file
for base_filename, versions in file_versions.items():
# Sort by version number
versions.sort(key=lambda x: x[0], reverse=True)
versions.sort(key=operator.itemgetter(0), reverse=True)
# Keep the newest max_versions versions, delete the rest
if len(versions) > max_versions:

View File

@@ -1,3 +1,4 @@
import operator
import traceback
import typing
@@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task(
current_version = version
latest_version = manifest.latest_version
def fix_only_checker(latest_version, current_version):
def fix_only_checker(latest_version: str, current_version: str):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
@@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task(
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}

View File

@@ -3,6 +3,7 @@
import os
import tempfile
import unittest
from pathlib import Path
import pytest
@@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
# Test download
with tempfile.NamedTemporaryFile() as temp_file:
storage.download(test_filename, temp_file.name)
with open(temp_file.name, "rb") as f:
downloaded_content = f.read()
downloaded_content = Path(temp_file.name).read_bytes()
assert downloaded_content == test_content
# Test scan

View File

@@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
import uuid
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
Path(file_path).write_text(csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download
@@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask:
db.session.commit()
# Test each unavailable document
for i, document in enumerate(test_cases):
for document in test_cases:
job_id = str(uuid.uuid4())
batch_create_segment_to_index_task(
job_id=job_id,
@@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(empty_csv_content)
Path(file_path).write_text(empty_csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download
@@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
Path(file_path).write_text(csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download

View File

@@ -362,7 +362,7 @@ class TestCleanDatasetTask:
# Create segments for each document
segments = []
for i, document in enumerate(documents):
for document in documents:
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
segments.append(segment)

View File

@@ -15,7 +15,7 @@ class FakeResponse:
self.status_code = status_code
self.headers = headers or {}
self.content = content
self.text = text if text else content.decode("utf-8", errors="ignore")
self.text = text or content.decode("utf-8", errors="ignore")
# ---------------------------

View File

@@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import Mock, create_autospec, patch
import pytest
@@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation:
# Console API create
console_create_file = "api/controllers/console/datasets/metadata.py"
if os.path.exists(console_create_file):
with open(console_create_file) as f:
content = f.read()
# Should contain nullable=False, not nullable=True
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
content = Path(console_create_file).read_text()
# Should contain nullable=False, not nullable=True
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
# Service API create
service_create_file = "api/controllers/service_api/dataset/metadata.py"
if os.path.exists(service_create_file):
with open(service_create_file) as f:
content = f.read()
# Should contain nullable=False, not nullable=True
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
assert "nullable=True" not in create_api_section
content = Path(service_create_file).read_text()
# Should contain nullable=False, not nullable=True
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
assert "nullable=True" not in create_api_section
class TestMetadataValidationSummary:

View File

@@ -1,6 +1,7 @@
from pathlib import Path
import yaml # type: ignore
from dotenv import dotenv_values
from pathlib import Path
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
"APP_MAX_EXECUTION_TIME",
@@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f:
def test_yaml_config():
# python set == operator is used to compare two sets
DIFF_API_WITH_DOCKER = (
API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
)
DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
if DIFF_API_WITH_DOCKER:
print(
f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}"
)
print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}")
raise Exception("API and Docker config sets are different")
DIFF_API_WITH_DOCKER_COMPOSE = (
API_CONFIG_SET
- DOCKER_COMPOSE_CONFIG_SET
- BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
)
if DIFF_API_WITH_DOCKER_COMPOSE:
print(
f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}"
)
print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}")
raise Exception("API and Docker Compose config sets are different")
print("All tests passed!")

View File

@@ -51,9 +51,7 @@ def cleanup() -> None:
if sys.stdin.isatty():
log.separator()
log.warning("This action cannot be undone!")
confirmation = input(
"Are you sure you want to remove all config and report files? (yes/no): "
)
confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ")
if confirmation.lower() not in ["yes", "y"]:
log.error("Cleanup cancelled.")

View File

@@ -3,4 +3,4 @@
from .config_helper import config_helper
from .logger_helper import Logger, ProgressLogger
__all__ = ["config_helper", "Logger", "ProgressLogger"]
__all__ = ["Logger", "ProgressLogger", "config_helper"]

View File

@@ -65,9 +65,9 @@ class ConfigHelper:
return None
try:
with open(config_path, "r") as f:
with open(config_path) as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
except (OSError, json.JSONDecodeError) as e:
print(f"❌ Error reading {filename}: {e}")
return None
@@ -101,7 +101,7 @@ class ConfigHelper:
with open(config_path, "w") as f:
json.dump(data, f, indent=2)
return True
except IOError as e:
except OSError as e:
print(f"❌ Error writing {filename}: {e}")
return False
@@ -133,7 +133,7 @@ class ConfigHelper:
try:
config_path.unlink()
return True
except IOError as e:
except OSError as e:
print(f"❌ Error deleting {filename}: {e}")
return False
@@ -148,9 +148,9 @@ class ConfigHelper:
return None
try:
with open(state_path, "r") as f:
with open(state_path) as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
except (OSError, json.JSONDecodeError) as e:
print(f"❌ Error reading {self.state_file}: {e}")
return None
@@ -170,7 +170,7 @@ class ConfigHelper:
with open(state_path, "w") as f:
json.dump(data, f, indent=2)
return True
except IOError as e:
except OSError as e:
print(f"❌ Error writing {self.state_file}: {e}")
return False

View File

@@ -159,9 +159,7 @@ class ProgressLogger:
if self.logger.use_colors:
progress_bar = self._create_progress_bar()
print(
f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}"
)
print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}")
self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)")
else:
print(f"\n[Step {self.current_step}/{self.total_steps}]")

View File

@@ -6,8 +6,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
from common import config_helper
from common import Logger
from common import Logger, config_helper
def configure_openai_plugin() -> None:
@@ -72,29 +71,19 @@ def configure_openai_plugin() -> None:
if response.status_code == 200:
log.success("OpenAI plugin configured successfully!")
log.key_value(
"API Base", config_payload["credentials"]["openai_api_base"]
)
log.key_value(
"API Key", config_payload["credentials"]["openai_api_key"]
)
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
elif response.status_code == 201:
log.success("OpenAI plugin credentials created successfully!")
log.key_value(
"API Base", config_payload["credentials"]["openai_api_base"]
)
log.key_value(
"API Key", config_payload["credentials"]["openai_api_key"]
)
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
elif response.status_code == 401:
log.error("Configuration failed: Unauthorized")
log.info("Token may have expired. Please run login_admin.py again")
else:
log.error(
f"Configuration failed with status code: {response.status_code}"
)
log.error(f"Configuration failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def create_api_key() -> None:
@@ -90,9 +90,7 @@ def create_api_key() -> None:
}
if config_helper.write_config("api_key_config", api_key_config):
log.info(
f"API key saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}")
else:
log.error("No API token received")
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
@@ -101,9 +99,7 @@ def create_api_key() -> None:
log.error("API key creation failed: Unauthorized")
log.info("Token may have expired. Please run login_admin.py again")
else:
log.error(
f"API key creation failed with status code: {response.status_code}"
)
log.error(f"API key creation failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,9 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper, Logger
import httpx
from common import Logger, config_helper
def import_workflow_app() -> None:
@@ -30,7 +31,7 @@ def import_workflow_app() -> None:
log.error(f"DSL file not found: {dsl_path}")
return
with open(dsl_path, "r") as f:
with open(dsl_path) as f:
yaml_content = f.read()
log.step("Importing workflow app from DSL...")
@@ -86,9 +87,7 @@ def import_workflow_app() -> None:
log.success("Workflow app imported successfully!")
log.key_value("App ID", app_id)
log.key_value("App Mode", response_data.get("app_mode"))
log.key_value(
"DSL Version", response_data.get("imported_dsl_version")
)
log.key_value("DSL Version", response_data.get("imported_dsl_version"))
# Save app_id to config
app_config = {
@@ -99,9 +98,7 @@ def import_workflow_app() -> None:
}
if config_helper.write_config("app_config", app_config):
log.info(
f"App config saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}")
else:
log.error("Import completed but no app_id received")
log.debug(f"Response: {json.dumps(response_data, indent=2)}")

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import time
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def install_openai_plugin() -> None:
@@ -28,9 +28,7 @@ def install_openai_plugin() -> None:
# API endpoint for plugin installation
base_url = "http://localhost:5001"
install_endpoint = (
f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
)
install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
# Plugin identifier
plugin_payload = {
@@ -83,9 +81,7 @@ def install_openai_plugin() -> None:
log.info("Polling for task completion...")
# Poll for task completion
task_endpoint = (
f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
)
task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max
attempt = 0
@@ -131,9 +127,7 @@ def install_openai_plugin() -> None:
plugins = task_info.get("plugins", [])
if plugins:
for plugin in plugins:
log.list_item(
f"{plugin.get('plugin_id')}: {plugin.get('message')}"
)
log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}")
break
# Continue polling if status is "pending" or other
@@ -149,9 +143,7 @@ def install_openai_plugin() -> None:
log.warning("Plugin may already be installed")
log.debug(f"Response: {response.text}")
else:
log.error(
f"Installation failed with status code: {response.status_code}"
)
log.error(f"Installation failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def login_admin() -> None:
@@ -77,16 +77,10 @@ def login_admin() -> None:
# Save token config
if config_helper.write_config("token_config", token_config):
log.info(
f"Token saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}")
# Show truncated token for verification
token_display = (
f"{access_token[:20]}..."
if len(access_token) > 20
else "Token saved"
)
token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved"
log.key_value("Access token", token_display)
elif response.status_code == 401:

View File

@@ -3,8 +3,10 @@
import json
import time
import uuid
from typing import Any, Iterator
from flask import Flask, request, jsonify, Response
from collections.abc import Iterator
from typing import Any
from flask import Flask, Response, jsonify, request
app = Flask(__name__)

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def publish_workflow() -> None:
@@ -79,9 +79,7 @@ def publish_workflow() -> None:
try:
response_data = response.json()
if response_data:
log.debug(
f"Response: {json.dumps(response_data, indent=2)}"
)
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
except json.JSONDecodeError:
# Response might be empty or non-JSON
pass
@@ -93,9 +91,7 @@ def publish_workflow() -> None:
log.error("Workflow publish failed: App not found")
log.info("Make sure the app was imported successfully")
else:
log.error(
f"Workflow publish failed with status code: {response.status_code}"
)
log.error(f"Workflow publish failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,9 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper, Logger
import httpx
from common import Logger, config_helper
def run_workflow(question: str = "fake question", streaming: bool = True) -> None:
@@ -70,9 +71,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
event = data.get("event")
if event == "workflow_started":
log.progress(
f"Workflow started: {data.get('data', {}).get('id')}"
)
log.progress(f"Workflow started: {data.get('data', {}).get('id')}")
elif event == "node_started":
node_data = data.get("data", {})
log.progress(
@@ -116,9 +115,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
# Some lines might not be JSON
pass
else:
log.error(
f"Workflow run failed with status code: {response.status_code}"
)
log.error(f"Workflow run failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
else:
# Handle blocking response
@@ -142,9 +139,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
log.info("📤 Final Answer:")
log.info(outputs.get("answer"), indent=2)
else:
log.error(
f"Workflow run failed with status code: {response.status_code}"
)
log.error(f"Workflow run failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -6,7 +6,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
from common import config_helper, Logger
from common import Logger, config_helper
def setup_admin_account() -> None:
@@ -24,9 +24,7 @@ def setup_admin_account() -> None:
# Save credentials to config file
if config_helper.write_config("admin_config", admin_config):
log.info(
f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}")
# API setup endpoint
base_url = "http://localhost:5001"
@@ -56,9 +54,7 @@ def setup_admin_account() -> None:
log.key_value("Username", admin_config["username"])
elif response.status_code == 400:
log.warning(
"Setup may have already been completed or invalid data provided"
)
log.warning("Setup may have already been completed or invalid data provided")
log.debug(f"Response: {response.text}")
else:
log.error(f"Setup failed with status code: {response.status_code}")

View File

@@ -1,9 +1,9 @@
#!/usr/bin/env python3
import socket
import subprocess
import sys
import time
import socket
from pathlib import Path
from common import Logger, ProgressLogger
@@ -93,9 +93,7 @@ def main() -> None:
if retry.lower() in ["yes", "y"]:
return main() # Recursively call main to check again
else:
print(
"❌ Setup cancelled. Please start the required services and try again."
)
print("❌ Setup cancelled. Please start the required services and try again.")
sys.exit(1)
log.success("All required services are running!")

View File

@@ -7,29 +7,28 @@ measuring key metrics like connection rate, event throughput, and time to first
"""
import json
import time
import logging
import os
import random
import statistics
import sys
import threading
import os
import logging
import statistics
from pathlib import Path
import time
from collections import deque
from dataclasses import asdict, dataclass
from datetime import datetime
from dataclasses import dataclass, asdict
from locust import HttpUser, task, between, events, constant
from typing import TypedDict, Literal, TypeAlias
from pathlib import Path
from typing import Literal, TypeAlias, TypedDict
import requests.exceptions
from locust import HttpUser, between, constant, events, task
# Add the stress-test directory to path to import common modules
sys.path.insert(0, str(Path(__file__).parent))
from common.config_helper import ConfigHelper # type: ignore[import-not-found]
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Configuration from environment
@@ -54,6 +53,7 @@ ErrorType: TypeAlias = Literal[
class ErrorCounts(TypedDict):
"""Error count tracking"""
connection_error: int
timeout: int
invalid_json: int
@@ -65,6 +65,7 @@ class ErrorCounts(TypedDict):
class SSEEvent(TypedDict):
"""Server-Sent Event structure"""
data: str
event: str
id: str | None
@@ -72,11 +73,13 @@ class SSEEvent(TypedDict):
class WorkflowInputs(TypedDict):
"""Workflow input structure"""
question: str
class WorkflowRequestData(TypedDict):
"""Workflow request payload"""
inputs: WorkflowInputs
response_mode: Literal["streaming"]
user: str
@@ -84,6 +87,7 @@ class WorkflowRequestData(TypedDict):
class ParsedEventData(TypedDict, total=False):
"""Parsed event data from SSE stream"""
event: str
task_id: str
workflow_run_id: str
@@ -93,6 +97,7 @@ class ParsedEventData(TypedDict, total=False):
class LocustStats(TypedDict):
"""Locust statistics structure"""
total_requests: int
total_failures: int
avg_response_time: float
@@ -102,6 +107,7 @@ class LocustStats(TypedDict):
class ReportData(TypedDict):
"""JSON report structure"""
timestamp: str
duration_seconds: float
metrics: dict[str, object] # Metrics as dict for JSON serialization
@@ -154,7 +160,7 @@ class MetricsTracker:
self.total_connections = 0
self.total_events = 0
self.start_time = time.time()
# Enhanced metrics with memory limits
self.max_samples = 10000 # Prevent unbounded growth
self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples)
@@ -233,9 +239,7 @@ class MetricsTracker:
max_ttfe = max(self.ttfe_samples)
p50_ttfe = statistics.median(self.ttfe_samples)
if len(self.ttfe_samples) >= 2:
quantiles = statistics.quantiles(
self.ttfe_samples, n=20, method="inclusive"
)
quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive")
p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile
else:
p95_ttfe = max_ttfe
@@ -255,9 +259,7 @@ class MetricsTracker:
if durations
else 0
)
events_per_stream_avg = (
statistics.mean(events_per_stream) if events_per_stream else 0
)
events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0
# Calculate inter-event latency statistics
all_inter_event_times = []
@@ -268,32 +270,20 @@ class MetricsTracker:
inter_event_latency_avg = statistics.mean(all_inter_event_times)
inter_event_latency_p50 = statistics.median(all_inter_event_times)
inter_event_latency_p95 = (
statistics.quantiles(
all_inter_event_times, n=20, method="inclusive"
)[18]
statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18]
if len(all_inter_event_times) >= 2
else max(all_inter_event_times)
)
else:
inter_event_latency_avg = inter_event_latency_p50 = (
inter_event_latency_p95
) = 0
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
else:
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = (
events_per_stream_avg
) = 0
inter_event_latency_avg = inter_event_latency_p50 = (
inter_event_latency_p95
) = 0
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
# Also calculate overall average rates
total_elapsed = current_time - self.start_time
overall_conn_rate = (
self.total_connections / total_elapsed if total_elapsed > 0 else 0
)
overall_event_rate = (
self.total_events / total_elapsed if total_elapsed > 0 else 0
)
overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0
overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0
return MetricsSnapshot(
active_connections=self.active_connections,
@@ -389,7 +379,7 @@ class DifyWorkflowUser(HttpUser):
# Load questions from file or use defaults
if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE):
with open(QUESTIONS_FILE, "r") as f:
with open(QUESTIONS_FILE) as f:
self.questions = [line.strip() for line in f if line.strip()]
else:
self.questions = [
@@ -451,18 +441,13 @@ class DifyWorkflowUser(HttpUser):
try:
# Validate response
if response.status_code >= 400:
error_type: ErrorType = (
"http_4xx" if response.status_code < 500 else "http_5xx"
)
error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx"
metrics.record_error(error_type)
response.failure(f"HTTP {response.status_code}")
return
content_type = response.headers.get("Content-Type", "")
if (
"text/event-stream" not in content_type
and "application/json" not in content_type
):
if "text/event-stream" not in content_type and "application/json" not in content_type:
logger.error(f"Expected text/event-stream, got: {content_type}")
metrics.record_error("invalid_response")
response.failure(f"Invalid content type: {content_type}")
@@ -473,10 +458,13 @@ class DifyWorkflowUser(HttpUser):
for line in response.iter_lines(decode_unicode=True):
# Check if runner is stopping
if getattr(self.environment.runner, 'state', '') in ('stopping', 'stopped'):
if getattr(self.environment.runner, "state", "") in (
"stopping",
"stopped",
):
logger.debug("Runner stopping, breaking streaming loop")
break
if line is not None:
bytes_received += len(line.encode("utf-8"))
@@ -489,9 +477,7 @@ class DifyWorkflowUser(HttpUser):
# Track inter-event timing
if last_event_time:
inter_event_times.append(
(current_time - last_event_time) * 1000
)
inter_event_times.append((current_time - last_event_time) * 1000)
last_event_time = current_time
if first_event_time is None:
@@ -512,15 +498,11 @@ class DifyWorkflowUser(HttpUser):
parsed_event: ParsedEventData = json.loads(event_data)
# Check for terminal events
if parsed_event.get("event") in TERMINAL_EVENTS:
logger.debug(
f"Received terminal event: {parsed_event.get('event')}"
)
logger.debug(f"Received terminal event: {parsed_event.get('event')}")
request_success = True
break
except json.JSONDecodeError as e:
logger.debug(
f"JSON decode error: {e} for data: {event_data[:100]}"
)
logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}")
metrics.record_error("invalid_json")
except Exception as e:
@@ -583,16 +565,18 @@ def on_test_start(environment: object, **kwargs: object) -> None:
# Periodic stats reporting
def report_stats() -> None:
if not hasattr(environment, 'runner'):
if not hasattr(environment, "runner"):
return
runner = environment.runner
while hasattr(runner, 'state') and runner.state not in ["stopped", "stopping"]:
while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]:
time.sleep(5) # Report every 5 seconds
if hasattr(runner, 'state') and runner.state == "running":
if hasattr(runner, "state") and runner.state == "running":
stats = metrics.get_stats()
# Only log on master node in distributed mode
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, 'runner') else True
is_master = (
not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
)
if is_master:
# Clear previous lines and show updated stats
logger.info("\n" + "=" * 80)
@@ -623,15 +607,15 @@ def on_test_start(environment: object, **kwargs: object) -> None:
logger.info(
f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}"
)
logger.info(f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)")
logger.info(
f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)"
)
logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}")
# Inter-event latency
if stats.inter_event_latency_avg > 0:
logger.info("-" * 80)
logger.info(
f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}"
)
logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}")
logger.info(
f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}"
)
@@ -647,9 +631,9 @@ def on_test_start(environment: object, **kwargs: object) -> None:
logger.info("=" * 80)
# Show Locust stats summary
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
total = environment.stats.total
if hasattr(total, 'num_requests') and total.num_requests > 0:
if hasattr(total, "num_requests") and total.num_requests > 0:
logger.info(
f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}"
)
@@ -687,21 +671,15 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info("")
logger.info("EVENTS")
logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}")
logger.info(
f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s"
)
logger.info(
f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s"
)
logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s")
logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s")
logger.info("")
logger.info("STREAM METRICS")
logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms")
logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms")
logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms")
logger.info(
f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}"
)
logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}")
logger.info("")
logger.info("INTER-EVENT LATENCY")
@@ -716,7 +694,9 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms")
logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms")
logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms")
logger.info(f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})")
logger.info(
f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})"
)
logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}")
# Error summary
@@ -730,7 +710,7 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info("=" * 80 + "\n")
# Export machine-readable report (only on master node)
is_master = not getattr(environment.runner, 'worker_id', None) if hasattr(environment, 'runner') else True
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
if is_master:
export_json_report(stats, test_duration, environment)
@@ -746,9 +726,9 @@ def export_json_report(stats: MetricsSnapshot, duration: float, environment: obj
# Access environment.stats.total attributes safely
locust_stats: LocustStats | None = None
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
total = environment.stats.total
if hasattr(total, 'num_requests') and total.num_requests > 0:
if hasattr(total, "num_requests") and total.num_requests > 0:
locust_stats = LocustStats(
total_requests=total.num_requests,
total_failures=total.num_failures,

View File

@@ -1,7 +1,15 @@
from dify_client.client import (
ChatClient,
CompletionClient,
WorkflowClient,
KnowledgeBaseClient,
DifyClient,
KnowledgeBaseClient,
WorkflowClient,
)
__all__ = [
"ChatClient",
"CompletionClient",
"DifyClient",
"KnowledgeBaseClient",
"WorkflowClient",
]

View File

@@ -8,16 +8,16 @@ class DifyClient:
self.api_key = api_key
self.base_url = base_url
def _send_request(self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False):
def _send_request(
self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False
):
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
url = f"{self.base_url}{endpoint}"
response = requests.request(
method, url, json=json, params=params, headers=headers, stream=stream
)
response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream)
return response
@@ -25,9 +25,7 @@ class DifyClient:
headers = {"Authorization": f"Bearer {self.api_key}"}
url = f"{self.base_url}{endpoint}"
response = requests.request(
method, url, data=data, headers=headers, files=files
)
response = requests.request(method, url, data=data, headers=headers, files=files)
return response
@@ -41,9 +39,7 @@ class DifyClient:
def file_upload(self, user: str, files: dict):
data = {"user": user}
return self._send_request_with_files(
"POST", "/files/upload", data=data, files=files
)
return self._send_request_with_files("POST", "/files/upload", data=data, files=files)
def text_to_audio(self, text: str, user: str, streaming: bool = False):
data = {"text": text, "user": user, "streaming": streaming}
@@ -55,7 +51,9 @@ class DifyClient:
class CompletionClient(DifyClient):
def create_completion_message(self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None):
def create_completion_message(
self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None
):
data = {
"inputs": inputs,
"response_mode": response_mode,
@@ -99,9 +97,7 @@ class ChatClient(DifyClient):
def get_suggested(self, message_id: str, user: str):
params = {"user": user}
return self._send_request(
"GET", f"/messages/{message_id}/suggested", params=params
)
return self._send_request("GET", f"/messages/{message_id}/suggested", params=params)
def stop_message(self, task_id: str, user: str):
data = {"user": user}
@@ -112,10 +108,9 @@ class ChatClient(DifyClient):
user: str,
last_id: str | None = None,
limit: int | None = None,
pinned: bool | None = None
pinned: bool | None = None,
):
params = {"user": user, "last_id": last_id,
"limit": limit, "pinned": pinned}
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
return self._send_request("GET", "/conversations", params=params)
def get_conversation_messages(
@@ -123,7 +118,7 @@ class ChatClient(DifyClient):
user: str,
conversation_id: str | None = None,
first_id: str | None = None,
limit: int | None = None
limit: int | None = None,
):
params = {"user": user}
@@ -136,13 +131,9 @@ class ChatClient(DifyClient):
return self._send_request("GET", "/messages", params=params)
def rename_conversation(
self, conversation_id: str, name: str, auto_generate: bool, user: str
):
def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str):
data = {"name": name, "auto_generate": auto_generate, "user": user}
return self._send_request(
"POST", f"/conversations/{conversation_id}/name", data
)
return self._send_request("POST", f"/conversations/{conversation_id}/name", data)
def delete_conversation(self, conversation_id: str, user: str):
data = {"user": user}
@@ -155,9 +146,7 @@ class ChatClient(DifyClient):
class WorkflowClient(DifyClient):
def run(
self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"
):
def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"):
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return self._send_request("POST", "/workflows/run", data)
@@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient):
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
return self._send_request(
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
)
return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs)
def create_document_by_text(
self, name, text, extra_params: dict | None = None, **kwargs
):
def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs):
"""
Create a document by text.
@@ -272,9 +257,7 @@ class KnowledgeBaseClient(DifyClient):
data = {"name": name, "text": text}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = (
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
return self._send_request("POST", url, json=data, **kwargs)
def create_document_by_file(
@@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient):
if original_document_id is not None:
data["original_document_id"] = original_document_id
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return self._send_request_with_files(
"POST", url, {"data": json.dumps(data)}, files
)
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
def update_document_by_file(
self, document_id: str, file_path: str, extra_params: dict | None = None
):
def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
"""
Update a document by file.
@@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient):
data = {}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = (
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
)
return self._send_request_with_files(
"POST", url, {"data": json.dumps(data)}, files
)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
def batch_indexing_status(self, batch_id: str, **kwargs):
"""

View File

@@ -1,6 +1,6 @@
from setuptools import setup
with open("README.md", "r", encoding="utf-8") as fh:
with open("README.md", encoding="utf-8") as fh:
long_description = fh.read()
setup(

View File

@@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__)
class TestKnowledgeBaseClient(unittest.TestCase):
def setUp(self):
self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL)
self.README_FILE_PATH = os.path.abspath(
os.path.join(FILE_PATH_BASE, "../README.md")
)
self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
self.dataset_id = None
self.document_id = None
self.segment_id = None
@@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _get_dataset_kb_client(self):
self.assertIsNotNone(self.dataset_id)
return KnowledgeBaseClient(
API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id
)
return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id)
def test_001_create_dataset(self):
response = self.knowledge_base_client.create_dataset(name="test_dataset")
@@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_004_update_document_by_text(self):
client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id)
response = client.update_document_by_text(
self.document_id, "test_document_updated", "test_text_updated"
)
response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
data = response.json()
self.assertIn("document", data)
self.assertIn("batch", data)
@@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_006_update_document_by_file(self):
client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id)
response = client.update_document_by_file(
self.document_id, self.README_FILE_PATH
)
response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
data = response.json()
self.assertIn("document", data)
self.assertIn("batch", data)
@@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_010_add_segments(self):
client = self._get_dataset_kb_client()
response = client.add_segments(
self.document_id, [{"content": "test text segment 1"}]
)
response = client.add_segments(self.document_id, [{"content": "test text segment 1"}])
data = response.json()
self.assertIn("data", data)
self.assertGreater(len(data["data"]), 0)
@@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase):
self.chat_client = ChatClient(API_KEY)
def test_create_chat_message(self):
response = self.chat_client.create_chat_message(
{}, "Hello, World!", "test_user"
)
response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user")
self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_remote_url(self):
files = [
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
]
response = self.chat_client.create_chat_message(
{}, "Describe the picture.", "test_user", files=files
)
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_local_file(self):
@@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase):
"upload_file_id": "your_file_id",
}
]
response = self.chat_client.create_chat_message(
{}, "Describe the picture.", "test_user", files=files
)
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text)
def test_get_conversation_messages(self):
response = self.chat_client.get_conversation_messages(
"test_user", "your_conversation_id"
)
response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id")
self.assertIn("answer", response.text)
def test_get_conversations(self):
@@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase):
self.assertIn("answer", response.text)
def test_create_completion_message_with_vision_model_by_remote_url(self):
files = [
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
]
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
response = self.completion_client.create_completion_message(
{"query": "Describe the picture."}, "blocking", "test_user", files
)
@@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase):
self.dify_client = DifyClient(API_KEY)
def test_message_feedback(self):
response = self.dify_client.message_feedback(
"your_message_id", "like", "test_user"
)
response = self.dify_client.message_feedback("your_message_id", "like", "test_user")
self.assertIn("success", response.text)
def test_get_application_parameters(self):