Refactor: move mcp connection utilities to common (#11304)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-17 15:34:17 +08:00
committed by GitHub
parent 0569b50fed
commit bd4bc57009
10 changed files with 27 additions and 27 deletions

View File

@@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM

View File

@@ -21,9 +21,8 @@ from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
from common.misc_utils import hash_str2int
from rag.llm.chat_model import ToolCallSession
from rag.prompts.generator import kb_prompt
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer

View File

@@ -25,7 +25,7 @@ from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821

View File

@@ -41,12 +41,12 @@ def list_agents(tenant_id):
return get_error_data_result("The agent doesn't exist.")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
orderby = request.args.get("orderby", "update_time")
order_by = request.args.get("orderby", "update_time")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
return get_result(data=canvas)

View File

@@ -41,7 +41,7 @@ from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data
from common.versions import get_ragflow_version
from common.config_utils import show_configs
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event()

View File

@@ -37,7 +37,7 @@ from peewee import OperationalError
from common.constants import ActiveEnum
from api.db.db_models import APIToken
from api.utils.json_encode import CustomJSONEncoder
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout
from common.constants import RetCode

View File

@@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC):
class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)"""
"""Checkpoint connector interface (with permission sync)"""
@abstractmethod
def load_from_checkpoint(
@@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
"""If dynamic, the credentials may change during usage ... meaning the client
needs to use the locking features of the credentials provider to operate
correctly.

View File

@@ -21,7 +21,7 @@ import weakref
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from string import Template
from typing import Any, Literal
from typing import Any, Literal, Protocol
from typing_extensions import override
@@ -30,12 +30,15 @@ from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
from rag.llm.chat_model import ToolCallSession
MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class MCPToolCallSession(ToolCallSession):
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
@@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession):
await self._process_mcp_tasks(None, msg)
else:
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
await self._process_mcp_tasks(None,
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
while not self._close:
@@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession):
raise
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
timeout=timeout)
if result.isError:
return f"MCP server error: {result.content}"
@@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) ->
except Exception:
logging.exception("Exception during MCP session cleanup thread management")
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
logging.info(
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
def shutdown_all_mcp_sessions():
@@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions():
logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
if isinstance(mcp_tool, dict):
return {
"type": "function",

View File

@@ -22,7 +22,6 @@ import re
import time
from abc import ABC
from copy import deepcopy
from typing import Any, Protocol
from urllib.parse import urljoin
import json_repair
@@ -65,10 +64,6 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))

View File

@@ -155,13 +155,13 @@ def qbullets_category(sections):
if re.match(pro, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
maximum = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
if h <= maximum:
continue
res = i
maxium = h
maximum = h
return res, QUESTION_PATTERN[res]
@@ -222,13 +222,13 @@ def bullets_category(sections):
if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
maximum = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
if h <= maximum:
continue
res = i
maxium = h
maximum = h
return res