Compare commits

...

25 Commits

Author SHA1 Message Date
Garfield Dai
2fccbef77d update templates. 2023-10-11 22:06:47 +08:00
takatost
14b61a5d5b Merge branch 'main' into feat/advanced-prompt-backend
# Conflicts:
#	api/core/model_providers/providers/tongyi_provider.py
2023-10-11 21:06:36 +08:00
Garfield Dai
400b47b0cd Merge remote-tracking branch 'origin/feat/advanced-prompt-backend' into feat/advanced-prompt-backend 2023-10-11 19:44:59 +08:00
Garfield Dai
b634812408 add advanced prompt. 2023-10-11 19:44:33 +08:00
takatost
10ea116113 fix: baichuan model mode missing 2023-10-11 18:36:13 +08:00
takatost
fc8f14737f Merge remote-tracking branch 'origin/main' into feat/advanced-prompt-backend 2023-10-11 17:57:10 +08:00
takatost
17289406d2 feat: optimize template parser 2023-10-11 17:18:11 +08:00
takatost
daa89a726f feat: replace prompt template parse implement 2023-10-11 16:45:25 +08:00
takatost
2480ab4027 fix: use None in llm call when stop words is empty 2023-10-11 14:14:17 +08:00
Garfield Dai
04034fa3d9 add validation. 2023-10-10 21:14:06 +08:00
Garfield Dai
94a0e0bacb update. 2023-10-10 20:37:39 +08:00
takatost
f9f4f1e9c2 feat: add stop validate rule 2023-10-10 20:31:03 +08:00
takatost
88cd59b337 feat: use custom stop words in llm call 2023-10-10 20:17:18 +08:00
Garfield Dai
619834744e Merge remote-tracking branch 'origin/feat/advanced-prompt-backend' into feat/advanced-prompt-backend 2023-10-10 19:54:09 +08:00
Garfield Dai
a2a7f9451a add validation 2023-10-10 19:52:13 +08:00
takatost
d1e4f36aa3 fix: tests 2023-10-10 17:51:03 +08:00
takatost
1d8b0008a1 feat: remove chinese check 2023-10-10 17:45:42 +08:00
Garfield Dai
56185d7144 add baichuan. 2023-10-10 17:44:10 +08:00
Garfield Dai
2c599ae672 Merge remote-tracking branch 'origin/main' into feat/advanced-prompt-backend 2023-10-10 17:29:58 +08:00
Garfield Dai
d793259e54 feat: add advanced prompt templates. 2023-10-10 17:13:33 +08:00
takatost
d4c64169f2 feat: add dataset configs 2023-10-10 16:43:07 +08:00
Garfield Dai
a3b71227ae feat: add advanced prompt templates. 2023-10-10 15:38:11 +08:00
Garfield Dai
d92b9f6508 feat: add advanced prompt templates. 2023-10-10 14:39:36 +08:00
Garfield Dai
45553f3d96 feat: add advanced prompt templates. 2023-10-10 14:14:01 +08:00
takatost
898062c38d feat: add model mode in text generation model list api 2023-10-10 12:08:53 +08:00
53 changed files with 686 additions and 288 deletions

View File

@@ -31,6 +31,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -81,6 +82,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,

View File

@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import setup, version, apikey, admin
# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate

View File

@@ -0,0 +1,24 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')

View File

@@ -17,8 +17,8 @@ from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
@@ -30,7 +30,7 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
query = PromptBuilder.process_template(query)
query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
@@ -160,14 +160,28 @@ class Completion:
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
else:
prompt_messages = model_instance.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
@@ -176,7 +190,7 @@ class Completion:
response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
@@ -290,7 +304,10 @@ class Completion:
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
prompt = prompt.format({
"prompt": PromptTemplateParser.remove_template_variables(pre_prompt),
"original_completion": original_completion
})
prompt_messages = [PromptMessage(content=prompt)]

View File

