mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 19:42:42 +08:00
Compare commits
68 Commits
feat/segme
...
fix/web-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d942d99dc | ||
|
|
78d3aa5fcd | ||
|
|
a7c78d2cd2 | ||
|
|
4db35fa375 | ||
|
|
e67a1413b6 | ||
|
|
4f3053a8cc | ||
|
|
b3c2bf125f | ||
|
|
9d5299e9ec | ||
|
|
aee15adf1b | ||
|
|
b185a70c21 | ||
|
|
a3aba7a9aa | ||
|
|
866ee5da91 | ||
|
|
e8039a7da8 | ||
|
|
5e0540077a | ||
|
|
b346bd9b83 | ||
|
|
062e2e915b | ||
|
|
e0a48c4972 | ||
|
|
f53242c081 | ||
|
|
4b53bb1a32 | ||
|
|
4c49ecedb5 | ||
|
|
4ff1870a4b | ||
|
|
6c832ee328 | ||
|
|
25264e7852 | ||
|
|
18dd0d569d | ||
|
|
3ea8d7a019 | ||
|
|
da3f10a55e | ||
|
|
8c991b5b26 | ||
|
|
22c1aafb9b | ||
|
|
8d6d1c442b | ||
|
|
95b179fb39 | ||
|
|
3a0a9e2d8f | ||
|
|
0a0d63457d | ||
|
|
920fb6d0e1 | ||
|
|
fd0fc8f4fe | ||
|
|
1c552ff23a | ||
|
|
5163dd38e5 | ||
|
|
2a27dad2fb | ||
|
|
930f74c610 | ||
|
|
3f250c9e12 | ||
|
|
fa408d264c | ||
|
|
09ea27f1ee | ||
|
|
db7156dafd | ||
|
|
4420281d96 | ||
|
|
d9afebe216 | ||
|
|
1d9cc5ca05 | ||
|
|
edb06f6aed | ||
|
|
6ca3bcbcfd | ||
|
|
71a9d63232 | ||
|
|
fb62017e50 | ||
|
|
9adbeadeec | ||
|
|
2f7b234cc5 | ||
|
|
4f5f9506ab | ||
|
|
0cc0b6e052 | ||
|
|
cd78adb0ab | ||
|
|
f42e7d1a61 | ||
|
|
c4d759dfba | ||
|
|
a58f95fa91 | ||
|
|
39574dcf6b | ||
|
|
5b06ded0b1 | ||
|
|
155a4733f6 | ||
|
|
b7c29ea1b6 | ||
|
|
cc2d71c253 | ||
|
|
cd11613952 | ||
|
|
e0d6d00a87 | ||
|
|
2dfb3e95f6 | ||
|
|
f207e180df | ||
|
|
948d64bbef | ||
|
|
01e912e543 |
38
.github/workflows/api-unit-tests.yml
vendored
Normal file
38
.github/workflows/api-unit-tests.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install -r api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest api/tests/unit_tests
|
||||
11
README.md
11
README.md
@@ -24,17 +24,18 @@ Visual data analysis, log review, and annotation for applications
|
||||
- [x] **Anthropic**: Claude2, Claude-instant
|
||||
- [x] **Replicate**
|
||||
- [x] **Hugging Face Hub**
|
||||
- [x] **ChatGLM**
|
||||
- [x] **Llama2**
|
||||
- [x] **MiniMax**
|
||||
- [x] **Spark**
|
||||
- [x] **Wenxin**
|
||||
- [x] **Tongyi**
|
||||
- [x] **ChatGLM**
|
||||
|
||||
|
||||
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
|
||||
* 1000 free Claude model queries to build Claude-powered apps
|
||||
* 600,000 free Claude model tokens to build Claude-powered apps
|
||||
* 200 free OpenAI queries to build OpenAI-based apps
|
||||
* 3 million Xunfei Spark Tokens are provided for creating AI applications based on Spark.
|
||||
* 1 million MiniMax Tokens are provided for creating AI applications based on the MiniMax.
|
||||
|
||||
|
||||
**2. Visual orchestration:** Build an AI app in minutes by writing and debugging prompts visually.
|
||||
|
||||
@@ -94,8 +95,6 @@ Features under development:
|
||||
We will support more datasets, including text, webpages, and even Notion content. Users can build AI applications based on their own data sources.
|
||||
- **Plugins**, introducing ChatGPT Plugin-standard plugins for applications, or using Dify-produced plugins
|
||||
We will release plugins complying with ChatGPT standard, or Dify's own plugins to enable more capabilities in applications.
|
||||
- **Open-source models**, e.g. adopting Llama as a model provider or for further fine-tuning
|
||||
We will work with excellent open-source models like Llama, by providing them as model options in our platform, or using them for further fine-tuning.
|
||||
|
||||
|
||||
## Q&A
|
||||
|
||||
@@ -27,14 +27,16 @@
|
||||
- [x] **Anthropic**:Claude2、Claude-instant
|
||||
- [x] **Replicate**
|
||||
- [x] **Hugging Face Hub**
|
||||
- [x] **ChatGLM**
|
||||
- [x] **Llama2**
|
||||
- [x] **MiniMax**
|
||||
- [x] **讯飞星火大模型**
|
||||
- [x] **文心一言**
|
||||
- [x] **通义千问**
|
||||
- [x] **ChatGLM**
|
||||
|
||||
|
||||
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
|
||||
* 1000 次 Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
|
||||
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
|
||||
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
|
||||
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
|
||||
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用
|
||||
@@ -90,8 +92,6 @@ docker compose up -d
|
||||
|
||||
- **数据集**,支持更多的数据集,通过网页、API 同步内容。用户可以根据自己的数据源构建 AI 应用程序。
|
||||
- **插件**,我们将发布符合 ChatGPT 标准的插件,支持更多 Dify 自己的插件,支持用户自定义插件能力,以在应用程序中启用更多功能,例如以支持以目标为导向的分解推理任务。
|
||||
- **开源模型支持**,支持 Hugging face Hub 上的开源模型。例如采用 Llama 作为模型提供者,或进行进一步的微调
|
||||
我们将与优秀的开源模型合作,通过在我们的平台中提供它们作为模型选项,或使用它们进行进一步的微调。
|
||||
|
||||
## Q&A
|
||||
|
||||
|
||||
@@ -117,10 +117,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_ANTHROPIC_ENABLED=false
|
||||
HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
||||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
||||
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
@@ -16,7 +16,7 @@ EXPOSE 5001
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
|
||||
|
||||
COPY requirements.txt /app/api/requirements.txt
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -102,6 +102,7 @@ def reset_encrypt_key_pair():
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
|
||||
db.session.query(ProviderModel).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
@@ -258,6 +259,8 @@ def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
new_quota_limit = hosted_model_providers.anthropic.quota_limit
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
@@ -265,6 +268,7 @@ def sync_anthropic_hosted_providers():
|
||||
Provider.provider_name == 'anthropic',
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.quota_limit != new_quota_limit
|
||||
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
@@ -272,9 +276,9 @@ def sync_anthropic_hosted_providers():
|
||||
page += 1
|
||||
for provider in providers:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
|
||||
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
|
||||
original_quota_limit = provider.quota_limit
|
||||
new_quota_limit = hosted_model_providers.anthropic.quota_limit
|
||||
division = math.ceil(new_quota_limit / 1000)
|
||||
|
||||
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
|
||||
|
||||
@@ -48,21 +48,23 @@ DEFAULTS = {
|
||||
'WEAVIATE_GRPC_ENABLED': 'True',
|
||||
'WEAVIATE_BATCH_SIZE': 100,
|
||||
'CELERY_BACKEND': 'database',
|
||||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
'UPLOAD_FILE_BATCH_LIMIT': 5,
|
||||
}
|
||||
|
||||
|
||||
@@ -104,7 +106,6 @@ class Config:
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
self.PDF_PREVIEW = get_bool_env('PDF_PREVIEW')
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
@@ -209,7 +210,7 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
@@ -217,23 +218,21 @@ class Config:
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
|
||||
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
|
||||
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))
|
||||
|
||||
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
|
||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
@@ -244,6 +243,10 @@ class Config:
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# uploading settings
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
import flask_restful
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
import flask
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -11,7 +14,9 @@ from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from libs.helper import TimestampField
|
||||
@@ -124,12 +129,39 @@ class AppListApi(Resource):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
default_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
default_model = None
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
default_model = None
|
||||
|
||||
if args['model_config'] is not None:
|
||||
# validate config
|
||||
model_config_dict = args['model_config']
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
model_config_dict["model"]["provider"]
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
|
||||
model_config_dict["model"]["name"] = default_model.name
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=args['model_config']
|
||||
config=model_config_dict
|
||||
)
|
||||
|
||||
app = App(
|
||||
@@ -141,21 +173,8 @@ class AppListApi(Resource):
|
||||
status='normal'
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement=model_configuration['opening_statement'],
|
||||
suggested_questions=json.dumps(model_configuration['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
agent_mode=json.dumps(model_configuration['agent_mode']),
|
||||
)
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(model_configuration)
|
||||
else:
|
||||
if 'mode' not in args or args['mode'] is None:
|
||||
abort(400, message="mode is required")
|
||||
@@ -165,20 +184,22 @@ class AppListApi(Resource):
|
||||
app = App(**model_config_template['app'])
|
||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||
|
||||
default_model = ModelFactory.get_default_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
app_model_config.model_dict["provider"]
|
||||
)
|
||||
|
||||
if default_model:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = default_model.provider_name
|
||||
model_dict['name'] = default_model.model_name
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
else:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Text Generation Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = default_model.model_provider.provider_name
|
||||
model_dict['name'] = default_model.name
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
app.name = args['name']
|
||||
app.mode = args['mode']
|
||||
@@ -297,7 +318,7 @@ class AppApi(Resource):
|
||||
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
db.session.delete(app)
|
||||
@@ -416,22 +437,9 @@ class AppCopy(Resource):
|
||||
|
||||
@staticmethod
|
||||
def create_app_model_config_copy(app_config, copy_app_id):
|
||||
copy_app_model_config = AppModelConfig(
|
||||
app_id=copy_app_id,
|
||||
provider=app_config.provider,
|
||||
model_id=app_config.model_id,
|
||||
configs=app_config.configs,
|
||||
opening_statement=app_config.opening_statement,
|
||||
suggested_questions=app_config.suggested_questions,
|
||||
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
|
||||
speech_to_text=app_config.speech_to_text,
|
||||
more_like_this=app_config.more_like_this,
|
||||
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
|
||||
model=app_config.model,
|
||||
user_input_form=app_config.user_input_form,
|
||||
pre_prompt=app_config.pre_prompt,
|
||||
agent_mode=app_config.agent_mode
|
||||
)
|
||||
copy_app_model_config = app_config.copy()
|
||||
copy_app_model_config.app_id = copy_app_id
|
||||
|
||||
return copy_app_model_config
|
||||
|
||||
@setup_required
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy import or_, func
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Union, Generator
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
@@ -16,6 +16,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.login.login import login_required
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.login.login import login_required
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppModelConfig
|
||||
@@ -35,20 +36,8 @@ class ModelConfigResource(Resource):
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement=model_configuration['opening_statement'],
|
||||
suggested_questions=json.dumps(model_configuration['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
agent_mode=json.dumps(model_configuration['agent_mode']),
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -5,9 +5,12 @@ from typing import Optional
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.login.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -21,10 +22,6 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
import services
|
||||
@@ -10,13 +11,15 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
@@ -33,6 +36,9 @@ dataset_detail_fields = {
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
@@ -74,8 +80,22 @@ class DatasetListApi(Resource):
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
# if len(valid_model_list) == 0:
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
# f"in the Settings -> Model Provider.")
|
||||
model_names = [item['model_name'] for item in valid_model_list]
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['embedding_model'] in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
response = {
|
||||
'data': marshal(datasets, dataset_detail_fields),
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
@@ -99,7 +119,6 @@ class DatasetListApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
@@ -233,6 +252,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
@@ -250,11 +271,14 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'])
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
@@ -262,11 +286,14 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'])
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response, 200
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@@ -274,6 +275,7 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
@@ -282,14 +284,19 @@ class DatasetDocumentListApi(Resource):
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
@@ -328,6 +335,7 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@@ -406,11 +414,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict, None, dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response
|
||||
|
||||
@@ -473,22 +483,27 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict, None, dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
elif dataset.data_source_type:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict,
|
||||
None, dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
@@ -575,7 +590,8 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
return marshal(document, self.document_status_fields)
|
||||
|
||||
|
||||
@@ -749,11 +765,13 @@ class DocumentMetadataApi(DocumentResource):
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
document.doc_metadata = {}
|
||||
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
if doc_type == 'others':
|
||||
document.doc_metadata = doc_metadata
|
||||
else:
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.utcnow()
|
||||
@@ -832,6 +850,22 @@ class DocumentStatusApi(DocumentResource):
|
||||
|
||||
remove_document_from_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
elif action == "un_archive":
|
||||
if not document.archived:
|
||||
raise InvalidActionError('Document is not archived.')
|
||||
|
||||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_document_to_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import InvalidActionError
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.login.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
@@ -17,7 +22,9 @@ from models.dataset import DocumentSegment
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
import pandas as pd
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
@@ -152,6 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
@@ -197,7 +218,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
# Set cache to prevent indexing the same segment multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_segment_from_index_task.delay(segment.id)
|
||||
disable_segment_from_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
@@ -222,6 +243,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@@ -233,7 +267,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
@@ -245,6 +279,61 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@@ -270,17 +359,88 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document)
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == 'qa_model':
|
||||
data = {'content': row[0], 'answer': row[1]}
|
||||
else:
|
||||
data = {'content': row[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {'error': str(e)}, 500
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': cache_result.decode()
|
||||
}, 200
|
||||
|
||||
|
||||
@@ -292,3 +452,6 @@ api.add_resource(DatasetDocumentSegmentAddApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
api.add_resource(DatasetDocumentSegmentBatchImportApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
|
||||
'/datasets/batch_import_status/<uuid:job_id>')
|
||||
|
||||
@@ -8,7 +8,8 @@ from pathlib import Path
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -25,12 +26,28 @@ from models.model import UploadFile
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileApi(Resource):
|
||||
upload_config_fields = {
|
||||
'file_size_limit': fields.Integer,
|
||||
'batch_count_limit': fields.Integer
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(upload_config_fields)
|
||||
def get(self):
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
|
||||
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
|
||||
return {
|
||||
'file_size_limit': file_size_limit,
|
||||
'batch_count_limit': batch_count_limit
|
||||
}, 200
|
||||
|
||||
file_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
@@ -60,8 +77,9 @@ class FileApi(Resource):
|
||||
file_content = file.read()
|
||||
file_size = len(file_content)
|
||||
|
||||
if file_size > FILE_SIZE_LIMIT:
|
||||
message = "({file_size} > {FILE_SIZE_LIMIT})"
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
if file_size > file_size_limit:
|
||||
message = "({file_size} > {file_size_limit})"
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, fields
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
@@ -11,7 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@@ -102,6 +104,10 @@ class HitTestingApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
|
||||
@@ -38,12 +38,20 @@ class StripeWebhookApi(Resource):
|
||||
logging.debug(event['data']['object']['payment_status'])
|
||||
logging.debug(event['data']['object']['metadata'])
|
||||
|
||||
session = stripe.checkout.Session.retrieve(
|
||||
event['data']['object']['id'],
|
||||
expand=['line_items'],
|
||||
)
|
||||
|
||||
logging.debug(session.line_items['data'][0]['quantity'])
|
||||
|
||||
# Fulfill the purchase...
|
||||
provider_checkout_service = ProviderCheckoutService()
|
||||
|
||||
try:
|
||||
provider_checkout_service.fulfill_provider_order(event)
|
||||
provider_checkout_service.fulfill_provider_order(event, session.line_items)
|
||||
except Exception as e:
|
||||
|
||||
logging.debug(str(e))
|
||||
return 'success', 200
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@@ -43,6 +46,13 @@ tenants_fields = {
|
||||
'current': fields.Boolean
|
||||
}
|
||||
|
||||
workspace_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'status': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@@ -57,6 +67,38 @@ class TenantListApi(Resource):
|
||||
return {'workspaces': marshal(tenants, tenants_fields)}, 200
|
||||
|
||||
|
||||
class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
|
||||
.paginate(page=args['page'], per_page=args['limit'])
|
||||
|
||||
has_more = False
|
||||
if len(tenants.items) == args['limit']:
|
||||
current_page_first_tenant = tenants[-1]
|
||||
rest_count = db.session.query(Tenant).filter(
|
||||
Tenant.created_at < current_page_first_tenant.created_at,
|
||||
Tenant.id != current_page_first_tenant.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
total = db.session.query(Tenant).count()
|
||||
return {
|
||||
'data': marshal(tenants.items, workspace_fields),
|
||||
'has_more': has_more,
|
||||
'limit': args['limit'],
|
||||
'page': args['page'],
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -92,6 +134,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
||||
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
|
||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
||||
|
||||
@@ -17,7 +17,7 @@ def validate_app_token(view=None):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).get(api_token.app_id)
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
|
||||
@@ -44,7 +44,7 @@ def validate_dataset_token(view=None):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('dataset')
|
||||
|
||||
dataset = db.session.query(Dataset).get(api_token.dataset_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound()
|
||||
|
||||
@@ -64,14 +64,14 @@ def validate_and_get_api_token(scope=None):
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized()
|
||||
if auth_header is None or ' ' not in auth_header:
|
||||
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized()
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
api_token = db.session.query(ApiToken).filter(
|
||||
ApiToken.token == auth_token,
|
||||
@@ -79,7 +79,7 @@ def validate_and_get_api_token(scope=None):
|
||||
).first()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized()
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
@@ -59,7 +59,11 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
try:
|
||||
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
|
||||
@@ -50,9 +50,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
@@ -66,12 +66,12 @@ class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
llm = cast(ChatOpenAI, model_instance.client)
|
||||
model, encoding = llm._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
|
||||
@@ -50,9 +50,13 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
try:
|
||||
action_match = re.search(r"```(.*?)\n?(.*?)```", text, re.DOTALL)
|
||||
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
|
||||
if action_match is not None:
|
||||
response = json.loads(action_match.group(2).strip(), strict=False)
|
||||
if isinstance(response, list):
|
||||
@@ -26,4 +26,4 @@ class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
else:
|
||||
return AgentFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
|
||||
@@ -94,7 +94,12 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
@@ -52,7 +52,7 @@ Action:
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
@@ -89,8 +89,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
@@ -99,7 +99,11 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
@@ -108,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2:
|
||||
if len(intermediate_steps) >= 2 and self.summary_llm:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
|
||||
@@ -32,7 +32,7 @@ class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_instance: BaseLLM
|
||||
tools: list[BaseTool]
|
||||
summary_model_instance: BaseLLM
|
||||
summary_model_instance: BaseLLM = None
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
@@ -65,7 +65,8 @@ class AgentExecutor:
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
@@ -74,7 +75,8 @@ class AgentExecutor:
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
@@ -83,7 +85,8 @@ class AgentExecutor:
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.status = 'llm_end'
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
@@ -81,11 +86,15 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.exception(error)
|
||||
logging.debug("Agent on_llm_error: %s", error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
@@ -164,7 +173,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.exception(error)
|
||||
logging.debug("Agent on_tool_error: %s", error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
@@ -68,4 +68,4 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.exception(error)
|
||||
logging.debug("Dataset tool on_llm_error: %s", error)
|
||||
|
||||
@@ -72,5 +72,5 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.exception(error)
|
||||
logging.debug("Dataset tool on_chain_error: %s", error)
|
||||
self.clear_chain_results()
|
||||
|
||||
@@ -126,7 +126,7 @@ class Completion:
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# get llm prompt
|
||||
|
||||
@@ -60,17 +60,7 @@ class ConversationMessageTask:
|
||||
def init(self):
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = {
|
||||
"model": self.app_model_config.model_dict,
|
||||
"pre_prompt": self.app_model_config.pre_prompt,
|
||||
"agent_mode": self.app_model_config.agent_mode_dict,
|
||||
"opening_statement": self.app_model_config.opening_statement,
|
||||
"suggested_questions": self.app_model_config.suggested_questions_list,
|
||||
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
|
||||
"more_like_this": self.app_model_config.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
|
||||
"user_input_form": self.app_model_config.user_input_form_list,
|
||||
}
|
||||
override_model_configs = self.app_model_config.to_dict()
|
||||
|
||||
introduction = ''
|
||||
system_instruction = ''
|
||||
@@ -129,9 +119,11 @@ class ConversationMessageTask:
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=self.model_instance.get_currency(),
|
||||
@@ -150,17 +142,24 @@ class ConversationMessageTask:
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
|
||||
|
||||
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.total_price = total_price
|
||||
|
||||
@@ -202,7 +201,9 @@ class ConversationMessageTask:
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=agent_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
@@ -216,25 +217,26 @@ class ConversationMessageTask:
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
|
||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_total_price = self.calc_total_price(
|
||||
loop_message_tokens,
|
||||
agent_message_unit_price,
|
||||
loop_answer_tokens,
|
||||
agent_answer_unit_price
|
||||
)
|
||||
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.message_price_unit = agent_message_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.answer_price_unit = agent_answer_price_unit
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
@@ -253,15 +255,6 @@ class ConversationMessageTask:
|
||||
|
||||
db.session.add(dataset_query)
|
||||
|
||||
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
|
||||
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def end(self):
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
class DatesetDocumentStore:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
@@ -59,7 +59,7 @@ class DatesetDocumentStore:
|
||||
return output
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == self._document_id
|
||||
@@ -69,7 +69,9 @@ class DatesetDocumentStore:
|
||||
max_position = 0
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
@@ -123,7 +125,7 @@ class DatesetDocumentStore:
|
||||
return result is not None
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[Document]:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from core.model_providers.error import LLMError
|
||||
from core.model_providers.error import LLMError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
@@ -51,6 +51,7 @@ class LLMGenerator:
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
|
||||
max_context_token_length = model_instance.model_rules.max_tokens.max
|
||||
max_context_token_length = max_context_token_length if max_context_token_length else 1500
|
||||
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
|
||||
|
||||
context = ''
|
||||
@@ -108,13 +109,16 @@ class LLMGenerator:
|
||||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=256,
|
||||
temperature=0
|
||||
try:
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=256,
|
||||
temperature=0
|
||||
)
|
||||
)
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
return []
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
|
||||
@@ -175,8 +179,8 @@ class LLMGenerator:
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -15,7 +15,9 @@ class IndexBuilder:
|
||||
return None
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
import uuid
|
||||
from typing import Optional, List, cast
|
||||
|
||||
from flask import current_app, Flask
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
@@ -66,14 +67,6 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
# new_documents = []
|
||||
# for document in documents:
|
||||
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
|
||||
# document_qa_list = self.format_split_text(response)
|
||||
# for result in document_qa_list:
|
||||
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
|
||||
# new_documents.append(document)
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
@@ -224,14 +217,25 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None) -> dict:
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
@@ -262,20 +266,19 @@ class IndexingRunner:
|
||||
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -283,18 +286,31 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
@@ -343,20 +359,19 @@ class IndexingRunner:
|
||||
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -364,7 +379,7 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
@@ -457,7 +472,8 @@ class IndexingRunner:
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_form=dataset_document.doc_form
|
||||
document_form=dataset_document.doc_form,
|
||||
document_language=dataset_document.doc_language
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
@@ -493,7 +509,8 @@ class IndexingRunner:
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
|
||||
processing_rule: DatasetProcessRule, tenant_id: str,
|
||||
document_form: str, document_language: str) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@@ -522,7 +539,9 @@ class IndexingRunner:
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
|
||||
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
|
||||
'document_language': document_language})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
@@ -530,28 +549,29 @@ class IndexingRunner:
|
||||
return all_qa_documents
|
||||
return all_documents
|
||||
|
||||
def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
|
||||
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
@@ -638,7 +658,9 @@ class IndexingRunner:
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
@@ -719,6 +741,32 @@ class IndexingRunner:
|
||||
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
|
||||
"""
|
||||
Batch add segments index processing
|
||||
"""
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
pass
|
||||
|
||||
108
api/core/login/login.py
Normal file
108
api/core/login/login.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
import flask_login
|
||||
from flask import current_app
|
||||
from flask import g
|
||||
from flask import has_request_context
|
||||
from flask import request
|
||||
from flask_login import user_logged_in
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
|
||||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user = LocalProxy(lambda: _get_user())
|
||||
|
||||
|
||||
def login_required(func):
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
|
||||
example::
|
||||
|
||||
@app.route('/post')
|
||||
@login_required
|
||||
def post():
|
||||
pass
|
||||
|
||||
If there are only certain times you need to require that your user is
|
||||
logged in, you can do so with::
|
||||
|
||||
if not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
...which is essentially the code that this function adds to your views.
|
||||
|
||||
It can be convenient to globally turn off authentication when unit testing.
|
||||
To enable this, if the application configuration variable `LOGIN_DISABLED`
|
||||
is set to `True`, this decorator will be ignored.
|
||||
|
||||
.. Note ::
|
||||
|
||||
Per `W3 guidelines for CORS preflight requests
|
||||
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
|
||||
HTTP ``OPTIONS`` requests are exempt from login checks.
|
||||
|
||||
:param func: The view function to decorate.
|
||||
:type func: function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False')
|
||||
if admin_api_key_enable:
|
||||
if auth_header:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
admin_api_key = os.getenv('ADMIN_API_KEY')
|
||||
|
||||
if admin_api_key:
|
||||
if os.getenv('ADMIN_API_KEY') == auth_token:
|
||||
workspace_id = request.headers.get('X-WORKSPACE-ID')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == workspace_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
|
||||
pass
|
||||
elif not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
# flask 1.x compatibility
|
||||
# current_app.ensure_sync is only available in Flask >= 2.0
|
||||
if callable(getattr(current_app, "ensure_sync", None)):
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def _get_user():
|
||||
if has_request_context():
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user()
|
||||
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
@@ -46,7 +46,8 @@ class ModelFactory:
|
||||
model_name: Optional[str] = None,
|
||||
model_kwargs: Optional[ModelKwargs] = None,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
callbacks: Callbacks = None,
|
||||
deduct_quota: bool = True) -> Optional[BaseLLM]:
|
||||
"""
|
||||
get text generation model.
|
||||
|
||||
@@ -56,6 +57,7 @@ class ModelFactory:
|
||||
:param model_kwargs:
|
||||
:param streaming:
|
||||
:param callbacks:
|
||||
:param deduct_quota:
|
||||
:return:
|
||||
"""
|
||||
is_default_model = False
|
||||
@@ -95,7 +97,7 @@ class ModelFactory:
|
||||
else:
|
||||
raise e
|
||||
|
||||
if is_default_model:
|
||||
if is_default_model or not deduct_quota:
|
||||
model_instance.deduct_quota = False
|
||||
|
||||
return model_instance
|
||||
|
||||
@@ -57,6 +57,12 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'huggingface_hub':
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
return HuggingfaceHubProvider
|
||||
elif provider_name == 'xinference':
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
return XinferenceProvider
|
||||
elif provider_name == 'openllm':
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
return OpenLLMProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -26,10 +26,20 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
**self.credentials
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base')
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
@@ -48,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
@@ -71,7 +71,7 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
return LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
import decimal
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
@@ -17,6 +19,63 @@ class BaseEmbedding(BaseProviderModel):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def calc_tokens_price(self, tokens: int) -> decimal.Decimal:
|
||||
"""
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:return: decimal.Decimal('0.0000001')
|
||||
"""
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.price_config['unit']
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:return: decimal.Decimal('0.0001')
|
||||
|
||||
"""
|
||||
unit_price = self.price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logger.debug(f'unit_price:{unit_price}')
|
||||
return unit_price
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
@@ -29,11 +88,14 @@ class BaseEmbedding(BaseProviderModel):
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
currency = self.price_config['currency']
|
||||
return currency
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from langchain.embeddings import MiniMaxEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
@@ -22,12 +19,6 @@ class MinimaxEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
|
||||
@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
@@ -65,7 +55,7 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = XinferenceEmbeddings(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid'],
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': decimal.Decimal('1.63'),
|
||||
'completion': decimal.Decimal('5.51'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': decimal.Decimal('11.02'),
|
||||
'completion': decimal.Decimal('32.68'),
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1m * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
|
||||
@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-35-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-35-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
base_model_name = self.credentials.get("base_model_name")
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[base_model_name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
@@ -166,7 +135,7 @@ class AzureOpenAIModel(BaseLLM):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
return LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
import decimal
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(BaseProviderModel):
|
||||
@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
|
||||
def _init_client(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get llm base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'prompt': decimal.Decimal(price_config['prompt']),
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
@@ -125,6 +161,8 @@ class BaseLLM(BaseProviderModel):
|
||||
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
self.model_provider.update_last_used()
|
||||
|
||||
if self.deduct_quota:
|
||||
self.model_provider.deduct_quota(total_tokens)
|
||||
|
||||
@@ -159,25 +197,64 @@ class BaseLLM(BaseProviderModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.get_price_unit(message_type)
|
||||
|
||||
@abstractmethod
|
||||
def get_currency(self):
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.0001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"unit_price={unit_price}")
|
||||
return unit_price
|
||||
|
||||
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get price unit.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
|
||||
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"price_unit={price_unit}")
|
||||
return price_unit
|
||||
|
||||
def get_currency(self) -> str:
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return:
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
raise NotImplementedError
|
||||
currency = self.price_config['currency']
|
||||
return currency
|
||||
|
||||
def get_model_kwargs(self):
|
||||
return self.model_kwargs
|
||||
|
||||
@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# not support calc price
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
||||
@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
|
||||
# TODO load price config from configs(db)
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
@@ -185,7 +148,7 @@ class OpenAIModel(BaseLLM):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
60
api/core/model_providers/models/llm/openllm_model.py
Normal file
60
api/core/model_providers/models/llm/openllm_model.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
|
||||
|
||||
class OpenLLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = OpenLLM(
|
||||
server_url=self.credentials.get('server_url'),
|
||||
callbacks=self.callbacks,
|
||||
llm_kwargs=self.provider_model_kwargs
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
|
||||
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.input = provider_model_kwargs
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
@@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatSpark(
|
||||
model_name=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
|
||||
contents = [message.content for message in messages]
|
||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
# TODO load price_config from configs(db)
|
||||
return Wenxin(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'ernie-bot': {
|
||||
'prompt': decimal.Decimal('0.012'),
|
||||
'completion': decimal.Decimal('0.012'),
|
||||
},
|
||||
'ernie-bot-turbo': {
|
||||
'prompt': decimal.Decimal('0.008'),
|
||||
'completion': decimal.Decimal('0.008')
|
||||
},
|
||||
'bloomz-7b': {
|
||||
'prompt': decimal.Decimal('0.006'),
|
||||
'completion': decimal.Decimal('0.006')
|
||||
}
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
|
||||
70
api/core/model_providers/models/llm/xinference_model.py
Normal file
70
api/core/model_providers/models/llm/xinference_model.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
|
||||
|
||||
class XinferenceModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = XinferenceLLM(
|
||||
server_url=self.credentials['server_url'],
|
||||
model_uid=self.credentials['model_uid'],
|
||||
)
|
||||
|
||||
client.callbacks = self.callbacks
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate(
|
||||
[prompts],
|
||||
stop,
|
||||
callbacks,
|
||||
generate_config={
|
||||
"stop": stop,
|
||||
"stream": self.streaming,
|
||||
**self.provider_model_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Xinference: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@@ -41,7 +41,7 @@ class OpenAIModeration(BaseProviderModel):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -40,7 +40,7 @@ class OpenAIWhisper(BaseSpeech2Text):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -183,6 +183,8 @@ class AnthropicProvider(BaseModelProvider):
|
||||
return {
|
||||
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
|
||||
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
|
||||
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@@ -31,7 +31,9 @@ class HostedAnthropic(BaseModel):
|
||||
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
paid_increase_quota: int = 1000000
|
||||
paid_min_quantity: int = 20
|
||||
paid_max_quantity: int = 100
|
||||
|
||||
|
||||
class HostedModelProviders(BaseModel):
|
||||
@@ -73,4 +75,6 @@ def init_app(app: Flask):
|
||||
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
|
||||
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
|
||||
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
|
||||
)
|
||||
|
||||
138
api/core/model_providers/providers/openllm_provider.py
Normal file
138
api/core/model_providers/providers/openllm_provider.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class OpenLLMProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'openllm'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenLLMModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'server_url': credentials['server_url']
|
||||
}
|
||||
|
||||
llm = OpenLLM(
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
@@ -116,7 +116,8 @@ class ReplicateProvider(BaseModelProvider):
|
||||
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
|
||||
elif model_type == ModelType.TEXT_GENERATION \
|
||||
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
and ('items' not in rst.openapi_schema['components']['schemas']['Output']
|
||||
or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
|
||||
except ReplicateError as e:
|
||||
|
||||
@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
|
||||
return [
|
||||
{
|
||||
'id': 'spark',
|
||||
'name': '星火认知大模型',
|
||||
'name': 'Spark V1.5',
|
||||
},
|
||||
{
|
||||
'id': 'spark-v2',
|
||||
'name': 'Spark V2.0',
|
||||
}
|
||||
]
|
||||
else:
|
||||
|
||||
193
api/core/model_providers/providers/xinference_provider.py
Normal file
193
api/core/model_providers/providers/xinference_provider.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class XinferenceProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'xinference'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = XinferenceModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = XinferenceEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
credentials = self.get_model_credentials(model_name, model_type)
|
||||
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
elif credentials['model_format'] == "ggmlv3":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
|
||||
|
||||
if 'model_uid' not in credentials:
|
||||
raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'server_url': credentials['server_url'],
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None,
|
||||
'model_uid': None,
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _get_extra_credentials(self, credentials: dict) -> dict:
|
||||
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get the model description, detail: {response.json()['detail']}"
|
||||
)
|
||||
desc = response.json()
|
||||
|
||||
extra_credentials = {
|
||||
'model_format': desc['model_format'],
|
||||
}
|
||||
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
|
||||
extra_credentials['model_handle_type'] = 'chatglm'
|
||||
elif "generate" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'generate'
|
||||
elif "chat" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'chat'
|
||||
else:
|
||||
raise NotImplementedError(f"Model handle type not supported.")
|
||||
|
||||
return extra_credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
@@ -8,5 +8,7 @@
|
||||
"wenxin",
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub"
|
||||
"huggingface_hub",
|
||||
"xinference",
|
||||
"openllm"
|
||||
]
|
||||
@@ -5,10 +5,25 @@
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"paid",
|
||||
"trial"
|
||||
],
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 1000
|
||||
"quota_unit": "tokens",
|
||||
"quota_limit": 600000
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"claude-instant-1": {
|
||||
"prompt": "1.63",
|
||||
"completion": "5.51",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"claude-2": {
|
||||
"prompt": "11.02",
|
||||
"completion": "32.68",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,5 +3,48 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
"model_flexibility": "configurable",
|
||||
"price_config":{
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo": {
|
||||
"prompt": "0.002",
|
||||
"completion": "0.0015",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-002": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,5 +9,24 @@
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"abab5.5-chat": {
|
||||
"prompt": "0.015",
|
||||
"completion": "0.015",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"abab5-chat": {
|
||||
"prompt": "0.015",
|
||||
"completion": "0.015",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"embo-01": {
|
||||
"completion": "0",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,5 +10,42 @@
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 200
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
7
api/core/model_providers/rules/openllm.json
Normal file
7
api/core/model_providers/rules/openllm.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
@@ -9,5 +9,19 @@
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"spark": {
|
||||
"prompt": "0.18",
|
||||
"completion": "0.18",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"spark-v2": {
|
||||
"prompt": "0.36",
|
||||
"completion": "0.36",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,5 +3,25 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"ernie-bot": {
|
||||
"prompt": "0.012",
|
||||
"completion": "0.012",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"ernie-bot-turbo": {
|
||||
"prompt": "0.008",
|
||||
"completion": "0.008",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"bloomz-7b": {
|
||||
"prompt": "0.006",
|
||||
"completion": "0.006",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
7
api/core/model_providers/rules/xinference.json
Normal file
7
api/core/model_providers/rules/xinference.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
@@ -14,14 +14,16 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.current_datetime_tool import DatetimeTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||
from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
|
||||
@@ -78,15 +80,22 @@ class OrchestratorRuleParser:
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
try:
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=agent_provider_name,
|
||||
model_name=agent_model_name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
),
|
||||
deduct_quota=False
|
||||
)
|
||||
)
|
||||
except ProviderTokenNotInitError as e:
|
||||
summary_model_instance = None
|
||||
|
||||
tools = self.to_tools(
|
||||
agent_model_instance=agent_model_instance,
|
||||
tool_configs=tool_configs,
|
||||
conversation_message_task=conversation_message_task,
|
||||
rest_tokens=rest_tokens,
|
||||
@@ -136,11 +145,12 @@ class OrchestratorRuleParser:
|
||||
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
:param agent_model_instance:
|
||||
:param rest_tokens:
|
||||
:param tool_configs: app agent tool configs
|
||||
:param conversation_message_task:
|
||||
@@ -158,7 +168,7 @@ class OrchestratorRuleParser:
|
||||
if tool_type == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
|
||||
elif tool_type == "web_reader":
|
||||
tool = self.to_web_reader_tool()
|
||||
tool = self.to_web_reader_tool(agent_model_instance)
|
||||
elif tool_type == "google_search":
|
||||
tool = self.to_google_search_tool()
|
||||
elif tool_type == "wikipedia":
|
||||
@@ -203,24 +213,28 @@ class OrchestratorRuleParser:
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self) -> Optional[BaseTool]:
|
||||
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
|
||||
:return:
|
||||
"""
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
try:
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=agent_model_instance.model_provider.provider_name,
|
||||
model_name=agent_model_instance.name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
),
|
||||
deduct_quota=False
|
||||
)
|
||||
)
|
||||
|
||||
summary_llm = summary_model_instance.client
|
||||
except ProviderTokenNotInitError:
|
||||
summary_model_instance = None
|
||||
|
||||
tool = WebReaderTool(
|
||||
llm=summary_llm,
|
||||
llm=summary_model_instance.client if summary_model_instance else None,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
@@ -248,11 +262,7 @@ class OrchestratorRuleParser:
|
||||
return tool
|
||||
|
||||
def to_current_datetime_tool(self) -> Optional[BaseTool]:
|
||||
tool = Tool(
|
||||
name="current_datetime",
|
||||
description="A tool when you want to get the current date, time, week, month or year, "
|
||||
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
|
||||
func=helper.get_current_datetime,
|
||||
tool = DatetimeTool(
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
from core.model_providers.error import LLMError
|
||||
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
|
||||
@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
json_string = text.strip()
|
||||
json_obj = json.loads(json_string)
|
||||
action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL)
|
||||
if action_match is not None:
|
||||
json_obj = json.loads(action_match.group(1).strip(), strict=False)
|
||||
else:
|
||||
raise LLMError("Could not parse LLM output: {text}")
|
||||
|
||||
return json_obj
|
||||
|
||||
@@ -39,18 +39,18 @@ MORE_LIKE_THIS_GENERATE_PROMPT = (
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"The output must be in JSON format following the specified schema:\n"
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||
)
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n"
|
||||
'The user will send a long text. Please think step by step.'
|
||||
'Step 1: Understand and summarize the main content of this text.\n'
|
||||
'Step 2: What key information or concepts are mentioned in this text?\n'
|
||||
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
|
||||
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
|
||||
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
|
||||
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
"Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
)
|
||||
|
||||
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||
|
||||
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
|
||||
class XinferenceEmbedding(XinferenceEmbeddings):
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
vectors = super().embed_documents(texts)
|
||||
|
||||
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
||||
|
||||
return normalized_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
vector = super().embed_query(text)
|
||||
|
||||
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
return normalized_vector
|
||||
84
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
84
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenLLM(LLM):
|
||||
"""OpenLLM, supporting both in-process model
|
||||
instance and remote OpenLLM servers.
|
||||
|
||||
If you have a OpenLLM server running, you can also use it remotely:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenLLM
|
||||
llm = OpenLLM(server_url='http://localhost:3000')
|
||||
llm("What is the difference between a duck and a goose?")
|
||||
"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to be passed to openllm.LLM"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openllm"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"llm_config": self.llm_kwargs
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(
|
||||
f'{self.server_url}/v1/generate',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
completion = json_response["responses"][0]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
raise NotImplementedError(
|
||||
"Async call is not supported for OpenLLM at the moment."
|
||||
)
|
||||
5
api/core/third_party/langchain/llms/spark.py
vendored
5
api/core/third_party/langchain/llms/spark.py
vendored
@@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
client = SparkLLMClient(
|
||||
model_name="<model_name>",
|
||||
app_id="<app_id>",
|
||||
api_key="<api_key>",
|
||||
api_secret="<api_secret>"
|
||||
@@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
|
||||
"""
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
model_name: str = "spark"
|
||||
"""The Spark model name."""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
@@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
|
||||
)
|
||||
|
||||
values["client"] = SparkLLMClient(
|
||||
model_name=values["model_name"],
|
||||
app_id=values["app_id"],
|
||||
api_key=values["api_key"],
|
||||
api_secret=values["api_secret"],
|
||||
|
||||
26
api/core/third_party/langchain/llms/wenxin.py
vendored
26
api/core/third_party/langchain/llms/wenxin.py
vendored
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
@@ -223,11 +224,24 @@ class Wenxin(LLM):
|
||||
for token in self._client.post(request).iter_lines():
|
||||
if token:
|
||||
token = token.decode("utf-8")
|
||||
completion = json.loads(token[5:])
|
||||
|
||||
yield GenerationChunk(text=completion['result'])
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(completion['result'])
|
||||
if token.startswith('data:'):
|
||||
completion = json.loads(token[5:])
|
||||
|
||||
if completion['is_end']:
|
||||
break
|
||||
yield GenerationChunk(text=completion['result'])
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(completion['result'])
|
||||
|
||||
if completion['is_end']:
|
||||
break
|
||||
else:
|
||||
try:
|
||||
json_response = json.loads(token)
|
||||
except JSONDecodeError:
|
||||
raise ValueError(f"Wenxin Response Error {token}")
|
||||
|
||||
raise ValueError(
|
||||
f"Wenxin API {json_response['error_code']}"
|
||||
f" error: {json_response['error_msg']}, "
|
||||
f"please confirm if the model you have chosen is already paid for."
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user