mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 11:55:44 +08:00
Compare commits
102 Commits
fix/web-do
...
fix/pub-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
000bb5c2a4 | ||
|
|
3a5c7c75ad | ||
|
|
a7415ecfd8 | ||
|
|
934def5fcc | ||
|
|
0796791de5 | ||
|
|
6c148b223d | ||
|
|
4b168f4838 | ||
|
|
1c114eaef3 | ||
|
|
e053215155 | ||
|
|
13482b0fc1 | ||
|
|
38fa152cc4 | ||
|
|
2d9616c29c | ||
|
|
915e26527b | ||
|
|
2d604d9330 | ||
|
|
e7199826cc | ||
|
|
70e24b7594 | ||
|
|
c1602aafc7 | ||
|
|
a3fec11438 | ||
|
|
b1fd1b3ab3 | ||
|
|
5397799aac | ||
|
|
8e837dde1a | ||
|
|
9ae91a2ec3 | ||
|
|
276d3d10a0 | ||
|
|
f13623184a | ||
|
|
ef61e1487f | ||
|
|
701e2b334f | ||
|
|
6ebd6e7890 | ||
|
|
bd3a9b2f8d | ||
|
|
18d3877151 | ||
|
|
53e83d8697 | ||
|
|
6377fc75c6 | ||
|
|
2c30d19cbe | ||
|
|
9b247fccd4 | ||
|
|
3d38aa7138 | ||
|
|
7d2552b3f2 | ||
|
|
117a209ad4 | ||
|
|
071e7800a0 | ||
|
|
a76fde3d23 | ||
|
|
1fc57d7358 | ||
|
|
916d8be0ae | ||
|
|
a38412de7b | ||
|
|
9c9f0ddb93 | ||
|
|
f8fbe96da4 | ||
|
|
215a27fd95 | ||
|
|
5cba2e7087 | ||
|
|
5623839c71 | ||
|
|
78d3aa5fcd | ||
|
|
a7c78d2cd2 | ||
|
|
4db35fa375 | ||
|
|
e67a1413b6 | ||
|
|
4f3053a8cc | ||
|
|
b3c2bf125f | ||
|
|
9d5299e9ec | ||
|
|
aee15adf1b | ||
|
|
b185a70c21 | ||
|
|
a3aba7a9aa | ||
|
|
866ee5da91 | ||
|
|
e8039a7da8 | ||
|
|
5e0540077a | ||
|
|
b346bd9b83 | ||
|
|
062e2e915b | ||
|
|
e0a48c4972 | ||
|
|
f53242c081 | ||
|
|
4b53bb1a32 | ||
|
|
4c49ecedb5 | ||
|
|
4ff1870a4b | ||
|
|
6c832ee328 | ||
|
|
25264e7852 | ||
|
|
18dd0d569d | ||
|
|
3ea8d7a019 | ||
|
|
da3f10a55e | ||
|
|
8c991b5b26 | ||
|
|
22c1aafb9b | ||
|
|
8d6d1c442b | ||
|
|
95b179fb39 | ||
|
|
3a0a9e2d8f | ||
|
|
0a0d63457d | ||
|
|
920fb6d0e1 | ||
|
|
fd0fc8f4fe | ||
|
|
1c552ff23a | ||
|
|
5163dd38e5 | ||
|
|
2a27dad2fb | ||
|
|
930f74c610 | ||
|
|
3f250c9e12 | ||
|
|
fa408d264c | ||
|
|
09ea27f1ee | ||
|
|
db7156dafd | ||
|
|
4420281d96 | ||
|
|
d9afebe216 | ||
|
|
1d9cc5ca05 | ||
|
|
edb06f6aed | ||
|
|
6ca3bcbcfd | ||
|
|
71a9d63232 | ||
|
|
fb62017e50 | ||
|
|
9adbeadeec | ||
|
|
2f7b234cc5 | ||
|
|
4f5f9506ab | ||
|
|
0cc0b6e052 | ||
|
|
cd78adb0ab | ||
|
|
f42e7d1a61 | ||
|
|
c4d759dfba | ||
|
|
a58f95fa91 |
@@ -53,9 +53,9 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
|
||||
|
||||
## Community channels
|
||||
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help!
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
|
||||
|
||||
### i18n (Internationalization) Support
|
||||
|
||||
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
||||
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
|
||||
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
|
||||
|
||||
@@ -16,15 +16,15 @@
|
||||
|
||||
## 本地开发
|
||||
|
||||
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose 堆栈。
|
||||
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
|
||||
|
||||
### Fork存储库
|
||||
|
||||
您需要 fork [存储库](https://github.com/langgenius/dify)。
|
||||
您需要 fork [Git 仓库](https://github.com/langgenius/dify)。
|
||||
|
||||
### 克隆存储库
|
||||
|
||||
克隆您在 GitHub 上 fork 的存储库:
|
||||
克隆您在 GitHub 上 fork 的仓库:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
|
||||
@@ -52,4 +52,4 @@ git clone git@github.com:<github_username>/dify.git
|
||||
|
||||
## コミュニティチャンネル
|
||||
|
||||
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/AhzKf7dNgk)に参加してください。私たちがお手伝いします!
|
||||
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!
|
||||
|
||||
@@ -30,10 +30,10 @@ Visual data analysis, log review, and annotation for applications
|
||||
- [x] **Spark**
|
||||
- [x] **Wenxin**
|
||||
- [x] **Tongyi**
|
||||
- [x] **ChatGLM**
|
||||
|
||||
|
||||
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
|
||||
* 1000 free Claude model queries to build Claude-powered apps
|
||||
* 600,000 free Claude model tokens to build Claude-powered apps
|
||||
* 200 free OpenAI queries to build OpenAI-based apps
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
|
||||
|
||||
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
|
||||
* 1000 次 Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
|
||||
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
|
||||
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
|
||||
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
|
||||
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用
|
||||
|
||||
@@ -117,10 +117,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_ANTHROPIC_ENABLED=false
|
||||
HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
||||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
||||
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
@@ -16,7 +16,7 @@ EXPOSE 5001
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
|
||||
|
||||
COPY requirements.txt /app/api/requirements.txt
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
@@ -6,10 +7,16 @@ import time
|
||||
|
||||
import click
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from libs.password import password_pattern, valid_password, hash_password
|
||||
from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
@@ -20,7 +27,7 @@ from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -102,6 +109,7 @@ def reset_encrypt_key_pair():
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
|
||||
db.session.query(ProviderModel).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
@@ -258,6 +266,8 @@ def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
new_quota_limit = hosted_model_providers.anthropic.quota_limit
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
@@ -265,6 +275,7 @@ def sync_anthropic_hosted_providers():
|
||||
Provider.provider_name == 'anthropic',
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.quota_limit != new_quota_limit
|
||||
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
@@ -272,9 +283,9 @@ def sync_anthropic_hosted_providers():
|
||||
page += 1
|
||||
for provider in providers:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
|
||||
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
|
||||
original_quota_limit = provider.quota_limit
|
||||
new_quota_limit = hosted_model_providers.anthropic.quota_limit
|
||||
division = math.ceil(new_quota_limit / 1000)
|
||||
|
||||
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
|
||||
@@ -292,6 +303,72 @@ def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
index.create_qdrant_dataset(dataset)
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@@ -300,3 +377,4 @@ def register_commands(app):
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
|
||||
@@ -48,19 +48,19 @@ DEFAULTS = {
|
||||
'WEAVIATE_GRPC_ENABLED': 'True',
|
||||
'WEAVIATE_BATCH_SIZE': 100,
|
||||
'CELERY_BACKEND': 'database',
|
||||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
@@ -100,13 +100,12 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.14"
|
||||
self.CURRENT_VERSION = "0.3.18"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
self.PDF_PREVIEW = get_bool_env('PDF_PREVIEW')
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
@@ -211,7 +210,7 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
@@ -219,23 +218,21 @@ class Config:
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
|
||||
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
|
||||
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))
|
||||
|
||||
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
|
||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
import flask_restful
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -3,7 +3,9 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
import flask
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -316,7 +318,7 @@ class AppApi(Resource):
|
||||
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
db.session.delete(app)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy import or_, func
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Union, Generator
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
@@ -16,6 +16,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.login.login import login_required
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.login.login import login_required
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppModelConfig
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -5,9 +5,12 @@ from typing import Optional
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.login.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
import services
|
||||
@@ -10,13 +11,15 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
@@ -33,6 +36,9 @@ dataset_detail_fields = {
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
@@ -74,8 +80,25 @@ class DatasetListApi(Resource):
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
# if len(valid_model_list) == 0:
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
# f"in the Settings -> Model Provider.")
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
response = {
|
||||
'data': marshal(datasets, dataset_detail_fields),
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
@@ -99,7 +122,6 @@ class DatasetListApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
@@ -233,6 +255,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
@@ -250,11 +274,14 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'])
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
@@ -262,11 +289,14 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'])
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response, 200
|
||||
|
||||
@@ -3,8 +3,9 @@ import random
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@@ -274,6 +275,8 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
@@ -282,14 +285,19 @@ class DatasetDocumentListApi(Resource):
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
@@ -328,6 +336,8 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@@ -406,11 +416,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response
|
||||
|
||||
@@ -473,22 +486,28 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
elif dataset.data_source_type:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict,
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
@@ -575,7 +594,8 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
return marshal(document, self.document_status_fields)
|
||||
|
||||
|
||||
@@ -749,11 +769,13 @@ class DocumentMetadataApi(DocumentResource):
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
document.doc_metadata = {}
|
||||
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
if doc_type == 'others':
|
||||
document.doc_metadata = doc_metadata
|
||||
else:
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.utcnow()
|
||||
@@ -832,12 +854,40 @@ class DocumentStatusApi(DocumentResource):
|
||||
|
||||
remove_document_from_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
elif action == "un_archive":
|
||||
if not document.archived:
|
||||
raise InvalidActionError('Document is not archived.')
|
||||
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + 1
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
|
||||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_document_to_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
|
||||
class DocumentPauseApi(DocumentResource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -867,6 +917,9 @@ class DocumentPauseApi(DocumentResource):
|
||||
|
||||
|
||||
class DocumentRecoverApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -892,6 +945,21 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentLimitApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""get document limit"""
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
|
||||
return {
|
||||
'documents_count': documents_count,
|
||||
'documents_limit': tenant_document_count
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@@ -917,3 +985,4 @@ api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentLimitApi, '/datasets/limit')
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import InvalidActionError
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.login.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
@@ -17,7 +22,9 @@ from models.dataset import DocumentSegment
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
import pandas as pd
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
@@ -152,6 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
@@ -197,7 +218,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
# Set cache to prevent indexing the same segment multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_segment_from_index_task.delay(segment.id)
|
||||
disable_segment_from_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
@@ -222,6 +243,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@@ -233,7 +267,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
@@ -245,6 +279,61 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@@ -270,17 +359,88 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document)
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == 'qa_model':
|
||||
data = {'content': row[0], 'answer': row[1]}
|
||||
else:
|
||||
data = {'content': row[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {'error': str(e)}, 500
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': cache_result.decode()
|
||||
}, 200
|
||||
|
||||
|
||||
@@ -292,3 +452,6 @@ api.add_resource(DatasetDocumentSegmentAddApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
api.add_resource(DatasetDocumentSegmentBatchImportApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
|
||||
'/datasets/batch_import_status/<uuid:job_id>')
|
||||
|
||||
@@ -8,7 +8,8 @@ from pathlib import Path
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, fields
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
@@ -11,7 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@@ -102,6 +104,10 @@ class HitTestingApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
|
||||
@@ -38,12 +38,20 @@ class StripeWebhookApi(Resource):
|
||||
logging.debug(event['data']['object']['payment_status'])
|
||||
logging.debug(event['data']['object']['metadata'])
|
||||
|
||||
session = stripe.checkout.Session.retrieve(
|
||||
event['data']['object']['id'],
|
||||
expand=['line_items'],
|
||||
)
|
||||
|
||||
logging.debug(session.line_items['data'][0]['quantity'])
|
||||
|
||||
# Fulfill the purchase...
|
||||
provider_checkout_service = ProviderCheckoutService()
|
||||
|
||||
try:
|
||||
provider_checkout_service.fulfill_provider_order(event)
|
||||
provider_checkout_service.fulfill_provider_order(event, session.line_items)
|
||||
except Exception as e:
|
||||
|
||||
logging.debug(str(e))
|
||||
return 'success', 200
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@@ -43,6 +46,13 @@ tenants_fields = {
|
||||
'current': fields.Boolean
|
||||
}
|
||||
|
||||
workspace_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'status': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@@ -57,6 +67,38 @@ class TenantListApi(Resource):
|
||||
return {'workspaces': marshal(tenants, tenants_fields)}, 200
|
||||
|
||||
|
||||
class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
|
||||
.paginate(page=args['page'], per_page=args['limit'])
|
||||
|
||||
has_more = False
|
||||
if len(tenants.items) == args['limit']:
|
||||
current_page_first_tenant = tenants[-1]
|
||||
rest_count = db.session.query(Tenant).filter(
|
||||
Tenant.created_at < current_page_first_tenant.created_at,
|
||||
Tenant.id != current_page_first_tenant.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
total = db.session.query(Tenant).count()
|
||||
return {
|
||||
'data': marshal(tenants.items, workspace_fields),
|
||||
'has_more': has_more,
|
||||
'limit': args['limit'],
|
||||
'page': args['page'],
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -92,6 +134,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
||||
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
|
||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
||||
|
||||
@@ -17,7 +17,7 @@ def validate_app_token(view=None):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).get(api_token.app_id)
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
|
||||
@@ -44,7 +44,7 @@ def validate_dataset_token(view=None):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('dataset')
|
||||
|
||||
dataset = db.session.query(Dataset).get(api_token.dataset_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound()
|
||||
|
||||
@@ -64,14 +64,14 @@ def validate_and_get_api_token(scope=None):
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized()
|
||||
if auth_header is None or ' ' not in auth_header:
|
||||
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized()
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
api_token = db.session.query(ApiToken).filter(
|
||||
ApiToken.token == auth_token,
|
||||
@@ -79,7 +79,7 @@ def validate_and_get_api_token(scope=None):
|
||||
).first()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized()
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
@@ -60,7 +60,13 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
@@ -97,6 +97,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
try:
|
||||
action_match = re.search(r"```(.*?)\n?(.*?)```", text, re.DOTALL)
|
||||
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
|
||||
if action_match is not None:
|
||||
response = json.loads(action_match.group(2).strip(), strict=False)
|
||||
if isinstance(response, list):
|
||||
@@ -26,4 +26,4 @@ class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
else:
|
||||
return AgentFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
|
||||
@@ -102,7 +102,13 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
@@ -52,7 +52,7 @@ Action:
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
@@ -106,7 +106,13 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
@@ -32,7 +32,7 @@ class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_instance: BaseLLM
|
||||
tools: list[BaseTool]
|
||||
summary_model_instance: BaseLLM
|
||||
summary_model_instance: BaseLLM = None
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.status = 'llm_end'
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
|
||||
@@ -126,17 +126,16 @@ class Completion:
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# get llm prompt
|
||||
prompt_messages, stop_words = cls.get_main_llm_prompt(
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=agent_execute_result,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
@@ -154,113 +153,6 @@ class Completion:
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
""" if agent_execute_result else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{{query}}\n"
|
||||
)
|
||||
|
||||
if agent_execute_result:
|
||||
inputs['context'] = agent_execute_result.output
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
query=query,
|
||||
**prompt_inputs
|
||||
)
|
||||
|
||||
return [PromptMessage(content=prompt_content)], None
|
||||
else:
|
||||
messages: List[BaseMessage] = []
|
||||
|
||||
human_inputs = {
|
||||
"query": query
|
||||
}
|
||||
|
||||
human_message_prompt = ""
|
||||
|
||||
if pre_prompt:
|
||||
pre_prompt_inputs = {k: inputs[k] for k in
|
||||
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
if k in inputs}
|
||||
|
||||
if pre_prompt_inputs:
|
||||
human_inputs.update(pre_prompt_inputs)
|
||||
|
||||
if agent_execute_result:
|
||||
human_inputs['context'] = agent_execute_result.output
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
"""
|
||||
|
||||
if pre_prompt:
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=human_message_prompt + query_prompt,
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
if memory.model_instance.model_rules.max_tokens.max:
|
||||
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
else:
|
||||
rest_tokens = 2000
|
||||
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>\n"
|
||||
human_message_prompt += histories + "\n</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
# construct main prompt
|
||||
human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=human_message_prompt,
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
messages.append(human_message)
|
||||
|
||||
for message in messages:
|
||||
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
||||
|
||||
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> str:
|
||||
@@ -307,13 +199,12 @@ And answer according to the language of the user's question.
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt without memory and context
|
||||
prompt_messages, _ = cls.get_main_llm_prompt(
|
||||
prompt_messages, _ = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
@@ -358,13 +249,12 @@ And answer according to the language of the user's question.
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
old_prompt_messages, _ = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
model=app_model_config.model_dict,
|
||||
old_prompt_messages, _ = final_model_instance.get_prompt(
|
||||
mode='completion',
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
agent_execute_result=None,
|
||||
query=message.query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
|
||||
@@ -119,9 +119,11 @@ class ConversationMessageTask:
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=self.model_instance.get_currency(),
|
||||
@@ -135,22 +137,30 @@ class ConversationMessageTask:
|
||||
db.session.flush()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
self._pub_handler.pub_text(text)
|
||||
if text is not None:
|
||||
self._pub_handler.pub_text(text)
|
||||
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
|
||||
|
||||
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.total_price = total_price
|
||||
|
||||
@@ -192,7 +202,9 @@ class ConversationMessageTask:
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=agent_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
@@ -206,25 +218,26 @@ class ConversationMessageTask:
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
|
||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_total_price = self.calc_total_price(
|
||||
loop_message_tokens,
|
||||
agent_message_unit_price,
|
||||
loop_answer_tokens,
|
||||
agent_answer_unit_price
|
||||
)
|
||||
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.message_price_unit = agent_message_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.answer_price_unit = agent_answer_price_unit
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
@@ -243,15 +256,6 @@ class ConversationMessageTask:
|
||||
|
||||
db.session.add(dataset_query)
|
||||
|
||||
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
|
||||
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def end(self):
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class ExcelLoader(BaseLoader):
|
||||
else:
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
|
||||
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
|
||||
document = Document(page_content=item, metadata={'source': self._file_path})
|
||||
data.append(document)
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
class DatesetDocumentStore:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
@@ -59,7 +59,7 @@ class DatesetDocumentStore:
|
||||
return output
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == self._document_id
|
||||
@@ -69,7 +69,9 @@ class DatesetDocumentStore:
|
||||
max_position = 0
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
@@ -101,6 +103,7 @@ class DatesetDocumentStore:
|
||||
content=doc.page_content,
|
||||
word_count=len(doc.page_content),
|
||||
tokens=tokens,
|
||||
enabled=False,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
if 'answer' in doc.metadata and doc.metadata['answer']:
|
||||
@@ -123,7 +126,7 @@ class DatesetDocumentStore:
|
||||
return result is not None
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[Document]:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
i = 0
|
||||
normalized_embedding_results = []
|
||||
for text in embedding_queue_texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
embedding.set_embedding(embedding_results[i])
|
||||
vector = embedding_results[i]
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
normalized_embedding_results.append(normalized_embedding)
|
||||
embedding.set_embedding(normalized_embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
|
||||
finally:
|
||||
i += 1
|
||||
|
||||
text_embeddings.extend(embedding_results)
|
||||
text_embeddings.extend(normalized_embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
try:
|
||||
embedding_results = self._embeddings.client.embed_query(text)
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results
|
||||
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ class LLMGenerator:
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
|
||||
max_context_token_length = model_instance.model_rules.max_tokens.max
|
||||
max_context_token_length = max_context_token_length if max_context_token_length else 1500
|
||||
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
|
||||
|
||||
context = ''
|
||||
@@ -178,8 +179,8 @@ class LLMGenerator:
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -15,7 +15,9 @@ class IndexBuilder:
|
||||
return None
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
@@ -173,3 +173,49 @@ class BaseVectorIndex(BaseIndex):
|
||||
|
||||
self.dataset = dataset
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def create_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"create_qdrant_dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except UnexpectedStatusCodeException as e:
|
||||
if e.status_code != 400:
|
||||
# 400 means index not exists
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
114
api/core/index/vector_index/milvus_vector_index.py
Normal file
114
api/core/index/vector_index/milvus_vector_index.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore, milvus
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
endpoint: str
|
||||
user: str
|
||||
password: str
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config MILVUS_ENDPOINT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
|
||||
class MilvusVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return MilvusVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
1691
api/core/index/vector_index/qdrant.py
Normal file
1691
api/core/index/vector_index/qdrant.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -44,15 +44,20 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
return self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Index_" + dataset_id.replace("-", "_")
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
@@ -62,7 +67,7 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='text',
|
||||
content_payload_key='page_content',
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
@@ -72,7 +77,9 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
@@ -81,7 +88,7 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='text'
|
||||
content_payload_key='page_content'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
@@ -108,8 +115,8 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
if class_prefix.startswith('Vector_'):
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
|
||||
@@ -67,14 +67,6 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
# new_documents = []
|
||||
# for document in documents:
|
||||
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
|
||||
# document_qa_list = self.format_split_text(response)
|
||||
# for result in document_qa_list:
|
||||
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
|
||||
# new_documents.append(document)
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
@@ -225,14 +217,25 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None) -> dict:
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
@@ -263,20 +266,19 @@ class IndexingRunner:
|
||||
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -284,18 +286,31 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
@@ -344,20 +359,19 @@ class IndexingRunner:
|
||||
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -365,7 +379,7 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
@@ -458,7 +472,8 @@ class IndexingRunner:
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_form=dataset_document.doc_form
|
||||
document_form=dataset_document.doc_form,
|
||||
document_language=dataset_document.doc_language
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
@@ -494,7 +509,8 @@ class IndexingRunner:
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
|
||||
processing_rule: DatasetProcessRule, tenant_id: str,
|
||||
document_form: str, document_language: str) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@@ -509,12 +525,13 @@ class IndexingRunner:
|
||||
documents = splitter.split_documents([text_doc])
|
||||
split_documents = []
|
||||
for document_node in documents:
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
|
||||
split_documents.append(document_node)
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
# processing qa document
|
||||
if document_form == 'qa_model':
|
||||
@@ -523,8 +540,9 @@ class IndexingRunner:
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
|
||||
'flask_app': current_app._get_current_object(), 'tenant_id': tenant_id, 'document_node': doc,
|
||||
'all_qa_documents': all_qa_documents})
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
|
||||
'document_language': document_language})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
@@ -532,14 +550,14 @@ class IndexingRunner:
|
||||
return all_qa_documents
|
||||
return all_documents
|
||||
|
||||
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents):
|
||||
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
@@ -641,7 +659,9 @@ class IndexingRunner:
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
@@ -672,6 +692,7 @@ class IndexingRunner:
|
||||
DocumentSegment.status == "indexing"
|
||||
).update({
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.utcnow()
|
||||
})
|
||||
|
||||
@@ -722,6 +743,32 @@ class IndexingRunner:
|
||||
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
|
||||
"""
|
||||
Batch add segments index processing
|
||||
"""
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
pass
|
||||
|
||||
108
api/core/login/login.py
Normal file
108
api/core/login/login.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
import flask_login
|
||||
from flask import current_app
|
||||
from flask import g
|
||||
from flask import has_request_context
|
||||
from flask import request
|
||||
from flask_login import user_logged_in
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
|
||||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user = LocalProxy(lambda: _get_user())
|
||||
|
||||
|
||||
def login_required(func):
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
|
||||
example::
|
||||
|
||||
@app.route('/post')
|
||||
@login_required
|
||||
def post():
|
||||
pass
|
||||
|
||||
If there are only certain times you need to require that your user is
|
||||
logged in, you can do so with::
|
||||
|
||||
if not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
...which is essentially the code that this function adds to your views.
|
||||
|
||||
It can be convenient to globally turn off authentication when unit testing.
|
||||
To enable this, if the application configuration variable `LOGIN_DISABLED`
|
||||
is set to `True`, this decorator will be ignored.
|
||||
|
||||
.. Note ::
|
||||
|
||||
Per `W3 guidelines for CORS preflight requests
|
||||
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
|
||||
HTTP ``OPTIONS`` requests are exempt from login checks.
|
||||
|
||||
:param func: The view function to decorate.
|
||||
:type func: function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False')
|
||||
if admin_api_key_enable:
|
||||
if auth_header:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
admin_api_key = os.getenv('ADMIN_API_KEY')
|
||||
|
||||
if admin_api_key:
|
||||
if os.getenv('ADMIN_API_KEY') == auth_token:
|
||||
workspace_id = request.headers.get('X-WORKSPACE-ID')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == workspace_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
|
||||
pass
|
||||
elif not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
# flask 1.x compatibility
|
||||
# current_app.ensure_sync is only available in Flask >= 2.0
|
||||
if callable(getattr(current_app, "ensure_sync", None)):
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def _get_user():
|
||||
if has_request_context():
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user()
|
||||
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
@@ -46,7 +46,8 @@ class ModelFactory:
|
||||
model_name: Optional[str] = None,
|
||||
model_kwargs: Optional[ModelKwargs] = None,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
callbacks: Callbacks = None,
|
||||
deduct_quota: bool = True) -> Optional[BaseLLM]:
|
||||
"""
|
||||
get text generation model.
|
||||
|
||||
@@ -56,6 +57,7 @@ class ModelFactory:
|
||||
:param model_kwargs:
|
||||
:param streaming:
|
||||
:param callbacks:
|
||||
:param deduct_quota:
|
||||
:return:
|
||||
"""
|
||||
is_default_model = False
|
||||
@@ -95,7 +97,7 @@ class ModelFactory:
|
||||
else:
|
||||
raise e
|
||||
|
||||
if is_default_model:
|
||||
if is_default_model or not deduct_quota:
|
||||
model_instance.deduct_quota = False
|
||||
|
||||
return model_instance
|
||||
|
||||
@@ -57,6 +57,12 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'huggingface_hub':
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
return HuggingfaceHubProvider
|
||||
elif provider_name == 'xinference':
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
return XinferenceProvider
|
||||
elif provider_name == 'openllm':
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
return OpenLLMProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -26,10 +26,20 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
**self.credentials
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base')
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
@@ -48,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
@@ -71,7 +71,7 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
return LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
import decimal
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
@@ -17,6 +19,63 @@ class BaseEmbedding(BaseProviderModel):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def calc_tokens_price(self, tokens: int) -> decimal.Decimal:
|
||||
"""
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:return: decimal.Decimal('0.0000001')
|
||||
"""
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.price_config['unit']
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:return: decimal.Decimal('0.0001')
|
||||
|
||||
"""
|
||||
unit_price = self.price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logger.debug(f'unit_price:{unit_price}')
|
||||
return unit_price
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
@@ -29,11 +88,14 @@ class BaseEmbedding(BaseProviderModel):
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
currency = self.price_config['currency']
|
||||
return currency
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from langchain.embeddings import MiniMaxEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
@@ -22,12 +19,6 @@ class MinimaxEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
|
||||
@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
@@ -65,7 +55,7 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = XinferenceEmbeddings(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid'],
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': decimal.Decimal('1.63'),
|
||||
'completion': decimal.Decimal('5.51'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': decimal.Decimal('11.02'),
|
||||
'completion': decimal.Decimal('32.68'),
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1m * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
@@ -101,7 +75,7 @@ class AnthropicModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-35-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-35-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
base_model_name = self.credentials.get("base_model_name")
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[base_model_name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
@@ -166,12 +135,12 @@ class AzureOpenAIModel(BaseLLM):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
return LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -1,15 +1,25 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
from typing import List, Optional, Any, Union, Tuple
|
||||
import decimal
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(BaseProviderModel):
|
||||
@@ -60,6 +70,40 @@ class BaseLLM(BaseProviderModel):
|
||||
def _init_client(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get llm base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][
|
||||
self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'prompt': decimal.Decimal(price_config['prompt']),
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
@@ -94,7 +138,7 @@ class BaseLLM(BaseProviderModel):
|
||||
result = self._run(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as ex:
|
||||
@@ -105,7 +149,7 @@ class BaseLLM(BaseProviderModel):
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
if self.streaming and not self.support_streaming():
|
||||
if self.streaming and not self.support_streaming:
|
||||
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
@@ -122,9 +166,12 @@ class BaseLLM(BaseProviderModel):
|
||||
total_tokens = result.llm_output['token_usage']['total_tokens']
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(messages)
|
||||
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
completion_tokens = self.get_num_tokens(
|
||||
[PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
self.model_provider.update_last_used()
|
||||
|
||||
if self.deduct_quota:
|
||||
self.model_provider.deduct_quota(total_tokens)
|
||||
|
||||
@@ -159,25 +206,64 @@ class BaseLLM(BaseProviderModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.get_price_unit(message_type)
|
||||
|
||||
@abstractmethod
|
||||
def get_currency(self):
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.0001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"unit_price={unit_price}")
|
||||
return unit_price
|
||||
|
||||
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get price unit.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
|
||||
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"price_unit={price_unit}")
|
||||
return price_unit
|
||||
|
||||
def get_currency(self) -> str:
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return:
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
raise NotImplementedError
|
||||
currency = self.price_config['currency']
|
||||
return currency
|
||||
|
||||
def get_model_kwargs(self):
|
||||
return self.model_kwargs
|
||||
@@ -212,10 +298,123 @@ class BaseLLM(BaseProviderModel):
|
||||
else:
|
||||
self.client.callbacks.extend(callbacks)
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return False
|
||||
|
||||
def get_prompt(self, mode: str,
|
||||
pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> \
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
|
||||
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
|
||||
return [PromptMessage(content=prompt)], stops
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'common_completion'
|
||||
else:
|
||||
return 'common_chat'
|
||||
|
||||
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
|
||||
context_prompt_content = prompt_template.format(
|
||||
context=context
|
||||
)
|
||||
|
||||
pre_prompt_content = ''
|
||||
if pre_prompt:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
pre_prompt_content = prompt_template.format(
|
||||
**prompt_inputs
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
|
||||
|
||||
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
|
||||
|
||||
if memory and 'histories_prompt' in prompt_rules:
|
||||
# append chat histories
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=prompt + query_prompt,
|
||||
inputs={
|
||||
'query': query
|
||||
}
|
||||
)
|
||||
|
||||
if self.model_rules.max_tokens.max:
|
||||
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
|
||||
max_tokens = self.model_kwargs.max_tokens
|
||||
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
else:
|
||||
rest_tokens = 2000
|
||||
|
||||
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
|
||||
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
|
||||
histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
|
||||
histories_prompt_content = prompt_template.format(
|
||||
histories=histories
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
|
||||
elif order == 'histories_prompt':
|
||||
prompt += histories_prompt_content
|
||||
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
|
||||
query_prompt_content = prompt_template.format(
|
||||
query=query
|
||||
)
|
||||
|
||||
prompt += query_prompt_content
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
stops = prompt_rules.get('stops')
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
return prompt, stops
|
||||
|
||||
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
|
||||
'prompt/generate_prompts')
|
||||
|
||||
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, 'r') as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
if not model_mode:
|
||||
|
||||
@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
@@ -64,7 +61,3 @@ class ChatGLMModel(BaseLLM):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
@@ -19,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
client = HuggingFaceEndpoint(
|
||||
streaming = self.streaming
|
||||
|
||||
if 'baichuan' in self.name.lower():
|
||||
streaming = False
|
||||
|
||||
client = HuggingFaceEndpointLLM(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task='text2text-generation',
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
streaming=streaming
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
@@ -62,12 +66,14 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# not support calc price
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
@property
|
||||
def support_streaming(self):
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
if 'baichuan' in self.name.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
|
||||
# TODO load price config from configs(db)
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
@@ -185,14 +148,14 @@ class OpenAIModel(BaseLLM):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
|
||||
65
api/core/model_providers/models/llm/openllm_model.py
Normal file
65
api/core/model_providers/models/llm/openllm_model.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
|
||||
|
||||
class OpenLLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = OpenLLM(
|
||||
server_url=self.credentials.get('server_url'),
|
||||
callbacks=self.callbacks,
|
||||
llm_kwargs=self.provider_model_kwargs
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
|
||||
@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
|
||||
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.input = provider_model_kwargs
|
||||
@@ -98,6 +91,6 @@ class ReplicateModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
@@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatSpark(
|
||||
model_name=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
|
||||
contents = [message.content for message in messages]
|
||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
@@ -68,6 +65,6 @@ class SparkModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
@@ -72,6 +69,6 @@ class TongyiModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
# TODO load price_config from configs(db)
|
||||
return Wenxin(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'ernie-bot': {
|
||||
'prompt': decimal.Decimal('0.012'),
|
||||
'completion': decimal.Decimal('0.012'),
|
||||
},
|
||||
'ernie-bot-turbo': {
|
||||
'prompt': decimal.Decimal('0.008'),
|
||||
'completion': decimal.Decimal('0.008')
|
||||
},
|
||||
'bloomz-7b': {
|
||||
'prompt': decimal.Decimal('0.006'),
|
||||
'completion': decimal.Decimal('0.006')
|
||||
}
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
@@ -86,7 +57,3 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Wenxin: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
79
api/core/model_providers/models/llm/xinference_model.py
Normal file
79
api/core/model_providers/models/llm/xinference_model.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
|
||||
|
||||
class XinferenceModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = XinferenceLLM(
|
||||
server_url=self.credentials['server_url'],
|
||||
model_uid=self.credentials['model_uid'],
|
||||
)
|
||||
|
||||
client.callbacks = self.callbacks
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate(
|
||||
[prompts],
|
||||
stop,
|
||||
callbacks,
|
||||
generate_config={
|
||||
"stop": stop,
|
||||
"stream": self.streaming,
|
||||
**self.provider_model_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Xinference: {str(ex)}")
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
@@ -41,7 +41,7 @@ class OpenAIModeration(BaseProviderModel):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -40,7 +40,7 @@ class OpenAIWhisper(BaseSpeech2Text):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -183,6 +183,8 @@ class AnthropicProvider(BaseModelProvider):
|
||||
return {
|
||||
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
|
||||
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
|
||||
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@@ -31,7 +31,9 @@ class HostedAnthropic(BaseModel):
|
||||
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
paid_increase_quota: int = 1000000
|
||||
paid_min_quantity: int = 20
|
||||
paid_max_quantity: int = 100
|
||||
|
||||
|
||||
class HostedModelProviders(BaseModel):
|
||||
@@ -73,4 +75,6 @@ def init_app(app: Flask):
|
||||
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
|
||||
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
|
||||
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from typing import Type
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpoint(
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task="text2text-generation",
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
credentials['task_type'] = 'text-generation'
|
||||
|
||||
if credentials['huggingfacehub_api_token']:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
|
||||
138
api/core/model_providers/providers/openllm_provider.py
Normal file
138
api/core/model_providers/providers/openllm_provider.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class OpenLLMProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'openllm'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenLLMModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'server_url': credentials['server_url']
|
||||
}
|
||||
|
||||
llm = OpenLLM(
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
@@ -116,7 +116,8 @@ class ReplicateProvider(BaseModelProvider):
|
||||
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
|
||||
elif model_type == ModelType.TEXT_GENERATION \
|
||||
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
and ('items' not in rst.openapi_schema['components']['schemas']['Output']
|
||||
or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
|
||||
except ReplicateError as e:
|
||||
|
||||
@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
|
||||
return [
|
||||
{
|
||||
'id': 'spark',
|
||||
'name': '星火认知大模型',
|
||||
'name': 'Spark V1.5',
|
||||
},
|
||||
{
|
||||
'id': 'spark-v2',
|
||||
'name': 'Spark V2.0',
|
||||
}
|
||||
]
|
||||
else:
|
||||
|
||||
192
api/core/model_providers/providers/xinference_provider.py
Normal file
192
api/core/model_providers/providers/xinference_provider.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class XinferenceProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'xinference'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = XinferenceModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = XinferenceEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
credentials = self.get_model_credentials(model_name, model_type)
|
||||
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
elif credentials['model_format'] == "ggmlv3":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
|
||||
|
||||
if 'model_uid' not in credentials:
|
||||
raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'server_url': credentials['server_url'],
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None,
|
||||
'model_uid': None,
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _get_extra_credentials(self, credentials: dict) -> dict:
|
||||
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get the model description, detail: {response.json()['detail']}"
|
||||
)
|
||||
desc = response.json()
|
||||
|
||||
extra_credentials = {
|
||||
'model_format': desc['model_format'],
|
||||
}
|
||||
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
|
||||
extra_credentials['model_handle_type'] = 'chatglm'
|
||||
elif "generate" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'generate'
|
||||
elif "chat" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'chat'
|
||||
else:
|
||||
raise NotImplementedError(f"Model handle type not supported.")
|
||||
|
||||
return extra_credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
@@ -8,5 +8,7 @@
|
||||
"wenxin",
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub"
|
||||
"huggingface_hub",
|
||||
"xinference",
|
||||
"openllm"
|
||||
]
|
||||
@@ -5,10 +5,25 @@
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"paid",
|
||||
"trial"
|
||||
],
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 1000
|
||||
"quota_unit": "tokens",
|
||||
"quota_limit": 600000
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"claude-instant-1": {
|
||||
"prompt": "1.63",
|
||||
"completion": "5.51",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"claude-2": {
|
||||
"prompt": "11.02",
|
||||
"completion": "32.68",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,5 +3,48 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
"model_flexibility": "configurable",
|
||||
"price_config":{
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo": {
|
||||
"prompt": "0.002",
|
||||
"completion": "0.0015",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-002": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,5 +9,24 @@
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"abab5.5-chat": {
|
||||
"prompt": "0.015",
|
||||
"completion": "0.015",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"abab5-chat": {
|
||||
"prompt": "0.015",
|
||||
"completion": "0.015",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"embo-01": {
|
||||
"completion": "0",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,5 +10,42 @@
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 200
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
7
api/core/model_providers/rules/openllm.json
Normal file
7
api/core/model_providers/rules/openllm.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
@@ -9,5 +9,19 @@
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"spark": {
|
||||
"prompt": "0.18",
|
||||
"completion": "0.18",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"spark-v2": {
|
||||
"prompt": "0.36",
|
||||
"completion": "0.36",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,5 +3,25 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"ernie-bot": {
|
||||
"prompt": "0.012",
|
||||
"completion": "0.012",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"ernie-bot-turbo": {
|
||||
"prompt": "0.008",
|
||||
"completion": "0.008",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"bloomz-7b": {
|
||||
"prompt": "0.006",
|
||||
"completion": "0.006",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
7
api/core/model_providers/rules/xinference.json
Normal file
7
api/core/model_providers/rules/xinference.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
@@ -17,12 +17,13 @@ from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.current_datetime_tool import DatetimeTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||
from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
|
||||
@@ -82,15 +83,19 @@ class OrchestratorRuleParser:
|
||||
try:
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=agent_provider_name,
|
||||
model_name=agent_model_name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
)
|
||||
),
|
||||
deduct_quota=False
|
||||
)
|
||||
except ProviderTokenNotInitError as e:
|
||||
summary_model_instance = None
|
||||
|
||||
tools = self.to_tools(
|
||||
agent_model_instance=agent_model_instance,
|
||||
tool_configs=tool_configs,
|
||||
conversation_message_task=conversation_message_task,
|
||||
rest_tokens=rest_tokens,
|
||||
@@ -140,11 +145,12 @@ class OrchestratorRuleParser:
|
||||
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
:param agent_model_instance:
|
||||
:param rest_tokens:
|
||||
:param tool_configs: app agent tool configs
|
||||
:param conversation_message_task:
|
||||
@@ -162,7 +168,7 @@ class OrchestratorRuleParser:
|
||||
if tool_type == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
|
||||
elif tool_type == "web_reader":
|
||||
tool = self.to_web_reader_tool()
|
||||
tool = self.to_web_reader_tool(agent_model_instance)
|
||||
elif tool_type == "google_search":
|
||||
tool = self.to_google_search_tool()
|
||||
elif tool_type == "wikipedia":
|
||||
@@ -207,24 +213,28 @@ class OrchestratorRuleParser:
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self) -> Optional[BaseTool]:
|
||||
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
|
||||
:return:
|
||||
"""
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
try:
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=agent_model_instance.model_provider.provider_name,
|
||||
model_name=agent_model_instance.name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
),
|
||||
deduct_quota=False
|
||||
)
|
||||
)
|
||||
|
||||
summary_llm = summary_model_instance.client
|
||||
except ProviderTokenNotInitError:
|
||||
summary_model_instance = None
|
||||
|
||||
tool = WebReaderTool(
|
||||
llm=summary_llm,
|
||||
llm=summary_model_instance.client if summary_model_instance else None,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
@@ -252,11 +262,7 @@ class OrchestratorRuleParser:
|
||||
return tool
|
||||
|
||||
def to_current_datetime_tool(self) -> Optional[BaseTool]:
|
||||
tool = Tool(
|
||||
name="current_datetime",
|
||||
description="A tool when you want to get the current date, time, week, month or year, "
|
||||
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
|
||||
func=helper.get_current_datetime,
|
||||
tool = DatetimeTool(
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user