Compare commits

...

102 Commits

Author SHA1 Message Date
John Wang
000bb5c2a4 fix: pub generate message text return null 2023-08-28 15:36:35 +08:00
zxhlyh
3a5c7c75ad Fix/model selector (#1032) 2023-08-28 10:54:41 +08:00
zxhlyh
a7415ecfd8 Fix/upload document limit (#1033) 2023-08-28 10:53:45 +08:00
KVOJJJin
934def5fcc Fix: eslint (#1030) 2023-08-27 17:06:16 +08:00
takatost
0796791de5 feat: hf inference endpoint stream support (#1028) 2023-08-26 19:48:34 +08:00
takatost
6c148b223d fix: dataset query truncated (#1026) 2023-08-26 17:35:17 +08:00
zxhlyh
4b168f4838 fix: maintenance notice (#1025) 2023-08-26 16:09:55 +08:00
takatost
1c114eaef3 feat: update contributing (#1020) 2023-08-25 21:19:13 +08:00
Jyong
e053215155 fix document estimate parameter (#1019)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 20:10:08 +08:00
zxhlyh
13482b0fc1 feat: maintenance notice (#1016) 2023-08-25 19:38:52 +08:00
Jyong
38fa152cc4 fix update document index technique (#1018)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 18:29:55 +08:00
Uranus
2d9616c29c fix: xinference last token being ignored (#1013) 2023-08-25 18:15:05 +08:00
Jyong
915e26527b update dataset index struct (#1012)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:52:33 +08:00
Jyong
2d604d9330 Fix/filter empty segment (#1004)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:50:29 +08:00
Jyong
e7199826cc embedding model available check (#1009)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 00:25:16 +08:00
crazywoola
70e24b7594 fix: loading and calc rem (#1006) 2023-08-24 23:24:33 +08:00
yezhwi
c1602aafc7 refactor:cache in place & function name (#1001) 2023-08-24 22:54:21 +08:00
crazywoola
a3fec11438 fix: styles (#1005) 2023-08-24 22:37:46 +08:00
Jyong
b1fd1b3ab3 Feat/vector db manage (#997)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:31 +08:00
Jyong
5397799aac document limit (#999)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:13 +08:00
takatost
8e837dde1a feat: bump version to 0.3.18 (#1000) 2023-08-24 18:13:18 +08:00
takatost
9ae91a2ec3 feat: optimize xinference request max token key and stop reason (#998) 2023-08-24 18:11:15 +08:00
Matri
276d3d10a0 fix: apps loading issue (#994) 2023-08-24 17:57:38 +08:00
crazywoola
f13623184a fix style in app share (#995) 2023-08-24 17:57:25 +08:00
takatost
ef61e1487f fix: safetensor arm complie error (#996) 2023-08-24 17:38:10 +08:00
takatost
701e2b334f feat: remove unnecessary prompt of baichuan (#993) 2023-08-24 15:30:59 +08:00
takatost
6ebd6e7890 feat: bump version to 0.3.17 (#992) 2023-08-24 15:12:47 +08:00
takatost
bd3a9b2f8d fix: xinference-chat-stream-response (#991) 2023-08-24 14:39:34 +08:00
takatost
18d3877151 feat: optimize xinference stream (#989) 2023-08-24 13:58:34 +08:00
takatost
53e83d8697 feat: optimize baichuan prompt (#988) 2023-08-24 12:07:10 +08:00
Matri
6377fc75c6 chore: update lintrc config (#986) 2023-08-24 11:46:59 +08:00
takatost
2c30d19cbe feat: add baichuan prompt (#985) 2023-08-24 10:22:36 +08:00
takatost
9b247fccd4 feat: adjust hf max tokens (#979) 2023-08-23 22:24:50 +08:00
John Wang
3d38aa7138 feat: bump version to 0.3.16 2023-08-23 20:16:54 +08:00
takatost
7d2552b3f2 feat: upgrade xinference to 0.2.1 which support stream response (#977) 2023-08-23 20:15:45 +08:00
yezhwi
117a209ad4 Fix:condition for dataset availability check (#973) 2023-08-23 19:57:27 +08:00
takatost
071e7800a0 fix: add hf task field (#976)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-23 19:48:31 +08:00
takatost
a76fde3d23 feat: optimize hf inference endpoint (#975) 2023-08-23 19:47:50 +08:00
Jyong
1fc57d7358 normalize embedding (#974)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-23 19:10:11 +08:00
Matri
916d8be0ae fix: activation page reload issue after activating (#964) 2023-08-23 13:54:40 +08:00
crazywoola
a38412de7b update doc (#965) 2023-08-23 12:29:52 +08:00
Matri
9c9f0ddb93 fix: user activation request 404 issue (#963) 2023-08-23 08:57:25 +08:00
takatost
f8fbe96da4 feat: bump version to 0.3.15 (#959) 2023-08-22 18:20:33 +08:00
zxhlyh
215a27fd95 Feat/add xinference openllm provider (#958) 2023-08-22 18:19:10 +08:00
takatost
5cba2e7087 fix: web reader tool retrieve content empty (#957) 2023-08-22 18:01:16 +08:00
Jyong
5623839c71 update document segment (#950)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-22 17:59:24 +08:00
takatost
78d3aa5fcd fix: embedding init err (#956) 2023-08-22 17:43:59 +08:00
zxhlyh
a7c78d2cd2 fix: spark provider field name (#955) 2023-08-22 17:28:18 +08:00
Joel
4db35fa375 chore: obsolete info api use new api (#954) 2023-08-22 16:59:57 +08:00
Joel
e67a1413b6 chore: create btn to first place (#953) 2023-08-22 16:20:56 +08:00
takatost
4f3053a8cc fix: xinference chat completion error (#952) 2023-08-22 15:58:04 +08:00
zxhlyh
b3c2bf125f Feat/model providers (#951) 2023-08-22 15:38:12 +08:00
zxhlyh
9d5299e9ec fix: segment error tip & save segment disable when loading (#949) 2023-08-22 15:22:16 +08:00
Jyong
aee15adf1b update document segment (#948)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-22 15:19:09 +08:00
zxhlyh
b185a70c21 Fix/speech to text button (#947) 2023-08-22 14:55:20 +08:00
takatost
a3aba7a9aa fix: provider model not delete when reset key pair (#946) 2023-08-22 13:48:58 +08:00
takatost
866ee5da91 fix: openllm generate cutoff (#945) 2023-08-22 13:43:36 +08:00
Matri
e8039a7da8 fix: add flex-wrap to categories container (#944) 2023-08-22 13:39:52 +08:00
bowen
5e0540077a chore: perfect type definition (#940) 2023-08-22 10:58:06 +08:00
Matri
b346bd9b83 fix: default language improvement in activation page (#942) 2023-08-22 09:28:37 +08:00
Matri
062e2e915b fix: login improvement (#941) 2023-08-21 21:26:32 +08:00
takatost
e0a48c4972 fix: xinference chat support (#939) 2023-08-21 20:44:29 +08:00
zxhlyh
f53242c081 Feat/add document status tooltip (#937) 2023-08-21 18:07:51 +08:00
Jyong
4b53bb1a32 Feat/token support (#909)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
2023-08-21 13:57:18 +08:00
takatost
4c49ecedb5 feat: optimize web reader summary in 3.5 (#933) 2023-08-21 11:58:01 +08:00
takatost
4ff1870a4b fix: web reader tool missing nodejs (#932) 2023-08-21 11:26:11 +08:00
takatost
6c832ee328 fix: remove openllm pypi package because of this package too large (#931) 2023-08-21 02:12:28 +08:00
takatost
25264e7852 feat: add xinference embedding model support (#930) 2023-08-20 19:35:07 +08:00
takatost
18dd0d569d fix: xinference max_tokens alisa error (#929) 2023-08-20 19:12:52 +08:00
takatost
3ea8d7a019 feat: add openllm support (#928) 2023-08-20 19:04:33 +08:00
takatost
da3f10a55e feat: server xinference support (#927) 2023-08-20 17:46:41 +08:00
Benjamin
8c991b5b26 Fix Readme.md typo error. (#926) 2023-08-20 12:02:04 +08:00
takatost
22c1aafb9b fix: document paused at format error (#925) 2023-08-20 01:54:12 +08:00
takatost
8d6d1c442b feat: optimize generate name length (#924) 2023-08-19 23:34:38 +08:00
takatost
95b179fb39 fix: replicate text generation model validate (#923) 2023-08-19 21:40:42 +08:00
takatost
3a0a9e2d8f fix: embedding get price definition missing (#922) 2023-08-19 21:31:40 +08:00
takatost
0a0d63457d feat: record price unit in messages (#919) 2023-08-19 18:51:40 +08:00
takatost
920fb6d0e1 fix: embedding price config (#918) 2023-08-19 16:54:08 +08:00
Krasus.Chen
fd0fc8f4fe Fix/price calc (#862) 2023-08-19 16:41:35 +08:00
takatost
1c552ff23a fix: azure embedding model credentials include base_model_name is invalid for openai sdk (#917) 2023-08-19 16:24:18 +08:00
takatost
5163dd38e5 fix: run extra model serval ex not return (#916) 2023-08-19 14:35:16 +08:00
takatost
2a27dad2fb fix: run model serval ex not return (#915) 2023-08-19 14:16:41 +08:00
takatost
930f74c610 feat: remove unuse envs (#912) 2023-08-18 21:34:28 +08:00
takatost
3f250c9e12 Update README_CN.md 2023-08-18 20:39:40 +08:00
takatost
fa408d264c Update README.md 2023-08-18 20:38:52 +08:00
takatost
09ea27f1ee feat: optimize service api authorization header invalid error (#910) 2023-08-18 20:32:44 +08:00
Jyong
db7156dafd Feature/mutil embedding model (#908)
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-18 17:37:31 +08:00
zxhlyh
4420281d96 Feat/segment add tag (#907) 2023-08-18 17:18:58 +08:00
takatost
d9afebe216 feat: optimize output parse (#906) 2023-08-18 17:00:40 +08:00
takatost
1d9cc5ca05 fix: universal chat when default model invalid (#905) 2023-08-18 16:20:42 +08:00
takatost
edb06f6aed fix: react router agent direct output (#904) 2023-08-18 14:31:20 +08:00
takatost
6ca3bcbcfd fix: sensitive_word_avoidance npe (#902) 2023-08-18 11:43:56 +08:00
Rhon Joe
71a9d63232 fix entrypoint script line endings (#900) 2023-08-18 10:42:44 +08:00
zxhlyh
fb62017e50 Fix/embedding chat (#899)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-08-18 10:39:05 +08:00
takatost
9adbeadeec feat: claude paid optimize (#890) 2023-08-17 16:56:20 +08:00
takatost
2f7b234cc5 fix: max token not exist in generate summary when calc rest tokens (#891) 2023-08-17 16:33:32 +08:00
zxhlyh
4f5f9506ab Feat/pay modal (#889) 2023-08-17 15:49:22 +08:00
takatost
0cc0b6e052 fix: error raise status code not exist (#888) 2023-08-17 15:33:35 +08:00
Joel
cd78adb0ab feat: support show model display name (#887) 2023-08-17 15:13:35 +08:00
takatost
f42e7d1a61 feat: add spark v2 support (#885) 2023-08-17 15:08:57 +08:00
takatost
c4d759dfba fix: wenxin error not raise when stream mode (#884) 2023-08-17 13:40:00 +08:00
zxhlyh
a58f95fa91 fix: web dockfile (#883) 2023-08-17 13:07:07 +08:00
280 changed files with 8239 additions and 1574 deletions

View File

@@ -53,9 +53,9 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
## Community channels
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help!
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
### i18n (Internationalization) Support
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.

View File

@@ -16,15 +16,15 @@
## 本地开发
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose 堆栈
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
### Fork存储库
您需要 fork [存储](https://github.com/langgenius/dify)。
您需要 fork [Git 仓](https://github.com/langgenius/dify)。
### 克隆存储库
克隆您在 GitHub 上 fork 的存储库:
克隆您在 GitHub 上 fork 的库:
```
git clone git@github.com:<github_username>/dify.git

View File

@@ -52,4 +52,4 @@ git clone git@github.com:<github_username>/dify.git
## コミュニティチャンネル
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/AhzKf7dNgk)に参加してください。私たちがお手伝いします!
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!

View File

@@ -30,10 +30,10 @@ Visual data analysis, log review, and annotation for applications
- [x] **Spark**
- [x] **Wenxin**
- [x] **Tongyi**
- [x] **ChatGLM**
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
* 1000 free Claude model queries to build Claude-powered apps
* 600,000 free Claude model tokens to build Claude-powered apps
* 200 free OpenAI queries to build OpenAI-based apps

View File

@@ -36,7 +36,7 @@
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
* 1000 次 Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用

View File

@@ -117,10 +117,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=

View File

@@ -16,7 +16,7 @@ EXPOSE 5001
WORKDIR /app/api
RUN apt-get update && \
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
COPY requirements.txt /app/api/requirements.txt

View File

@@ -1,4 +1,5 @@
import datetime
import json
import math
import random
import string
@@ -6,10 +7,16 @@ import time
import click
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
@@ -20,7 +27,7 @@ from models.model import Account
import secrets
import base64
from models.provider import Provider, ProviderType, ProviderQuotaType
from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
@click.command('reset-password', help='Reset the account password.')
@@ -102,6 +109,7 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
db.session.query(ProviderModel).delete()
db.session.commit()
click.echo(click.style('Congratulations! '
@@ -258,6 +266,8 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
new_quota_limit = hosted_model_providers.anthropic.quota_limit
page = 1
while True:
try:
@@ -265,6 +275,7 @@ def sync_anthropic_hosted_providers():
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_limit != new_quota_limit
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
@@ -272,9 +283,9 @@ def sync_anthropic_hosted_providers():
page += 1
for provider in providers:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
original_quota_limit = provider.quota_limit
new_quota_limit = hosted_model_providers.anthropic.quota_limit
division = math.ceil(new_quota_limit / 1000)
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
@@ -292,6 +303,72 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
click.echo(click.style('Start create qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
index.create_qdrant_dataset(dataset)
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@@ -300,3 +377,4 @@ def register_commands(app):
app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)

View File

@@ -48,19 +48,19 @@ DEFAULTS = {
'WEAVIATE_GRPC_ENABLED': 'True',
'WEAVIATE_BATCH_SIZE': 100,
'CELERY_BACKEND': 'database',
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
@@ -100,13 +100,12 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.14"
self.CURRENT_VERSION = "0.3.18"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
self.TESTING = False
self.LOG_LEVEL = get_env('LOG_LEVEL')
self.PDF_PREVIEW = get_bool_env('PDF_PREVIEW')
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
@@ -211,7 +210,7 @@ class Config:
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
@@ -219,23 +218,21 @@ class Config:
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
import flask_restful
from flask_restful import Resource, fields, marshal_with
from werkzeug.exceptions import Forbidden

View File

@@ -3,7 +3,9 @@ import json
import logging
from datetime import datetime
from flask_login import login_required, current_user
import flask
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden
@@ -316,7 +318,7 @@ class AppApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app = _get_app(app_id, current_user.current_tenant_id)
db.session.delete(app)

View File

@@ -2,7 +2,7 @@
import logging
from flask import request
from flask_login import login_required
from core.login.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services

View File

@@ -5,7 +5,7 @@ from typing import Generator, Union
import flask_login
from flask import Response, stream_with_context
from flask_login import login_required
from core.login.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services

View File

@@ -1,7 +1,8 @@
from datetime import datetime
import pytz
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from flask_restful.inputs import int_range
from sqlalchemy import or_, func

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -3,7 +3,7 @@ import logging
from typing import Union, Generator
from flask import Response, stream_with_context
from flask_login import current_user, login_required
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
@@ -16,6 +16,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.login.login import login_required
from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db

View File

@@ -3,12 +3,13 @@ import json
from flask import request
from flask_restful import Resource
from flask_login import login_required, current_user
from flask_login import current_user
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.login.login import login_required
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from models.model import AppModelConfig

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from werkzeug.exceptions import NotFound, Forbidden

View File

@@ -4,7 +4,8 @@ from datetime import datetime
import pytz
from flask import jsonify
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -5,9 +5,12 @@ from typing import Optional
import flask_login
import requests
from flask import request, redirect, current_app, session
from flask_login import current_user, login_required
from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from core.login.login import login_required
from libs.oauth_data_source import NotionOAuth
from controllers.console import api
from ..setup import setup_required

View File

@@ -3,7 +3,8 @@ import json
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
from werkzeug.exceptions import NotFound

View File

@@ -1,6 +1,7 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
@@ -10,13 +11,15 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
dataset_detail_fields = {
'id': fields.String,
@@ -33,6 +36,9 @@ dataset_detail_fields = {
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
@@ -74,8 +80,25 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0:
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
response = {
'data': marshal(datasets, dataset_detail_fields),
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
@@ -99,7 +122,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
@@ -233,6 +255,8 @@ class DatasetIndexingEstimateApi(Resource):
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
@@ -250,11 +274,14 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'])
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
@@ -262,11 +289,14 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response, 200

View File

@@ -3,8 +3,9 @@ import random
from datetime import datetime
from typing import List
from flask import request
from flask_login import login_required, current_user
from flask import request, current_app
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc
from werkzeug.exceptions import NotFound, Forbidden
@@ -274,6 +275,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
@@ -282,14 +285,19 @@ class DatasetDocumentListApi(Resource):
# validate args
DocumentService.document_create_args_validate(args)
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
@@ -328,6 +336,8 @@ class DatasetInitApi(Resource):
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
try:
@@ -406,11 +416,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict)
data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response
@@ -473,22 +486,28 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict)
data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif dataset.data_source_type:
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif dataset.data_source_type == 'notion_import':
indexing_runner = IndexingRunner()
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict)
data_process_rule_dict,
None, 'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response
@@ -575,7 +594,8 @@ class DocumentIndexingStatusApi(DocumentResource):
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
return marshal(document, self.document_status_fields)
@@ -749,11 +769,13 @@ class DocumentMetadataApi(DocumentResource):
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {}
for key, value_type in metadata_schema.items():
value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
if doc_type == 'others':
document.doc_metadata = doc_metadata
else:
for key, value_type in metadata_schema.items():
value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
document.doc_type = doc_type
document.updated_at = datetime.utcnow()
@@ -832,12 +854,40 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id)
return {'result': 'success'}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError('Document is not archived.')
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + 1
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.utcnow()
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
return {'result': 'success'}, 200
else:
raise InvalidActionError()
class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
@@ -867,6 +917,9 @@ class DocumentPauseApi(DocumentResource):
class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
@@ -892,6 +945,21 @@ class DocumentRecoverApi(DocumentResource):
return {'result': 'success'}, 204
class DocumentLimitApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""get document limit"""
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
return {
'documents_count': documents_count,
'documents_limit': tenant_document_count
}, 200
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@@ -917,3 +985,4 @@ api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentLimitApi, '/datasets/limit')

View File

@@ -1,15 +1,20 @@
# -*- coding:utf-8 -*-
import uuid
from datetime import datetime
from flask_login import login_required, current_user
from flask import request
from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import InvalidActionError
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.login.login import login_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@@ -17,7 +22,9 @@ from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
import pandas as pd
segment_fields = {
'id': fields.String,
@@ -152,6 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
@@ -197,7 +218,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_segment_from_index_task.delay(segment.id)
disable_segment_from_index_task.delay(segment.id)
return {'result': 'success'}, 200
else:
@@ -222,6 +243,19 @@ class DatasetDocumentSegmentAddApi(Resource):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
@@ -233,7 +267,7 @@ class DatasetDocumentSegmentAddApi(Resource):
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
@@ -245,6 +279,61 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -270,17 +359,88 @@ class DatasetDocumentSegmentUpdateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document)
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
# Skip the first row
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
if document.doc_form == 'qa_model':
data = {'content': row[0], 'answer': row[1]}
else:
data = {'content': row[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
current_user.current_tenant_id, current_user.id)
except Exception as e:
return {'error': str(e)}, 500
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
'job_id': job_id,
'job_status': 'waiting'
}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, job_id):
job_id = str(job_id)
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
@@ -292,3 +452,6 @@ api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentSegmentBatchImportApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
'/datasets/batch_import_status/<uuid:job_id>')

View File

@@ -8,7 +8,8 @@ from pathlib import Path
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields
from werkzeug.exceptions import NotFound

View File

@@ -1,6 +1,7 @@
import logging
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, marshal, fields
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
@@ -11,7 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@@ -102,6 +104,10 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ValueError as e:
raise ValueError(str(e))
except Exception as e:

View File

@@ -1,7 +1,8 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
from sqlalchemy import and_
from werkzeug.exceptions import NotFound, Forbidden, BadRequest

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import and_

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource
from functools import wraps

View File

@@ -1,7 +1,8 @@
import json
from functools import wraps
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required

View File

@@ -38,12 +38,20 @@ class StripeWebhookApi(Resource):
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)
logging.debug(session.line_items['data'][0]['quantity'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event)
provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e:
logging.debug(str(e))
return 'success', 200

View File

@@ -3,7 +3,8 @@ from datetime import datetime
import pytz
from flask import current_app, request
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

View File

@@ -1,6 +1,7 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
import services

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,4 +1,5 @@
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,6 +1,7 @@
import json
from flask_login import login_required, current_user
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -2,10 +2,13 @@
import logging
from flask import request
from flask_login import login_required, current_user
from flask_restful import Resource, fields, marshal_with, reqparse, marshal
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.setup import setup_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import account_initialization_required
@@ -43,6 +46,13 @@ tenants_fields = {
'current': fields.Boolean
}
workspace_fields = {
'id': fields.String,
'name': fields.String,
'status': fields.String,
'created_at': TimestampField
}
class TenantListApi(Resource):
@setup_required
@@ -57,6 +67,38 @@ class TenantListApi(Resource):
return {'workspaces': marshal(tenants, tenants_fields)}, 200
class WorkspaceListApi(Resource):
@setup_required
@admin_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
.paginate(page=args['page'], per_page=args['limit'])
has_more = False
if len(tenants.items) == args['limit']:
current_page_first_tenant = tenants[-1]
rest_count = db.session.query(Tenant).filter(
Tenant.created_at < current_page_first_tenant.created_at,
Tenant.id != current_page_first_tenant.id
).count()
if rest_count > 0:
has_more = True
total = db.session.query(Tenant).count()
return {
'data': marshal(tenants.items, workspace_fields),
'has_more': has_more,
'limit': args['limit'],
'page': args['page'],
'total': total
}, 200
class TenantApi(Resource):
@setup_required
@login_required
@@ -92,6 +134,7 @@ class SwitchWorkspaceApi(Resource):
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant

View File

@@ -17,7 +17,7 @@ def validate_app_token(view=None):
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).get(api_token.app_id)
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
if not app_model:
raise NotFound()
@@ -44,7 +44,7 @@ def validate_dataset_token(view=None):
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('dataset')
dataset = db.session.query(Dataset).get(api_token.dataset_id)
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
if not dataset:
raise NotFound()
@@ -64,14 +64,14 @@ def validate_and_get_api_token(scope=None):
Validate and get API token.
"""
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized()
if auth_header is None or ' ' not in auth_header:
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized()
raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = db.session.query(ApiToken).filter(
ApiToken.token == auth_token,
@@ -79,7 +79,7 @@ def validate_and_get_api_token(scope=None):
).first()
if not api_token:
raise Unauthorized()
raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.utcnow()
db.session.commit()

View File

@@ -60,7 +60,13 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": observation}, log=observation)
try:
return super().plan(intermediate_steps, callbacks, **kwargs)
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception

View File

@@ -97,6 +97,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
@classmethod

View File

@@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:

View File

@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
action_match = re.search(r"```(.*?)\n?(.*?)```", text, re.DOTALL)
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
@@ -26,4 +26,4 @@ class StructuredChatOutputParser(LCStructuredChatOutputParser):
else:
return AgentFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
raise OutputParserException(f"Could not parse LLM output: {text}")

View File

@@ -102,7 +102,13 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
raise new_exception
try:
return self.output_parser.parse(full_output)
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")

View File

@@ -52,7 +52,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:
@@ -106,7 +106,13 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
raise new_exception
try:
return self.output_parser.parse(full_output)
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")

View File

@@ -32,7 +32,7 @@ class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_instance: BaseLLM
tools: list[BaseTool]
summary_model_instance: BaseLLM
summary_model_instance: BaseLLM = None
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6

View File

@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.status = 'llm_end'
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any

View File

@@ -126,17 +126,16 @@ class Completion:
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# get llm prompt
prompt_messages, stop_words = cls.get_main_llm_prompt(
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
agent_execute_result=agent_execute_result,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
@@ -154,113 +153,6 @@ class Completion:
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n"
)
if agent_execute_result:
inputs['context'] = agent_execute_result.output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
query=query,
**prompt_inputs
)
return [PromptMessage(content=prompt_content)], None
else:
messages: List[BaseMessage] = []
human_inputs = {
"query": query
}
human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if agent_execute_result:
human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
"""
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)
if memory.model_instance.model_rules.max_tokens.max:
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>\n"
human_message_prompt += histories + "\n</histories>"
human_message_prompt += query_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
messages.append(human_message)
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str:
@@ -307,13 +199,12 @@ And answer according to the language of the user's question.
max_tokens = 0
# get prompt without memory and context
prompt_messages, _ = cls.get_main_llm_prompt(
prompt_messages, _ = model_instance.get_prompt(
mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
agent_execute_result=None,
query=query,
context=None,
memory=None
)
@@ -358,13 +249,12 @@ And answer according to the language of the user's question.
)
# get llm prompt
old_prompt_messages, _ = cls.get_main_llm_prompt(
mode="completion",
model=app_model_config.model_dict,
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
agent_execute_result=None,
query=message.query,
context=None,
memory=None
)

View File

@@ -119,9 +119,11 @@ class ConversationMessageTask:
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=self.model_instance.get_currency(),
@@ -135,22 +137,30 @@ class ConversationMessageTask:
db.session.flush()
def append_message_text(self, text: str):
self._pub_handler.pub_text(text)
if text is not None:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price
@@ -192,7 +202,9 @@ class ConversationMessageTask:
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
message_price_unit=0,
answer=agent_loop.completion,
answer_price_unit=0,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
@@ -206,25 +218,26 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_total_price = self.calc_total_price(
loop_message_tokens,
agent_message_unit_price,
loop_answer_tokens,
agent_answer_unit_price
)
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.message_price_unit = agent_message_price_unit
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.answer_price_unit = agent_answer_price_unit
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
@@ -243,15 +256,6 @@ class ConversationMessageTask:
db.session.add(dataset_query)
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def end(self):
self._pub_handler.pub_end()

View File

@@ -38,7 +38,7 @@ class ExcelLoader(BaseLoader):
else:
row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
document = Document(page_content=item, metadata={'source': self._file_path})
data.append(document)

View File

@@ -10,10 +10,10 @@ from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore:
def __init__(
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
@@ -59,7 +59,7 @@ class DatesetDocumentStore:
return output
def add_documents(
self, docs: Sequence[Document], allow_update: bool = True
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == self._document_id
@@ -69,7 +69,9 @@ class DatesetDocumentStore:
max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id
tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider,
model_name=self._dataset.embedding_model
)
for doc in docs:
@@ -101,6 +103,7 @@ class DatesetDocumentStore:
content=doc.page_content,
word_count=len(doc.page_content),
tokens=tokens,
enabled=False,
created_by=self._user_id,
)
if 'answer' in doc.metadata and doc.metadata['answer']:
@@ -123,7 +126,7 @@ class DatesetDocumentStore:
return result is not None
def get_document(
self, doc_id: str, raise_error: bool = True
self, doc_id: str, raise_error: bool = True
) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id)

View File

@@ -1,6 +1,7 @@
import logging
from typing import List
import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
i = 0
normalized_embedding_results = []
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i])
vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
normalized_embedding_results.append(normalized_embedding)
embedding.set_embedding(normalized_embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
finally:
i += 1
text_embeddings.extend(embedding_results)
text_embeddings.extend(normalized_embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:
@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
try:
embedding_results = self._embeddings.client.embed_query(text)
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
return embedding_results

View File

@@ -51,6 +51,7 @@ class LLMGenerator:
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
@@ -178,8 +179,8 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_qa_document(cls, tenant_id: str, query):
prompt = GENERATOR_QA_PROMPT
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,

View File

@@ -15,7 +15,9 @@ class IndexBuilder:
return None
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)

View File

@@ -173,3 +173,49 @@ class BaseVectorIndex(BaseIndex):
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
def create_qdrant_dataset(self, dataset: Dataset):
logging.info(f"create_qdrant_dataset {dataset.id}")
try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@@ -0,0 +1,114 @@
from typing import Optional, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore, milvus
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
endpoint: str
user: str
password: str
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config MILVUS_ENDPOINT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def get_type(self) -> str:
return 'milvus'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False

File diff suppressed because it is too large Load Diff

View File

@@ -44,15 +44,20 @@ class QdrantVectorIndex(BaseVectorIndex):
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
@@ -62,7 +67,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
content_payload_key='page_content',
**self._client_config.to_qdrant_params()
)
@@ -72,7 +77,9 @@ class QdrantVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
@@ -81,7 +88,7 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
content_payload_key='page_content'
)
def _get_vector_store_class(self) -> type:
@@ -108,8 +115,8 @@ class QdrantVectorIndex(BaseVectorIndex):
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if class_prefix.startswith('Vector_'):
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True

View File

@@ -67,14 +67,6 @@ class IndexingRunner:
dataset_document=dataset_document,
processing_rule=processing_rule
)
# new_documents = []
# for document in documents:
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
# document_qa_list = self.format_split_text(response)
# for result in document_qa_list:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_documents.append(document)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
@@ -225,14 +217,25 @@ class IndexingRunner:
db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict:
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0
preview_texts = []
total_segments = 0
@@ -263,20 +266,19 @@ class IndexingRunner:
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -284,18 +286,31 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion
tokens = 0
@@ -344,20 +359,19 @@ class IndexingRunner:
tokens += embedding_model.get_num_tokens(document.page_content)
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -365,7 +379,7 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
@@ -458,7 +472,8 @@ class IndexingRunner:
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form
document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language
)
# save node to document segment
@@ -494,7 +509,8 @@ class IndexingRunner:
return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
processing_rule: DatasetProcessRule, tenant_id: str,
document_form: str, document_language: str) -> List[Document]:
"""
Split the text documents into nodes.
"""
@@ -509,12 +525,13 @@ class IndexingRunner:
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
split_documents.append(document_node)
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == 'qa_model':
@@ -523,8 +540,9 @@ class IndexingRunner:
sub_documents = all_documents[i:i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'flask_app': current_app._get_current_object(), 'tenant_id': tenant_id, 'document_node': doc,
'all_qa_documents': all_qa_documents})
'flask_app': current_app._get_current_object(),
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
'document_language': document_language})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
@@ -532,14 +550,14 @@ class IndexingRunner:
return all_qa_documents
return all_documents
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents):
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
@@ -641,7 +659,9 @@ class IndexingRunner:
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# chunk nodes by chunk size
@@ -672,6 +692,7 @@ class IndexingRunner:
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.utcnow()
})
@@ -722,6 +743,32 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
class DocumentIsPausedException(Exception):
pass

108
api/core/login/login.py Normal file
View File

@@ -0,0 +1,108 @@
import os
from functools import wraps
import flask_login
from flask import current_app
from flask import g
from flask import has_request_context
from flask import request
from flask_login import user_logged_in
from flask_login.config import EXEMPT_METHODS
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = LocalProxy(lambda: _get_user())
def login_required(func):
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
example::
@app.route('/post')
@login_required
def post():
pass
If there are only certain times you need to require that your user is
logged in, you can do so with::
if not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
...which is essentially the code that this function adds to your views.
It can be convenient to globally turn off authentication when unit testing.
To enable this, if the application configuration variable `LOGIN_DISABLED`
is set to `True`, this decorator will be ignored.
.. Note ::
Per `W3 guidelines for CORS preflight requests
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
HTTP ``OPTIONS`` requests are exempt from login checks.
:param func: The view function to decorate.
:type func: function
"""
@wraps(func)
def decorated_view(*args, **kwargs):
auth_header = request.headers.get('Authorization')
admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False')
if admin_api_key_enable:
if auth_header:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
admin_api_key = os.getenv('ADMIN_API_KEY')
if admin_api_key:
if os.getenv('ADMIN_API_KEY') == auth_token:
workspace_id = request.headers.get('X-WORKSPACE-ID')
if workspace_id:
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
.filter(Tenant.id == workspace_id) \
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
.filter(TenantAccountJoin.role == 'owner') \
.one_or_none()
if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user())
if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
pass
elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return decorated_view
def _get_user():
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user()
return g._login_user
return None

View File

@@ -46,7 +46,8 @@ class ModelFactory:
model_name: Optional[str] = None,
model_kwargs: Optional[ModelKwargs] = None,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
callbacks: Callbacks = None,
deduct_quota: bool = True) -> Optional[BaseLLM]:
"""
get text generation model.
@@ -56,6 +57,7 @@ class ModelFactory:
:param model_kwargs:
:param streaming:
:param callbacks:
:param deduct_quota:
:return:
"""
is_default_model = False
@@ -95,7 +97,7 @@ class ModelFactory:
else:
raise e
if is_default_model:
if is_default_model or not deduct_quota:
model_instance.deduct_quota = False
return model_instance

View File

@@ -57,6 +57,12 @@ class ModelProviderFactory:
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
elif provider_name == 'xinference':
from core.model_providers.providers.xinference_provider import XinferenceProvider
return XinferenceProvider
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
else:
raise NotImplementedError

View File

@@ -26,10 +26,20 @@ class AzureOpenAIEmbedding(BaseEmbedding):
openai_api_version=AZURE_OPENAI_API_VERSION,
chunk_size=16,
max_retries=1,
**self.credentials
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base')
)
super().__init__(model_provider, client, name)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, text: str) -> int:
"""
@@ -48,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
@@ -71,7 +71,7 @@ class AzureOpenAIEmbedding(BaseEmbedding):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
return LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import Any
import decimal
import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
import logging
logger = logging.getLogger(__name__)
class BaseEmbedding(BaseProviderModel):
name: str
@@ -17,6 +19,63 @@ class BaseEmbedding(BaseProviderModel):
super().__init__(model_provider, client)
self.name = name
@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def calc_tokens_price(self, tokens: int) -> decimal.Decimal:
"""
calc tokens total price.
:param tokens:
:return: decimal.Decimal('0.0000001')
"""
unit_price = self.price_config['completion']
unit = self.price_config['unit']
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self) -> decimal.Decimal:
"""
get token price.
:return: decimal.Decimal('0.0001')
"""
unit_price = self.price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logger.debug(f'unit_price:{unit_price}')
return unit_price
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
@@ -29,11 +88,14 @@ class BaseEmbedding(BaseProviderModel):
return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self):
return 'USD'
"""
get token currency.
:return: get from price config, default 'USD'
"""
currency = self.price_config['currency']
return currency
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:

View File

@@ -1,6 +1,3 @@
import decimal
import logging
from langchain.embeddings import MiniMaxEmbeddings
from core.model_providers.error import LLMBadRequestError
@@ -22,12 +19,6 @@ class MinimaxEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")

View File

@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
@@ -65,7 +55,7 @@ class OpenAIEmbedding(BaseEmbedding):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")

View File

@@ -0,0 +1,27 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class XinferenceEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = XinferenceEmbeddings(
server_url=credentials['server_url'],
model_uid=credentials['model_uid'],
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
else:
return ex

View File

@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
@@ -101,7 +75,7 @@ class AnthropicModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
@@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003':
@@ -166,12 +135,12 @@ class AzureOpenAIModel(BaseLLM):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
return LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -1,15 +1,25 @@
import json
import os
import re
from abc import abstractmethod
from typing import List, Optional, Any, Union
from typing import List, Optional, Any, Union, Tuple
import decimal
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM
import logging
logger = logging.getLogger(__name__)
class BaseLLM(BaseProviderModel):
@@ -60,6 +70,40 @@ class BaseLLM(BaseProviderModel):
def _init_client(self) -> Any:
raise NotImplementedError
@property
def base_model_name(self) -> str:
"""
get llm base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][
self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
@@ -94,7 +138,7 @@ class BaseLLM(BaseProviderModel):
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
**kwargs
)
except Exception as ex:
@@ -105,7 +149,7 @@ class BaseLLM(BaseProviderModel):
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
if self.streaming and not self.support_streaming:
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
@@ -122,9 +166,12 @@ class BaseLLM(BaseProviderModel):
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
completion_tokens = self.get_num_tokens(
[PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
self.model_provider.update_last_used()
if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)
@@ -159,25 +206,64 @@ class BaseLLM(BaseProviderModel):
"""
raise NotImplementedError
@abstractmethod
def get_token_price(self, tokens: int, message_type: MessageType):
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
"""
get token price.
calc tokens total price.
:param tokens:
:param message_type:
:return:
"""
raise NotImplementedError
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit = self.get_price_unit(message_type)
@abstractmethod
def get_currency(self):
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
"""
get token price.
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"unit_price={unit_price}")
return unit_price
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
"""
get price unit.
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"price_unit={price_unit}")
return price_unit
def get_currency(self) -> str:
"""
get token currency.
:return:
:return: get from price config, default 'USD'
"""
raise NotImplementedError
currency = self.price_config['currency']
return currency
def get_model_kwargs(self):
return self.model_kwargs
@@ -212,10 +298,123 @@ class BaseLLM(BaseProviderModel):
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return False
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode:

View File

@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
@@ -64,7 +61,3 @@ class ChatGLMModel(BaseLLM):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@@ -1,16 +1,14 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.llms import HuggingFaceEndpoint
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
class HuggingfaceHubModel(BaseLLM):
@@ -19,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
client = HuggingFaceEndpoint(
streaming = self.streaming
if 'baichuan' in self.name.lower():
streaming = False
client = HuggingFaceEndpointLLM(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task='text2text-generation',
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
streaming=streaming
)
else:
client = HuggingFaceHub(
@@ -62,12 +66,14 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False
@property
def support_streaming(self):
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'baichuan' in self.name.lower():
return False
return True

View File

@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'

View File

@@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
# TODO load price config from configs(db)
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS:
@@ -185,14 +148,14 @@ class OpenAIModel(BaseLLM):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True
# def is_model_valid_or_raise(self):

View File

@@ -0,0 +1,65 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.openllm import OpenLLM
class OpenLLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = OpenLLM(
server_url=self.credentials.get('server_url'),
callbacks=self.callbacks,
llm_kwargs=self.provider_model_kwargs
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM: {str(ex)}")

View File

@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs
@@ -98,6 +91,6 @@ class ReplicateModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -1,5 +1,4 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
@@ -19,6 +18,7 @@ class SparkModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
@@ -68,6 +65,6 @@ class SparkModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
@@ -72,6 +69,6 @@ class TongyiModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
streaming=self.streaming,
callbacks=self.callbacks,
@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'ernie-bot': {
'prompt': decimal.Decimal('0.012'),
'completion': decimal.Decimal('0.012'),
},
'ernie-bot-turbo': {
'prompt': decimal.Decimal('0.008'),
'completion': decimal.Decimal('0.008')
},
'bloomz-7b': {
'prompt': decimal.Decimal('0.006'),
'completion': decimal.Decimal('0.006')
}
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
@@ -86,7 +57,3 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -0,0 +1,79 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
class XinferenceModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = XinferenceLLM(
server_url=self.credentials['server_url'],
model_uid=self.credentials['model_uid'],
)
client.callbacks = self.callbacks
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate(
[prompts],
stop,
callbacks,
generate_config={
"stop": stop,
"stream": self.streaming,
**self.provider_model_kwargs,
}
)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@@ -41,7 +41,7 @@ class OpenAIModeration(BaseProviderModel):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -40,7 +40,7 @@ class OpenAIWhisper(BaseSpeech2Text):
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:

View File

@@ -183,6 +183,8 @@ class AnthropicProvider(BaseModelProvider):
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
}
return None

View File

@@ -31,7 +31,9 @@ class HostedAnthropic(BaseModel):
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
paid_increase_quota: int = 1000000
paid_min_quantity: int = 20
paid_max_quantity: int = 100
class HostedModelProviders(BaseModel):
@@ -73,4 +75,6 @@ def init_app(app: Flask):
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)

View File

@@ -2,7 +2,6 @@ import json
from typing import Type
from huggingface_hub import HfApi
from langchain.llms import HuggingFaceEndpoint
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from models.provider import ProviderType
@@ -51,7 +51,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
)
@classmethod
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
try:
llm = HuggingFaceEndpoint(
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task="text2text-generation",
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if 'task_type' not in credentials:
credentials['task_type'] = 'text-generation'
if credentials['huggingfacehub_api_token']:
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id,

View File

@@ -0,0 +1,138 @@
import json
from typing import Type
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType
class OpenLLMProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'openllm'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenLLMModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
try:
credential_kwargs = {
'server_url': credentials['server_url']
}
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -116,7 +116,8 @@ class ReplicateProvider(BaseModelProvider):
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
elif model_type == ModelType.TEXT_GENERATION \
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
and ('items' not in rst.openapi_schema['components']['schemas']['Output']
or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
except ReplicateError as e:

View File

@@ -29,7 +29,11 @@ class SparkProvider(BaseModelProvider):
return [
{
'id': 'spark',
'name': '星火认知大模型',
'name': 'Spark V1.5',
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
}
]
else:

View File

@@ -0,0 +1,192 @@
import json
from typing import Type
import requests
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType
class XinferenceProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'xinference'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = XinferenceModel
elif model_type == ModelType.EMBEDDINGS:
model_class = XinferenceEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
if 'model_uid' not in credentials:
raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
try:
credential_kwargs = {
'server_url': credentials['server_url'],
'model_uid': credentials['model_uid'],
}
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
'model_uid': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def _get_extra_credentials(self, credentials: dict) -> dict:
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get the model description, detail: {response.json()['detail']}"
)
desc = response.json()
extra_credentials = {
'model_format': desc['model_format'],
}
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
extra_credentials['model_handle_type'] = 'chatglm'
elif "generate" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'generate'
elif "chat" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'chat'
else:
raise NotImplementedError(f"Model handle type not supported.")
return extra_credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -8,5 +8,7 @@
"wenxin",
"chatglm",
"replicate",
"huggingface_hub"
"huggingface_hub",
"xinference",
"openllm"
]

View File

@@ -5,10 +5,25 @@
],
"system_config": {
"supported_quota_types": [
"paid",
"trial"
],
"quota_unit": "times",
"quota_limit": 1000
"quota_unit": "tokens",
"quota_limit": 600000
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"claude-instant-1": {
"prompt": "1.63",
"completion": "5.51",
"unit": "0.000001",
"currency": "USD"
},
"claude-2": {
"prompt": "11.02",
"completion": "32.68",
"unit": "0.000001",
"currency": "USD"
}
}
}

View File

@@ -3,5 +3,48 @@
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"price_config":{
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo": {
"prompt": "0.002",
"completion": "0.0015",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-002": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
}

View File

@@ -9,5 +9,24 @@
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"abab5.5-chat": {
"prompt": "0.015",
"completion": "0.015",
"unit": "0.001",
"currency": "RMB"
},
"abab5-chat": {
"prompt": "0.015",
"completion": "0.015",
"unit": "0.001",
"currency": "RMB"
},
"embo-01": {
"completion": "0",
"unit": "0.0001",
"currency": "RMB"
}
}
}

View File

@@ -10,5 +10,42 @@
"quota_unit": "times",
"quota_limit": 200
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -9,5 +9,19 @@
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"spark": {
"prompt": "0.18",
"completion": "0.18",
"unit": "0.0001",
"currency": "RMB"
},
"spark-v2": {
"prompt": "0.36",
"completion": "0.36",
"unit": "0.0001",
"currency": "RMB"
}
}
}

View File

@@ -3,5 +3,25 @@
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot-turbo": {
"prompt": "0.008",
"completion": "0.008",
"unit": "0.001",
"currency": "RMB"
},
"bloomz-7b": {
"prompt": "0.006",
"completion": "0.006",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -17,12 +17,13 @@ from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
@@ -82,15 +83,19 @@ class OrchestratorRuleParser:
try:
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=agent_provider_name,
model_name=agent_model_name,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
)
),
deduct_quota=False
)
except ProviderTokenNotInitError as e:
summary_model_instance = None
tools = self.to_tools(
agent_model_instance=agent_model_instance,
tool_configs=tool_configs,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
@@ -140,11 +145,12 @@ class OrchestratorRuleParser:
return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
@@ -162,7 +168,7 @@ class OrchestratorRuleParser:
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool()
tool = self.to_web_reader_tool(agent_model_instance)
elif tool_type == "google_search":
tool = self.to_google_search_tool()
elif tool_type == "wikipedia":
@@ -207,24 +213,28 @@ class OrchestratorRuleParser:
return tool
def to_web_reader_tool(self) -> Optional[BaseTool]:
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
"""
A tool for reading web pages
:return:
"""
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
try:
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=agent_model_instance.model_provider.provider_name,
model_name=agent_model_instance.name,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
),
deduct_quota=False
)
)
summary_llm = summary_model_instance.client
except ProviderTokenNotInitError:
summary_model_instance = None
tool = WebReaderTool(
llm=summary_llm,
llm=summary_model_instance.client if summary_model_instance else None,
max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
@@ -252,11 +262,7 @@ class OrchestratorRuleParser:
return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]:
tool = Tool(
name="current_datetime",
description="A tool when you want to get the current date, time, week, month or year, "
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
func=helper.get_current_datetime,
tool = DatetimeTool(
callbacks=[DifyStdOutCallbackHandler()]
)

Some files were not shown because too many files have changed in this diff Show More