@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -74,10 +74,10 @@ class ConversationMessageTask:
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(**prompt_inputs)
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
@@ -150,12 +150,12 @@ class ConversationMessageTask:
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
@@ -163,7 +163,7 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
@@ -226,15 +226,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price

View File

@@ -10,7 +10,7 @@ from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
@@ -56,7 +56,7 @@ class LLMGenerator:
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_with_empty_context = prompt.format({"context": ''})
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
@@ -84,7 +84,7 @@ class LLMGenerator:
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
prompt = prompt.format({"context": context})
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
@@ -93,7 +93,7 @@ class LLMGenerator:
@classmethod
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
prompt = prompt.format({"prompt": pre_prompt})
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
@@ -109,13 +109,14 @@ class LLMGenerator:
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt = JinjaPromptTemplate(
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
prompt_template = PromptTemplateParser(
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
)
_input = prompt.format_prompt(histories=histories)
prompt = prompt_template.format({
"histories": histories,
"format_instructions": format_instructions
})
try:
model_instance = ModelFactory.get_text_generation_model(
@@ -128,10 +129,10 @@ class LLMGenerator:
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
questions = output_parser.parse(output.content)
except LLMError:
questions = []
@@ -145,19 +146,21 @@ class LLMGenerator:
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
output_parser = RuleConfigGeneratorOutputParser()
prompt = OutLinePromptTemplate(
template=output_parser.get_format_instructions(),
input_variables=["audiences", "hoping_to_solve"],
partial_variables={
"variable": '{variable}',
"lanA": '{lanA}',
"lanB": '{lanB}',
"topic": '{topic}'
},
validate_template=False
prompt_template = PromptTemplateParser(
template=output_parser.get_format_instructions()
)
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
prompt = prompt_template.format(
inputs={
"audiences": audiences,
"hoping_to_solve": hoping_to_solve,
"variable": "{{variable}}",
"lanA": "{{lanA}}",
"lanB": "{{lanB}}",
"topic": "{{topic}}"
},
remove_template_variables=False
)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
@@ -167,10 +170,10 @@ class LLMGenerator:
)
)
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
rule_config = output_parser.parse(output.content)
except LLMError as e:
raise e

View File

@@ -282,7 +282,7 @@ class IndexingRunner:
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -378,7 +378,7 @@ class IndexingRunner:
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts

View File

@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:

View File

@@ -12,20 +12,20 @@ class LLMRunResult(BaseModel):
class MessageType(enum.Enum):
HUMAN = 'human'
USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
type: MessageType = MessageType.USER
content: str = ''
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
@@ -39,7 +39,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
elif isinstance(message, SystemMessage):

View File

@@ -17,7 +17,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM
import logging
@@ -227,7 +227,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return:
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -245,7 +245,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -260,7 +260,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
@@ -324,6 +324,73 @@ class BaseLLM(BaseProviderModel):
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {
"user_prefix": "user",
"assistant_prefix": "assistant"
}
prompt_list = []
prompt_messages = []
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
for prompt_item in prompt_list:
prompt = prompt_item['text']
# todo: key word validation
prompt_template = PromptTemplateParser(template=prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if context:
prompt_inputs['#context#'] = context
if memory:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
if query:
prompt_inputs['#query#'] = query
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
print(f'advanced prompt: {prompt}')
return prompt_messages
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
@@ -337,17 +404,17 @@ class BaseLLM(BaseProviderModel):
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
prompt_inputs
)
prompt = ''
@@ -380,10 +447,8 @@ class BaseLLM(BaseProviderModel):
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
@@ -394,15 +459,15 @@ class BaseLLM(BaseProviderModel):
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
print(f'simple prompt: {prompt}')
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
@@ -444,7 +509,7 @@ class BaseLLM(BaseProviderModel):
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
if message.type == MessageType.USER:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))

View File

@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
{
'id': 'claude-2',
'name': 'claude-2',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name == 'text-davinci-003':
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]

