Compare commits

...

2 Commits

Author SHA1 Message Date
takatost
4d995ab3ac fix: model mode 2023-11-24 19:42:50 +08:00
takatost
254e4df78e feat: ChatGLM use OpenAI API schema instead of original API schema 2023-11-24 19:38:51 +08:00
2 changed files with 65 additions and 23 deletions

View File

@@ -1,27 +1,42 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.llms import ChatGLM
from langchain.schema import LLMResult
from langchain.schema import LLMResult, get_buffer_string
from core.model_providers.error import LLMBadRequestError
from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
class ChatGLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatGLM(
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
endpoint_url=self.credentials.get('api_base'),
**provider_model_kwargs
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['api_base'] + '/v1'
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
@@ -45,19 +60,40 @@ class ChatGLMModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to ChatGLM API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to ChatGLM API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("ChatGLM service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -25,21 +25,26 @@ class ChatGLMProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
'id': 'chatglm3-6b',
'name': 'ChatGLM3-6B',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
'id': 'chatglm3-6b-32k',
'name': 'ChatGLM3-6B-32K',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
@@ -64,8 +69,9 @@ class ChatGLMProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'chatglm-6b': 2000,
'chatglm2-6b': 32000,
'chatglm3-6b-32k': 32000,
'chatglm3-6b': 8000,
'chatglm2-6b': 8000,
}
return ModelKwargsRules(