Compare commits

...

68 Commits

Author SHA1 Message Date
John Wang
0d942d99dc fix: web reader tool retrieve content empty 2023-08-22 17:52:57 +08:00
takatost
78d3aa5fcd fix: embedding init err (#956) 2023-08-22 17:43:59 +08:00
zxhlyh
a7c78d2cd2 fix: spark provider field name (#955) 2023-08-22 17:28:18 +08:00
Joel
4db35fa375 chore: obsolete info api use new api (#954) 2023-08-22 16:59:57 +08:00
Joel
e67a1413b6 chore: create btn to first place (#953) 2023-08-22 16:20:56 +08:00
takatost
4f3053a8cc fix: xinference chat completion error (#952) 2023-08-22 15:58:04 +08:00
zxhlyh
b3c2bf125f Feat/model providers (#951) 2023-08-22 15:38:12 +08:00
zxhlyh
9d5299e9ec fix: segment error tip & save segment disable when loading (#949) 2023-08-22 15:22:16 +08:00
Jyong
aee15adf1b update document segment (#948)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-22 15:19:09 +08:00
zxhlyh
b185a70c21 Fix/speech to text button (#947) 2023-08-22 14:55:20 +08:00
takatost
a3aba7a9aa fix: provider model not delete when reset key pair (#946) 2023-08-22 13:48:58 +08:00
takatost
866ee5da91 fix: openllm generate cutoff (#945) 2023-08-22 13:43:36 +08:00
Matri
e8039a7da8 fix: add flex-wrap to categories container (#944) 2023-08-22 13:39:52 +08:00
bowen
5e0540077a chore: perfect type definition (#940) 2023-08-22 10:58:06 +08:00
Matri
b346bd9b83 fix: default language improvement in activation page (#942) 2023-08-22 09:28:37 +08:00
Matri
062e2e915b fix: login improvement (#941) 2023-08-21 21:26:32 +08:00
takatost
e0a48c4972 fix: xinference chat support (#939) 2023-08-21 20:44:29 +08:00
zxhlyh
f53242c081 Feat/add document status tooltip (#937) 2023-08-21 18:07:51 +08:00
Jyong
4b53bb1a32 Feat/token support (#909)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
2023-08-21 13:57:18 +08:00
takatost
4c49ecedb5 feat: optimize web reader summary in 3.5 (#933) 2023-08-21 11:58:01 +08:00
takatost
4ff1870a4b fix: web reader tool missing nodejs (#932) 2023-08-21 11:26:11 +08:00
takatost
6c832ee328 fix: remove openllm pypi package because of this package too large (#931) 2023-08-21 02:12:28 +08:00
takatost
25264e7852 feat: add xinference embedding model support (#930) 2023-08-20 19:35:07 +08:00
takatost
18dd0d569d fix: xinference max_tokens alisa error (#929) 2023-08-20 19:12:52 +08:00
takatost
3ea8d7a019 feat: add openllm support (#928) 2023-08-20 19:04:33 +08:00
takatost
da3f10a55e feat: server xinference support (#927) 2023-08-20 17:46:41 +08:00
Benjamin
8c991b5b26 Fix Readme.md typo error. (#926) 2023-08-20 12:02:04 +08:00
takatost
22c1aafb9b fix: document paused at format error (#925) 2023-08-20 01:54:12 +08:00
takatost
8d6d1c442b feat: optimize generate name length (#924) 2023-08-19 23:34:38 +08:00
takatost
95b179fb39 fix: replicate text generation model validate (#923) 2023-08-19 21:40:42 +08:00
takatost
3a0a9e2d8f fix: embedding get price definition missing (#922) 2023-08-19 21:31:40 +08:00
takatost
0a0d63457d feat: record price unit in messages (#919) 2023-08-19 18:51:40 +08:00
takatost
920fb6d0e1 fix: embedding price config (#918) 2023-08-19 16:54:08 +08:00
Krasus.Chen
fd0fc8f4fe Fix/price calc (#862) 2023-08-19 16:41:35 +08:00
takatost
1c552ff23a fix: azure embedding model credentials include base_model_name is invalid for openai sdk (#917) 2023-08-19 16:24:18 +08:00
takatost
5163dd38e5 fix: run extra model serval ex not return (#916) 2023-08-19 14:35:16 +08:00
takatost
2a27dad2fb fix: run model serval ex not return (#915) 2023-08-19 14:16:41 +08:00
takatost
930f74c610 feat: remove unuse envs (#912) 2023-08-18 21:34:28 +08:00
takatost
3f250c9e12 Update README_CN.md 2023-08-18 20:39:40 +08:00
takatost
fa408d264c Update README.md 2023-08-18 20:38:52 +08:00
takatost
09ea27f1ee feat: optimize service api authorization header invalid error (#910) 2023-08-18 20:32:44 +08:00
Jyong
db7156dafd Feature/mutil embedding model (#908)
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-18 17:37:31 +08:00
zxhlyh
4420281d96 Feat/segment add tag (#907) 2023-08-18 17:18:58 +08:00
takatost
d9afebe216 feat: optimize output parse (#906) 2023-08-18 17:00:40 +08:00
takatost
1d9cc5ca05 fix: universal chat when default model invalid (#905) 2023-08-18 16:20:42 +08:00
takatost
edb06f6aed fix: react router agent direct output (#904) 2023-08-18 14:31:20 +08:00
takatost
6ca3bcbcfd fix: sensitive_word_avoidance npe (#902) 2023-08-18 11:43:56 +08:00
Rhon Joe
71a9d63232 fix entrypoint script line endings (#900) 2023-08-18 10:42:44 +08:00
zxhlyh
fb62017e50 Fix/embedding chat (#899)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-08-18 10:39:05 +08:00
takatost
9adbeadeec feat: claude paid optimize (#890) 2023-08-17 16:56:20 +08:00
takatost
2f7b234cc5 fix: max token not exist in generate summary when calc rest tokens (#891) 2023-08-17 16:33:32 +08:00
zxhlyh
4f5f9506ab Feat/pay modal (#889) 2023-08-17 15:49:22 +08:00
takatost
0cc0b6e052 fix: error raise status code not exist (#888) 2023-08-17 15:33:35 +08:00
Joel
cd78adb0ab feat: support show model display name (#887) 2023-08-17 15:13:35 +08:00
takatost
f42e7d1a61 feat: add spark v2 support (#885) 2023-08-17 15:08:57 +08:00
takatost
c4d759dfba fix: wenxin error not raise when stream mode (#884) 2023-08-17 13:40:00 +08:00
zxhlyh
a58f95fa91 fix: web dockfile (#883) 2023-08-17 13:07:07 +08:00
takatost
39574dcf6b feat: optimize prompt of suggested_questions_after_answer (#881) 2023-08-17 10:46:33 +08:00
Matri
5b06ded0b1 build: improve dockerfile (#851)
Co-authored-by: MatriQi <matri@aifi.io>
2023-08-17 10:25:11 +08:00
Matri
155a4733f6 Feat/customizable file upload config (#818) 2023-08-16 23:14:27 +08:00
takatost
b7c29ea1b6 feat: optimize model when app create (#875) 2023-08-16 22:29:18 +08:00
takatost
cc2d71c253 feat: optimize override app model config convert (#874) 2023-08-16 20:48:42 +08:00
Panmuse
cd11613952 Update README.md (#865) 2023-08-16 19:26:35 +08:00
Panmuse
e0d6d00a87 Update README_CN.md (#867) 2023-08-16 19:26:11 +08:00
takatost
2dfb3e95f6 feat: optimize error record in agent (#869) 2023-08-16 15:55:42 +08:00
Jyong
f207e180df fix multi thread app context (#868)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-16 15:39:31 +08:00
takatost
948d64bbef fix: get_num_tokens_from_messages params error (#866) 2023-08-16 14:58:44 +08:00
Joel
01e912e543 fix: promptEng menu in wrong place (#864) 2023-08-16 14:56:17 +08:00
242 changed files with 10622 additions and 1560 deletions

38
.github/workflows/api-unit-tests.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: Run Pytest
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/unit_tests

View File

@@ -24,17 +24,18 @@ Visual data analysis, log review, and annotation for applications
- [x] **Anthropic**: Claude2, Claude-instant
- [x] **Replicate**
- [x] **Hugging Face Hub**
- [x] **ChatGLM**
- [x] **Llama2**
- [x] **MiniMax**
- [x] **Spark**
- [x] **Wenxin**
- [x] **Tongyi**
- [x] **ChatGLM**
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
* 1000 free Claude model queries to build Claude-powered apps
* 600,000 free Claude model tokens to build Claude-powered apps
* 200 free OpenAI queries to build OpenAI-based apps
* 3 million Xunfei Spark Tokens are provided for creating AI applications based on Spark.
* 1 million MiniMax Tokens are provided for creating AI applications based on the MiniMax.
**2. Visual orchestration:** Build an AI app in minutes by writing and debugging prompts visually.
@@ -94,8 +95,6 @@ Features under development:
We will support more datasets, including text, webpages, and even Notion content. Users can build AI applications based on their own data sources.
- **Plugins**, introducing ChatGPT Plugin-standard plugins for applications, or using Dify-produced plugins
We will release plugins complying with ChatGPT standard, or Dify's own plugins to enable more capabilities in applications.
- **Open-source models**, e.g. adopting Llama as a model provider or for further fine-tuning
We will work with excellent open-source models like Llama, by providing them as model options in our platform, or using them for further fine-tuning.
## Q&A

View File

@@ -27,14 +27,16 @@
- [x] **Anthropic**Claude2、Claude-instant
- [x] **Replicate**
- [x] **Hugging Face Hub**
- [x] **ChatGLM**
- [x] **Llama2**
- [x] **MiniMax**
- [x] **讯飞星火大模型**
- [x] **文心一言**
- [x] **通义千问**
- [x] **ChatGLM**
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
* 1000 次 Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用
@@ -90,8 +92,6 @@ docker compose up -d
- **数据集**支持更多的数据集通过网页、API 同步内容。用户可以根据自己的数据源构建 AI 应用程序。
- **插件**,我们将发布符合 ChatGPT 标准的插件,支持更多 Dify 自己的插件,支持用户自定义插件能力,以在应用程序中启用更多功能,例如以支持以目标为导向的分解推理任务。
- **开源模型支持**,支持 Hugging face Hub 上的开源模型。例如采用 Llama 作为模型提供者,或进行进一步的微调
我们将与优秀的开源模型合作,通过在我们的平台中提供它们作为模型选项,或使用它们进行进一步的微调。
## Q&A

View File

@@ -117,10 +117,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=

View File

@@ -16,7 +16,7 @@ EXPOSE 5001
WORKDIR /app/api
RUN apt-get update && \
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
COPY requirements.txt /app/api/requirements.txt

View File

@@ -20,7 +20,7 @@ from models.model import Account
import secrets
import base64
from models.provider import Provider, ProviderType, ProviderQuotaType
from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
@click.command('reset-password', help='Reset the account password.')
@@ -102,6 +102,7 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
db.session.query(ProviderModel).delete()
db.session.commit()
click.echo(click.style('Congratulations! '
@@ -258,6 +259,8 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
new_quota_limit = hosted_model_providers.anthropic.quota_limit
page = 1
while True:
try:
@@ -265,6 +268,7 @@ def sync_anthropic_hosted_providers():
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_limit != new_quota_limit
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
@@ -272,9 +276,9 @@ def sync_anthropic_hosted_providers():
page += 1
for provider in providers:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
original_quota_limit = provider.quota_limit
new_quota_limit = hosted_model_providers.anthropic.quota_limit
division = math.ceil(new_quota_limit / 1000)
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \

View File

@@ -48,21 +48,23 @@ DEFAULTS = {
'WEAVIATE_GRPC_ENABLED': 'True',
'WEAVIATE_BATCH_SIZE': 100,
'CELERY_BACKEND': 'database',
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
'UPLOAD_FILE_BATCH_LIMIT': 5,
}
@@ -104,7 +106,6 @@ class Config:
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
self.TESTING = False
self.LOG_LEVEL = get_env('LOG_LEVEL')
self.PDF_PREVIEW = get_bool_env('PDF_PREVIEW')
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
@@ -209,7 +210,7 @@ class Config:
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
@@ -217,23 +218,21 @@ class Config:
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
@@ -244,6 +243,10 @@ class Config:
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
# uploading settings
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
class CloudEditionConfig(Config):

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
import flask_restful
from flask_restful import Resource, fields, marshal_with
from werkzeug.exceptions import Forbidden

View File

@@ -1,8 +1,11 @@
# -*- coding:utf-8 -*-
import json
import logging
from datetime import datetime
from flask_login import login_required, current_user
import flask
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden
@@ -11,7 +14,9 @@ from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField
@@ -124,12 +129,39 @@ class AppListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
default_model = ModelFactory.get_text_generation_model(
tenant_id=current_user.current_tenant_id
)
except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None
except Exception as e:
logging.exception(e)
default_model = None
if args['model_config'] is not None:
# validate config
model_config_dict = args['model_config']
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
model_config_dict["model"]["provider"]
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
model_config_dict["model"]["name"] = default_model.name
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=args['model_config']
config=model_config_dict
)
app = App(
@@ -141,21 +173,8 @@ class AppListApi(Resource):
status='normal'
)
app_model_config = AppModelConfig(
provider="",
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
)
app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(model_configuration)
else:
if 'mode' not in args or args['mode'] is None:
abort(400, message="mode is required")
@@ -165,20 +184,22 @@ class AppListApi(Resource):
app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config'])
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
app_model_config.model_dict["provider"]
)
if default_model:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.provider_name
model_dict['name'] = default_model.model_name
app_model_config.model = json.dumps(model_dict)
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.model_provider.provider_name
model_dict['name'] = default_model.name
app_model_config.model = json.dumps(model_dict)
app.name = args['name']
app.mode = args['mode']
@@ -297,7 +318,7 @@ class AppApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app = _get_app(app_id, current_user.current_tenant_id)
db.session.delete(app)
@@ -416,22 +437,9 @@ class AppCopy(Resource):
@staticmethod
def create_app_model_config_copy(app_config, copy_app_id):
copy_app_model_config = AppModelConfig(
app_id=copy_app_id,
provider=app_config.provider,
model_id=app_config.model_id,
configs=app_config.configs,
opening_statement=app_config.opening_statement,
suggested_questions=app_config.suggested_questions,
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model,
user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt,
agent_mode=app_config.agent_mode
)
copy_app_model_config = app_config.copy()
copy_app_model_config.app_id = copy_app_id
return copy_app_model_config
@setup_required

View File

@@ -2,7 +2,7 @@
import logging
from flask import request
from flask_login import login_required
from core.login.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services

View File

@@ -5,7 +5,7 @@ from typing import Generator, Union
import flask_login
from flask import Response, stream_with_context
from flask_login import login_required
from core.login.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services

View File

@@ -1,7 +1,8 @@
from datetime import datetime
import pytz
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from flask_restful.inputs import int_range
from sqlalchemy import or_, func

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -3,7 +3,7 @@ import logging
from typing import Union, Generator
from flask import Response, stream_with_context
from flask_login import current_user, login_required
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
@@ -16,6 +16,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.login.login import login_required
from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db

View File

@@ -3,12 +3,13 @@ import json
from flask import request
from flask_restful import Resource
from flask_login import login_required, current_user
from flask_login import current_user
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.login.login import login_required
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from models.model import AppModelConfig
@@ -35,20 +36,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig(
app_id=app_model.id,
provider="",
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config)
db.session.flush()

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from werkzeug.exceptions import NotFound, Forbidden

View File

@@ -4,7 +4,8 @@ from datetime import datetime
import pytz
from flask import jsonify
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -5,9 +5,12 @@ from typing import Optional
import flask_login
import requests
from flask import request, redirect, current_app, session
from flask_login import current_user, login_required
from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from core.login.login import login_required
from libs.oauth_data_source import NotionOAuth
from controllers.console import api
from ..setup import setup_required

View File

@@ -3,7 +3,8 @@ import json
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
from werkzeug.exceptions import NotFound
@@ -21,10 +22,6 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
cache = TTLCache(maxsize=None, ttl=30)
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
PREVIEW_WORDS_LIMIT = 3000
class DataSourceApi(Resource):
integrate_icon_fields = {

View File

@@ -1,6 +1,7 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
@@ -10,13 +11,15 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
dataset_detail_fields = {
'id': fields.String,
@@ -33,6 +36,9 @@ dataset_detail_fields = {
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
@@ -74,8 +80,22 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0:
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
model_names = [item['model_name'] for item in valid_model_list]
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['embedding_model'] in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
response = {
'data': marshal(datasets, dataset_detail_fields),
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
@@ -99,7 +119,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
@@ -233,6 +252,8 @@ class DatasetIndexingEstimateApi(Resource):
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
@@ -250,11 +271,14 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'])
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
@@ -262,11 +286,14 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response, 200

View File

@@ -4,7 +4,8 @@ from datetime import datetime
from typing import List
from flask import request
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc
from werkzeug.exceptions import NotFound, Forbidden
@@ -274,6 +275,7 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
@@ -282,14 +284,19 @@ class DatasetDocumentListApi(Resource):
# validate args
DocumentService.document_create_args_validate(args)
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
@@ -328,6 +335,7 @@ class DatasetInitApi(Resource):
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
try:
@@ -406,11 +414,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict)
data_process_rule_dict, None, dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response
@@ -473,22 +483,27 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict)
data_process_rule_dict, None, dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif dataset.data_source_type:
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif dataset.data_source_type == 'notion_import':
indexing_runner = IndexingRunner()
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict)
data_process_rule_dict,
None, dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response
@@ -575,7 +590,8 @@ class DocumentIndexingStatusApi(DocumentResource):
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
return marshal(document, self.document_status_fields)
@@ -749,11 +765,13 @@ class DocumentMetadataApi(DocumentResource):
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {}
for key, value_type in metadata_schema.items():
value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
if doc_type == 'others':
document.doc_metadata = doc_metadata
else:
for key, value_type in metadata_schema.items():
value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
document.doc_type = doc_type
document.updated_at = datetime.utcnow()
@@ -832,6 +850,22 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id)
return {'result': 'success'}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError('Document is not archived.')
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.utcnow()
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
return {'result': 'success'}, 200
else:
raise InvalidActionError()

View File

@@ -1,15 +1,20 @@
# -*- coding:utf-8 -*-
import uuid
from datetime import datetime
from flask_login import login_required, current_user
from flask import request
from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import InvalidActionError
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.login.login import login_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@@ -17,7 +22,9 @@ from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
import pandas as pd
segment_fields = {
'id': fields.String,
@@ -152,6 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
@@ -197,7 +218,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_segment_from_index_task.delay(segment.id)
disable_segment_from_index_task.delay(segment.id)
return {'result': 'success'}, 200
else:
@@ -222,6 +243,19 @@ class DatasetDocumentSegmentAddApi(Resource):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
@@ -233,7 +267,7 @@ class DatasetDocumentSegmentAddApi(Resource):
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
@@ -245,6 +279,61 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -270,17 +359,88 @@ class DatasetDocumentSegmentUpdateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document)
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
# Skip the first row
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
if document.doc_form == 'qa_model':
data = {'content': row[0], 'answer': row[1]}
else:
data = {'content': row[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
current_user.current_tenant_id, current_user.id)
except Exception as e:
return {'error': str(e)}, 500
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
'job_id': job_id,
'job_status': 'waiting'
}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, job_id):
job_id = str(job_id)
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
@@ -292,3 +452,6 @@ api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentSegmentBatchImportApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
'/datasets/batch_import_status/<uuid:job_id>')

View File

@@ -8,7 +8,8 @@ from pathlib import Path
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields
from werkzeug.exceptions import NotFound
@@ -25,12 +26,28 @@ from models.model import UploadFile
cache = TTLCache(maxsize=None, ttl=30)
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx']
PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource):
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(upload_config_fields)
def get(self):
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
return {
'file_size_limit': file_size_limit,
'batch_count_limit': batch_count_limit
}, 200
file_fields = {
'id': fields.String,
'name': fields.String,
@@ -60,8 +77,9 @@ class FileApi(Resource):
file_content = file.read()
file_size = len(file_content)
if file_size > FILE_SIZE_LIMIT:
message = "({file_size} > {FILE_SIZE_LIMIT})"
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if file_size > file_size_limit:
message = "({file_size} > {file_size_limit})"
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]

View File

@@ -1,6 +1,7 @@
import logging
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, marshal, fields
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
@@ -11,7 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@@ -102,6 +104,10 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ValueError as e:
raise ValueError(str(e))
except Exception as e:

View File

@@ -1,7 +1,8 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
from sqlalchemy import and_
from werkzeug.exceptions import NotFound, Forbidden, BadRequest

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import and_

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource
from functools import wraps

View File

@@ -1,7 +1,8 @@
import json
from functools import wraps
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required

View File

@@ -38,12 +38,20 @@ class StripeWebhookApi(Resource):
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)
logging.debug(session.line_items['data'][0]['quantity'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event)
provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e:
logging.debug(str(e))
return 'success', 200

View File

@@ -3,7 +3,8 @@ from datetime import datetime
import pytz
from flask import current_app, request
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

View File

@@ -1,6 +1,7 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
import services

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,6 +1,7 @@
import json
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -2,10 +2,13 @@
import logging
from flask import request
from flask_login import login_required, current_user
from flask_restful import Resource, fields, marshal_with, reqparse, marshal
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.setup import setup_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import account_initialization_required
@@ -43,6 +46,13 @@ tenants_fields = {
'current': fields.Boolean
}
workspace_fields = {
'id': fields.String,
'name': fields.String,
'status': fields.String,
'created_at': TimestampField
}
class TenantListApi(Resource):
@setup_required
@@ -57,6 +67,38 @@ class TenantListApi(Resource):
return {'workspaces': marshal(tenants, tenants_fields)}, 200
class WorkspaceListApi(Resource):
@setup_required
@admin_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
.paginate(page=args['page'], per_page=args['limit'])
has_more = False
if len(tenants.items) == args['limit']:
current_page_first_tenant = tenants[-1]
rest_count = db.session.query(Tenant).filter(
Tenant.created_at < current_page_first_tenant.created_at,
Tenant.id != current_page_first_tenant.id
).count()
if rest_count > 0:
has_more = True
total = db.session.query(Tenant).count()
return {
'data': marshal(tenants.items, workspace_fields),
'has_more': has_more,
'limit': args['limit'],
'page': args['page'],
'total': total
}, 200
class TenantApi(Resource):
@setup_required
@login_required
@@ -92,6 +134,7 @@ class SwitchWorkspaceApi(Resource):
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant

View File

@@ -17,7 +17,7 @@ def validate_app_token(view=None):
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).get(api_token.app_id)
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
if not app_model:
raise NotFound()
@@ -44,7 +44,7 @@ def validate_dataset_token(view=None):
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('dataset')
dataset = db.session.query(Dataset).get(api_token.dataset_id)
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
if not dataset:
raise NotFound()
@@ -64,14 +64,14 @@ def validate_and_get_api_token(scope=None):
Validate and get API token.
"""
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized()
if auth_header is None or ' ' not in auth_header:
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized()
raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = db.session.query(ApiToken).filter(
ApiToken.token == auth_token,
@@ -79,7 +79,7 @@ def validate_and_get_api_token(scope=None):
).first()
if not api_token:
raise Unauthorized()
raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.utcnow()
db.session.commit()

View File

@@ -59,7 +59,11 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
return super().plan(intermediate_steps, callbacks, **kwargs)
try:
return super().plan(intermediate_steps, callbacks, **kwargs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
async def aplan(
self,

View File

@@ -50,9 +50,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})

View File

@@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:
@@ -66,12 +66,12 @@ class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
return new_messages
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, llm)
llm = cast(ChatOpenAI, model_instance.client)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n

View File

@@ -50,9 +50,13 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})

View File

@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
action_match = re.search(r"```(.*?)\n?(.*?)```", text, re.DOTALL)
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
@@ -26,4 +26,4 @@ class StructuredChatOutputParser(LCStructuredChatOutputParser):
else:
return AgentFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
raise OutputParserException(f"Could not parse LLM output: {text}")

View File

@@ -94,7 +94,12 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try:
return self.output_parser.parse(full_output)

View File

@@ -52,7 +52,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:
@@ -89,8 +89,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
@@ -99,7 +99,11 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try:
return self.output_parser.parse(full_output)
@@ -108,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2:
if len(intermediate_steps) >= 2 and self.summary_llm:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]

View File

@@ -32,7 +32,7 @@ class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_instance: BaseLLM
tools: list[BaseTool]
summary_model_instance: BaseLLM
summary_model_instance: BaseLLM = None
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
@@ -65,7 +65,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
@@ -74,7 +75,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
@@ -83,7 +85,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:

View File

@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.status = 'llm_end'
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
@@ -81,11 +86,15 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.exception(error)
logging.debug("Agent on_llm_error: %s", error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
@@ -164,7 +173,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.exception(error)
logging.debug("Agent on_tool_error: %s", error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None

View File

@@ -68,4 +68,4 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.exception(error)
logging.debug("Dataset tool on_llm_error: %s", error)

View File

@@ -72,5 +72,5 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.exception(error)
logging.debug("Dataset tool on_chain_error: %s", error)
self.clear_chain_results()

View File

@@ -126,7 +126,7 @@ class Completion:
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# get llm prompt

View File

@@ -60,17 +60,7 @@ class ConversationMessageTask:
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = {
"model": self.app_model_config.model_dict,
"pre_prompt": self.app_model_config.pre_prompt,
"agent_mode": self.app_model_config.agent_mode_dict,
"opening_statement": self.app_model_config.opening_statement,
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
override_model_configs = self.app_model_config.to_dict()
introduction = ''
system_instruction = ''
@@ -129,9 +119,11 @@ class ConversationMessageTask:
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=self.model_instance.get_currency(),
@@ -150,17 +142,24 @@ class ConversationMessageTask:
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price
@@ -202,7 +201,9 @@ class ConversationMessageTask:
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
message_price_unit=0,
answer=agent_loop.completion,
answer_price_unit=0,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
@@ -216,25 +217,26 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_total_price = self.calc_total_price(
loop_message_tokens,
agent_message_unit_price,
loop_answer_tokens,
agent_answer_unit_price
)
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.message_price_unit = agent_message_price_unit
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.answer_price_unit = agent_answer_price_unit
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
@@ -253,15 +255,6 @@ class ConversationMessageTask:
db.session.add(dataset_query)
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def end(self):
self._pub_handler.pub_end()

View File

@@ -10,10 +10,10 @@ from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore:
def __init__(
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
@@ -59,7 +59,7 @@ class DatesetDocumentStore:
return output
def add_documents(
self, docs: Sequence[Document], allow_update: bool = True
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == self._document_id
@@ -69,7 +69,9 @@ class DatesetDocumentStore:
max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id
tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider,
model_name=self._dataset.embedding_model
)
for doc in docs:
@@ -123,7 +125,7 @@ class DatesetDocumentStore:
return result is not None
def get_document(
self, doc_id: str, raise_error: bool = True
self, doc_id: str, raise_error: bool = True
) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id)

View File

@@ -2,7 +2,7 @@ import logging
from langchain.schema import OutputParserException
from core.model_providers.error import LLMError
from core.model_providers.error import LLMError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
@@ -51,6 +51,7 @@ class LLMGenerator:
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
@@ -108,13 +109,16 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
try:
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
)
)
)
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]
@@ -175,8 +179,8 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_qa_document(cls, tenant_id: str, query):
prompt = GENERATOR_QA_PROMPT
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,

View File

@@ -15,7 +15,9 @@ class IndexBuilder:
return None
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)

View File

@@ -7,6 +7,7 @@ import time
import uuid
from typing import Optional, List, cast
from flask import current_app, Flask
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@@ -66,14 +67,6 @@ class IndexingRunner:
dataset_document=dataset_document,
processing_rule=processing_rule
)
# new_documents = []
# for document in documents:
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
# document_qa_list = self.format_split_text(response)
# for result in document_qa_list:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_documents.append(document)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
@@ -224,14 +217,25 @@ class IndexingRunner:
db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict:
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0
preview_texts = []
total_segments = 0
@@ -262,20 +266,19 @@ class IndexingRunner:
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -283,18 +286,31 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion
tokens = 0
@@ -343,20 +359,19 @@ class IndexingRunner:
tokens += embedding_model.get_num_tokens(document.page_content)
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -364,7 +379,7 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
@@ -457,7 +472,8 @@ class IndexingRunner:
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form
document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language
)
# save node to document segment
@@ -493,7 +509,8 @@ class IndexingRunner:
return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
processing_rule: DatasetProcessRule, tenant_id: str,
document_form: str, document_language: str) -> List[Document]:
"""
Split the text documents into nodes.
"""
@@ -522,7 +539,9 @@ class IndexingRunner:
sub_documents = all_documents[i:i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
'flask_app': current_app._get_current_object(),
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
'document_language': document_language})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
@@ -530,28 +549,29 @@ class IndexingRunner:
return all_qa_documents
return all_documents
def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception(e)
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception(e)
all_qa_documents.extend(format_documents)
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
@@ -638,7 +658,9 @@ class IndexingRunner:
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# chunk nodes by chunk size
@@ -719,6 +741,32 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
class DocumentIsPausedException(Exception):
pass

108
api/core/login/login.py Normal file
View File

@@ -0,0 +1,108 @@
import os
from functools import wraps
import flask_login
from flask import current_app
from flask import g
from flask import has_request_context
from flask import request
from flask_login import user_logged_in
from flask_login.config import EXEMPT_METHODS
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = LocalProxy(lambda: _get_user())
def login_required(func):
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
example::
@app.route('/post')
@login_required
def post():
pass
If there are only certain times you need to require that your user is
logged in, you can do so with::
if not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
...which is essentially the code that this function adds to your views.
It can be convenient to globally turn off authentication when unit testing.
To enable this, if the application configuration variable `LOGIN_DISABLED`
is set to `True`, this decorator will be ignored.
.. Note ::
Per `W3 guidelines for CORS preflight requests
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
HTTP ``OPTIONS`` requests are exempt from login checks.
:param func: The view function to decorate.
:type func: function
"""
@wraps(func)
def decorated_view(*args, **kwargs):
auth_header = request.headers.get('Authorization')
admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False')
if admin_api_key_enable:
if auth_header:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
admin_api_key = os.getenv('ADMIN_API_KEY')
if admin_api_key:
if os.getenv('ADMIN_API_KEY') == auth_token:
workspace_id = request.headers.get('X-WORKSPACE-ID')
if workspace_id:
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
.filter(Tenant.id == workspace_id) \
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
.filter(TenantAccountJoin.role == 'owner') \
.one_or_none()
if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user())
if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
pass
elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return decorated_view
def _get_user():
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user()
return g._login_user
return None

View File

@@ -46,7 +46,8 @@ class ModelFactory:
model_name: Optional[str] = None,
model_kwargs: Optional[ModelKwargs] = None,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
callbacks: Callbacks = None,
deduct_quota: bool = True) -> Optional[BaseLLM]:
"""
get text generation model.
@@ -56,6 +57,7 @@ class ModelFactory:
:param model_kwargs:
:param streaming:
:param callbacks:
:param deduct_quota:
:return:
"""
is_default_model = False
@@ -95,7 +97,7 @@ class ModelFactory:
else:
raise e
if is_default_model:
if is_default_model or not deduct_quota:
model_instance.deduct_quota = False
return model_instance

View File

@@ -57,6 +57,12 @@ class ModelProviderFactory:
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
elif provider_name == 'xinference':
from core.model_providers.providers.xinference_provider import XinferenceProvider
return XinferenceProvider
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
else:
raise NotImplementedError

View File

@@ -26,10 +26,20 @@ class AzureOpenAIEmbedding(BaseEmbedding):
openai_api_version=AZURE_OPENAI_API_VERSION,
chunk_size=16,
max_retries=1,
**self.credentials
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base')
)
super().__init__(model_provider, client, name)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, text: str) -> int:
"""
@@ -48,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
@@ -71,7 +71,7 @@ class AzureOpenAIEmbedding(BaseEmbedding):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
return LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import Any
import decimal
import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
import logging
logger = logging.getLogger(__name__)
class BaseEmbedding(BaseProviderModel):
name: str
@@ -17,6 +19,63 @@ class BaseEmbedding(BaseProviderModel):
super().__init__(model_provider, client)
self.name = name
@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def calc_tokens_price(self, tokens: int) -> decimal.Decimal:
"""
calc tokens total price.
:param tokens:
:return: decimal.Decimal('0.0000001')
"""
unit_price = self.price_config['completion']
unit = self.price_config['unit']
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self) -> decimal.Decimal:
"""
get token price.
:return: decimal.Decimal('0.0001')
"""
unit_price = self.price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logger.debug(f'unit_price:{unit_price}')
return unit_price
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
@@ -29,11 +88,14 @@ class BaseEmbedding(BaseProviderModel):
return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self):
return 'USD'
"""
get token currency.
:return: get from price config, default 'USD'
"""
currency = self.price_config['currency']
return currency
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:

View File

@@ -1,6 +1,3 @@
import decimal
import logging
from langchain.embeddings import MiniMaxEmbeddings
from core.model_providers.error import LLMBadRequestError
@@ -22,12 +19,6 @@ class MinimaxEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")

View File

@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
@@ -65,7 +55,7 @@ class OpenAIEmbedding(BaseEmbedding):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")

View File

@@ -0,0 +1,27 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class XinferenceEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = XinferenceEmbeddings(
server_url=credentials['server_url'],
model_uid=credentials['model_uid'],
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
else:
return ex

View File

@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():

View File

@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
@@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003':
@@ -166,7 +135,7 @@ class AzureOpenAIModel(BaseLLM):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
return LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import List, Optional, Any, Union
import decimal
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM
import logging
logger = logging.getLogger(__name__)
class BaseLLM(BaseProviderModel):
@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
def _init_client(self) -> Any:
raise NotImplementedError
@property
def base_model_name(self) -> str:
"""
get llm base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
@@ -125,6 +161,8 @@ class BaseLLM(BaseProviderModel):
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
self.model_provider.update_last_used()
if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)
@@ -159,25 +197,64 @@ class BaseLLM(BaseProviderModel):
"""
raise NotImplementedError
@abstractmethod
def get_token_price(self, tokens: int, message_type: MessageType):
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
"""
get token price.
calc tokens total price.
:param tokens:
:param message_type:
:return:
"""
raise NotImplementedError
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit = self.get_price_unit(message_type)
@abstractmethod
def get_currency(self):
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
"""
get token price.
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"unit_price={unit_price}")
return unit_price
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
"""
get price unit.
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"price_unit={price_unit}")
return price_unit
def get_currency(self) -> str:
"""
get token currency.
:return:
:return: get from price config, default 'USD'
"""
raise NotImplementedError
currency = self.price_config['currency']
return currency
def get_model_kwargs(self):
return self.model_kwargs

View File

@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'

View File

@@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs

View File

@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'

View File

@@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
# TODO load price config from configs(db)
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS:
@@ -185,7 +148,7 @@ class OpenAIModel(BaseLLM):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -0,0 +1,60 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.openllm import OpenLLM
class OpenLLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = OpenLLM(
server_url=self.credentials.get('server_url'),
callbacks=self.callbacks,
llm_kwargs=self.provider_model_kwargs
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs

View File

@@ -1,5 +1,4 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
@@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'

View File

@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'

View File

@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
streaming=self.streaming,
callbacks=self.callbacks,
@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'ernie-bot': {
'prompt': decimal.Decimal('0.012'),
'completion': decimal.Decimal('0.012'),
},
'ernie-bot-turbo': {
'prompt': decimal.Decimal('0.008'),
'completion': decimal.Decimal('0.008')
},
'bloomz-7b': {
'prompt': decimal.Decimal('0.006'),
'completion': decimal.Decimal('0.006')
}
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():

View File

@@ -0,0 +1,70 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
class XinferenceModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = XinferenceLLM(
server_url=self.credentials['server_url'],
model_uid=self.credentials['model_uid'],
)
client.callbacks = self.callbacks
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate(
[prompts],
stop,
callbacks,
generate_config={
"stop": stop,
"stream": self.streaming,
**self.provider_model_kwargs,
}
)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference: {str(ex)}")
@classmethod
def support_streaming(cls):
return True

View File

@@ -41,7 +41,7 @@ class OpenAIModeration(BaseProviderModel):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -40,7 +40,7 @@ class OpenAIWhisper(BaseSpeech2Text):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -183,6 +183,8 @@ class AnthropicProvider(BaseModelProvider):
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
}
return None

View File

@@ -31,7 +31,9 @@ class HostedAnthropic(BaseModel):
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
paid_increase_quota: int = 1000000
paid_min_quantity: int = 20
paid_max_quantity: int = 100
class HostedModelProviders(BaseModel):
@@ -73,4 +75,6 @@ def init_app(app: Flask):
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)

View File

@@ -0,0 +1,138 @@
import json
from typing import Type
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType
class OpenLLMProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'openllm'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenLLMModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
try:
credential_kwargs = {
'server_url': credentials['server_url']
}
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -116,7 +116,8 @@ class ReplicateProvider(BaseModelProvider):
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
elif model_type == ModelType.TEXT_GENERATION \
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
and ('items' not in rst.openapi_schema['components']['schemas']['Output']
or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
except ReplicateError as e:

View File

@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
return [
{
'id': 'spark',
'name': '星火认知大模型',
'name': 'Spark V1.5',
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
}
]
else:

View File

@@ -0,0 +1,193 @@
import json
from typing import Type
import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType
class XinferenceProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'xinference'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = XinferenceModel
elif model_type == ModelType.EMBEDDINGS:
model_class = XinferenceEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
if 'model_uid' not in credentials:
raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
try:
credential_kwargs = {
'server_url': credentials['server_url'],
'model_uid': credentials['model_uid'],
}
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
'model_uid': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def _get_extra_credentials(self, credentials: dict) -> dict:
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get the model description, detail: {response.json()['detail']}"
)
desc = response.json()
extra_credentials = {
'model_format': desc['model_format'],
}
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
extra_credentials['model_handle_type'] = 'chatglm'
elif "generate" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'generate'
elif "chat" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'chat'
else:
raise NotImplementedError(f"Model handle type not supported.")
return extra_credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -8,5 +8,7 @@
"wenxin",
"chatglm",
"replicate",
"huggingface_hub"
"huggingface_hub",
"xinference",
"openllm"
]

View File

@@ -5,10 +5,25 @@
],
"system_config": {
"supported_quota_types": [
"paid",
"trial"
],
"quota_unit": "times",
"quota_limit": 1000
"quota_unit": "tokens",
"quota_limit": 600000
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"claude-instant-1": {
"prompt": "1.63",
"completion": "5.51",
"unit": "0.000001",
"currency": "USD"
},
"claude-2": {
"prompt": "11.02",
"completion": "32.68",
"unit": "0.000001",
"currency": "USD"
}
}
}

View File

@@ -3,5 +3,48 @@
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"price_config":{
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo": {
"prompt": "0.002",
"completion": "0.0015",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-002": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
}

View File

@@ -9,5 +9,24 @@
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"abab5.5-chat": {
"prompt": "0.015",
"completion": "0.015",
"unit": "0.001",
"currency": "RMB"
},
"abab5-chat": {
"prompt": "0.015",
"completion": "0.015",
"unit": "0.001",
"currency": "RMB"
},
"embo-01": {
"completion": "0",
"unit": "0.0001",
"currency": "RMB"
}
}
}

View File

@@ -10,5 +10,42 @@
"quota_unit": "times",
"quota_limit": 200
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -9,5 +9,19 @@
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"spark": {
"prompt": "0.18",
"completion": "0.18",
"unit": "0.0001",
"currency": "RMB"
},
"spark-v2": {
"prompt": "0.36",
"completion": "0.36",
"unit": "0.0001",
"currency": "RMB"
}
}
}

View File

@@ -3,5 +3,25 @@
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot-turbo": {
"prompt": "0.008",
"completion": "0.008",
"unit": "0.001",
"currency": "RMB"
},
"bloomz-7b": {
"prompt": "0.006",
"completion": "0.006",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -14,14 +14,16 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
@@ -78,15 +80,22 @@ class OrchestratorRuleParser:
elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
try:
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=agent_provider_name,
model_name=agent_model_name,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
),
deduct_quota=False
)
)
except ProviderTokenNotInitError as e:
summary_model_instance = None
tools = self.to_tools(
agent_model_instance=agent_model_instance,
tool_configs=tool_configs,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
@@ -136,11 +145,12 @@ class OrchestratorRuleParser:
return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
@@ -158,7 +168,7 @@ class OrchestratorRuleParser:
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool()
tool = self.to_web_reader_tool(agent_model_instance)
elif tool_type == "google_search":
tool = self.to_google_search_tool()
elif tool_type == "wikipedia":
@@ -203,24 +213,28 @@ class OrchestratorRuleParser:
return tool
def to_web_reader_tool(self) -> Optional[BaseTool]:
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
"""
A tool for reading web pages
:return:
"""
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
try:
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=agent_model_instance.model_provider.provider_name,
model_name=agent_model_instance.name,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
),
deduct_quota=False
)
)
summary_llm = summary_model_instance.client
except ProviderTokenNotInitError:
summary_model_instance = None
tool = WebReaderTool(
llm=summary_llm,
llm=summary_model_instance.client if summary_model_instance else None,
max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
@@ -248,11 +262,7 @@ class OrchestratorRuleParser:
return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]:
tool = Tool(
name="current_datetime",
description="A tool when you want to get the current date, time, week, month or year, "
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
func=helper.get_current_datetime,
tool = DatetimeTool(
callbacks=[DifyStdOutCallbackHandler()]
)

View File

@@ -1,7 +1,10 @@
import json
import re
from typing import Any
from langchain.schema import BaseOutputParser
from core.model_providers.error import LLMError
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
def parse(self, text: str) -> Any:
json_string = text.strip()
json_obj = json.loads(json_string)
action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL)
if action_match is not None:
json_obj = json.loads(action_match.group(1).strip(), strict=False)
else:
raise LLMError("Could not parse LLM output: {text}")
return json_obj

View File

@@ -39,18 +39,18 @@ MORE_LIKE_THIS_GENERATE_PROMPT = (
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"The output must be in JSON format following the specified schema:\n"
"The output must be an array in JSON format following the specified schema:\n"
"[\"question1\",\"question2\",\"question3\"]\n"
)
GENERATOR_QA_PROMPT = (
"Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n"
'The user will send a long text. Please think step by step.'
'Step 1: Understand and summarize the main content of this text.\n'
'Step 2: What key information or concepts are mentioned in this text?\n'
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
"Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \

View File

@@ -0,0 +1,21 @@
from typing import List
import numpy as np
from langchain.embeddings import XinferenceEmbeddings
class XinferenceEmbedding(XinferenceEmbeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
vectors = super().embed_documents(texts)
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
return normalized_vectors
def embed_query(self, text: str) -> List[float]:
vector = super().embed_query(text)
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
return normalized_vector

View File

@@ -0,0 +1,84 @@
from __future__ import annotations
import logging
from typing import (
Any,
Dict,
List,
Optional,
)
import requests
from langchain.llms.utils import enforce_stop_tokens
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class OpenLLM(LLM):
"""OpenLLM, supporting both in-process model
instance and remote OpenLLM servers.
If you have a OpenLLM server running, you can also use it remotely:
.. code-block:: python
from langchain.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000')
llm("What is the difference between a duck and a goose?")
"""
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to be passed to openllm.LLM"""
@property
def _llm_type(self) -> str:
return "openllm"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> str:
params = {
"prompt": prompt,
"llm_config": self.llm_kwargs
}
headers = {"Content-Type": "application/json"}
response = requests.post(
f'{self.server_url}/v1/generate',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
json_response = response.json()
completion = json_response["responses"][0]
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
raise NotImplementedError(
"Async call is not supported for OpenLLM at the moment."
)

View File

@@ -25,6 +25,7 @@ class ChatSpark(BaseChatModel):
.. code-block:: python
client = SparkLLMClient(
model_name="<model_name>",
app_id="<app_id>",
api_key="<api_key>",
api_secret="<api_secret>"
@@ -32,6 +33,9 @@ class ChatSpark(BaseChatModel):
"""
client: Any = None #: :meta private:
model_name: str = "spark"
"""The Spark model name."""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
@@ -66,6 +70,7 @@ class ChatSpark(BaseChatModel):
)
values["client"] = SparkLLMClient(
model_name=values["model_name"],
app_id=values["app_id"],
api_key=values["api_key"],
api_secret=values["api_secret"],

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import logging
from json import JSONDecodeError
from typing import (
Any,
Dict,
@@ -223,11 +224,24 @@ class Wenxin(LLM):
for token in self._client.post(request).iter_lines():
if token:
token = token.decode("utf-8")
completion = json.loads(token[5:])
yield GenerationChunk(text=completion['result'])
if run_manager:
run_manager.on_llm_new_token(completion['result'])
if token.startswith('data:'):
completion = json.loads(token[5:])
if completion['is_end']:
break
yield GenerationChunk(text=completion['result'])
if run_manager:
run_manager.on_llm_new_token(completion['result'])
if completion['is_end']:
break
else:
try:
json_response = json.loads(token)
except JSONDecodeError:
raise ValueError(f"Wenxin Response Error {token}")
raise ValueError(
f"Wenxin API {json_response['error_code']}"
f" error: {json_response['error_msg']}, "
f"please confirm if the model you have chosen is already paid for."
)

Some files were not shown because too many files have changed in this diff Show More