View File

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
Returns the name of a provider.
"""
return 'baichuan'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
'mode': ModelMode.CHAT.value,
}
]
else:

View File

@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
return [{
'id': provider_model.model_name,
'name': provider_model.model_name
} for provider_model in provider_models]
provider_model_list = []
for provider_model in provider_models:
provider_model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
"""
raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""

View File

@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
if credentials['completion_type'] == 'chat_completion':
return ModelMode.CHAT.value
else:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
@@ -29,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
'mode': ModelMode.COMPLETION.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -45,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name in COMPLETION_MODELS:
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -3,7 +3,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
{
'id': 'spark',
'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider):
{
'id': 'qwen-turbo',
'name': 'qwen-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'qwen-plus',
'name': 'qwen-plus',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
'mode': ModelMode.CHAT.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -1,7 +1,5 @@
import math
from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
@@ -27,7 +25,6 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser:
@@ -51,6 +48,7 @@ class OrchestratorRuleParser:
tool_configs = agent_mode_config.get('tools', [])
agent_provider_name = model_dict.get('provider', 'openai')
agent_model_name = model_dict.get('name', 'gpt-4')
dataset_configs = self.app_model_config.dataset_configs_dict
agent_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
@@ -97,13 +95,14 @@ class OrchestratorRuleParser:
summary_model_instance = None
tools = self.to_tools(
agent_model_instance=agent_model_instance,
tool_configs=tool_configs,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
agent_model_instance=agent_model_instance,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource,
retriever_from=retriever_from
retriever_from=retriever_from,
dataset_configs=dataset_configs
)
if len(tools) == 0:
@@ -171,20 +170,12 @@ class OrchestratorRuleParser:
return None
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
retriever_from: str = 'dev') -> list[BaseTool]:
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks:
:param return_resource:
:param retriever_from:
:return:
"""
tools = []
@@ -196,15 +187,15 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(agent_model_instance)
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search":
tool = self.to_google_search_tool()
tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool()
tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
elif tool_type == "current_datetime":
tool = self.to_current_datetime_tool()
tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
if tool:
tool.callbacks.extend(callbacks)
@@ -213,12 +204,15 @@ class OrchestratorRuleParser:
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
dataset_configs: dict, rest_tokens: int,
return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param dataset_configs:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
@@ -236,10 +230,20 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
top_k = dataset_configs.get("top_k", 2)
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold")
if score_threshold_config and score_threshold_config.get("enable"):
score_threshold = score_threshold_config.get("value")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=k,
top_k=top_k,
score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
@@ -248,7 +252,7 @@ class OrchestratorRuleParser:
return tool
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
def to_web_reader_tool(self, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
"""
A tool for reading web pages
@@ -277,7 +281,7 @@ class OrchestratorRuleParser:
return tool
def to_google_search_tool(self) -> Optional[BaseTool]:
def to_google_search_tool(self, **kwargs) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
@@ -296,14 +300,14 @@ class OrchestratorRuleParser:
return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]:
def to_current_datetime_tool(self, **kwargs) -> Optional[BaseTool]:
tool = DatetimeTool(
callbacks=[DifyStdOutCallbackHandler()]
)
return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]:
def to_wikipedia_tool(self, **kwargs) -> Optional[BaseTool]:
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
@@ -315,22 +319,18 @@ class OrchestratorRuleParser:
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MAX_K = 10
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
if rest_tokens == -1:
return DEFAULT_K
return top_k
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
return top_k
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
return top_k
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
@@ -338,14 +338,7 @@ class OrchestratorRuleParser:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
if rest_tokens < segment_max_tokens * top_k:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
return min(context_limit_tokens // segment_max_tokens, MAX_K)
return min(top_k, 10)

View File

@@ -0,0 +1,81 @@
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
},
"conversation_histories_role": {
"user_prefix": "Human: ",
"assistant_prefix": "Assistant: "
}
}
}
CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n{{#pre_prompt#}}"
},{
"role": "user",
"text": "{{#query#}}"
}]
}
}
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n{{#pre_prompt#}}"
}
}
}
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
},
"conversation_histories_role": {
"user_prefix": "用户:",
"assistant_prefix": "助手:"
}
}
}
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n{{#pre_prompt#}}"
},{
"role": "user",
"text": ""
}]
}
}
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n{{#pre_prompt#}}"
}]
}
}
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n{{#pre_prompt#}}"
}
}
}

