Compare commits

...

68 Commits

Author SHA1 Message Date
StyleZhang
2b17fe2f52 remove console 2023-09-07 16:43:32 +08:00
StyleZhang
078614f80c fix: onMessageEnd 2023-09-07 16:32:46 +08:00
StyleZhang
fd492f8fec replace onCitatin to onMessageEnd 2023-09-07 14:06:37 +08:00
StyleZhang
8ac9d98fd8 fix 2023-09-07 13:21:23 +08:00
StyleZhang
12449524f1 Merge branch 'main' into feat/chat-add-origin 2023-09-07 09:34:59 +08:00
StyleZhang
5fe2c41259 feat: chat add citation 2023-09-07 09:34:26 +08:00
Joel
0708bd60ee fix: try to fix chunk load error (#1109) 2023-09-06 15:47:53 +08:00
Joel
23a6c85b80 chore: handle workspace apps scrollbar (#1101) 2023-09-05 15:56:21 +08:00
bowen
4a28599fbd fix: optimize feedback and app icon (#1099) 2023-09-05 09:13:59 +08:00
seewhy
7c66d3c793 feat: Optimize the description for Azure deployment name (#1091) 2023-09-04 14:26:22 +08:00
Joel
cc9edfffd8 fix: markdown code lang capitalization and line number color (#1098) 2023-09-04 11:31:25 +08:00
Joel
6fa2454c9a fix: change frontend start script (#1096) 2023-09-04 11:10:32 +08:00
crazywoola
487e699021 fix: ui in chat openning statement (#1094) 2023-09-04 10:26:46 +08:00
takatost
a7cdb745c1 feat: support spark v2 validate (#1086) 2023-09-01 20:53:32 +08:00
takatost
73c86ee6a0 fix: prompt of title generation (#1084) 2023-09-01 14:55:58 +08:00
takatost
48eb590065 feat: optimize last_active_at update (#1083) 2023-09-01 13:58:26 +08:00
takatost
33562a9d8d feat: optimize prompt (#1080) 2023-09-01 11:46:06 +08:00
Rhon Joe
c9194ba382 chore(api): api image multistage build (#1069) 2023-09-01 11:13:22 +08:00
takatost
a199fa6388 feat: optimize high load sql query of document segment (#1078) 2023-09-01 10:52:39 +08:00
takatost
4c8608dc61 feat: optimize conversation title generation output must be a valid JSON (#1077) 2023-09-01 10:31:42 +08:00
Garfield Dai
a6b0f788e7 feat: add visual studio code debug config. (#1068)
Co-authored-by: Keruberosu <631677014@qq.com>
2023-09-01 09:15:06 +08:00
takatost
df6604a734 feat: optimize generation of conversation title (#1075) 2023-09-01 02:28:37 +08:00
takatost
1ca86cf9ce feat: bump version to 0.3.19 (#1074) 2023-08-31 21:42:58 +08:00
takatost
78e26f8b75 fix: summary no docs (#1073) 2023-08-31 20:19:26 +08:00
takatost
2191312bb9 fix: segments query missing idx hit (#1072) 2023-08-31 19:39:44 +08:00
takatost
fcc6b41ab7 feat: decrease claude model request time by set max top_k to 10 (#1071) 2023-08-31 18:23:44 +08:00
Joel
9458b8978f feat: siderbar operation support portal (#1061) 2023-08-31 17:46:51 +08:00
takatost
d75e8aeafa feat: disable anthropic retry (#1067) 2023-08-31 16:44:46 +08:00
takatost
2eba98a465 feat: optimize anthropic connection pool (#1066) 2023-08-31 16:18:59 +08:00
takatost
a7a7aab7a0 fix: csv import error (#1063) 2023-08-31 15:42:28 +08:00
crazywoola
86bfbb47d5 chore: doc issue (#1062) 2023-08-31 14:54:16 +08:00
yezhwi
d33a269548 refactor(file extractor): file extractor (#1059) 2023-08-31 14:45:31 +08:00
Matri
d3f8ea2df0 Feat/support to invite multiple users (#1011) 2023-08-31 01:18:31 +08:00
Jyong
7df56ed617 fix error weaviate vector (#1058)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-30 20:34:17 +08:00
Joel
e34dcc0406 feat: code support copy (#1057) 2023-08-30 18:08:47 +08:00
Joel
a834ba8759 feat: support rename conversation (#1056) 2023-08-30 17:32:32 +08:00
KVOJJJin
c67f345d0e Fix: disable operations of dataset when embedding unavailable (#1055)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-30 17:27:19 +08:00
yezhwi
8b8e510bfe fix: handle AttributeError for datasets and index (#1052) 2023-08-30 11:14:16 +08:00
crazywoola
3db839a5cb 773 change embed title welcome to use (#1053) 2023-08-30 11:03:25 +08:00
takatost
417c19577a feat: add LocalAI local embedding model support (#1021)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-29 22:22:02 +08:00
Jyong
b5953039de recreate qdrant vector (#1049)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 15:00:36 +08:00
Jyong
a43e80dd9c add qdrant migration (#1046)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 10:37:04 +08:00
WangBooth
ad5f27bc5f fix openpyxl dimensions error (#1041) 2023-08-29 10:36:48 +08:00
Joel
05e0985f29 chore: match new dataset tool format (#1044) 2023-08-29 09:07:45 +08:00
takatost
7b3314c5db fix: dataset desc (#1045) 2023-08-29 09:07:27 +08:00
Jyong
a55ba6e614 Fix/ignore economy dataset (#1043)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 03:37:45 +08:00
bowen
f9bec1edf8 chore: perfect type definition (#1003) 2023-08-28 19:48:53 +08:00
Jyong
16199e968e fix notion import limit check (#1042)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-28 16:49:03 +08:00
takatost
02452421d5 fix: pub generate message text return null (#1037) 2023-08-28 16:43:54 +08:00
zxhlyh
3a5c7c75ad Fix/model selector (#1032) 2023-08-28 10:54:41 +08:00
zxhlyh
a7415ecfd8 Fix/upload document limit (#1033) 2023-08-28 10:53:45 +08:00
KVOJJJin
934def5fcc Fix: eslint (#1030) 2023-08-27 17:06:16 +08:00
takatost
0796791de5 feat: hf inference endpoint stream support (#1028) 2023-08-26 19:48:34 +08:00
takatost
6c148b223d fix: dataset query truncated (#1026) 2023-08-26 17:35:17 +08:00
zxhlyh
4b168f4838 fix: maintenance notice (#1025) 2023-08-26 16:09:55 +08:00
takatost
1c114eaef3 feat: update contributing (#1020) 2023-08-25 21:19:13 +08:00
Jyong
e053215155 fix document estimate parameter (#1019)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 20:10:08 +08:00
zxhlyh
13482b0fc1 feat: maintenance notice (#1016) 2023-08-25 19:38:52 +08:00
Jyong
38fa152cc4 fix update document index technique (#1018)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 18:29:55 +08:00
Uranus
2d9616c29c fix: xinference last token being ignored (#1013) 2023-08-25 18:15:05 +08:00
Jyong
915e26527b update dataset index struct (#1012)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:52:33 +08:00
Jyong
2d604d9330 Fix/filter empty segment (#1004)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:50:29 +08:00
Jyong
e7199826cc embedding model available check (#1009)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 00:25:16 +08:00
crazywoola
70e24b7594 fix: loading and calc rem (#1006) 2023-08-24 23:24:33 +08:00
yezhwi
c1602aafc7 refactor:cache in place & function name (#1001) 2023-08-24 22:54:21 +08:00
crazywoola
a3fec11438 fix: styles (#1005) 2023-08-24 22:37:46 +08:00
Jyong
b1fd1b3ab3 Feat/vector db manage (#997)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:31 +08:00
Jyong
5397799aac document limit (#999)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:13 +08:00
353 changed files with 9956 additions and 2626 deletions

View File

@@ -20,7 +20,8 @@ def check_file_for_chinese_comments(file_path):
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
'prompts.py']
for root, _, files in os.walk("."):
for file in files:

3
.gitignore vendored
View File

@@ -149,4 +149,5 @@ sdks/python-client/build
sdks/python-client/dist
sdks/python-client/dify_client.egg-info
.vscode/
.vscode/*
!.vscode/launch.json

27
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,27 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Flask",
"type": "python",
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "api/app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"args": [
"run",
"--host=0.0.0.0",
"--port=5001",
"--debug"
],
"jinja": true,
"justMyCode": true
}
]
}

View File

@@ -53,9 +53,9 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
## Community channels
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help!
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
### i18n (Internationalization) Support
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.

View File

@@ -16,15 +16,15 @@
## 本地开发
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose 堆栈
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
### Fork存储库
您需要 fork [存储](https://github.com/langgenius/dify)。
您需要 fork [Git 仓](https://github.com/langgenius/dify)。
### 克隆存储库
克隆您在 GitHub 上 fork 的存储库:
克隆您在 GitHub 上 fork 的库:
```
git clone git@github.com:<github_username>/dify.git

View File

@@ -52,4 +52,4 @@ git clone git@github.com:<github_username>/dify.git
## コミュニティチャンネル
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/AhzKf7dNgk)に参加してください。私たちがお手伝いします!
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!

View File

@@ -1,7 +1,18 @@
FROM python:3.10-slim
# packages install stage
FROM python:3.10-slim AS base
LABEL maintainer="takatost@gmail.com"
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ python3-dev libc-dev libffi-dev
COPY requirements.txt /requirements.txt
RUN pip install --prefix=/pkg -r requirements.txt
# build stage
FROM python:3.10-slim AS builder
ENV FLASK_APP app.py
ENV EDITION SELF_HOSTED
ENV DEPLOY_ENV PRODUCTION
@@ -15,13 +26,12 @@ EXPOSE 5001
WORKDIR /app/api
RUN apt-get update && \
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
COPY requirements.txt /app/api/requirements.txt
RUN pip install -r requirements.txt
RUN apt-get update \
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs \
&& apt-get autoremove \
&& rm -rf /var/lib/apt/lists/*
COPY --from=base /pkg /usr/local
COPY . /app/api/
COPY docker/entrypoint.sh /entrypoint.sh

View File

@@ -54,9 +54,11 @@
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
8. If you need to debug local async processing, you can run `celery -A app.celery worker -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
8. Start frontend:
8. Start frontend
You can start the frontend by running `npm install && npm run dev` in web/ folder, or you can use docker to start the frontend, for example:
```
docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5000 --name web-self-hosted langgenius/dify-web:latest
docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5001 --name web-self-hosted langgenius/dify-web:latest
```
This will start a dify frontend, now you are all set, happy coding!

View File

@@ -1,6 +1,6 @@
# -*- coding:utf-8 -*-
import os
from datetime import datetime
from datetime import datetime, timedelta
from werkzeug.exceptions import Forbidden
@@ -145,8 +145,12 @@ def load_user(user_id):
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
account.last_active_at = datetime.utcnow()
db.session.commit()
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)

View File

@@ -1,4 +1,5 @@
import datetime
import json
import math
import random
import string
@@ -6,10 +7,16 @@ import time
import click
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
@@ -296,6 +303,142 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
click.echo(click.style('Start create qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.SYSTEM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
def update_qdrant_indexes():
click.echo(click.style('Start Update qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Update dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.update_qdrant_dataset(dataset)
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@@ -304,3 +447,5 @@ def register_commands(app):
app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)

View File

@@ -100,7 +100,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.18"
self.CURRENT_VERSION = "0.3.19"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')

View File

@@ -87,13 +87,19 @@ class DatasetListApi(Resource):
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
model_names = [item['model_name'] for item in valid_model_list]
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['embedding_model'] in model_names:
item['embedding_available'] = True
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else:
item['embedding_available'] = False
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
@@ -119,14 +125,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
dataset = DatasetService.create_empty_dataset(
@@ -150,20 +148,39 @@ class DatasetApi(Resource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(
dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return marshal(dataset, dataset_detail_fields), 200
data = marshal(dataset, dataset_detail_fields)
# check embedding setting
provider_service = ProviderService()
# get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data['embedding_available'] = True
else:
data['embedding_available'] = False
else:
data['embedding_available'] = True
return data, 200
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False,
@@ -251,6 +268,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
@@ -272,7 +290,8 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -287,7 +306,8 @@ class DatasetIndexingEstimateApi(Resource):
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "

View File

@@ -3,7 +3,7 @@ import random
from datetime import datetime
from typing import List
from flask import request
from flask import request, current_app
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
@@ -138,6 +138,10 @@ class GetProcessRuleApi(Resource):
req_data = request.args
document_id = req_data.get('document_id')
# get default rules
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
@@ -158,11 +162,9 @@ class GetProcessRuleApi(Resource):
order_by(DatasetProcessRule.created_at.desc()). \
limit(1). \
one_or_none()
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
else:
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
if dataset_process_rule:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
return {
'mode': mode,
@@ -275,7 +277,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
@@ -284,20 +287,6 @@ class DatasetDocumentListApi(Resource):
# validate args
DocumentService.document_create_args_validate(args)
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex:
@@ -335,17 +324,20 @@ class DatasetInitApi(Resource):
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
if args['indexing_technique'] == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
DocumentService.document_create_args_validate(args)
@@ -414,7 +406,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict, None, dataset_id)
data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -483,7 +476,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict, None, dataset_id)
data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -497,7 +491,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict,
None, dataset_id)
None, 'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -725,6 +719,12 @@ class DocumentDeleteApi(DocumentResource):
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
@@ -787,6 +787,12 @@ class DocumentStatusApi(DocumentResource):
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner
@@ -855,6 +861,14 @@ class DocumentStatusApi(DocumentResource):
if not document.archived:
raise InvalidActionError('Document is not archived.')
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + 1
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
document.archived = False
document.archived_at = None
document.archived_by = None
@@ -872,6 +886,10 @@ class DocumentStatusApi(DocumentResource):
class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
@@ -901,6 +919,9 @@ class DocumentPauseApi(DocumentResource):
class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
@@ -926,6 +947,21 @@ class DocumentRecoverApi(DocumentResource):
return {'result': 'success'}, 204
class DocumentLimitApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""get document limit"""
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
return {
'documents_count': documents_count,
'documents_limit': tenant_document_count
}, 200
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@@ -951,3 +987,4 @@ api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentLimitApi, '/datasets/limit')

View File

@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# get file from request
file = request.files['file']
# check file

View File

@@ -83,7 +83,7 @@ class FileApi(Resource):
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension not in ALLOWED_EXTENSIONS:
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
# user uuid as file name
@@ -136,7 +136,7 @@ class FilePreviewApi(Resource):
# extract text from file
extension = upload_file.extension
if extension not in ALLOWED_EXTENSIONS:
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)

View File

@@ -49,46 +49,43 @@ class MemberInviteEmailApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=str, required=True, location='json')
parser.add_argument('emails', type=str, required=True, location='json', action='append')
parser.add_argument('role', type=str, required=True, default='admin', location='json')
args = parser.parse_args()
invitee_email = args['email']
invitee_emails = args['emails']
invitee_role = args['role']
if invitee_role not in ['admin', 'normal']:
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
inviter = current_user
try:
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
inviter=inviter)
account = db.session.query(Account, TenantAccountJoin.role).join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
).filter(Account.email == args['email']).first()
account, role = account
account = marshal(account, account_fields)
account['role'] = role
except services.errors.account.CannotOperateSelfError as e:
return {'code': 'cannot-operate-self', 'message': str(e)}, 400
except services.errors.account.NoPermissionError as e:
return {'code': 'forbidden', 'message': str(e)}, 403
except services.errors.account.AccountAlreadyInTenantError as e:
return {'code': 'email-taken', 'message': str(e)}, 409
except Exception as e:
return {'code': 'unexpected-error', 'message': str(e)}, 500
# todo:413
invitation_results = []
console_web_url = current_app.config.get("CONSOLE_WEB_URL")
for invitee_email in invitee_emails:
try:
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
inviter=inviter)
account = db.session.query(Account, TenantAccountJoin.role).join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
).filter(Account.email == invitee_email).first()
account, role = account
invitation_results.append({
'status': 'success',
'email': invitee_email,
'url': f'{console_web_url}/activate?workspace_id={current_user.current_tenant_id}&email={invitee_email}&token={token}'
})
account = marshal(account, account_fields)
account['role'] = role
except Exception as e:
invitation_results.append({
'status': 'failed',
'email': invitee_email,
'message': str(e)
})
return {
'result': 'success',
'account': account,
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
current_app.config.get("CONSOLE_WEB_URL"),
str(current_user.current_tenant_id),
invitee_email,
token
)
'invitation_results': invitation_results,
}, 201

View File

@@ -52,7 +52,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
@@ -60,7 +60,13 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": observation}, log=observation)
try:
return super().plan(intermediate_steps, callbacks, **kwargs)
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception

View File

@@ -45,7 +45,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
self.llm.max_tokens = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
@@ -97,6 +97,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
@classmethod

View File

@@ -90,7 +90,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
@@ -102,7 +102,13 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
raise new_exception
try:
return self.output_parser.parse(full_output)
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")

View File

@@ -106,7 +106,13 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
raise new_exception
try:
return self.output_parser.parse(full_output)
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")

View File

@@ -1,5 +1,6 @@
import json
import logging
from json import JSONDecodeError
from typing import Any, Dict, List, Union, Optional
@@ -44,10 +45,15 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str,
**kwargs: Any,
) -> None:
# tool_name = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id')
query = input_dict.get('query')
tool_name: str = serialized.get('name')
dataset_id = tool_name.removeprefix('dataset-')
try:
input_dict = json.loads(input_str.replace("'", "\""))
query = input_dict.get('query')
except JSONDecodeError:
query = input_str
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(

View File

@@ -137,7 +137,8 @@ class ConversationMessageTask:
db.session.flush()
def append_message_text(self, text: str):
self._pub_handler.pub_text(text)
if text is not None:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens

View File

@@ -6,7 +6,7 @@ import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv import CSVLoader
from core.data_loader.loader.csv_loader import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
@@ -47,17 +47,18 @@ class FileExtractor:
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
input_file = Path(file_path)
delimiter = '\n'
if input_file.suffix == '.xlsx':
file_extension = input_file.suffix.lower()
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
elif file_extension == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt

View File

@@ -1,10 +1,10 @@
import logging
import csv
from typing import Optional, Dict, List
from langchain.document_loaders import CSVLoader as LCCSVLoader
from langchain.document_loaders.helpers import detect_file_encodings
from models.dataset import Document
from langchain.schema import Document
logger = logging.getLogger(__name__)

View File

@@ -30,6 +30,8 @@ class ExcelLoader(BaseLoader):
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
if 'A1:A1' == sheet.calculate_dimension():
sheet.reset_dimensions()
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
@@ -38,7 +40,7 @@ class ExcelLoader(BaseLoader):
else:
row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
document = Document(page_content=item, metadata={'source': self._file_path})
data.append(document)

View File

@@ -67,12 +67,13 @@ class DatesetDocumentStore:
if max_position is None:
max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider,
model_name=self._dataset.embedding_model
)
embedding_model = None
if self._dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider,
model_name=self._dataset.embedding_model
)
for doc in docs:
if not isinstance(doc, Document):
@@ -88,7 +89,7 @@ class DatesetDocumentStore:
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(doc.page_content)
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
if not segment_document:
max_position += 1

View File

@@ -1,3 +1,4 @@
import json
import logging
from langchain.schema import OutputParserException
@@ -22,18 +23,25 @@ class LLMGenerator:
if len(query) > 2000:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
query = query.replace("\n", " ")
prompt += query + "\n"
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=50
temperature=1,
max_tokens=100
)
)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
result_dict = json.loads(answer)
answer = result_dict['Your Output']
return answer.strip()
@classmethod

View File

@@ -1,10 +1,18 @@
import json
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.dataset import Dataset
from models.provider import Provider, ProviderType
class IndexBuilder:
@@ -35,4 +43,13 @@ class IndexBuilder:
)
)
else:
raise ValueError('Unknown indexing technique')
raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)

View File

@@ -25,7 +25,7 @@ class KeywordTableIndex(BaseIndex):
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
@@ -52,7 +52,7 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
@@ -199,15 +199,18 @@ class KeywordTableIndex(BaseIndex):
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: List[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(node_id, keywords)
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)

View File

@@ -15,12 +15,12 @@ from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@@ -143,7 +143,7 @@ class BaseVectorIndex(BaseIndex):
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
@@ -173,3 +173,73 @@ class BaseVectorIndex(BaseIndex):
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
def create_qdrant_dataset(self, dataset: Dataset):
logging.info(f"create_qdrant_dataset {dataset.id}")
try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def update_qdrant_dataset(self, dataset: Dataset):
logging.info(f"update_qdrant_dataset {dataset.id}")
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).first()
if segment:
try:
exist = self.text_exists(segment.index_node_id)
if exist:
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@@ -0,0 +1,114 @@
from typing import Optional, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore, milvus
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
endpoint: str
user: str
password: str
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config MILVUS_ENDPOINT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def get_type(self) -> str:
return 'milvus'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False

File diff suppressed because it is too large Load Diff

View File

@@ -44,15 +44,20 @@ class QdrantVectorIndex(BaseVectorIndex):
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
@@ -62,7 +67,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
content_payload_key='page_content',
**self._client_config.to_qdrant_params()
)
@@ -72,7 +77,9 @@ class QdrantVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
@@ -81,7 +88,7 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
content_payload_key='page_content'
)
def _get_vector_store_class(self) -> type:
@@ -108,8 +115,8 @@ class QdrantVectorIndex(BaseVectorIndex):
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if class_prefix.startswith('Vector_'):
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True

View File

@@ -217,25 +217,29 @@ class IndexingRunner:
db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0
preview_texts = []
total_segments = 0
@@ -263,8 +267,8 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
@@ -286,32 +290,35 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts
}
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion
tokens = 0
preview_texts = []
@@ -356,8 +363,8 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += embedding_model.get_num_tokens(document.page_content)
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(document.page_content)
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
@@ -379,8 +386,8 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts
}
@@ -399,7 +406,8 @@ class IndexingRunner:
filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none()
text_docs = FileExtractor.load(file_detail)
if file_detail:
text_docs = FileExtractor.load(file_detail)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
@@ -525,12 +533,13 @@ class IndexingRunner:
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
split_documents.append(document_node)
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == 'qa_model':
@@ -656,12 +665,13 @@ class IndexingRunner:
"""
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
@@ -671,11 +681,11 @@ class IndexingRunner:
# check document is paused
self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents
)
if dataset.indexing_technique == 'high_quality' or embedding_model:
tokens += sum(
embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents
)
# save vector index
if vector_index:

View File

@@ -63,6 +63,9 @@ class ModelProviderFactory:
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
else:
raise NotImplementedError

View File

@@ -0,0 +1,29 @@
from langchain.embeddings import LocalAIEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class LocalAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = LocalAIEmbeddings(
model=name,
openai_api_key="1",
openai_api_base=credentials['server_url'],
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
else:
return ex

View File

@@ -1,11 +1,8 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import anthropic
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
@@ -13,6 +10,7 @@ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM
class AnthropicModel(BaseLLM):
@@ -20,7 +18,7 @@ class AnthropicModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatAnthropic(
return AnthropicLLM(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
@@ -75,7 +73,7 @@ class AnthropicModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
**kwargs
)
except Exception as ex:
@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
if self.streaming and not self.support_streaming:
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return False
def get_prompt(self, mode: str,

View File

@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
streaming = self.streaming
if 'baichuan' in self.name.lower():
streaming = False
client = HuggingFaceEndpointLLM(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks
callbacks=self.callbacks,
streaming=streaming
)
else:
client = HuggingFaceHub(
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False
@property
def support_streaming(self):
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'baichuan' in self.name.lower():
return False
return True

View File

@@ -0,0 +1,131 @@
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class LocalAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
if credentials['completion_type'] == 'chat_completion':
self.model_mode = ModelMode.CHAT
else:
self.model_mode = ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1',
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1'
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to LocalAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to LocalAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("LocalAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True
# def is_model_valid_or_raise(self):

View File

@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference: {str(ex)}")
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -5,7 +5,6 @@ from typing import Type, Optional
import anthropic
from flask import current_app
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core.helper import encrypter
@@ -16,6 +15,7 @@ from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM
from models.provider import ProviderType
@@ -92,7 +92,7 @@ class AnthropicProvider(BaseModelProvider):
if 'anthropic_api_url' in credentials:
credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
chat_llm = ChatAnthropic(
chat_llm = AnthropicLLM(
model='claude-instant-1',
max_tokens_to_sample=10,
temperature=0,

View File

@@ -0,0 +1,164 @@
import json
from typing import Type
from langchain.embeddings import LocalAIEmbeddings
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from models.provider import ProviderType
class LocalAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'localai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = LocalAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = LocalAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
try:
if model_type == ModelType.EMBEDDINGS:
model = LocalAIEmbeddings(
model=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url']
)
model.embed_query("ping")
else:
if ('completion_type' not in credentials
or credentials['completion_type'] not in ['completion', 'chat_completion']):
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
if credentials['completion_type'] == 'chat_completion':
model = EnhanceChatOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model([HumanMessage(content='ping')])
else:
model = EnhanceOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model('ping')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider):
if 'api_secret' not in credentials:
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
try:
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
try:
chat_llm = ChatSpark(
model_name='spark-v2',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
# try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex

View File

@@ -10,5 +10,6 @@
"replicate",
"huggingface_hub",
"xinference",
"openllm"
"openllm",
"localai"
]

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -283,6 +283,7 @@ class OrchestratorRuleParser:
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MAX_K = 10
if rest_tokens == -1:
return DEFAULT_K
@@ -311,5 +312,5 @@ class OrchestratorRuleParser:
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
return min(context_limit_tokens // segment_max_tokens, MAX_K)

View File

@@ -1,10 +1,65 @@
CONVERSATION_TITLE_PROMPT = (
"Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
"If what the human said is conducted in English, you should only return an English title.\n"
"If what the human said is conducted in Chinese, you should only return a Chinese title.\n"
"title:"
)
# Written by YORKI MINAKO🤡
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
Notice: the language type user use could be diverse, which can be English, Chinese, Español, Arabic, Japanese, French, and etc.
MAKE SURE your output is the SAME language as the user's input!
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
Your output MUST be a valid JSON.
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
example 1:
User Input: hi, yesterday i had some burgers.
{
"Language Type": "The user's input is pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "sharing yesterday's food"
}
example 2:
User Input: hello
{
"Language Type": "The user's input is written in pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Greeting myself☺"
}
example 3:
User Input: why mmap file: oom
{
"Language Type": "The user's input is written in pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Asking about the reason for mmap file: oom"
}
example 4:
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么
{
"Language Type": "The user's input English-Chinese mixed",
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
}
example 5:
User Input: why小红的年龄is老than小明
{
"Language Type": "The user's input is English-Chinese mixed",
"Your Reasoning": "The English parts are subjective particles, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
"Your Output": "询问小红和小明的年龄"
}
example 6:
User Input: yo, 你今天咋样?
{
"Language Type": "The user's input is English-Chinese mixed",
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
"Your Output": "查询今日我的状态☺️"
}
User Input:
"""
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"

View File

@@ -0,0 +1,48 @@
from typing import Dict
from httpx import Limits
from langchain.chat_models import ChatAnthropic
from langchain.utils import get_from_dict_or_env, check_package_version
from pydantic import root_validator
class AnthropicLLM(ChatAnthropic):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = get_from_dict_or_env(
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
)
# Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env(
values,
"anthropic_api_url",
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
)
try:
import anthropic
check_package_version("anthropic", gte_version="0.3")
values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"],
timeout=values["default_request_timeout"],
max_retries=0,
connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
)
values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"],
timeout=values["default_request_timeout"],
)
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"Please it install it with `pip install anthropic`."
)
return values

View File

@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,

View File

@@ -1,7 +1,11 @@
from typing import Dict
from typing import Dict, Any, Optional, List, Iterable, Iterator
from huggingface_hub import InferenceClient
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.embeddings.huggingface_hub import VALID_TASKS
from langchain.llms import HuggingFaceEndpoint
from pydantic import Extra, root_validator
from langchain.llms.utils import enforce_stop_tokens
from pydantic import root_validator
from langchain.utils import get_from_dict_or_env
@@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
huggingfacehub_api_token="my-api-key"
)
"""
client: Any
streaming: bool = False
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
@@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values['client'] = InferenceClient(values['endpoint_url'], token=huggingfacehub_api_token)
values["huggingfacehub_api_token"] = huggingfacehub_api_token
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
# payload samples
params = {**_model_kwargs, **kwargs}
# generation parameter
gen_kwargs = {
**params,
'stop_sequences': stop
}
response = self.client.text_generation(prompt, stream=self.streaming, details=True, **gen_kwargs)
if self.streaming and isinstance(response, Iterable):
combined_text_output = ""
for token in self._stream_response(response, run_manager):
combined_text_output += token
completion = combined_text_output
else:
completion = response.generated_text
if self.task == "text-generation":
text = completion
# Remove prompt if included in generated text.
if text.startswith(prompt):
text = text[len(prompt) :]
elif self.task == "text2text-generation":
text = completion
else:
raise ValueError(
f"Got invalid task {self.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
def _stream_response(
self,
response: Iterable,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> Iterator[str]:
for r in response:
# skip special tokens
if r.token.special:
continue
token = r.token.text
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=None
)
# yield the generated token
yield token

View File

@@ -1,7 +1,10 @@
import os
from typing import Dict, Any, Mapping, Optional, Union, Tuple
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
from langchain import OpenAI
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
from langchain.schema.output import GenerationChunk
from pydantic import root_validator
@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
if 'text' in stream_resp["choices"][0]:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)

View File

@@ -3,17 +3,20 @@ from typing import Optional, List, Any, Union, Generator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import RESTfulChatglmCppChatModelHandle, \
RESTfulChatModelHandle, RESTfulGenerateModelHandle
from xinference.client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
)
class XinferenceLLM(Xinference):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.
@@ -29,7 +32,9 @@ class XinferenceLLM(Xinference):
model = self.client.get_model(self.model_uid)
if isinstance(model, RESTfulChatModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if stop:
generate_config["stop"] = stop
@@ -37,10 +42,10 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
@@ -48,7 +53,9 @@ class XinferenceLLM(Xinference):
completion = model.chat(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["message"]["content"]
elif isinstance(model, RESTfulGenerateModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if stop:
generate_config["stop"] = stop
@@ -56,27 +63,31 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
completion = model.generate(
prompt=prompt, generate_config=generate_config
)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config: "ChatglmCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
completion = combined_text_output
@@ -90,12 +101,21 @@ class XinferenceLLM(Xinference):
return completion
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
self,
model: Union[
"RESTfulGenerateModelHandle",
"RESTfulChatModelHandle",
"RESTfulChatglmCppChatModelHandle",
],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[
Union[
"LlamaCppGenerateConfig",
"PytorchGenerateConfig",
"ChatglmCppGenerateConfig",
]
] = None,
) -> Generator[str, None, None]:
"""
Args:
@@ -108,7 +128,9 @@ class XinferenceLLM(Xinference):
Yields:
A string token.
"""
if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)):
if isinstance(
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
):
streaming_response = model.chat(
prompt=prompt, generate_config=generate_config
)
@@ -123,14 +145,10 @@ class XinferenceLLM(Xinference):
if choices:
choice = choices[0]
if isinstance(choice, dict):
if 'finish_reason' in choice and choice['finish_reason'] \
and choice['finish_reason'] in ['stop', 'length']:
break
if 'text' in choice:
if "text" in choice:
token = choice.get("text", "")
elif 'delta' in choice and 'content' in choice['delta']:
token = choice.get('delta').get('content')
elif "delta" in choice and "content" in choice["delta"]:
token = choice.get("delta").get("content")
else:
continue
log_probs = choice.get("logprobs")

View File

@@ -1,4 +1,3 @@
import re
from typing import Type
from flask import current_app
@@ -16,7 +15,6 @@ from models.dataset import Dataset, DocumentSegment
class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
@@ -37,27 +35,22 @@ class DatasetRetrieverTool(BaseTool):
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
description += '\nID of dataset MUST be ' + dataset.id
return cls(
name=f'dataset-{dataset.id}',
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs
)
def _run(self, dataset_id: str, query: str) -> str:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, dataset_id, re.IGNORECASE)
if match:
dataset_id = match.group()
def _run(self, query: str) -> str:
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
Dataset.id == self.dataset_id
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
if dataset.indexing_technique == "economy":
# use keyword table query
@@ -105,7 +98,8 @@ class DatasetRetrieverTool(BaseTool):
hit_callback.on_tool_end(documents)
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)

View File

@@ -88,11 +88,9 @@ class WebReaderTool(BaseTool):
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
if len(docs) == 0:
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
return "No content found."
docs = docs[1:]
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]

View File

@@ -0,0 +1,38 @@
from langchain.vectorstores import Milvus
class MilvusVectorStore(Milvus):
def del_texts(self, where_filter: dict):
if not where_filter:
raise ValueError('where_filter must not be empty')
self._client.batch.delete_objects(
class_name=self._index_name,
where=where_filter,
output='minimal'
)
def del_text(self, uuid: str) -> None:
self._client.data_object.delete(
uuid,
class_name=self._index_name
)
def text_exists(self, uuid: str) -> bool:
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": uuid,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][self._index_name]
if len(entries) == 0:
return False
return True
def delete(self):
self._client.schema.delete_class(self._index_name)

View File

@@ -1,10 +1,11 @@
from typing import cast, Any
from langchain.schema import Document
from langchain.vectorstores import Qdrant
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal
from core.index.vector_index.qdrant import Qdrant
class QdrantVectorStore(Qdrant):
def del_texts(self, filter: Filter):

View File

@@ -1,6 +1,5 @@
from events.dataset_event import dataset_was_deleted
from events.event_handlers.document_index_event import document_index_created
from tasks.clean_dataset_task import clean_dataset_task
import datetime
import logging
import time

View File

@@ -26,7 +26,7 @@ def handle(sender, **kwargs):
conversation.name = name
except:
conversation.name = 'New Chat'
conversation.name = 'New conversation'
db.session.add(conversation)
db.session.commit()

View File

@@ -0,0 +1,46 @@
"""update_dataset_model_field_null_available
Revision ID: 4bcffcd64aa4
Revises: 853f9b9cd3b6
Create Date: 2023-08-28 20:58:50.077056
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4bcffcd64aa4'
down_revision = '853f9b9cd3b6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'openai'::character varying"))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'openai'::character varying"))
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
# ### end Alembic commands ###

View File

@@ -36,10 +36,8 @@ class Dataset(db.Model):
updated_by = db.Column(UUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying"))
embedding_model_provider = db.Column(db.String(
255), nullable=False, server_default=db.text("'openai'::character varying"))
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
@property
def dataset_keyword_table(self):

View File

@@ -10,6 +10,7 @@ from flask import current_app
from sqlalchemy import func
from core.index.index import IndexBuilder
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from flask_login import current_user
@@ -91,16 +92,18 @@ class DatasetService:
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
embedding_model = None
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
dataset.embedding_model = embedding_model.name if embedding_model else None
db.session.add(dataset)
db.session.commit()
return dataset
@@ -115,17 +118,50 @@ class DatasetService:
else:
return dataset
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: "
f"{ex.description}")
@staticmethod
def update_dataset(dataset_id, data, user):
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_permission(dataset, user)
action = None
if dataset.indexing_technique != data['indexing_technique']:
# if update indexing_technique
if data['indexing_technique'] == 'economy':
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
action = 'remove'
filtered_data['embedding_model'] = None
filtered_data['embedding_model_provider'] = None
elif data['indexing_technique'] == 'high_quality':
deal_dataset_vector_index_task.delay(dataset_id, 'add')
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
action = 'add'
# get embedding model setting
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
filtered_data['updated_by'] = user.id
filtered_data['updated_at'] = datetime.datetime.now()
@@ -133,7 +169,8 @@ class DatasetService:
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
@@ -394,16 +431,26 @@ class DocumentService:
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
if 'original_document_id' not in document_data or not document_data['original_document_id']:
count = 0
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
count = len(upload_file_list)
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + count
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = document_data["data_source"]["type"]
db.session.commit()
if not dataset.indexing_technique:
if 'indexing_technique' not in document_data \
@@ -411,6 +458,13 @@ class DocumentService:
raise ValueError("Indexing technique is required")
dataset.indexing_technique = document_data["indexing_technique"]
if document_data["indexing_technique"] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -455,12 +509,12 @@ class DocumentService:
data_source_info = {
"upload_file_id": file_id,
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -501,12 +555,12 @@ class DocumentService:
"notion_page_icon": page['page_icon'],
"type": page['type']
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -525,10 +579,10 @@ class DocumentService:
return documents, batch
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -557,6 +611,7 @@ class DocumentService:
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available':
raise ValueError("Document is not available")
@@ -649,15 +704,26 @@ class DocumentService:
@staticmethod
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
count = 0
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
count = len(upload_file_list)
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + count
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@@ -665,8 +731,8 @@ class DocumentService:
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name,
embedding_model_provider=embedding_model.model_provider.provider_name
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
)
db.session.add(dataset)
@@ -874,21 +940,25 @@ class SegmentService:
if document.doc_form == 'qa_model':
if 'answer' not in args or not args['answer']:
raise ValueError("Answer is required")
if not args['answer'].strip():
raise ValueError("Answer is empty")
if 'content' not in args or not args['content'] or not args['content'].strip():
raise ValueError("Content is empty")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
content = args['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
@@ -950,15 +1020,16 @@ class SegmentService:
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
@@ -990,10 +1061,11 @@ class SegmentService:
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is deleting.")
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
# enabled segment need to delete index
if segment.enabled:
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
db.session.commit()

View File

@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
raise ValueError('Document is not available.')
document_segments = []
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == dataset_document.id
).scalar()

View File

@@ -3,8 +3,10 @@ import time
import click
from celery import shared_task
from flask import current_app
from core.index.index import IndexBuilder
from core.index.vector_index.vector_index import VectorIndex
from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin, Document
@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try:
vector_index.delete()
except Exception:

View File

@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found')
if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
index.delete()
elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter(
@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
documents = []
for dataset_document in dataset_documents:
# delete from vector index
@@ -65,7 +65,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index.add_texts(documents)
index.create(documents)
end_at = time.perf_counter()
logging.info(

View File

@@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=
# OpenLLM Credentials
OPENLLM_SERVER_URL=
OPENLLM_SERVER_URL=
# LocalAI Credentials
LOCALAI_SERVER_URL=

View File

@@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
server_url = os.environ['LOCALAI_SERVER_URL']
model_provider = LocalAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url,
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return LocalAIEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)

View File

@@ -0,0 +1,68 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.localai_provider import LocalAIProvider
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider(server_url):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({}),
is_valid=True,
)
def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
server_url = os.environ['LOCALAI_SERVER_URL']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
return LocalAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View File

@@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i
def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc")
mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
@@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
)
def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(

View File

@@ -0,0 +1,116 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'localai'
MODEL_PROVIDER_CLASS = LocalAIProvider
VALIDATE_CREDENTIAL = {
'server_url': 'http://127.0.0.1:8080/'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='username/test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
def test_is_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if server_url is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials={}
)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt, mocker):
server_url = 'http://127.0.0.1:8080/'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', server_url)
assert result['server_url'] == f'encrypted_{server_url}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS
)
assert result['server_url'] == 'http://127.0.0.1:8080/'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
obfuscated=True
)
middle_token = result['server_url'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.3.18
image: langgenius/dify-api:0.3.19
restart: always
environment:
# Startup mode, 'api' starts the API server.
@@ -124,7 +124,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.3.18
image: langgenius/dify-api:0.3.19
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@@ -176,7 +176,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.3.18
image: langgenius/dify-web:0.3.19
restart: always
environment:
EDITION: SELF_HOSTED

View File

@@ -33,20 +33,18 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
if (!response)
return <Loading />
const handleError = (err: Error | null) => {
if (!err) {
notify({
type: 'success',
message: t('common.actionMsg.modifiedSuccessfully'),
})
const handleCallbackResult = (err: Error | null, message?: string) => {
const type = err ? 'error' : 'success'
message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully')
if (type === 'success') {
mutate(detailParams)
}
else {
notify({
type: 'error',
message: t('common.actionMsg.modificationFailed'),
})
}
notify({
type,
message: t(`common.actionMsg.${message}`),
})
}
const onChangeSiteStatus = async (value: boolean) => {
@@ -56,7 +54,8 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
body: { enable_site: value },
}) as Promise<App>,
)
handleError(err)
handleCallbackResult(err)
}
const onChangeApiStatus = async (value: boolean) => {
@@ -66,7 +65,8 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
body: { enable_api: value },
}) as Promise<App>,
)
handleError(err)
handleCallbackResult(err)
}
const onSaveSiteConfig: IAppCardProps['onSaveSiteConfig'] = async (params) => {
@@ -79,7 +79,7 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
if (!err)
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
handleError(err)
handleCallbackResult(err)
}
const onGenerateCode = async () => {
@@ -88,7 +88,8 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
url: `/apps/${appId}/site/access-token-reset`,
}) as Promise<UpdateAppSiteCodeResponse>,
)
handleError(err)
handleCallbackResult(err, err ? 'generatedUnsuccessfully' : 'generatedSuccessfully')
}
return (

View File

@@ -2,7 +2,7 @@ import React from 'react'
import ChartView from './chartView'
import CardView from './cardView'
import { getLocaleOnServer } from '@/i18n/server'
import { useTranslation } from '@/i18n/i18next-serverside-config'
import { useTranslation as translate } from '@/i18n/i18next-serverside-config'
import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel'
export type IDevelopProps = {
@@ -13,7 +13,11 @@ const Overview = async ({
params: { appId },
}: IDevelopProps) => {
const locale = getLocaleOnServer()
const { t } = await useTranslation(locale, 'app-overview')
/*
rename useTranslation to avoid lint error
please check: https://github.com/i18next/next-13-app-dir-i18next-example/issues/24
*/
const { t } = await translate(locale, 'app-overview')
return (
<div className="h-full px-16 py-6 overflow-scroll">
<ApikeyInfoPanel />

View File

@@ -9,6 +9,7 @@ import style from '../list.module.css'
import AppModeLabel from './AppModeLabel'
import s from './style.module.css'
import SettingsModal from '@/app/components/app/overview/settings'
import type { ConfigParams } from '@/app/components/app/overview/settings'
import type { App } from '@/types/app'
import Confirm from '@/app/components/base/confirm'
import { ToastContext } from '@/app/components/base/toast'
@@ -73,7 +74,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
}
const onSaveSiteConfig = useCallback(
async (params: any) => {
async (params: ConfigParams) => {
const [err] = await asyncRunSafe<App>(
updateAppSiteConfig({
url: `/apps/${app.id}/site`,
@@ -92,7 +93,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
else {
notify({
type: 'error',
message: t('common.actionMsg.modificationFailed'),
message: t('common.actionMsg.modifiedUnsuccessfully'),
})
}
},
@@ -100,12 +101,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
)
const Operations = (props: any) => {
const onClickSettings = async (e: any) => {
const onClickSettings = async (e: React.MouseEvent<HTMLButtonElement>) => {
props?.onClose()
e.preventDefault()
await getAppDetail()
}
const onClickDelete = async (e: any) => {
const onClickDelete = async (e: React.MouseEvent<HTMLDivElement>) => {
props?.onClose()
e.preventDefault()
setShowConfirmDelete(true)

View File

@@ -1,5 +1,5 @@
'use client'
import type { FC } from 'react'
import type { FC, SVGProps } from 'react'
import React, { useEffect } from 'react'
import { usePathname } from 'next/navigation'
import useSWR from 'swr'
@@ -57,7 +57,7 @@ const LikedItem: FC<{ type?: 'plugin' | 'app'; appStatus?: boolean; detail: Rela
)
}
const TargetIcon: FC<{ className?: string }> = ({ className }) => {
const TargetIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clip-path="url(#clip0_4610_6951)">
<path d="M10.6666 5.33325V3.33325L12.6666 1.33325L13.3332 2.66659L14.6666 3.33325L12.6666 5.33325H10.6666ZM10.6666 5.33325L7.9999 7.99988M14.6666 7.99992C14.6666 11.6818 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6818 1.33325 7.99992C1.33325 4.31802 4.31802 1.33325 7.99992 1.33325M11.3333 7.99992C11.3333 9.84087 9.84087 11.3333 7.99992 11.3333C6.15897 11.3333 4.66659 9.84087 4.66659 7.99992C4.66659 6.15897 6.15897 4.66659 7.99992 4.66659" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
@@ -70,7 +70,7 @@ const TargetIcon: FC<{ className?: string }> = ({ className }) => {
</svg>
}
const TargetSolidIcon: FC<{ className?: string }> = ({ className }) => {
const TargetSolidIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path fillRule="evenodd" clipRule="evenodd" d="M12.7733 0.67512C12.9848 0.709447 13.1669 0.843364 13.2627 1.03504L13.83 2.16961L14.9646 2.73689C15.1563 2.83273 15.2902 3.01486 15.3245 3.22639C15.3588 3.43792 15.2894 3.65305 15.1379 3.80458L13.1379 5.80458C13.0128 5.92961 12.8433 5.99985 12.6665 5.99985H10.9426L8.47124 8.47124C8.21089 8.73159 7.78878 8.73159 7.52843 8.47124C7.26808 8.21089 7.26808 7.78878 7.52843 7.52843L9.9998 5.05707V3.33318C9.9998 3.15637 10.07 2.9868 10.1951 2.86177L12.1951 0.861774C12.3466 0.710244 12.5617 0.640794 12.7733 0.67512Z" fill="#155EEF" />
<path d="M1.99984 7.99984C1.99984 4.68613 4.68613 1.99984 7.99984 1.99984C8.36803 1.99984 8.6665 1.70136 8.6665 1.33317C8.6665 0.964981 8.36803 0.666504 7.99984 0.666504C3.94975 0.666504 0.666504 3.94975 0.666504 7.99984C0.666504 12.0499 3.94975 15.3332 7.99984 15.3332C12.0499 15.3332 15.3332 12.0499 15.3332 7.99984C15.3332 7.63165 15.0347 7.33317 14.6665 7.33317C14.2983 7.33317 13.9998 7.63165 13.9998 7.99984C13.9998 11.3135 11.3135 13.9998 7.99984 13.9998C4.68613 13.9998 1.99984 11.3135 1.99984 7.99984Z" fill="#155EEF" />
@@ -78,7 +78,7 @@ const TargetSolidIcon: FC<{ className?: string }> = ({ className }) => {
</svg>
}
const BookOpenIcon: FC<{ className?: string }> = ({ className }) => {
const BookOpenIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path opacity="0.12" d="M1 3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7V10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1Z" fill="#155EEF" />
<path d="M6 10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7M6 10.5V4.7M6 10.5L6.05003 10.425C6.39735 9.90398 6.57101 9.64349 6.80045 9.45491C7.00357 9.28796 7.23762 9.1627 7.4892 9.0863C7.77337 9 8.08645 9 8.71259 9H9.4C9.96005 9 10.2401 9 10.454 8.89101C10.6422 8.79513 10.7951 8.64215 10.891 8.45399C11 8.24008 11 7.96005 11 7.4V3.1C11 2.53995 11 2.25992 10.891 2.04601C10.7951 1.85785 10.6422 1.70487 10.454 1.60899C10.2401 1.5 9.96005 1.5 9.4 1.5H9.2C8.07989 1.5 7.51984 1.5 7.09202 1.71799C6.71569 1.90973 6.40973 2.21569 6.21799 2.59202C6 3.01984 6 3.5799 6 4.7" stroke="#155EEF" strokeLinecap="round" strokeLinejoin="round" />

View File

@@ -6,9 +6,11 @@ const Layout: FC<{
children: React.ReactNode
}> = ({ children }) => {
return (
<div className="min-w-[300px]">
<GA gaType={GaType.webapp} />
{children}
<div className=''>
<div className="min-w-[300px]">
<GA gaType={GaType.webapp} />
{children}
</div>
</div>
)
}

View File

@@ -1,8 +1,9 @@
import React from 'react'
import type { FC } from 'react'
import NavLink from './navLink'
import AppBasic from './basic'
import type { NavIcon } from './navLink'
export type IAppDetailNavProps = {
iconType?: 'app' | 'dataset' | 'notion'
title: string
@@ -12,13 +13,13 @@ export type IAppDetailNavProps = {
navigation: Array<{
name: string
href: string
icon: any
selectedIcon: any
icon: NavIcon
selectedIcon: NavIcon
}>
extraInfo?: React.ReactNode
}
const AppDetailNav: FC<IAppDetailNavProps> = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }) => {
const AppDetailNav = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }: IAppDetailNavProps) => {
return (
<div className="flex flex-col w-56 overflow-y-auto bg-white border-r border-gray-200 shrink-0">
<div className="flex flex-shrink-0 p-4">

View File

@@ -1,17 +1,30 @@
'use client'
import { useSelectedLayoutSegment } from 'next/navigation'
import classNames from 'classnames'
import Link from 'next/link'
export type NavIcon = React.ComponentType<
React.PropsWithoutRef<React.ComponentProps<'svg'>> & {
title?: string | undefined
titleId?: string | undefined
}
>
export type NavLinkProps = {
name: string
href: string
iconMap: {
selected: NavIcon
normal: NavIcon
}
}
export default function NavLink({
name,
href,
iconMap,
}: {
name: string
href: string
iconMap: { selected: any; normal: any }
}) {
}: NavLinkProps) {
const segment = useSelectedLayoutSegment()
const isActive = href.toLowerCase().split('/')?.pop() === segment?.toLowerCase()
const NavIcon = isActive ? iconMap.selected : iconMap.normal

View File

@@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import { UserCircleIcon } from '@heroicons/react/24/solid'
import cn from 'classnames'
import type { DisplayScene, FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc, ThoughtItem } from '../type'
import type { CitationItem, DisplayScene, FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc, ThoughtItem } from '../type'
import OperationBtn from '../operation'
import LoadingAnim from '../loading-anim'
import { EditIcon, EditIconSolid, OpeningStatementIcon, RatingIcon } from '../icon-component'
@@ -13,6 +13,7 @@ import s from '../style.module.css'
import MoreInfo from '../more-info'
import CopyBtn from '../copy-btn'
import Thought from '../thought'
import Citation from '../citation'
import { randomString } from '@/utils'
import type { Annotation, MessageRating } from '@/models/log'
import AppContext from '@/context/app-context'
@@ -45,11 +46,14 @@ export type IAnswerProps = {
isResponsing?: boolean
answerIconClassName?: string
thoughts?: ThoughtItem[]
citation?: CitationItem[]
isThinking?: boolean
dataSets?: DataSet[]
isShowCitation?: boolean
isShowCitationHitInfo?: boolean
}
// The component needs to maintain its own state to control whether to display input component
const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing, answerIconClassName, thoughts, isThinking, dataSets }) => {
const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing, answerIconClassName, thoughts, citation, isThinking, dataSets, isShowCitation, isShowCitationHitInfo = false }) => {
const { id, content, more, feedback, adminFeedback, annotation: initAnnotation } = item
const [showEdit, setShowEdit] = useState(false)
const [loading, setLoading] = useState(false)
@@ -237,6 +241,11 @@ const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedba
</div>
</>
}
{
!!citation?.length && !isThinking && isShowCitation && !isResponsing && (
<Citation data={citation} showHitInfo={isShowCitationHitInfo} />
)
}
</div>
<div className='absolute top-[-14px] right-[-14px] flex flex-row justify-end gap-1'>
{!item.isOpeningStatement && (

View File

@@ -0,0 +1,65 @@
import { useMemo } from 'react'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import type { CitationItem } from '../type'
import Popup from './popup'
export type Resources = {
documentId: string
documentName: string
dataSourceType: string
sources: CitationItem[]
}
type CitationProps = {
data: CitationItem[]
showHitInfo?: boolean
}
const Citation: FC<CitationProps> = ({
data,
showHitInfo,
}) => {
const { t } = useTranslation()
const resources = useMemo(() => data.reduce((prev: Resources[], next) => {
const documentId = next.document_id
const documentName = next.document_name
const dataSourceType = next.data_source_type
const documentIndex = prev.findIndex(i => i.documentId === documentId)
if (documentIndex > -1) {
prev[documentIndex].sources.push(next)
}
else {
prev.push({
documentId,
documentName,
dataSourceType,
sources: [next],
})
}
return prev
}, []), [data])
return (
<div className='mt-3'>
<div className='flex items-center mb-2 text-xs font-medium text-gray-500'>
{t('common.chat.citation.title')}
<div className='grow ml-2 h-[1px] bg-black/5' />
</div>
<div className='flex'>
{
resources.map((res, index) => (
<Popup
key={index}
data={res}
showHitInfo={showHitInfo}
/>
))
}
</div>
</div>
)
}
export default Citation

View File

@@ -0,0 +1,113 @@
import { Fragment, useState } from 'react'
import type { FC } from 'react'
import Link from 'next/link'
import { useTranslation } from 'react-i18next'
import Tooltip from './tooltip'
import ProgressTooltip from './progress-tooltip'
import type { Resources } from './index'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import FileIcon from '@/app/components/base/file-icon'
import { Hash02 } from '@/app/components/base/icons/src/vender/line/general'
import { ArrowUpRight } from '@/app/components/base/icons/src/vender/line/arrows'
type PopupProps = {
data: Resources
showHitInfo?: boolean
}
const Popup: FC<PopupProps> = ({
data,
showHitInfo = false,
}) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const fileType = data.dataSourceType === 'upload_file'
? (/\.([^.]*)$/g.exec(data.documentName)?.[1] || '')
: 'notion'
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='top-start'
offset={{
mainAxis: 8,
crossAxis: -2,
}}
>
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<div className='flex items-center mr-1 px-2 max-w-[240px] h-7 bg-white rounded-lg'>
<FileIcon type={fileType} className='mr-1 w-4 h-4' />
<div className='text-xs text-gray-600 truncate'>{data.documentName}</div>
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent>
<div className='w-[360px] bg-gray-50 rounded-xl shadow-lg'>
<div className='px-4 pt-3 pb-2'>
<div className='flex items-center h-[18px]'>
<FileIcon type={fileType} className='mr-1 w-4 h-4' />
<div className='text-xs font-medium text-gray-600 truncate'>{data.documentName}</div>
</div>
</div>
<div className='px-4 py-0.5 max-h-[450px] bg-white rounded-lg overflow-auto'>
{
data.sources.map((source, index) => (
<Fragment key={index}>
<div className='group py-3'>
{
showHitInfo && (
<div className='flex items-center justify-between mb-2'>
<div className='flex items-center px-1.5 h-5 border border-gray-200 rounded-md'>
<Hash02 className='mr-0.5 w-3 h-3 text-gray-400' />
<div className='text-[11px] font-medium text-gray-500'>{source.segment_position}</div>
</div>
<Link
href={`/datasets/${source.dataset_id}/documents/${source.document_id}`}
className='hidden items-center h-[18px] text-xs text-primary-600 group-hover:flex'>
Link to dataset
<ArrowUpRight className='ml-1 w-3 h-3' />
</Link>
</div>
)
}
<div className='text-[13px] text-gray-800'>{source.content}</div>
{
showHitInfo && (
<div className='flex items-center mt-2 text-xs font-medium text-gray-500'>
<Tooltip
text={t('common.chat.citation.characters')}
data={source.word_count}
/>
<Tooltip
text={t('common.chat.citation.hitCount')}
data={source.hit_count}
/>
<Tooltip
text={t('common.chat.citation.vectorHash')}
data={source.index_node_hash.substring(0, 7)}
/>
<ProgressTooltip data={Number(source.score.toFixed(2))} />
</div>
)
}
</div>
{
index !== data.sources.length - 1 && (
<div className='my-1 h-[1px] bg-black/5' />
)
}
</Fragment>
))
}
</div>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default Popup

View File

@@ -0,0 +1,46 @@
import { useState } from 'react'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
type ProgressTooltipProps = {
data: number
}
const ProgressTooltip: FC<ProgressTooltipProps> = ({
data,
}) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='top-start'
>
<PortalToFollowElemTrigger
onMouseEnter={() => setOpen(true)}
onMouseLeave={() => setOpen(false)}
>
<div className='grow flex items-center'>
<div className='mr-1 w-16 h-1.5 rounded-[3px] border border-gray-400 overflow-hidden'>
<div className='bg-gray-400 h-full' style={{ width: `${data * 100}%` }}></div>
</div>
{data}
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent>
<div className='p-3 bg-white text-xs font-medium text-gray-500 rounded-lg shadow-lg'>
{t('common.chat.citation.hitScore')} {data}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default ProgressTooltip

View File

@@ -0,0 +1,45 @@
import { useState } from 'react'
import type { FC } from 'react'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import { TypeSquare } from '@/app/components/base/icons/src/vender/line/editor'
type TooltipProps = {
data: number | string
text: string
}
const Tooltip: FC<TooltipProps> = ({
data,
text,
}) => {
const [open, setOpen] = useState(false)
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='top-start'
>
<PortalToFollowElemTrigger
onMouseEnter={() => setOpen(true)}
onMouseLeave={() => setOpen(false)}
>
<div className='flex items-center mr-6'>
<TypeSquare className='mr-1 w-3 h-3' />
{data}
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent>
<div className='p-3 bg-white text-xs font-medium text-gray-500 rounded-lg shadow-lg'>
{text} {data}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default Tooltip

View File

@@ -8,32 +8,36 @@ import Tooltip from '@/app/components/base/tooltip'
type ICopyBtnProps = {
value: string
className?: string
isPlain?: boolean
}
const CopyBtn = ({
value,
className,
isPlain,
}: ICopyBtnProps) => {
const [isCopied, setIsCopied] = React.useState(false)
return (
<div className={`${className}`}>
<Tooltip
selector="copy-btn-tooltip"
selector={`copy-btn-tooltip-${value}`}
content={(isCopied ? t('appApi.copied') : t('appApi.copy')) as string}
className='z-10'
>
<div
className={'box-border p-0.5 flex items-center justify-center rounded-md bg-white cursor-pointer'}
style={{
boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)',
}}
style={!isPlain
? {
boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)',
}
: {}}
onClick={() => {
copy(value)
setIsCopied(true)
}}
>
<div className={`w-6 h-6 hover:bg-gray-50 ${s.copyIcon} ${isCopied ? s.copied : ''}`}></div>
<div className={`w-6 h-6 rounded-md hover:bg-gray-50 ${s.copyIcon} ${isCopied ? s.copied : ''}`}></div>
</div>
</Tooltip>
</div>

View File

@@ -1,4 +1,4 @@
import type { FC } from 'react'
import type { FC, SVGProps } from 'react'
import { HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline'
export const stopIcon = (
@@ -7,7 +7,7 @@ export const stopIcon = (
</svg>
)
export const OpeningStatementIcon: FC<{ className?: string }> = ({ className }) => (
export const OpeningStatementIcon = ({ className }: SVGProps<SVGElement>) => (
<svg className={className} width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fillRule="evenodd" clipRule="evenodd" d="M6.25002 1C3.62667 1 1.50002 3.12665 1.50002 5.75C1.50002 6.28 1.58702 6.79071 1.7479 7.26801C1.7762 7.35196 1.79285 7.40164 1.80368 7.43828L1.80722 7.45061L1.80535 7.45452C1.79249 7.48102 1.77339 7.51661 1.73766 7.58274L0.911727 9.11152C0.860537 9.20622 0.807123 9.30503 0.770392 9.39095C0.733879 9.47635 0.674738 9.63304 0.703838 9.81878C0.737949 10.0365 0.866092 10.2282 1.05423 10.343C1.21474 10.4409 1.38213 10.4461 1.475 10.4451C1.56844 10.444 1.68015 10.4324 1.78723 10.4213L4.36472 10.1549C4.406 10.1506 4.42758 10.1484 4.44339 10.1472L4.44542 10.147L4.45161 10.1492C4.47103 10.1562 4.49738 10.1663 4.54285 10.1838C5.07332 10.3882 5.64921 10.5 6.25002 10.5C8.87338 10.5 11 8.37335 11 5.75C11 3.12665 8.87338 1 6.25002 1ZM4.48481 4.29111C5.04844 3.81548 5.7986 3.9552 6.24846 4.47463C6.69831 3.9552 7.43879 3.82048 8.01211 4.29111C8.58544 4.76175 8.6551 5.562 8.21247 6.12453C7.93825 6.47305 7.24997 7.10957 6.76594 7.54348C6.58814 7.70286 6.49924 7.78255 6.39255 7.81466C6.30103 7.84221 6.19589 7.84221 6.10436 7.81466C5.99767 7.78255 5.90878 7.70286 5.73098 7.54348C5.24694 7.10957 4.55867 6.47305 4.28444 6.12453C3.84182 5.562 3.92117 4.76675 4.48481 4.29111Z" fill="#667085" />
</svg>
@@ -17,13 +17,13 @@ export const RatingIcon: FC<{ isLike: boolean }> = ({ isLike }) => {
return isLike ? <HandThumbUpIcon className='w-4 h-4' /> : <HandThumbDownIcon className='w-4 h-4' />
}
export const EditIcon: FC<{ className?: string }> = ({ className }) => {
export const EditIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className}>
<path d="M14 11.9998L13.3332 12.7292C12.9796 13.1159 12.5001 13.3332 12.0001 13.3332C11.5001 13.3332 11.0205 13.1159 10.6669 12.7292C10.3128 12.3432 9.83332 12.1265 9.33345 12.1265C8.83359 12.1265 8.35409 12.3432 7.99998 12.7292M2 13.3332H3.11636C3.44248 13.3332 3.60554 13.3332 3.75899 13.2963C3.89504 13.2637 4.0251 13.2098 4.1444 13.1367C4.27895 13.0542 4.39425 12.9389 4.62486 12.7083L13 4.33316C13.5523 3.78087 13.5523 2.88544 13 2.33316C12.4477 1.78087 11.5523 1.78087 11 2.33316L2.62484 10.7083C2.39424 10.9389 2.27894 11.0542 2.19648 11.1888C2.12338 11.3081 2.0695 11.4381 2.03684 11.5742C2 11.7276 2 11.8907 2 12.2168V13.3332Z" stroke="#6B7280" strokeLinecap="round" strokeLinejoin="round" />
</svg>
}
export const EditIconSolid: FC<{ className?: string }> = ({ className }) => {
export const EditIconSolid = ({ className }: SVGProps<SVGElement>) => {
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className}>
<path fillRule="evenodd" clipRule="evenodd" d="M10.8374 8.63108C11.0412 8.81739 11.0554 9.13366 10.8691 9.33747L10.369 9.88449C10.0142 10.2725 9.52293 10.5001 9.00011 10.5001C8.47746 10.5001 7.98634 10.2727 7.63157 9.8849C7.45561 9.69325 7.22747 9.59515 7.00014 9.59515C6.77271 9.59515 6.54446 9.69335 6.36846 9.88517C6.18177 10.0886 5.86548 10.1023 5.66201 9.91556C5.45853 9.72888 5.44493 9.41259 5.63161 9.20911C5.98678 8.82201 6.47777 8.59515 7.00014 8.59515C7.52251 8.59515 8.0135 8.82201 8.36867 9.20911L8.36924 9.20974C8.54486 9.4018 8.77291 9.50012 9.00011 9.50012C9.2273 9.50012 9.45533 9.40182 9.63095 9.20979L10.131 8.66276C10.3173 8.45895 10.6336 8.44476 10.8374 8.63108Z" fill="#6B7280" />
<path fillRule="evenodd" clipRule="evenodd" d="M7.89651 1.39656C8.50599 0.787085 9.49414 0.787084 10.1036 1.39656C10.7131 2.00604 10.7131 2.99419 10.1036 3.60367L3.82225 9.88504C3.81235 9.89494 3.80254 9.90476 3.79281 9.91451C3.64909 10.0585 3.52237 10.1855 3.3696 10.2791C3.23539 10.3613 3.08907 10.4219 2.93602 10.4587C2.7618 10.5005 2.58242 10.5003 2.37897 10.5001C2.3652 10.5001 2.35132 10.5001 2.33732 10.5001H1.50005C1.22391 10.5001 1.00005 10.2763 1.00005 10.0001V9.16286C1.00005 9.14886 1.00004 9.13497 1.00003 9.1212C0.999836 8.91776 0.999669 8.73838 1.0415 8.56416C1.07824 8.4111 1.13885 8.26479 1.22109 8.13058C1.31471 7.97781 1.44166 7.85109 1.58566 7.70736C1.5954 7.69764 1.60523 7.68783 1.61513 7.67793L7.89651 1.39656Z" fill="#6B7280" />

View File

@@ -24,6 +24,7 @@ import type { DataSet } from '@/models/datasets'
export type IChatProps = {
configElem?: React.ReactNode
chatList: IChatItem[]
controlChatUpdateAllConversation?: number
/**
* Whether to display the editing area and rating status
*/
@@ -47,14 +48,17 @@ export type IChatProps = {
isShowSuggestion?: boolean
suggestionList?: string[]
isShowSpeechToText?: boolean
isShowCitation?: boolean
answerIconClassName?: string
isShowConfigElem?: boolean
dataSets?: DataSet[]
isShowCitationHitInfo?: boolean
}
const Chat: FC<IChatProps> = ({
configElem,
chatList,
feedbackDisabled = false,
isHideFeedbackEdit = false,
isHideSendInput = false,
@@ -72,9 +76,11 @@ const Chat: FC<IChatProps> = ({
isShowSuggestion,
suggestionList,
isShowSpeechToText,
isShowCitation,
answerIconClassName,
isShowConfigElem,
dataSets,
isShowCitationHitInfo,
}) => {
const { t } = useTranslation()
const { notify } = useContext(ToastContext)
@@ -160,6 +166,7 @@ const Chat: FC<IChatProps> = ({
if (item.isAnswer) {
const isLast = item.id === chatList[chatList.length - 1].id
const thoughts = item.agent_thoughts?.filter(item => item.thought !== '[DONE]')
const citation = item.citation
const isThinking = !item.content && item.agent_thoughts && item.agent_thoughts?.length > 0 && !item.agent_thoughts.some(item => item.thought === '[DONE]')
return <Answer
key={item.id}
@@ -172,8 +179,11 @@ const Chat: FC<IChatProps> = ({
isResponsing={isResponsing && isLast}
answerIconClassName={answerIconClassName}
thoughts={thoughts}
citation={citation}
isThinking={isThinking}
dataSets={dataSets}
isShowCitation={isShowCitation}
isShowCitationHitInfo={isShowCitationHitInfo}
/>
}
return <Question key={item.id} id={item.id} content={item.content} more={item.more} useCurrentUserAvatar={useCurrentUserAvatar} />

View File

@@ -37,12 +37,13 @@ const Thought: FC<IThoughtProps> = ({
const getThoughtText = (item: ThoughtItem) => {
try {
const input = JSON.parse(item.tool_input)
// dataset
if (item.tool.startsWith('dataset-')) {
const dataSetId = item.tool.replace('dataset-', '')
const datasetName = dataSets?.find(item => item.id === dataSetId)?.name || 'unknown dataset'
return t('explore.universalChat.thought.res.dataset').replace('{datasetName}', `<span class="text-gray-700">${datasetName}</span>`)
}
switch (item.tool) {
case 'dataset':
// eslint-disable-next-line no-case-declarations
const datasetName = dataSets?.find(item => item.id === input.dataset_id)?.name || 'unknown dataset'
return t('explore.universalChat.thought.res.dataset').replace('{datasetName}', `<span class="text-gray-700">${datasetName}</span>`)
case 'web_reader':
return t(`explore.universalChat.thought.res.webReader.${!input.cursor ? 'normal' : 'hasPageInfo'}`).replace('{url}', `<a href="${input.url}" class="text-[#155EEF]">${input.url}</a>`)
case 'google_search':

View File

@@ -23,10 +23,26 @@ export type ThoughtItem = {
tool_input: string
message_id: string
}
export type CitationItem = {
content: string
data_source_type: string
dataset_name: string
dataset_id: string
document_id: string
document_name: string
hit_count: number
index_node_hash: string
segment_id: string
segment_position: number
score: number
word_count: number
}
export type IChatItem = {
id: string
content: string
agent_thoughts?: ThoughtItem[]
citation?: CitationItem[]
/**
* Specific message type
*/
@@ -51,3 +67,8 @@ export type IChatItem = {
useCurrentUserAvatar?: boolean
isOpeningStatement?: boolean
}
export type MessageEnd = {
id: string
retriever_resources?: CitationItem[]
}

View File

@@ -10,6 +10,7 @@ import OperationBtn from '../base/operation-btn'
import VarIcon from '../base/icons/var-icon'
import EditModel from './config-model'
import IconTypeIcon from './input-type-icon'
import type { IInputTypeIconProps } from './input-type-icon'
import s from './style.module.css'
import Tooltip from '@/app/components/base/tooltip'
import type { PromptVariable } from '@/models/debug'
@@ -37,8 +38,8 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
return obj
})()
const updatePromptVariable = (key: string, updateKey: string, newValue: any) => {
const newPromptVariables = promptVariables.map((item, i) => {
const updatePromptVariable = (key: string, updateKey: string, newValue: string | boolean) => {
const newPromptVariables = promptVariables.map((item) => {
if (item.key === key) {
return {
...item,
@@ -179,7 +180,7 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
<tr key={index} className="h-9 leading-9">
<td className="w-[160px] border-b border-gray-100 pl-3">
<div className='flex items-center space-x-1'>
<IconTypeIcon type={type} />
<IconTypeIcon type={type as IInputTypeIconProps['type']} />
{!readonly
? (
<input

View File

@@ -2,7 +2,7 @@
import React from 'react'
import type { FC } from 'react'
type IInputTypeIconProps = {
export type IInputTypeIconProps = {
type: 'string' | 'select'
}

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 175 KiB

Some files were not shown because too many files have changed in this diff Show More