Compare commits

...

46 Commits

Author SHA1 Message Date
jyong
7a50d7ff75 fix error weaviate vector 2023-08-30 20:31:56 +08:00
Joel
e34dcc0406 feat: code support copy (#1057) 2023-08-30 18:08:47 +08:00
Joel
a834ba8759 feat: support rename conversation (#1056) 2023-08-30 17:32:32 +08:00
KVOJJJin
c67f345d0e Fix: disable operations of dataset when embedding unavailable (#1055)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-30 17:27:19 +08:00
yezhwi
8b8e510bfe fix: handle AttributeError for datasets and index (#1052) 2023-08-30 11:14:16 +08:00
crazywoola
3db839a5cb 773 change embed title welcome to use (#1053) 2023-08-30 11:03:25 +08:00
takatost
417c19577a feat: add LocalAI local embedding model support (#1021)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-29 22:22:02 +08:00
Jyong
b5953039de recreate qdrant vector (#1049)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 15:00:36 +08:00
Jyong
a43e80dd9c add qdrant migration (#1046)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 10:37:04 +08:00
WangBooth
ad5f27bc5f fix openpyxl dimensions error (#1041) 2023-08-29 10:36:48 +08:00
Joel
05e0985f29 chore: match new dataset tool format (#1044) 2023-08-29 09:07:45 +08:00
takatost
7b3314c5db fix: dataset desc (#1045) 2023-08-29 09:07:27 +08:00
Jyong
a55ba6e614 Fix/ignore economy dataset (#1043)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 03:37:45 +08:00
bowen
f9bec1edf8 chore: perfect type definition (#1003) 2023-08-28 19:48:53 +08:00
Jyong
16199e968e fix notion import limit check (#1042)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-28 16:49:03 +08:00
takatost
02452421d5 fix: pub generate message text return null (#1037) 2023-08-28 16:43:54 +08:00
zxhlyh
3a5c7c75ad Fix/model selector (#1032) 2023-08-28 10:54:41 +08:00
zxhlyh
a7415ecfd8 Fix/upload document limit (#1033) 2023-08-28 10:53:45 +08:00
KVOJJJin
934def5fcc Fix: eslint (#1030) 2023-08-27 17:06:16 +08:00
takatost
0796791de5 feat: hf inference endpoint stream support (#1028) 2023-08-26 19:48:34 +08:00
takatost
6c148b223d fix: dataset query truncated (#1026) 2023-08-26 17:35:17 +08:00
zxhlyh
4b168f4838 fix: maintenance notice (#1025) 2023-08-26 16:09:55 +08:00
takatost
1c114eaef3 feat: update contributing (#1020) 2023-08-25 21:19:13 +08:00
Jyong
e053215155 fix document estimate parameter (#1019)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 20:10:08 +08:00
zxhlyh
13482b0fc1 feat: maintenance notice (#1016) 2023-08-25 19:38:52 +08:00
Jyong
38fa152cc4 fix update document index technique (#1018)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 18:29:55 +08:00
Uranus
2d9616c29c fix: xinference last token being ignored (#1013) 2023-08-25 18:15:05 +08:00
Jyong
915e26527b update dataset index struct (#1012)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:52:33 +08:00
Jyong
2d604d9330 Fix/filter empty segment (#1004)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:50:29 +08:00
Jyong
e7199826cc embedding model available check (#1009)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 00:25:16 +08:00
crazywoola
70e24b7594 fix: loading and calc rem (#1006) 2023-08-24 23:24:33 +08:00
yezhwi
c1602aafc7 refactor:cache in place & function name (#1001) 2023-08-24 22:54:21 +08:00
crazywoola
a3fec11438 fix: styles (#1005) 2023-08-24 22:37:46 +08:00
Jyong
b1fd1b3ab3 Feat/vector db manage (#997)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:31 +08:00
Jyong
5397799aac document limit (#999)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:13 +08:00
takatost
8e837dde1a feat: bump version to 0.3.18 (#1000) 2023-08-24 18:13:18 +08:00
takatost
9ae91a2ec3 feat: optimize xinference request max token key and stop reason (#998) 2023-08-24 18:11:15 +08:00
Matri
276d3d10a0 fix: apps loading issue (#994) 2023-08-24 17:57:38 +08:00
crazywoola
f13623184a fix style in app share (#995) 2023-08-24 17:57:25 +08:00
takatost
ef61e1487f fix: safetensor arm complie error (#996) 2023-08-24 17:38:10 +08:00
takatost
701e2b334f feat: remove unnecessary prompt of baichuan (#993) 2023-08-24 15:30:59 +08:00
takatost
6ebd6e7890 feat: bump version to 0.3.17 (#992) 2023-08-24 15:12:47 +08:00
takatost
bd3a9b2f8d fix: xinference-chat-stream-response (#991) 2023-08-24 14:39:34 +08:00
takatost
18d3877151 feat: optimize xinference stream (#989) 2023-08-24 13:58:34 +08:00
takatost
53e83d8697 feat: optimize baichuan prompt (#988) 2023-08-24 12:07:10 +08:00
Matri
6377fc75c6 chore: update lintrc config (#986) 2023-08-24 11:46:59 +08:00
259 changed files with 5077 additions and 845 deletions

View File

@@ -53,9 +53,9 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
## Community channels
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help!
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
### i18n (Internationalization) Support
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.

View File

@@ -16,15 +16,15 @@
## 本地开发
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose 堆栈
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
### Fork存储库
您需要 fork [存储](https://github.com/langgenius/dify)。
您需要 fork [Git 仓](https://github.com/langgenius/dify)。
### 克隆存储库
克隆您在 GitHub 上 fork 的存储库:
克隆您在 GitHub 上 fork 的库:
```
git clone git@github.com:<github_username>/dify.git

View File

@@ -52,4 +52,4 @@ git clone git@github.com:<github_username>/dify.git
## コミュニティチャンネル
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/AhzKf7dNgk)に参加してください。私たちがお手伝いします!
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!

View File

@@ -1,4 +1,5 @@
import datetime
import json
import math
import random
import string
@@ -6,10 +7,16 @@ import time
import click
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
@@ -296,6 +303,142 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
click.echo(click.style('Start create qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.SYSTEM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
def update_qdrant_indexes():
click.echo(click.style('Start Update qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Update dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.update_qdrant_dataset(dataset)
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@@ -304,3 +447,5 @@ def register_commands(app):
app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)

View File

@@ -100,7 +100,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.16"
self.CURRENT_VERSION = "0.3.18"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')

View File

@@ -87,13 +87,19 @@ class DatasetListApi(Resource):
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
model_names = [item['model_name'] for item in valid_model_list]
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['embedding_model'] in model_names:
item['embedding_available'] = True
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else:
item['embedding_available'] = False
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
@@ -119,14 +125,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
dataset = DatasetService.create_empty_dataset(
@@ -150,20 +148,39 @@ class DatasetApi(Resource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(
dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return marshal(dataset, dataset_detail_fields), 200
data = marshal(dataset, dataset_detail_fields)
# check embedding setting
provider_service = ProviderService()
# get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data['embedding_available'] = True
else:
data['embedding_available'] = False
else:
data['embedding_available'] = True
return data, 200
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False,
@@ -251,6 +268,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
@@ -272,7 +290,8 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -287,7 +306,8 @@ class DatasetIndexingEstimateApi(Resource):
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "

View File

@@ -3,7 +3,7 @@ import random
from datetime import datetime
from typing import List
from flask import request
from flask import request, current_app
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
@@ -138,6 +138,10 @@ class GetProcessRuleApi(Resource):
req_data = request.args
document_id = req_data.get('document_id')
# get default rules
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
@@ -158,11 +162,9 @@ class GetProcessRuleApi(Resource):
order_by(DatasetProcessRule.created_at.desc()). \
limit(1). \
one_or_none()
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
else:
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
if dataset_process_rule:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
return {
'mode': mode,
@@ -275,7 +277,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
@@ -284,20 +287,6 @@ class DatasetDocumentListApi(Resource):
# validate args
DocumentService.document_create_args_validate(args)
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex:
@@ -335,17 +324,20 @@ class DatasetInitApi(Resource):
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
if args['indexing_technique'] == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
DocumentService.document_create_args_validate(args)
@@ -414,7 +406,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict, None, dataset_id)
data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -483,7 +476,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict, None, dataset_id)
data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -497,7 +491,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict,
None, dataset_id)
None, 'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
@@ -725,6 +719,12 @@ class DocumentDeleteApi(DocumentResource):
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
@@ -787,6 +787,12 @@ class DocumentStatusApi(DocumentResource):
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner
@@ -855,6 +861,14 @@ class DocumentStatusApi(DocumentResource):
if not document.archived:
raise InvalidActionError('Document is not archived.')
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + 1
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
document.archived = False
document.archived_at = None
document.archived_by = None
@@ -872,6 +886,10 @@ class DocumentStatusApi(DocumentResource):
class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
@@ -901,6 +919,9 @@ class DocumentPauseApi(DocumentResource):
class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
@@ -926,6 +947,21 @@ class DocumentRecoverApi(DocumentResource):
return {'result': 'success'}, 204
class DocumentLimitApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""get document limit"""
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
return {
'documents_count': documents_count,
'documents_limit': tenant_document_count
}, 200
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@@ -951,3 +987,4 @@ api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentLimitApi, '/datasets/limit')

View File

@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# get file from request
file = request.files['file']
# check file

View File

@@ -52,7 +52,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
@@ -60,7 +60,13 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": observation}, log=observation)
try:
return super().plan(intermediate_steps, callbacks, **kwargs)
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception

View File

@@ -45,7 +45,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
self.llm.max_tokens = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
@@ -97,6 +97,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
@classmethod

View File

@@ -90,7 +90,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
@@ -102,7 +102,13 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
raise new_exception
try:
return self.output_parser.parse(full_output)
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")

View File

@@ -106,7 +106,13 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
raise new_exception
try:
return self.output_parser.parse(full_output)
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")

View File

@@ -1,5 +1,6 @@
import json
import logging
from json import JSONDecodeError
from typing import Any, Dict, List, Union, Optional
@@ -44,10 +45,15 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str,
**kwargs: Any,
) -> None:
# tool_name = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id')
query = input_dict.get('query')
tool_name: str = serialized.get('name')
dataset_id = tool_name.removeprefix('dataset-')
try:
input_dict = json.loads(input_str.replace("'", "\""))
query = input_dict.get('query')
except JSONDecodeError:
query = input_str
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(

View File

@@ -137,7 +137,8 @@ class ConversationMessageTask:
db.session.flush()
def append_message_text(self, text: str):
self._pub_handler.pub_text(text)
if text is not None:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens

View File

@@ -30,6 +30,8 @@ class ExcelLoader(BaseLoader):
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
if 'A1:A1' == sheet.calculate_dimension():
sheet.reset_dimensions()
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
@@ -38,7 +40,7 @@ class ExcelLoader(BaseLoader):
else:
row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
document = Document(page_content=item, metadata={'source': self._file_path})
data.append(document)

View File

@@ -67,12 +67,13 @@ class DatesetDocumentStore:
if max_position is None:
max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider,
model_name=self._dataset.embedding_model
)
embedding_model = None
if self._dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider,
model_name=self._dataset.embedding_model
)
for doc in docs:
if not isinstance(doc, Document):
@@ -88,7 +89,7 @@ class DatesetDocumentStore:
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(doc.page_content)
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
if not segment_document:
max_position += 1

View File

@@ -1,10 +1,18 @@
import json
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.dataset import Dataset
from models.provider import Provider, ProviderType
class IndexBuilder:
@@ -35,4 +43,13 @@ class IndexBuilder:
)
)
else:
raise ValueError('Unknown indexing technique')
raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)

View File

@@ -15,12 +15,12 @@ from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@@ -143,7 +143,7 @@ class BaseVectorIndex(BaseIndex):
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
@@ -173,3 +173,73 @@ class BaseVectorIndex(BaseIndex):
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
def create_qdrant_dataset(self, dataset: Dataset):
logging.info(f"create_qdrant_dataset {dataset.id}")
try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def update_qdrant_dataset(self, dataset: Dataset):
logging.info(f"update_qdrant_dataset {dataset.id}")
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).first()
if segment:
try:
exist = self.text_exists(segment.index_node_id)
if exist:
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@@ -0,0 +1,114 @@
from typing import Optional, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore, milvus
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
endpoint: str
user: str
password: str
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config MILVUS_ENDPOINT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def get_type(self) -> str:
return 'milvus'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False

File diff suppressed because it is too large Load Diff

View File

@@ -44,15 +44,20 @@ class QdrantVectorIndex(BaseVectorIndex):
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
@@ -62,7 +67,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
content_payload_key='page_content',
**self._client_config.to_qdrant_params()
)
@@ -72,7 +77,9 @@ class QdrantVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
@@ -81,7 +88,7 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
content_payload_key='page_content'
)
def _get_vector_store_class(self) -> type:
@@ -108,8 +115,8 @@ class QdrantVectorIndex(BaseVectorIndex):
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if class_prefix.startswith('Vector_'):
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True

View File

@@ -217,25 +217,29 @@ class IndexingRunner:
db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0
preview_texts = []
total_segments = 0
@@ -263,8 +267,8 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
@@ -286,32 +290,35 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts
}
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
else:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion
tokens = 0
preview_texts = []
@@ -356,8 +363,8 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += embedding_model.get_num_tokens(document.page_content)
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(document.page_content)
if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model(
@@ -379,8 +386,8 @@ class IndexingRunner:
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts
}
@@ -399,7 +406,8 @@ class IndexingRunner:
filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none()
text_docs = FileExtractor.load(file_detail)
if file_detail:
text_docs = FileExtractor.load(file_detail)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
@@ -525,12 +533,13 @@ class IndexingRunner:
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
split_documents.append(document_node)
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == 'qa_model':
@@ -656,12 +665,13 @@ class IndexingRunner:
"""
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
@@ -671,11 +681,11 @@ class IndexingRunner:
# check document is paused
self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents
)
if dataset.indexing_technique == 'high_quality' or embedding_model:
tokens += sum(
embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents
)
# save vector index
if vector_index:

View File

@@ -63,6 +63,9 @@ class ModelProviderFactory:
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
else:
raise NotImplementedError

View File

@@ -0,0 +1,29 @@
from langchain.embeddings import LocalAIEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class LocalAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = LocalAIEmbeddings(
model=name,
openai_api_key="1",
openai_api_base=credentials['server_url'],
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
else:
return ex

View File

@@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
**kwargs
)
except Exception as ex:
@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
if self.streaming and not self.support_streaming:
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return False
def get_prompt(self, mode: str,

View File

@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
streaming = self.streaming
if 'baichuan' in self.name.lower():
streaming = False
client = HuggingFaceEndpointLLM(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks
callbacks=self.callbacks,
streaming=streaming
)
else:
client = HuggingFaceHub(
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False
@property
def support_streaming(self):
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'baichuan' in self.name.lower():
return False
return True

View File

@@ -0,0 +1,131 @@
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class LocalAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
if credentials['completion_type'] == 'chat_completion':
self.model_mode = ModelMode.CHAT
else:
self.model_mode = ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1',
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1'
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to LocalAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to LocalAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("LocalAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True
# def is_model_valid_or_raise(self):

View File

@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference: {str(ex)}")
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -0,0 +1,164 @@
import json
from typing import Type
from langchain.embeddings import LocalAIEmbeddings
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from models.provider import ProviderType
class LocalAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'localai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = LocalAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = LocalAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
try:
if model_type == ModelType.EMBEDDINGS:
model = LocalAIEmbeddings(
model=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url']
)
model.embed_query("ping")
else:
if ('completion_type' not in credentials
or credentials['completion_type'] not in ['completion', 'chat_completion']):
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
if credentials['completion_type'] == 'chat_completion':
model = EnhanceChatOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model([HumanMessage(content='ping')])
else:
model = EnhanceOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model('ping')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -2,7 +2,6 @@ import json
from typing import Type
import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider):
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)

View File

@@ -10,5 +10,6 @@
"replicate",
"huggingface_hub",
"xinference",
"openllm"
"openllm",
"localai"
]

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -1,13 +1,13 @@
{
"human_prefix": "用户",
"assistant_prefix": "助手",
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n\n",
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n\n",
"histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt",
"histories_prompt"
],
"query_prompt": "用户:{{query}}\n助手",
"query_prompt": "用户:{{query}}",
"stops": ["用户:"]
}

View File

@@ -1,5 +1,5 @@
{
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n",
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt"

View File

@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,

View File

@@ -1,7 +1,11 @@
from typing import Dict
from typing import Dict, Any, Optional, List, Iterable, Iterator
from huggingface_hub import InferenceClient
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.embeddings.huggingface_hub import VALID_TASKS
from langchain.llms import HuggingFaceEndpoint
from pydantic import Extra, root_validator
from langchain.llms.utils import enforce_stop_tokens
from pydantic import root_validator
from langchain.utils import get_from_dict_or_env
@@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
huggingfacehub_api_token="my-api-key"
)
"""
client: Any
streaming: bool = False
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
@@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values['client'] = InferenceClient(values['endpoint_url'], token=huggingfacehub_api_token)
values["huggingfacehub_api_token"] = huggingfacehub_api_token
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
# payload samples
params = {**_model_kwargs, **kwargs}
# generation parameter
gen_kwargs = {
**params,
'stop_sequences': stop
}
response = self.client.text_generation(prompt, stream=self.streaming, details=True, **gen_kwargs)
if self.streaming and isinstance(response, Iterable):
combined_text_output = ""
for token in self._stream_response(response, run_manager):
combined_text_output += token
completion = combined_text_output
else:
completion = response.generated_text
if self.task == "text-generation":
text = completion
# Remove prompt if included in generated text.
if text.startswith(prompt):
text = text[len(prompt) :]
elif self.task == "text2text-generation":
text = completion
else:
raise ValueError(
f"Got invalid task {self.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
def _stream_response(
self,
response: Iterable,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> Iterator[str]:
for r in response:
# skip special tokens
if r.token.special:
continue
token = r.token.text
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=None
)
# yield the generated token
yield token

View File

@@ -1,7 +1,10 @@
import os
from typing import Dict, Any, Mapping, Optional, Union, Tuple
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
from langchain import OpenAI
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
from langchain.schema.output import GenerationChunk
from pydantic import root_validator
@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
if 'text' in stream_resp["choices"][0]:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)

View File

@@ -3,8 +3,11 @@ from typing import Optional, List, Any, Union, Generator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import RESTfulChatglmCppChatModelHandle, \
RESTfulChatModelHandle, RESTfulGenerateModelHandle
from xinference.client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
)
class XinferenceLLM(Xinference):
@@ -29,7 +32,9 @@ class XinferenceLLM(Xinference):
model = self.client.get_model(self.model_uid)
if isinstance(model, RESTfulChatModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if stop:
generate_config["stop"] = stop
@@ -37,10 +42,10 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
@@ -48,7 +53,9 @@ class XinferenceLLM(Xinference):
completion = model.chat(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["message"]["content"]
elif isinstance(model, RESTfulGenerateModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if stop:
generate_config["stop"] = stop
@@ -65,10 +72,14 @@ class XinferenceLLM(Xinference):
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
completion = model.generate(
prompt=prompt, generate_config=generate_config
)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config: "ChatglmCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if generate_config and generate_config.get("stream"):
combined_text_output = ""
@@ -89,13 +100,22 @@ class XinferenceLLM(Xinference):
return completion
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
model: Union[
"RESTfulGenerateModelHandle",
"RESTfulChatModelHandle",
"RESTfulChatglmCppChatModelHandle",
],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
generate_config: Optional[
Union[
"LlamaCppGenerateConfig",
"PytorchGenerateConfig",
"ChatglmCppGenerateConfig",
]
] = None,
) -> Generator[str, None, None]:
"""
Args:
@@ -108,12 +128,14 @@ class XinferenceLLM(Xinference):
Yields:
A string token.
"""
if isinstance(model, RESTfulGenerateModelHandle):
streaming_response = model.generate(
if isinstance(
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
):
streaming_response = model.chat(
prompt=prompt, generate_config=generate_config
)
else:
streaming_response = model.chat(
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
@@ -123,7 +145,12 @@ class XinferenceLLM(Xinference):
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
if "text" in choice:
token = choice.get("text", "")
elif "delta" in choice and "content" in choice["delta"]:
token = choice.get("delta").get("content")
else:
continue
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(

View File

@@ -1,4 +1,3 @@
import re
from typing import Type
from flask import current_app
@@ -16,7 +15,6 @@ from models.dataset import Dataset, DocumentSegment
class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
@@ -37,27 +35,22 @@ class DatasetRetrieverTool(BaseTool):
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
description += '\nID of dataset MUST be ' + dataset.id
return cls(
name=f'dataset-{dataset.id}',
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs
)
def _run(self, dataset_id: str, query: str) -> str:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, dataset_id, re.IGNORECASE)
if match:
dataset_id = match.group()
def _run(self, query: str) -> str:
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
Dataset.id == self.dataset_id
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
if dataset.indexing_technique == "economy":
# use keyword table query

View File

@@ -0,0 +1,38 @@
from langchain.vectorstores import Milvus
class MilvusVectorStore(Milvus):
def del_texts(self, where_filter: dict):
if not where_filter:
raise ValueError('where_filter must not be empty')
self._client.batch.delete_objects(
class_name=self._index_name,
where=where_filter,
output='minimal'
)
def del_text(self, uuid: str) -> None:
self._client.data_object.delete(
uuid,
class_name=self._index_name
)
def text_exists(self, uuid: str) -> bool:
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": uuid,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][self._index_name]
if len(entries) == 0:
return False
return True
def delete(self):
self._client.schema.delete_class(self._index_name)

View File

@@ -1,10 +1,11 @@
from typing import cast, Any
from langchain.schema import Document
from langchain.vectorstores import Qdrant
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal
from core.index.vector_index.qdrant import Qdrant
class QdrantVectorStore(Qdrant):
def del_texts(self, filter: Filter):

View File

@@ -1,6 +1,5 @@
from events.dataset_event import dataset_was_deleted
from events.event_handlers.document_index_event import document_index_created
from tasks.clean_dataset_task import clean_dataset_task
import datetime
import logging
import time

View File

@@ -0,0 +1,46 @@
"""update_dataset_model_field_null_available
Revision ID: 4bcffcd64aa4
Revises: 853f9b9cd3b6
Create Date: 2023-08-28 20:58:50.077056
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4bcffcd64aa4'
down_revision = '853f9b9cd3b6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'openai'::character varying"))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'openai'::character varying"))
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
# ### end Alembic commands ###

View File

@@ -36,10 +36,8 @@ class Dataset(db.Model):
updated_by = db.Column(UUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying"))
embedding_model_provider = db.Column(db.String(
255), nullable=False, server_default=db.text("'openai'::character varying"))
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
@property
def dataset_keyword_table(self):

View File

@@ -49,4 +49,5 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.1
xinference==0.2.1
safetensors==0.3.2

View File

@@ -10,6 +10,7 @@ from flask import current_app
from sqlalchemy import func
from core.index.index import IndexBuilder
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from flask_login import current_user
@@ -91,16 +92,18 @@ class DatasetService:
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
embedding_model = None
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
dataset.embedding_model = embedding_model.name if embedding_model else None
db.session.add(dataset)
db.session.commit()
return dataset
@@ -115,17 +118,50 @@ class DatasetService:
else:
return dataset
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: "
f"{ex.description}")
@staticmethod
def update_dataset(dataset_id, data, user):
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_permission(dataset, user)
action = None
if dataset.indexing_technique != data['indexing_technique']:
# if update indexing_technique
if data['indexing_technique'] == 'economy':
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
action = 'remove'
filtered_data['embedding_model'] = None
filtered_data['embedding_model_provider'] = None
elif data['indexing_technique'] == 'high_quality':
deal_dataset_vector_index_task.delay(dataset_id, 'add')
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
action = 'add'
# get embedding model setting
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
filtered_data['updated_by'] = user.id
filtered_data['updated_at'] = datetime.datetime.now()
@@ -133,7 +169,8 @@ class DatasetService:
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
@@ -394,16 +431,26 @@ class DocumentService:
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
if 'original_document_id' not in document_data or not document_data['original_document_id']:
count = 0
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
count = len(upload_file_list)
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + count
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = document_data["data_source"]["type"]
db.session.commit()
if not dataset.indexing_technique:
if 'indexing_technique' not in document_data \
@@ -411,6 +458,13 @@ class DocumentService:
raise ValueError("Indexing technique is required")
dataset.indexing_technique = document_data["indexing_technique"]
if document_data["indexing_technique"] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -455,12 +509,12 @@ class DocumentService:
data_source_info = {
"upload_file_id": file_id,
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -501,12 +555,12 @@ class DocumentService:
"notion_page_icon": page['page_icon'],
"type": page['type']
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -525,10 +579,10 @@ class DocumentService:
return documents, batch
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -557,6 +611,7 @@ class DocumentService:
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available':
raise ValueError("Document is not available")
@@ -649,15 +704,26 @@ class DocumentService:
@staticmethod
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
count = 0
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
count = len(upload_file_list)
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + count
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@@ -665,8 +731,8 @@ class DocumentService:
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name,
embedding_model_provider=embedding_model.model_provider.provider_name
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
)
db.session.add(dataset)
@@ -874,21 +940,25 @@ class SegmentService:
if document.doc_form == 'qa_model':
if 'answer' not in args or not args['answer']:
raise ValueError("Answer is required")
if not args['answer'].strip():
raise ValueError("Answer is empty")
if 'content' not in args or not args['content'] or not args['content'].strip():
raise ValueError("Content is empty")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
content = args['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
@@ -950,15 +1020,16 @@ class SegmentService:
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
@@ -990,10 +1061,11 @@ class SegmentService:
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is deleting.")
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
# enabled segment need to delete index
if segment.enabled:
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
db.session.commit()

View File

@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
raise ValueError('Document is not available.')
document_segments = []
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == dataset_document.id
).scalar()

View File

@@ -3,8 +3,10 @@ import time
import click
from celery import shared_task
from flask import current_app
from core.index.index import IndexBuilder
from core.index.vector_index.vector_index import VectorIndex
from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin, Document
@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try:
vector_index.delete()
except Exception:

View File

@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found')
if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
index.delete()
elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter(
@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
documents = []
for dataset_document in dataset_documents:
# delete from vector index
@@ -65,7 +65,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index.add_texts(documents)
index.create(documents)
end_at = time.perf_counter()
logging.info(

View File

@@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=
# OpenLLM Credentials
OPENLLM_SERVER_URL=
OPENLLM_SERVER_URL=
# LocalAI Credentials
LOCALAI_SERVER_URL=

View File

@@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
server_url = os.environ['LOCALAI_SERVER_URL']
model_provider = LocalAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url,
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return LocalAIEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)

View File

@@ -0,0 +1,68 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.localai_provider import LocalAIProvider
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider(server_url):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({}),
is_valid=True,
)
def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
server_url = os.environ['LOCALAI_SERVER_URL']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
return LocalAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View File

@@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i
def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc")
mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
@@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
)
def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(

View File

@@ -0,0 +1,116 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'localai'
MODEL_PROVIDER_CLASS = LocalAIProvider
VALIDATE_CREDENTIAL = {
'server_url': 'http://127.0.0.1:8080/'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='username/test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
def test_is_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if server_url is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials={}
)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt, mocker):
server_url = 'http://127.0.0.1:8080/'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', server_url)
assert result['server_url'] == f'encrypted_{server_url}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS
)
assert result['server_url'] == 'http://127.0.0.1:8080/'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
obfuscated=True
)
middle_token = result['server_url'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.3.16
image: langgenius/dify-api:0.3.18
restart: always
environment:
# Startup mode, 'api' starts the API server.
@@ -124,7 +124,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.3.16
image: langgenius/dify-api:0.3.18
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@@ -176,7 +176,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.3.16
image: langgenius/dify-web:0.3.18
restart: always
environment:
EDITION: SELF_HOSTED

View File

@@ -1,7 +1,7 @@
{
"extends": [
"@antfu",
"plugin:react-hooks/recommended"
"next",
"@antfu"
],
"rules": {
"@typescript-eslint/consistent-type-definitions": [

View File

@@ -2,7 +2,7 @@ import React from 'react'
import ChartView from './chartView'
import CardView from './cardView'
import { getLocaleOnServer } from '@/i18n/server'
import { useTranslation } from '@/i18n/i18next-serverside-config'
import { useTranslation as translate } from '@/i18n/i18next-serverside-config'
import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel'
export type IDevelopProps = {
@@ -13,7 +13,11 @@ const Overview = async ({
params: { appId },
}: IDevelopProps) => {
const locale = getLocaleOnServer()
const { t } = await useTranslation(locale, 'app-overview')
/*
rename useTranslation to avoid lint error
please check: https://github.com/i18next/next-13-app-dir-i18next-example/issues/24
*/
const { t } = await translate(locale, 'app-overview')
return (
<div className="h-full px-16 py-6 overflow-scroll">
<ApikeyInfoPanel />

View File

@@ -9,6 +9,7 @@ import style from '../list.module.css'
import AppModeLabel from './AppModeLabel'
import s from './style.module.css'
import SettingsModal from '@/app/components/app/overview/settings'
import type { ConfigParams } from '@/app/components/app/overview/settings'
import type { App } from '@/types/app'
import Confirm from '@/app/components/base/confirm'
import { ToastContext } from '@/app/components/base/toast'
@@ -73,7 +74,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
}
const onSaveSiteConfig = useCallback(
async (params: any) => {
async (params: ConfigParams) => {
const [err] = await asyncRunSafe<App>(
updateAppSiteConfig({
url: `/apps/${app.id}/site`,
@@ -100,12 +101,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
)
const Operations = (props: any) => {
const onClickSettings = async (e: any) => {
const onClickSettings = async (e: React.MouseEvent<HTMLButtonElement>) => {
props?.onClose()
e.preventDefault()
await getAppDetail()
}
const onClickDelete = async (e: any) => {
const onClickDelete = async (e: React.MouseEvent<HTMLDivElement>) => {
props?.onClose()
e.preventDefault()
setShowConfirmDelete(true)

View File

@@ -1,15 +1,14 @@
'use client'
import { useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import useSWRInfinite from 'swr/infinite'
import { debounce } from 'lodash-es'
import { useTranslation } from 'react-i18next'
import AppCard from './AppCard'
import NewAppCard from './NewAppCard'
import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps'
import { useAppContext, useSelector } from '@/context/app-context'
import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import Confirm from '@/app/components/base/confirm/common'
@@ -24,15 +23,18 @@ const Apps = () => {
const { t } = useTranslation()
const { isCurrentWorkspaceManager } = useAppContext()
const { data, isLoading, setSize, mutate } = useSWRInfinite(getKey, fetchAppList, { revalidateFirstPage: false })
const loadingStateRef = useRef(false)
const pageContainerRef = useSelector(state => state.pageContainerRef)
const anchorRef = useRef<HTMLAnchorElement>(null)
const anchorRef = useRef<HTMLDivElement>(null)
const searchParams = useSearchParams()
const router = useRouter()
const payProviderName = searchParams.get('provider_name')
const payStatus = searchParams.get('payment_result')
const [showPayStatusModal, setShowPayStatusModal] = useState(false)
const handleCancelShowPayStatusModal = useCallback(() => {
setShowPayStatusModal(false)
router.replace('/', { forceOptimisticNavigation: false })
}, [router])
useEffect(() => {
document.title = `${t('app.title')} - Dify`
if (localStorage.getItem(NEED_REFRESH_APP_LIST_KEY) === '1') {
@@ -41,35 +43,24 @@ const Apps = () => {
}
if (payProviderName === ProviderEnum.anthropic && (payStatus === 'succeeded' || payStatus === 'cancelled'))
setShowPayStatusModal(true)
}, [])
}, [mutate, payProviderName, payStatus, t])
useEffect(() => {
loadingStateRef.current = isLoading
}, [isLoading])
useEffect(() => {
const onScroll = debounce(() => {
if (!loadingStateRef.current) {
const { scrollTop, clientHeight } = pageContainerRef.current!
const anchorOffset = anchorRef.current!.offsetTop
if (anchorOffset - scrollTop - clientHeight < 100)
let observer: IntersectionObserver | undefined
if (anchorRef.current) {
observer = new IntersectionObserver((entries) => {
if (entries[0].isIntersecting)
setSize(size => size + 1)
}
}, 50)
pageContainerRef.current?.addEventListener('scroll', onScroll)
return () => pageContainerRef.current?.removeEventListener('scroll', onScroll)
}, [])
const handleCancelShowPayStatusModal = () => {
setShowPayStatusModal(false)
router.replace('/', { forceOptimisticNavigation: false })
}
}, { rootMargin: '100px' })
observer.observe(anchorRef.current)
}
return () => observer?.disconnect()
}, [isLoading, setSize, anchorRef, mutate])
return (
<nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
<><nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
{ isCurrentWorkspaceManager
&& <NewAppCard ref={anchorRef} onSuccess={mutate} />}
&& <NewAppCard onSuccess={mutate} />}
{data?.map(({ data: apps }) => apps.map(app => (
<AppCard key={app.id} app={app} onRefresh={mutate} />
)))}
@@ -95,6 +86,8 @@ const Apps = () => {
)
}
</nav>
<div ref={anchorRef} className='h-0'> </div>
</>
)
}

View File

@@ -1,5 +1,5 @@
'use client'
import type { FC } from 'react'
import type { FC, SVGProps } from 'react'
import React, { useEffect } from 'react'
import { usePathname } from 'next/navigation'
import useSWR from 'swr'
@@ -57,7 +57,7 @@ const LikedItem: FC<{ type?: 'plugin' | 'app'; appStatus?: boolean; detail: Rela
)
}
const TargetIcon: FC<{ className?: string }> = ({ className }) => {
const TargetIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clip-path="url(#clip0_4610_6951)">
<path d="M10.6666 5.33325V3.33325L12.6666 1.33325L13.3332 2.66659L14.6666 3.33325L12.6666 5.33325H10.6666ZM10.6666 5.33325L7.9999 7.99988M14.6666 7.99992C14.6666 11.6818 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6818 1.33325 7.99992C1.33325 4.31802 4.31802 1.33325 7.99992 1.33325M11.3333 7.99992C11.3333 9.84087 9.84087 11.3333 7.99992 11.3333C6.15897 11.3333 4.66659 9.84087 4.66659 7.99992C4.66659 6.15897 6.15897 4.66659 7.99992 4.66659" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
@@ -70,7 +70,7 @@ const TargetIcon: FC<{ className?: string }> = ({ className }) => {
</svg>
}
const TargetSolidIcon: FC<{ className?: string }> = ({ className }) => {
const TargetSolidIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path fillRule="evenodd" clipRule="evenodd" d="M12.7733 0.67512C12.9848 0.709447 13.1669 0.843364 13.2627 1.03504L13.83 2.16961L14.9646 2.73689C15.1563 2.83273 15.2902 3.01486 15.3245 3.22639C15.3588 3.43792 15.2894 3.65305 15.1379 3.80458L13.1379 5.80458C13.0128 5.92961 12.8433 5.99985 12.6665 5.99985H10.9426L8.47124 8.47124C8.21089 8.73159 7.78878 8.73159 7.52843 8.47124C7.26808 8.21089 7.26808 7.78878 7.52843 7.52843L9.9998 5.05707V3.33318C9.9998 3.15637 10.07 2.9868 10.1951 2.86177L12.1951 0.861774C12.3466 0.710244 12.5617 0.640794 12.7733 0.67512Z" fill="#155EEF" />
<path d="M1.99984 7.99984C1.99984 4.68613 4.68613 1.99984 7.99984 1.99984C8.36803 1.99984 8.6665 1.70136 8.6665 1.33317C8.6665 0.964981 8.36803 0.666504 7.99984 0.666504C3.94975 0.666504 0.666504 3.94975 0.666504 7.99984C0.666504 12.0499 3.94975 15.3332 7.99984 15.3332C12.0499 15.3332 15.3332 12.0499 15.3332 7.99984C15.3332 7.63165 15.0347 7.33317 14.6665 7.33317C14.2983 7.33317 13.9998 7.63165 13.9998 7.99984C13.9998 11.3135 11.3135 13.9998 7.99984 13.9998C4.68613 13.9998 1.99984 11.3135 1.99984 7.99984Z" fill="#155EEF" />
@@ -78,7 +78,7 @@ const TargetSolidIcon: FC<{ className?: string }> = ({ className }) => {
</svg>
}
const BookOpenIcon: FC<{ className?: string }> = ({ className }) => {
const BookOpenIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path opacity="0.12" d="M1 3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7V10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1Z" fill="#155EEF" />
<path d="M6 10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7M6 10.5V4.7M6 10.5L6.05003 10.425C6.39735 9.90398 6.57101 9.64349 6.80045 9.45491C7.00357 9.28796 7.23762 9.1627 7.4892 9.0863C7.77337 9 8.08645 9 8.71259 9H9.4C9.96005 9 10.2401 9 10.454 8.89101C10.6422 8.79513 10.7951 8.64215 10.891 8.45399C11 8.24008 11 7.96005 11 7.4V3.1C11 2.53995 11 2.25992 10.891 2.04601C10.7951 1.85785 10.6422 1.70487 10.454 1.60899C10.2401 1.5 9.96005 1.5 9.4 1.5H9.2C8.07989 1.5 7.51984 1.5 7.09202 1.71799C6.71569 1.90973 6.40973 2.21569 6.21799 2.59202C6 3.01984 6 3.5799 6 4.7" stroke="#155EEF" strokeLinecap="round" strokeLinejoin="round" />

View File

@@ -1,4 +1,4 @@
import React from "react";
import React from 'react'
import type { FC } from 'react'
import GA, { GaType } from '@/app/components/base/ga'
@@ -6,8 +6,8 @@ const Layout: FC<{
children: React.ReactNode
}> = ({ children }) => {
return (
<div className="overflow-x-auto">
<div className="w-screen h-screen min-w-[300px]">
<div className=''>
<div className="min-w-[300px]">
<GA gaType={GaType.webapp} />
{children}
</div>
@@ -15,4 +15,4 @@ const Layout: FC<{
)
}
export default Layout
export default Layout

View File

@@ -1,8 +1,9 @@
import React from 'react'
import type { FC } from 'react'
import NavLink from './navLink'
import AppBasic from './basic'
import type { NavIcon } from './navLink'
export type IAppDetailNavProps = {
iconType?: 'app' | 'dataset' | 'notion'
title: string
@@ -12,13 +13,13 @@ export type IAppDetailNavProps = {
navigation: Array<{
name: string
href: string
icon: any
selectedIcon: any
icon: NavIcon
selectedIcon: NavIcon
}>
extraInfo?: React.ReactNode
}
const AppDetailNav: FC<IAppDetailNavProps> = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }) => {
const AppDetailNav = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }: IAppDetailNavProps) => {
return (
<div className="flex flex-col w-56 overflow-y-auto bg-white border-r border-gray-200 shrink-0">
<div className="flex flex-shrink-0 p-4">

View File

@@ -1,17 +1,30 @@
'use client'
import { useSelectedLayoutSegment } from 'next/navigation'
import classNames from 'classnames'
import Link from 'next/link'
export type NavIcon = React.ComponentType<
React.PropsWithoutRef<React.ComponentProps<'svg'>> & {
title?: string | undefined
titleId?: string | undefined
}
>
export type NavLinkProps = {
name: string
href: string
iconMap: {
selected: NavIcon
normal: NavIcon
}
}
export default function NavLink({
name,
href,
iconMap,
}: {
name: string
href: string
iconMap: { selected: any; normal: any }
}) {
}: NavLinkProps) {
const segment = useSelectedLayoutSegment()
const isActive = href.toLowerCase().split('/')?.pop() === segment?.toLowerCase()
const NavIcon = isActive ? iconMap.selected : iconMap.normal

View File

@@ -8,32 +8,36 @@ import Tooltip from '@/app/components/base/tooltip'
type ICopyBtnProps = {
value: string
className?: string
isPlain?: boolean
}
const CopyBtn = ({
value,
className,
isPlain,
}: ICopyBtnProps) => {
const [isCopied, setIsCopied] = React.useState(false)
return (
<div className={`${className}`}>
<Tooltip
selector="copy-btn-tooltip"
selector={`copy-btn-tooltip-${value}`}
content={(isCopied ? t('appApi.copied') : t('appApi.copy')) as string}
className='z-10'
>
<div
className={'box-border p-0.5 flex items-center justify-center rounded-md bg-white cursor-pointer'}
style={{
boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)',
}}
style={!isPlain
? {
boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)',
}
: {}}
onClick={() => {
copy(value)
setIsCopied(true)
}}
>
<div className={`w-6 h-6 hover:bg-gray-50 ${s.copyIcon} ${isCopied ? s.copied : ''}`}></div>
<div className={`w-6 h-6 rounded-md hover:bg-gray-50 ${s.copyIcon} ${isCopied ? s.copied : ''}`}></div>
</div>
</Tooltip>
</div>

View File

@@ -1,4 +1,4 @@
import type { FC } from 'react'
import type { FC, SVGProps } from 'react'
import { HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline'
export const stopIcon = (
@@ -7,7 +7,7 @@ export const stopIcon = (
</svg>
)
export const OpeningStatementIcon: FC<{ className?: string }> = ({ className }) => (
export const OpeningStatementIcon = ({ className }: SVGProps<SVGElement>) => (
<svg className={className} width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fillRule="evenodd" clipRule="evenodd" d="M6.25002 1C3.62667 1 1.50002 3.12665 1.50002 5.75C1.50002 6.28 1.58702 6.79071 1.7479 7.26801C1.7762 7.35196 1.79285 7.40164 1.80368 7.43828L1.80722 7.45061L1.80535 7.45452C1.79249 7.48102 1.77339 7.51661 1.73766 7.58274L0.911727 9.11152C0.860537 9.20622 0.807123 9.30503 0.770392 9.39095C0.733879 9.47635 0.674738 9.63304 0.703838 9.81878C0.737949 10.0365 0.866092 10.2282 1.05423 10.343C1.21474 10.4409 1.38213 10.4461 1.475 10.4451C1.56844 10.444 1.68015 10.4324 1.78723 10.4213L4.36472 10.1549C4.406 10.1506 4.42758 10.1484 4.44339 10.1472L4.44542 10.147L4.45161 10.1492C4.47103 10.1562 4.49738 10.1663 4.54285 10.1838C5.07332 10.3882 5.64921 10.5 6.25002 10.5C8.87338 10.5 11 8.37335 11 5.75C11 3.12665 8.87338 1 6.25002 1ZM4.48481 4.29111C5.04844 3.81548 5.7986 3.9552 6.24846 4.47463C6.69831 3.9552 7.43879 3.82048 8.01211 4.29111C8.58544 4.76175 8.6551 5.562 8.21247 6.12453C7.93825 6.47305 7.24997 7.10957 6.76594 7.54348C6.58814 7.70286 6.49924 7.78255 6.39255 7.81466C6.30103 7.84221 6.19589 7.84221 6.10436 7.81466C5.99767 7.78255 5.90878 7.70286 5.73098 7.54348C5.24694 7.10957 4.55867 6.47305 4.28444 6.12453C3.84182 5.562 3.92117 4.76675 4.48481 4.29111Z" fill="#667085" />
</svg>
@@ -17,13 +17,13 @@ export const RatingIcon: FC<{ isLike: boolean }> = ({ isLike }) => {
return isLike ? <HandThumbUpIcon className='w-4 h-4' /> : <HandThumbDownIcon className='w-4 h-4' />
}
export const EditIcon: FC<{ className?: string }> = ({ className }) => {
export const EditIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className}>
<path d="M14 11.9998L13.3332 12.7292C12.9796 13.1159 12.5001 13.3332 12.0001 13.3332C11.5001 13.3332 11.0205 13.1159 10.6669 12.7292C10.3128 12.3432 9.83332 12.1265 9.33345 12.1265C8.83359 12.1265 8.35409 12.3432 7.99998 12.7292M2 13.3332H3.11636C3.44248 13.3332 3.60554 13.3332 3.75899 13.2963C3.89504 13.2637 4.0251 13.2098 4.1444 13.1367C4.27895 13.0542 4.39425 12.9389 4.62486 12.7083L13 4.33316C13.5523 3.78087 13.5523 2.88544 13 2.33316C12.4477 1.78087 11.5523 1.78087 11 2.33316L2.62484 10.7083C2.39424 10.9389 2.27894 11.0542 2.19648 11.1888C2.12338 11.3081 2.0695 11.4381 2.03684 11.5742C2 11.7276 2 11.8907 2 12.2168V13.3332Z" stroke="#6B7280" strokeLinecap="round" strokeLinejoin="round" />
</svg>
}
export const EditIconSolid: FC<{ className?: string }> = ({ className }) => {
export const EditIconSolid = ({ className }: SVGProps<SVGElement>) => {
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className}>
<path fillRule="evenodd" clipRule="evenodd" d="M10.8374 8.63108C11.0412 8.81739 11.0554 9.13366 10.8691 9.33747L10.369 9.88449C10.0142 10.2725 9.52293 10.5001 9.00011 10.5001C8.47746 10.5001 7.98634 10.2727 7.63157 9.8849C7.45561 9.69325 7.22747 9.59515 7.00014 9.59515C6.77271 9.59515 6.54446 9.69335 6.36846 9.88517C6.18177 10.0886 5.86548 10.1023 5.66201 9.91556C5.45853 9.72888 5.44493 9.41259 5.63161 9.20911C5.98678 8.82201 6.47777 8.59515 7.00014 8.59515C7.52251 8.59515 8.0135 8.82201 8.36867 9.20911L8.36924 9.20974C8.54486 9.4018 8.77291 9.50012 9.00011 9.50012C9.2273 9.50012 9.45533 9.40182 9.63095 9.20979L10.131 8.66276C10.3173 8.45895 10.6336 8.44476 10.8374 8.63108Z" fill="#6B7280" />
<path fillRule="evenodd" clipRule="evenodd" d="M7.89651 1.39656C8.50599 0.787085 9.49414 0.787084 10.1036 1.39656C10.7131 2.00604 10.7131 2.99419 10.1036 3.60367L3.82225 9.88504C3.81235 9.89494 3.80254 9.90476 3.79281 9.91451C3.64909 10.0585 3.52237 10.1855 3.3696 10.2791C3.23539 10.3613 3.08907 10.4219 2.93602 10.4587C2.7618 10.5005 2.58242 10.5003 2.37897 10.5001C2.3652 10.5001 2.35132 10.5001 2.33732 10.5001H1.50005C1.22391 10.5001 1.00005 10.2763 1.00005 10.0001V9.16286C1.00005 9.14886 1.00004 9.13497 1.00003 9.1212C0.999836 8.91776 0.999669 8.73838 1.0415 8.56416C1.07824 8.4111 1.13885 8.26479 1.22109 8.13058C1.31471 7.97781 1.44166 7.85109 1.58566 7.70736C1.5954 7.69764 1.60523 7.68783 1.61513 7.67793L7.89651 1.39656Z" fill="#6B7280" />

View File

@@ -24,6 +24,7 @@ import type { DataSet } from '@/models/datasets'
export type IChatProps = {
configElem?: React.ReactNode
chatList: IChatItem[]
controlChatUpdateAllConversation?: number
/**
* Whether to display the editing area and rating status
*/
@@ -55,6 +56,7 @@ export type IChatProps = {
const Chat: FC<IChatProps> = ({
configElem,
chatList,
controlChatUpdateAllConversation,
feedbackDisabled = false,
isHideFeedbackEdit = false,
isHideSendInput = false,

View File

@@ -37,12 +37,13 @@ const Thought: FC<IThoughtProps> = ({
const getThoughtText = (item: ThoughtItem) => {
try {
const input = JSON.parse(item.tool_input)
// dataset
if (item.tool.startsWith('dataset-')) {
const dataSetId = item.tool.replace('dataset-', '')
const datasetName = dataSets?.find(item => item.id === dataSetId)?.name || 'unknown dataset'
return t('explore.universalChat.thought.res.dataset').replace('{datasetName}', `<span class="text-gray-700">${datasetName}</span>`)
}
switch (item.tool) {
case 'dataset':
// eslint-disable-next-line no-case-declarations
const datasetName = dataSets?.find(item => item.id === input.dataset_id)?.name || 'unknown dataset'
return t('explore.universalChat.thought.res.dataset').replace('{datasetName}', `<span class="text-gray-700">${datasetName}</span>`)
case 'web_reader':
return t(`explore.universalChat.thought.res.webReader.${!input.cursor ? 'normal' : 'hasPageInfo'}`).replace('{url}', `<a href="${input.url}" class="text-[#155EEF]">${input.url}</a>`)
case 'google_search':

View File

@@ -10,6 +10,7 @@ import OperationBtn from '../base/operation-btn'
import VarIcon from '../base/icons/var-icon'
import EditModel from './config-model'
import IconTypeIcon from './input-type-icon'
import type { IInputTypeIconProps } from './input-type-icon'
import s from './style.module.css'
import Tooltip from '@/app/components/base/tooltip'
import type { PromptVariable } from '@/models/debug'
@@ -37,8 +38,8 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
return obj
})()
const updatePromptVariable = (key: string, updateKey: string, newValue: any) => {
const newPromptVariables = promptVariables.map((item, i) => {
const updatePromptVariable = (key: string, updateKey: string, newValue: string | boolean) => {
const newPromptVariables = promptVariables.map((item) => {
if (item.key === key) {
return {
...item,
@@ -179,7 +180,7 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
<tr key={index} className="h-9 leading-9">
<td className="w-[160px] border-b border-gray-100 pl-3">
<div className='flex items-center space-x-1'>
<IconTypeIcon type={type} />
<IconTypeIcon type={type as IInputTypeIconProps['type']} />
{!readonly
? (
<input

View File

@@ -2,7 +2,7 @@
import React from 'react'
import type { FC } from 'react'
type IInputTypeIconProps = {
export type IInputTypeIconProps = {
type: 'string' | 'select'
}

View File

@@ -43,7 +43,7 @@ const CardItem: FC<ICardItemProps> = ({
selector={`unavailable-tag-${config.id}`}
htmlContent={t('dataset.unavailableTip')}
>
<span className='shrink-0 px-1 border boder-gray-200 rounded-md text-gray-500 text-xs font-normal leading-[18px]'>{t('dataset.unavailable')}</span>
<span className='shrink-0 inline-flex whitespace-nowrap px-1 border boder-gray-200 rounded-md text-gray-500 text-xs font-normal leading-[18px]'>{t('dataset.unavailable')}</span>
</Tooltip>
)}
</div>

View File

@@ -1,5 +1,5 @@
'use client'
import type { FC } from 'react'
import type { FC, SVGProps } from 'react'
import React, { useState } from 'react'
import useSWR from 'swr'
import { usePathname } from 'next/navigation'
@@ -29,7 +29,7 @@ export type QueryParam = {
// Custom page count is not currently supported.
const limit = 10
const ThreeDotsIcon: FC<{ className?: string }> = ({ className }) => {
const ThreeDotsIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M5 6.5V5M8.93934 7.56066L10 6.5M10.0103 11.5H11.5103" stroke="#374151" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" />
</svg>

File diff suppressed because one or more lines are too long

View File

@@ -1,5 +1,5 @@
'use client'
import type { FC } from 'react'
import type { HTMLProps } from 'react'
import React, { useMemo, useState } from 'react'
import {
Cog8ToothIcon,
@@ -37,7 +37,7 @@ export type IAppCardProps = {
onGenerateCode?: () => Promise<void>
}
const EmbedIcon: FC<{ className?: string }> = ({ className = '' }) => {
const EmbedIcon = ({ className = '' }: HTMLProps<HTMLDivElement>) => {
return <div className={`${style.codeBrowserIcon} ${className}`}></div>
}

View File

@@ -3,6 +3,7 @@
import type { ChangeEvent, FC } from 'react'
import React, { useState } from 'react'
import data from '@emoji-mart/data'
import type { Emoji, EmojiMartData } from '@emoji-mart/data'
import { SearchIndex, init } from 'emoji-mart'
import cn from 'classnames'
import {
@@ -30,9 +31,9 @@ declare global {
init({ data })
async function search(value: string) {
const emojis = await SearchIndex.search(value) || []
const emojis: Emoji[] = await SearchIndex.search(value) || []
const results = emojis.map((emoji: any) => {
const results = emojis.map((emoji) => {
return emoji.skins[0].native
})
return results
@@ -59,6 +60,7 @@ const backgroundColors = [
'#ECE9FE',
'#FFE4E8',
]
type IEmojiPickerProps = {
isModal?: boolean
onSelect?: (emoji: string, background: string) => void
@@ -69,14 +71,13 @@ const EmojiPicker: FC<IEmojiPickerProps> = ({
isModal = true,
onSelect,
onClose,
}) => {
const { t } = useTranslation()
const { categories } = data as any
const { categories } = data as EmojiMartData
const [selectedEmoji, setSelectedEmoji] = useState('')
const [selectedBackground, setSelectedBackground] = useState(backgroundColors[0])
const [searchedEmojis, setSearchedEmojis] = useState([])
const [searchedEmojis, setSearchedEmojis] = useState<string[]>([])
const [isSearching, setIsSearching] = useState(false)
return isModal ? <Modal
@@ -133,11 +134,11 @@ const EmojiPicker: FC<IEmojiPickerProps> = ({
</div>
</>}
{categories.map((category: any, index: number) => {
{categories.map((category, index: number) => {
return <div key={`category-${index}`} className='flex flex-col'>
<p className='font-medium uppercase text-xs text-[#101828] mb-1'>{category.id}</p>
<div className='w-full h-full grid grid-cols-8 gap-1'>
{category.emojis.map((emoji: string, index: number) => {
{category.emojis.map((emoji, index: number) => {
return <div
key={`emoji-${index}`}
className='inline-flex w-10 h-10 rounded-lg items-center justify-center'

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 76 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 73 KiB

View File

@@ -0,0 +1,3 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8.3767 15.6163L2.71985 21.2732M11.6944 6.64181L10.1335 8.2027C10.0062 8.33003 9.94252 8.39369 9.86999 8.44427C9.80561 8.48917 9.73616 8.52634 9.66309 8.555C9.58077 8.58729 9.49249 8.60495 9.31592 8.64026L5.65145 9.37315C4.69915 9.56361 4.223 9.65884 4.00024 9.9099C3.80617 10.1286 3.71755 10.4213 3.75771 10.7109C3.8038 11.0434 4.14715 11.3867 4.83387 12.0735L11.9196 19.1592C12.6063 19.8459 12.9497 20.1893 13.2821 20.2354C13.5718 20.2755 13.8645 20.1869 14.0832 19.9928C14.3342 19.7701 14.4294 19.2939 14.6199 18.3416L15.3528 14.6771C15.3881 14.5006 15.4058 14.4123 15.4381 14.33C15.4667 14.2569 15.5039 14.1875 15.5488 14.1231C15.5994 14.0505 15.663 13.9869 15.7904 13.8596L17.3512 12.2987C17.4326 12.2173 17.4734 12.1766 17.5181 12.141C17.5578 12.1095 17.5999 12.081 17.644 12.0558C17.6936 12.0274 17.7465 12.0048 17.8523 11.9594L20.3467 10.8904C21.0744 10.5785 21.4383 10.4226 21.6035 10.1706C21.7481 9.95025 21.7998 9.68175 21.7474 9.42348C21.6875 9.12813 21.4076 8.84822 20.8478 8.28839L15.7047 3.14526C15.1448 2.58543 14.8649 2.30552 14.5696 2.24565C14.3113 2.19329 14.0428 2.245 13.8225 2.38953C13.5705 2.55481 13.4145 2.91866 13.1027 3.64636L12.0337 6.14071C11.9883 6.24653 11.9656 6.29944 11.9373 6.34905C11.9121 6.39313 11.8836 6.43522 11.852 6.47496C11.8165 6.51971 11.7758 6.56041 11.6944 6.64181Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -64,6 +64,8 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = '<%= svgName %>'
export default Icon
`.trim())

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Dify'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Github'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'MessageChatSquare'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Csv'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Md'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Anthropic'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'AnthropicText'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'AzureOpenaiService'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'AzureOpenaiServiceText'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Azureai'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'AzureaiText'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Chatglm'
export default Icon

View File

@@ -11,4 +11,6 @@ const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseP
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'ChatglmText'
export default Icon

Some files were not shown because too many files have changed in this diff Show More