View File

@@ -1,38 +1,24 @@
import re
from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
from langchain.schema import BaseMessage
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
class PromptBuilder:
@classmethod
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
prompt_template = PromptTemplateParser(prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt = prompt_template.format(prompt_inputs)
return prompt
@classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
system_message = system_prompt_template.format(**prompt_inputs)
return system_message
return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
ai_message = ai_prompt_template.format(**prompt_inputs)
return ai_message
return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
human_message = human_prompt_template.format(**inputs)
return human_message
@classmethod
def process_template(cls, template: str):
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
return processed_template
return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))

View File

@@ -1,79 +1,38 @@
import re
from typing import Any
from jinja2 import Environment, meta
from langchain import PromptTemplate
from langchain.formatting import StrictFormatter
REGEX = re.compile(r"\{\{([a-zA-Z0-9_]{1,16}|#histories#|#query#|#context#)\}\}")
class JinjaPromptTemplate(PromptTemplate):
template_format: str = "jinja2"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
class PromptTemplateParser:
"""
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
4. In addition to the above, 3 types of special template variable Keys are accepted:
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
# Regular expression to match the template rules
return re.findall(REGEX, self.template)
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if remove_template_variables:
return PromptTemplateParser.remove_template_variables(value)
return value
return re.sub(REGEX, replacer, self.template)
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
env = Environment()
template = template.replace("{{}}", "{}")
ast = env.parse(template)
input_variables = meta.find_undeclared_variables(ast)
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
class OutLinePromptTemplate(PromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return OneLineFormatter().format(self.template, **kwargs)
class OneLineFormatter(StrictFormatter):
def parse(self, format_string):
last_end = 0
results = []
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
field_name = match.group(1)
start, end = match.span()
literal_text = format_string[last_end:start]
last_end = end
results.append((literal_text, field_name, '', None))
remaining_literal_text = format_string[last_end:]
if remaining_literal_text:
results.append((remaining_literal_text, None, None, None))
return results
def remove_template_variables(cls, text: str):
return re.sub(REGEX, r'{\1}', text)

View File

@@ -157,10 +157,10 @@ and fill in variables, with a welcome sentence, and keep TLDR.
```
<< MY INTENDED AUDIENCES >>
{audiences}
{{audiences}}
<< HOPING TO SOLVE >>
{hoping_to_solve}
{{hoping_to_solve}}
<< OUTPUT >>
"""

View File

@@ -1,5 +1,5 @@
import json
from typing import Type
from typing import Type, Optional
from flask import current_app
from langchain.tools import BaseTool
@@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool):
tenant_id: str
dataset_id: str
k: int = 3
top_k: int = 2
score_threshold: Optional[float] = None
conversation_message_task: ConversationMessageTask
return_resource: str
retriever_from: str
@@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool):
)
)
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
return str("\n".join([document.page_content for document in documents]))
else:
@@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool):
return ''
except ProviderTokenNotInitError:
return ''
embeddings = CacheEmbedding(embedding_model)
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
if self.k > 0:
if self.top_k > 0:
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.k,
'k': self.top_k,
'score_threshold': self.score_threshold,
'filter': {
'group_id': [dataset.id]
}

View File

