Compare commits

...

6 Commits

Author SHA1 Message Date
Yeuoly
9156934a05 fix: xinference mock 2024-01-28 08:44:47 +08:00
Yeuoly
ea866b37f0 support blocking function call 2024-01-27 01:41:01 +08:00
Yeuoly
c476836889 Merge branch 'fix/xinference-max-chunks' into feat/blocking-function-call 2024-01-27 00:51:40 +08:00
Yeuoly
2cda79699c feat: min temp 2024-01-23 18:29:48 +08:00
Yeuoly
f02d34cccb feat: xinference supports tool call and fill in max tokens 2024-01-21 13:51:50 +08:00
Yeuoly
56d2bdf73a fix: add max chunks to xinference 2024-01-21 12:47:22 +08:00
22 changed files with 255 additions and 69 deletions

View File

@@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
@@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
memory=memory,
)
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner(
@@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_cot_runner.run(
model_instance=model_instance,
conversation=conversation,
message=message,
query=query,
@@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_fc_runner.run(
model_instance=model_instance,
conversation=conversation,
message=message,
query=query,

View File

@@ -1,7 +1,7 @@
import logging
import json
from typing import Optional, List, Tuple, Union
from typing import Optional, List, Tuple, Union, cast
from datetime import datetime
from mimetypes import guess_extension
@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_manager import ModelInstance
from core.file.message_file_parser import FileTransferMethod
logger = logging.getLogger(__name__)
@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
prompt_messages: Optional[List[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None:
"""
Agent runner
@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.history_prompt_messages = prompt_messages
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback
self.agent_callback = DifyAgentCallbackHandler()
@@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
MessageAgentThought.message_id == self.message.id,
).count()
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
"""
Repacket app orchestration config

View File

@@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
def run(self, conversation: Conversation,
message: Message,
query: str,
) -> Union[Generator, LLMResult]:
@@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
@@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input:
if action_name and action_input is not None:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,

View File

@@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
from core.model_manager import ModelInstance
from core.application_queue_manager import PublishFrom
@@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
def run(self, conversation: Conversation,
message: Message,
query: str,
) -> Generator[LLMResultChunk, None, None]:
@@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False
@@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop,
stream=True,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
@@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
current_llm_usage = None
for chunk in chunks:
if self.stream_tool_call:
for chunk in chunks:
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result: LLMResult = chunks
# check if there is any tool call
if self.check_tool_calls(chunk):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
@@ -138,18 +169,30 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
if result.usage:
increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data
else:
response += chunk.delta.message.content
response += result.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
if not result.message.content:
result.message.content = ''
yield chunk
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
)
)
# save thought
self.save_agent_thought(
@@ -287,6 +330,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
@@ -304,6 +355,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
))
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
json.loads(prompt_message.function.arguments),
))
return tool_calls
def organize_prompt_messages(self, prompt_template: str,
query: str = None,

View File

@@ -78,6 +78,7 @@ class ModelFeature(Enum):
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
class DefaultParameterName(Enum):

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 4096

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16385

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16385

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16385

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 4096

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000

View File

@@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 8192

View File

@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
ParameterRule, ParameterType)
ParameterRule, ParameterType, ModelFeature)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper,
from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
XinferenceModelExtraParameter)
from core.model_runtime.utils import helper
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
"""
if 'temperature' in model_parameters:
if model_parameters['temperature'] < 0.01:
model_parameters['temperature'] = 0.01
elif model_parameters['temperature'] > 1.0:
model_parameters['temperature'] = 0.99
return self._generate(
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
tools=tools, stop=stop, stream=stream, user=user,
@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
if extra_param.support_function_call:
credentials['support_function_call'] = True
except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
)
),
),
ParameterRule(
name='top_p',
@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
support_function_call = credentials.get('support_function_call', False)
entity = AIModelEntity(
model=model,
@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
model_properties={
ModelPropertyKey.MODE: completion_type,
},
@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
client = OpenAI(
base_url=f'{credentials["server_url"]}/v1',
api_key='abc',

View File

@@ -2,7 +2,7 @@ import time
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
@@ -102,8 +103,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError:
except (InvokeAuthorizationError, RuntimeError):
raise CredentialsValidateFailedError('Invalid api key')
@property
@@ -160,6 +168,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
@@ -167,7 +176,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={},
model_properties={
ModelPropertyKey.MAX_CHUNKS: 1,
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
},
parameter_rules=[]
)

View File

@@ -1,6 +1,7 @@
from threading import Lock
from time import time
from typing import List
from os import path
from requests import get
from requests.adapters import HTTPAdapter
@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
model_format: str
model_handle_type: str
model_ability: List[str]
max_tokens: int = 512
support_function_call: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None:
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
support_function_call: bool, max_tokens: int) -> None:
self.model_format = model_format
self.model_handle_type = model_handle_type
self.model_ability = model_ability
self.support_function_call = support_function_call
self.max_tokens = max_tokens
cache = {}
cache_lock = Lock()
@@ -49,7 +55,7 @@ class XinferenceHelper:
get xinference model extra parameter like model_format and model_handle_type
"""
url = f'{server_url}/v1/models/{model_uid}'
url = path.join(server_url, 'v1/models', model_uid)
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
@@ -66,10 +72,12 @@ class XinferenceHelper:
response_json = response.json()
model_format = response_json['model_format']
model_ability = response_json['model_ability']
model_format = response_json.get('model_format', 'ggmlv3')
model_ability = response_json.get('model_ability', [])
if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
if response_json.get('model_type') == 'embedding':
model_handle_type = 'embedding'
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
model_handle_type = 'chatglm'
elif 'generate' in model_ability:
model_handle_type = 'generate'
@@ -78,8 +86,13 @@ class XinferenceHelper:
else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability
max_tokens = response_json.get('max_tokens', 512)
return XinferenceModelExtraParameter(
model_format=model_format,
model_handle_type=model_handle_type,
model_ability=model_ability
model_ability=model_ability,
support_function_call=support_function_call,
max_tokens=max_tokens
)

View File

@@ -2,6 +2,10 @@ model: glm-3-turbo
label:
en_US: glm-3-turbo
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
parameter_rules:

View File

@@ -2,6 +2,10 @@ model: glm-4
label:
en_US: glm-4
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
parameter_rules:

View File

@@ -48,7 +48,7 @@ dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4
transformers~=4.31.0
pandas==1.5.3
xinference-client~=0.6.4
xinference-client~=0.8.1
safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.8

View File

@@ -19,58 +19,86 @@ class MockXinferenceClass(object):
raise RuntimeError('404 Not Found')
if 'generate' == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url)
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url)
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url)
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url)
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found')
def get(self: Session, url: str, **kwargs):
if '/v1/models/' in url:
response = Response()
response = Response()
if 'v1/models/' in url:
# get model uid
model_uid = url.split('/')[-1]
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404
raise ConnectionError('404 Not Found')
return response
# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404
raise ConnectionError('404 Not Found')
return response
if model_uid in ['generate', 'chat']:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response
elif model_uid == 'embedding':
response.status_code = 200
response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response
elif 'v1/cluster/auth' in url:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
"auth": true
}'''
return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)