@@ -28,6 +28,10 @@ model_config_fields = {
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
'prompt_type': fields.String,
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
}
app_detail_fields = {

View File

@@ -0,0 +1,37 @@
"""add advanced prompt templates
Revision ID: b3a09c049e8e
Revises: 2e9819ca5b28
Create Date: 2023-10-10 15:23:23.395420
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b3a09c049e8e'
down_revision = '2e9819ca5b28'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('dataset_configs')
batch_op.drop_column('completion_prompt_config')
batch_op.drop_column('chat_prompt_config')
batch_op.drop_column('prompt_type')
# ### end Alembic commands ###

View File

@@ -93,6 +93,10 @@ class AppModelConfig(db.Model):
agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
retriever_resource = db.Column(db.Text)
prompt_type = db.Column(db.String(255), nullable=False, default='simple')
chat_prompt_config = db.Column(db.Text)
completion_prompt_config = db.Column(db.Text)
dataset_configs = db.Column(db.Text)
@property
def app(self):
@@ -139,6 +143,18 @@ class AppModelConfig(db.Model):
def agent_mode_dict(self) -> dict:
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
@property
def chat_prompt_config_dict(self) -> dict:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
@property
def completion_prompt_config_dict(self) -> dict:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
@property
def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
def to_dict(self) -> dict:
return {
"provider": "",
@@ -155,7 +171,11 @@ class AppModelConfig(db.Model):
"user_input_form": self.user_input_form_list,
"dataset_query_variable": self.dataset_query_variable,
"pre_prompt": self.pre_prompt,
"agent_mode": self.agent_mode_dict
"agent_mode": self.agent_mode_dict,
"prompt_type": self.prompt_type,
"chat_prompt_config": self.chat_prompt_config_dict,
"completion_prompt_config": self.completion_prompt_config_dict,
"dataset_configs": self.dataset_configs_dict
}
def from_model_config_dict(self, model_config: dict):
@@ -177,6 +197,13 @@ class AppModelConfig(db.Model):
self.agent_mode = json.dumps(model_config['agent_mode'])
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
if model_config.get('retriever_resource') else None
self.prompt_type = model_config.get('prompt_type', 'simple')
self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \
if model_config.get('chat_prompt_config') else None
self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \
if model_config.get('completion_prompt_config') else None
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
if model_config.get('dataset_configs') else None
return self
def copy(self):
@@ -196,7 +223,11 @@ class AppModelConfig(db.Model):
user_input_form=self.user_input_form,
dataset_query_variable=self.dataset_query_variable,
pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode
agent_mode=self.agent_mode,
prompt_type=self.prompt_type,
chat_prompt_config=self.chat_prompt_config,
completion_prompt_config=self.completion_prompt_config,
dataset_configs=self.dataset_configs
)
return new_app_model_config

View File

@@ -0,0 +1,42 @@
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
class AdvancedPromptTemplateService:
@staticmethod
def get_prompt(args: dict):
app_mode = args['app_mode']
model_mode = args['model_mode']
model_name = args['model_name']
if 'baichuan' in model_name:
return AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode)
else:
return AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode)
@staticmethod
def get_common_prompt(app_mode: str, model_mode:str):
if app_mode == 'chat':
if model_mode == 'completion':
return CHAT_APP_COMPLETION_PROMPT_CONFIG
elif model_mode == 'chat':
return CHAT_APP_CHAT_PROMPT_CONFIG
elif app_mode == 'completion':
if model_mode == 'completion':
return COMPLETION_APP_COMPLETION_PROMPT_CONFIG
elif model_mode == 'chat':
return COMPLETION_APP_CHAT_PROMPT_CONFIG
@staticmethod
def get_baichuan_prompt(app_mode: str, model_mode:str):
if app_mode == 'chat':
if model_mode == 'completion':
return BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
elif model_mode == 'chat':
return BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
elif app_mode == 'completion':
if model_mode == 'completion':
return BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
elif model_mode == 'chat':
return BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG

View File

@@ -34,40 +34,28 @@ class AppModelConfigService:
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
#
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
# llm_constant.max_context_token_length[model_name]:
# raise ValueError(
# "max_tokens must be an integer greater than 0 "
# "and not exceeding the maximum value of the corresponding model")
#
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
#
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
# raise ValueError("temperature must be a float between 0 and 2")
#
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
# raise ValueError("top_p must be a float between 0 and 2")
#
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
# raise ValueError("presence_penalty must be a float between -2 and 2")
#
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
# raise ValueError("frequency_penalty must be a float between -2 and 2")
# stop
if 'stop' not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")
# Filter out extra parameters
filtered_cp = {
@@ -75,7 +63,8 @@ class AppModelConfigService:
"temperature": cp["temperature"],
"top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"],
"frequency_penalty": cp["frequency_penalty"]
"frequency_penalty": cp["frequency_penalty"],
"stop": cp["stop"]
}
return filtered_cp
@@ -211,6 +200,10 @@ class AppModelConfigService:
model_ids = [m['id'] for m in model_list]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
# model.mode
if 'mode' not in config['model'] or not config['model']["mode"]:
config['model']["mode"] = ""
# model.completion_params
if 'completion_params' not in config["model"]:
@@ -339,6 +332,9 @@ class AppModelConfigService:
# dataset_query_variable
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
# advanced prompt validation
AppModelConfigService.is_advanced_prompt_valid(config)
# Filter out extra parameters
filtered_config = {
"opening_statement": config["opening_statement"],
@@ -351,12 +347,17 @@ class AppModelConfigService:
"model": {
"provider": config["model"]["provider"],
"name": config["model"]["name"],
"mode": config['model']["mode"],
"completion_params": config["model"]["completion_params"]
},
"user_input_form": config["user_input_form"],
"dataset_query_variable": config.get('dataset_query_variable'),
"pre_prompt": config["pre_prompt"],
"agent_mode": config["agent_mode"]
"agent_mode": config["agent_mode"],
"prompt_type": config["prompt_type"],
"chat_prompt_config": config["chat_prompt_config"],
"completion_prompt_config": config["completion_prompt_config"],
"dataset_configs": config["dataset_configs"]
}
return filtered_config
@@ -375,4 +376,41 @@ class AppModelConfigService:
if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
@staticmethod
def is_advanced_prompt_valid(config: dict) -> None:
# prompt_type
if 'prompt_type' not in config or not config["prompt_type"]:
config["prompt_type"] = "simple"
if config['prompt_type'] not in ['simple', 'advanced']:
raise ValueError("prompt_type must be in ['simple', 'advanced']")
# chat_prompt_config
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
config["chat_prompt_config"] = {}
if not isinstance(config["chat_prompt_config"], dict):
raise ValueError("chat_prompt_config must be of object type")
# completion_prompt_config
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
config["completion_prompt_config"] = {}
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
# dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config['prompt_type'] == 'advanced':
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")

View File

@@ -482,6 +482,9 @@ class ProviderService:
'features': []
}
if 'mode' in model:
valid_model_dict['model_mode'] = model['mode']
if 'features' in model:
valid_model_dict['features'] = model['features']

View File

@@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('claude-2')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6

View File

@@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22

View File

@@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('baichuan2-53b')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst > 0
@@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
@@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages

View File

@@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
@@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5

View File

@@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5

View File

@@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22

View File

@@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('facebook/opt-125m', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5

View File

@@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 7

View File

@@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('spark')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6

View File

@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('qwen-turbo')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5

View File

@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('ernie-bot')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5

View File

@@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('llama-2-chat', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5

View File

@@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('chatglm_lite')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst > 0
@@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
@@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages

View File

@@ -1,7 +1,7 @@
from typing import Type
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.base import BaseModelProvider
@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
return 'fake'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [{'id': 'test_model', 'name': 'Test Model'}]
return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
return OpenAIModel

View File

@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
provider = FakeModelProvider(provider=Provider())
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
assert result == [{'id': 'test_model', 'name': 'test_model'}]
assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
def test_check_quota_over_limit(mocker):