mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 19:42:42 +08:00
Compare commits
68 Commits
fix/filter
...
feat/chat-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b17fe2f52 | ||
|
|
078614f80c | ||
|
|
fd492f8fec | ||
|
|
8ac9d98fd8 | ||
|
|
12449524f1 | ||
|
|
5fe2c41259 | ||
|
|
0708bd60ee | ||
|
|
23a6c85b80 | ||
|
|
4a28599fbd | ||
|
|
7c66d3c793 | ||
|
|
cc9edfffd8 | ||
|
|
6fa2454c9a | ||
|
|
487e699021 | ||
|
|
a7cdb745c1 | ||
|
|
73c86ee6a0 | ||
|
|
48eb590065 | ||
|
|
33562a9d8d | ||
|
|
c9194ba382 | ||
|
|
a199fa6388 | ||
|
|
4c8608dc61 | ||
|
|
a6b0f788e7 | ||
|
|
df6604a734 | ||
|
|
1ca86cf9ce | ||
|
|
78e26f8b75 | ||
|
|
2191312bb9 | ||
|
|
fcc6b41ab7 | ||
|
|
9458b8978f | ||
|
|
d75e8aeafa | ||
|
|
2eba98a465 | ||
|
|
a7a7aab7a0 | ||
|
|
86bfbb47d5 | ||
|
|
d33a269548 | ||
|
|
d3f8ea2df0 | ||
|
|
7df56ed617 | ||
|
|
e34dcc0406 | ||
|
|
a834ba8759 | ||
|
|
c67f345d0e | ||
|
|
8b8e510bfe | ||
|
|
3db839a5cb | ||
|
|
417c19577a | ||
|
|
b5953039de | ||
|
|
a43e80dd9c | ||
|
|
ad5f27bc5f | ||
|
|
05e0985f29 | ||
|
|
7b3314c5db | ||
|
|
a55ba6e614 | ||
|
|
f9bec1edf8 | ||
|
|
16199e968e | ||
|
|
02452421d5 | ||
|
|
3a5c7c75ad | ||
|
|
a7415ecfd8 | ||
|
|
934def5fcc | ||
|
|
0796791de5 | ||
|
|
6c148b223d | ||
|
|
4b168f4838 | ||
|
|
1c114eaef3 | ||
|
|
e053215155 | ||
|
|
13482b0fc1 | ||
|
|
38fa152cc4 | ||
|
|
2d9616c29c | ||
|
|
915e26527b | ||
|
|
2d604d9330 | ||
|
|
e7199826cc | ||
|
|
70e24b7594 | ||
|
|
c1602aafc7 | ||
|
|
a3fec11438 | ||
|
|
b1fd1b3ab3 | ||
|
|
5397799aac |
@@ -20,7 +20,8 @@ def check_file_for_chinese_comments(file_path):
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
|
||||
'prompts.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -149,4 +149,5 @@ sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
27
.vscode/launch.json
vendored
Normal file
27
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "api/app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--debug"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -53,9 +53,9 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
|
||||
|
||||
## Community channels
|
||||
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help!
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
|
||||
|
||||
### i18n (Internationalization) Support
|
||||
|
||||
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
||||
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
|
||||
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
|
||||
|
||||
@@ -16,15 +16,15 @@
|
||||
|
||||
## 本地开发
|
||||
|
||||
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose 堆栈。
|
||||
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
|
||||
|
||||
### Fork存储库
|
||||
|
||||
您需要 fork [存储库](https://github.com/langgenius/dify)。
|
||||
您需要 fork [Git 仓库](https://github.com/langgenius/dify)。
|
||||
|
||||
### 克隆存储库
|
||||
|
||||
克隆您在 GitHub 上 fork 的存储库:
|
||||
克隆您在 GitHub 上 fork 的仓库:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
|
||||
@@ -52,4 +52,4 @@ git clone git@github.com:<github_username>/dify.git
|
||||
|
||||
## コミュニティチャンネル
|
||||
|
||||
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/AhzKf7dNgk)に参加してください。私たちがお手伝いします!
|
||||
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
FROM python:3.10-slim
|
||||
# packages install stage
|
||||
FROM python:3.10-slim AS base
|
||||
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ python3-dev libc-dev libffi-dev
|
||||
|
||||
COPY requirements.txt /requirements.txt
|
||||
|
||||
RUN pip install --prefix=/pkg -r requirements.txt
|
||||
|
||||
# build stage
|
||||
FROM python:3.10-slim AS builder
|
||||
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
@@ -15,13 +26,12 @@ EXPOSE 5001
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
|
||||
|
||||
COPY requirements.txt /app/api/requirements.txt
|
||||
|
||||
RUN pip install -r requirements.txt
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs \
|
||||
&& apt-get autoremove \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=base /pkg /usr/local
|
||||
COPY . /app/api/
|
||||
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
@@ -54,9 +54,11 @@
|
||||
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
|
||||
8. If you need to debug local async processing, you can run `celery -A app.celery worker -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
|
||||
|
||||
8. Start frontend:
|
||||
8. Start frontend
|
||||
|
||||
You can start the frontend by running `npm install && npm run dev` in web/ folder, or you can use docker to start the frontend, for example:
|
||||
|
||||
```
|
||||
docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5000 --name web-self-hosted langgenius/dify-web:latest
|
||||
docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5001 --name web-self-hosted langgenius/dify-web:latest
|
||||
```
|
||||
This will start a dify frontend, now you are all set, happy coding!
|
||||
10
api/app.py
10
api/app.py
@@ -1,6 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -145,8 +145,12 @@ def load_user(user_id):
|
||||
_create_tenant_for_account(account)
|
||||
session['workspace_id'] = account.current_tenant_id
|
||||
|
||||
account.last_active_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
# update last_active_at when last_active_at is more than 10 minutes ago
|
||||
if current_time - account.last_active_at > timedelta(minutes=10):
|
||||
account.last_active_at = current_time
|
||||
db.session.commit()
|
||||
|
||||
# Log in the user with the updated user_id
|
||||
flask_login.login_user(account, remember=True)
|
||||
|
||||
145
api/commands.py
145
api/commands.py
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
@@ -6,10 +7,16 @@ import time
|
||||
|
||||
import click
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from libs.password import password_pattern, valid_password, hash_password
|
||||
from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
@@ -296,6 +303,142 @@ def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
|
||||
def update_qdrant_indexes():
|
||||
click.echo(click.style('Start Update qdrant indexes.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Update dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.update_qdrant_dataset(dataset)
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@@ -304,3 +447,5 @@ def register_commands(app):
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(update_qdrant_indexes)
|
||||
@@ -100,7 +100,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.18"
|
||||
self.CURRENT_VERSION = "0.3.19"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@@ -87,13 +87,19 @@ class DatasetListApi(Resource):
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
# f"in the Settings -> Model Provider.")
|
||||
model_names = [item['model_name'] for item in valid_model_list]
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['embedding_model'] in model_names:
|
||||
item['embedding_available'] = True
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
item['embedding_available'] = True
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
@@ -119,14 +125,6 @@ class DatasetListApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
@@ -150,20 +148,39 @@ class DatasetApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(
|
||||
dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
# get valid model list
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
if data['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data['embedding_available'] = True
|
||||
else:
|
||||
data['embedding_available'] = False
|
||||
else:
|
||||
data['embedding_available'] = True
|
||||
return data, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False,
|
||||
@@ -251,6 +268,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
@@ -272,7 +290,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -287,7 +306,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
@@ -138,6 +138,10 @@ class GetProcessRuleApi(Resource):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get('document_id')
|
||||
|
||||
# get default rules
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
if document_id:
|
||||
# get the latest process rule
|
||||
document = Document.query.get_or_404(document_id)
|
||||
@@ -158,11 +162,9 @@ class GetProcessRuleApi(Resource):
|
||||
order_by(DatasetProcessRule.created_at.desc()). \
|
||||
limit(1). \
|
||||
one_or_none()
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
else:
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
|
||||
return {
|
||||
'mode': mode,
|
||||
@@ -275,7 +277,8 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
@@ -284,20 +287,6 @@ class DatasetDocumentListApi(Resource):
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
@@ -335,17 +324,20 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
@@ -414,7 +406,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict, None, dataset_id)
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -483,7 +476,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict, None, dataset_id)
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -497,7 +491,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict,
|
||||
None, dataset_id)
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -725,6 +719,12 @@ class DocumentDeleteApi(DocumentResource):
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
@@ -787,6 +787,12 @@ class DocumentStatusApi(DocumentResource):
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
@@ -855,6 +861,14 @@ class DocumentStatusApi(DocumentResource):
|
||||
if not document.archived:
|
||||
raise InvalidActionError('Document is not archived.')
|
||||
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + 1
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
|
||||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
@@ -872,6 +886,10 @@ class DocumentStatusApi(DocumentResource):
|
||||
|
||||
|
||||
class DocumentPauseApi(DocumentResource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -901,6 +919,9 @@ class DocumentPauseApi(DocumentResource):
|
||||
|
||||
|
||||
class DocumentRecoverApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -926,6 +947,21 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentLimitApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""get document limit"""
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
|
||||
return {
|
||||
'documents_count': documents_count,
|
||||
'documents_limit': tenant_document_count
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@@ -951,3 +987,4 @@ api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentLimitApi, '/datasets/limit')
|
||||
|
||||
@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
|
||||
@@ -83,7 +83,7 @@ class FileApi(Resource):
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension not in ALLOWED_EXTENSIONS:
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
@@ -136,7 +136,7 @@ class FilePreviewApi(Resource):
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension not in ALLOWED_EXTENSIONS:
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = FileExtractor.load(upload_file, return_text=True)
|
||||
|
||||
@@ -49,46 +49,43 @@ class MemberInviteEmailApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=str, required=True, location='json')
|
||||
parser.add_argument('emails', type=str, required=True, location='json', action='append')
|
||||
parser.add_argument('role', type=str, required=True, default='admin', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
invitee_email = args['email']
|
||||
invitee_emails = args['emails']
|
||||
invitee_role = args['role']
|
||||
if invitee_role not in ['admin', 'normal']:
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
inviter = current_user
|
||||
|
||||
try:
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == args['email']).first()
|
||||
account, role = account
|
||||
account = marshal(account, account_fields)
|
||||
account['role'] = role
|
||||
except services.errors.account.CannotOperateSelfError as e:
|
||||
return {'code': 'cannot-operate-self', 'message': str(e)}, 400
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
return {'code': 'forbidden', 'message': str(e)}, 403
|
||||
except services.errors.account.AccountAlreadyInTenantError as e:
|
||||
return {'code': 'email-taken', 'message': str(e)}, 409
|
||||
except Exception as e:
|
||||
return {'code': 'unexpected-error', 'message': str(e)}, 500
|
||||
|
||||
# todo:413
|
||||
invitation_results = []
|
||||
console_web_url = current_app.config.get("CONSOLE_WEB_URL")
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == invitee_email).first()
|
||||
account, role = account
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?workspace_id={current_user.current_tenant_id}&email={invitee_email}&token={token}'
|
||||
})
|
||||
account = marshal(account, account_fields)
|
||||
account['role'] = role
|
||||
except Exception as e:
|
||||
invitation_results.append({
|
||||
'status': 'failed',
|
||||
'email': invitee_email,
|
||||
'message': str(e)
|
||||
})
|
||||
|
||||
return {
|
||||
'result': 'success',
|
||||
'account': account,
|
||||
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
str(current_user.current_tenant_id),
|
||||
invitee_email,
|
||||
token
|
||||
)
|
||||
'invitation_results': invitation_results,
|
||||
}, 201
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
@@ -60,7 +60,13 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
@@ -45,7 +45,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
self.llm.max_tokens = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
@@ -97,6 +97,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -90,7 +90,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
elif len(self.dataset_tools) == 1:
|
||||
tool = next(iter(self.dataset_tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
@@ -102,7 +102,13 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
@@ -106,7 +106,13 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
@@ -44,10 +45,15 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# tool_name = serialized.get('name')
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
dataset_id = input_dict.get('dataset_id')
|
||||
query = input_dict.get('query')
|
||||
tool_name: str = serialized.get('name')
|
||||
dataset_id = tool_name.removeprefix('dataset-')
|
||||
|
||||
try:
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
query = input_dict.get('query')
|
||||
except JSONDecodeError:
|
||||
query = input_str
|
||||
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
||||
|
||||
def on_tool_end(
|
||||
|
||||
@@ -137,7 +137,8 @@ class ConversationMessageTask:
|
||||
db.session.flush()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
self._pub_handler.pub_text(text)
|
||||
if text is not None:
|
||||
self._pub_handler.pub_text(text)
|
||||
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
|
||||
@@ -6,7 +6,7 @@ import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv import CSVLoader
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
@@ -47,17 +47,18 @@ class FileExtractor:
|
||||
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
if input_file.suffix == '.xlsx':
|
||||
file_extension = input_file.suffix.lower()
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif input_file.suffix == '.pdf':
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif input_file.suffix in ['.md', '.markdown']:
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif input_file.suffix in ['.htm', '.html']:
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif input_file.suffix == '.docx':
|
||||
elif file_extension == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif input_file.suffix == '.csv':
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import csv
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
|
||||
from models.dataset import Document
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,6 +30,8 @@ class ExcelLoader(BaseLoader):
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||
# loop over all sheets
|
||||
for sheet in wb:
|
||||
if 'A1:A1' == sheet.calculate_dimension():
|
||||
sheet.reset_dimensions()
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
if all(v is None for v in row):
|
||||
continue
|
||||
@@ -38,7 +40,7 @@ class ExcelLoader(BaseLoader):
|
||||
else:
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
|
||||
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
|
||||
document = Document(page_content=item, metadata={'source': self._file_path})
|
||||
data.append(document)
|
||||
|
||||
|
||||
@@ -67,12 +67,13 @@ class DatesetDocumentStore:
|
||||
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
)
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
if not isinstance(doc, Document):
|
||||
@@ -88,7 +89,7 @@ class DatesetDocumentStore:
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(doc.page_content)
|
||||
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.schema import OutputParserException
|
||||
@@ -22,18 +23,25 @@ class LLMGenerator:
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
prompt = prompt.format(query=query)
|
||||
query = query.replace("\n", " ")
|
||||
|
||||
prompt += query + "\n"
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=50
|
||||
temperature=1,
|
||||
max_tokens=100
|
||||
)
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
|
||||
result_dict = json.loads(answer)
|
||||
answer = result_dict['Your Output']
|
||||
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from models.dataset import Dataset
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@@ -35,4 +43,13 @@ class IndexBuilder:
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unknown indexing technique')
|
||||
raise ValueError('Unknown indexing technique')
|
||||
|
||||
@classmethod
|
||||
def get_default_high_quality_index(cls, dataset: Dataset):
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=' ')
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = {}
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
@@ -52,7 +52,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
@@ -199,15 +199,18 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
|
||||
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id == node_id
|
||||
).first()
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(node_id, keywords)
|
||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
@@ -15,12 +15,12 @@ from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseVectorIndex(BaseIndex):
|
||||
|
||||
|
||||
def __init__(self, dataset: Dataset, embeddings: Embeddings):
|
||||
super().__init__(dataset)
|
||||
self._embeddings = embeddings
|
||||
self._vector_store = None
|
||||
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -143,7 +143,7 @@ class BaseVectorIndex(BaseIndex):
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
@@ -173,3 +173,73 @@ class BaseVectorIndex(BaseIndex):
|
||||
|
||||
self.dataset = dataset
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def create_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"create_qdrant_dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except UnexpectedStatusCodeException as e:
|
||||
if e.status_code != 400:
|
||||
# 400 means index not exists
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def update_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"update_qdrant_dataset {dataset.id}")
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).first()
|
||||
|
||||
if segment:
|
||||
try:
|
||||
exist = self.text_exists(segment.index_node_id)
|
||||
if exist:
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
114
api/core/index/vector_index/milvus_vector_index.py
Normal file
114
api/core/index/vector_index/milvus_vector_index.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore, milvus
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
endpoint: str
|
||||
user: str
|
||||
password: str
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config MILVUS_ENDPOINT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
|
||||
class MilvusVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return MilvusVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
1691
api/core/index/vector_index/qdrant.py
Normal file
1691
api/core/index/vector_index/qdrant.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -44,15 +44,20 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
return self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Index_" + dataset_id.replace("-", "_")
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
@@ -62,7 +67,7 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='text',
|
||||
content_payload_key='page_content',
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
@@ -72,7 +77,9 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
@@ -81,7 +88,7 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='text'
|
||||
content_payload_key='page_content'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
@@ -108,8 +115,8 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
if class_prefix.startswith('Vector_'):
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
|
||||
@@ -217,25 +217,29 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
@@ -263,8 +267,8 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
if indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
@@ -286,32 +290,35 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
@@ -356,8 +363,8 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
@@ -379,8 +386,8 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
@@ -399,7 +406,8 @@ class IndexingRunner:
|
||||
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
||||
one_or_none()
|
||||
|
||||
text_docs = FileExtractor.load(file_detail)
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail)
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
@@ -525,12 +533,13 @@ class IndexingRunner:
|
||||
documents = splitter.split_documents([text_doc])
|
||||
split_documents = []
|
||||
for document_node in documents:
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
|
||||
split_documents.append(document_node)
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
# processing qa document
|
||||
if document_form == 'qa_model':
|
||||
@@ -656,12 +665,13 @@ class IndexingRunner:
|
||||
"""
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
@@ -671,11 +681,11 @@ class IndexingRunner:
|
||||
# check document is paused
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
|
||||
tokens += sum(
|
||||
embedding_model.get_num_tokens(document.page_content)
|
||||
for document in chunk_documents
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += sum(
|
||||
embedding_model.get_num_tokens(document.page_content)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
|
||||
@@ -63,6 +63,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'openllm':
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
return OpenLLMProvider
|
||||
elif provider_name == 'localai':
|
||||
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||
return LocalAIProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from langchain.embeddings import LocalAIEmbeddings
|
||||
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class LocalAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = LocalAIEmbeddings(
|
||||
model=name,
|
||||
openai_api_key="1",
|
||||
openai_api_base=credentials['server_url'],
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@@ -1,11 +1,8 @@
|
||||
import decimal
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import anthropic
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
@@ -13,6 +10,7 @@ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM
|
||||
|
||||
|
||||
class AnthropicModel(BaseLLM):
|
||||
@@ -20,7 +18,7 @@ class AnthropicModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatAnthropic(
|
||||
return AnthropicLLM(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
@@ -75,7 +73,7 @@ class AnthropicModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
|
||||
@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
|
||||
result = self._run(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as ex:
|
||||
@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
if self.streaming and not self.support_streaming():
|
||||
if self.streaming and not self.support_streaming:
|
||||
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
|
||||
else:
|
||||
self.client.callbacks.extend(callbacks)
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return False
|
||||
|
||||
def get_prompt(self, mode: str,
|
||||
|
||||
@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
streaming = self.streaming
|
||||
|
||||
if 'baichuan' in self.name.lower():
|
||||
streaming = False
|
||||
|
||||
client = HuggingFaceEndpointLLM(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks
|
||||
callbacks=self.callbacks,
|
||||
streaming=streaming
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
@property
|
||||
def support_streaming(self):
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
if 'baichuan' in self.name.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
131
api/core/model_providers/models/llm/localai_model.py
Normal file
131
api/core/model_providers/models/llm/localai_model.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, get_buffer_string
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class LocalAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
if credentials['completion_type'] == 'chat_completion':
|
||||
self.model_mode = ModelMode.CHAT
|
||||
else:
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.model_mode == ModelMode.COMPLETION:
|
||||
client = EnhanceOpenAI(
|
||||
model_name=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
openai_api_key="1",
|
||||
openai_api_base=self.credentials['server_url'] + '/v1',
|
||||
**provider_model_kwargs
|
||||
)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p')
|
||||
}
|
||||
|
||||
client = EnhanceChatOpenAI(
|
||||
model_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
openai_api_key="1",
|
||||
openai_api_base=self.credentials['server_url'] + '/v1'
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, str):
|
||||
return self._client.get_num_tokens(prompts)
|
||||
else:
|
||||
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.model_mode == ModelMode.COMPLETION:
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p')
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to LocalAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to LocalAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("LocalAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
|
||||
@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Wenxin: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Xinference: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Type, Optional
|
||||
|
||||
import anthropic
|
||||
from flask import current_app
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
@@ -16,6 +15,7 @@ from core.model_providers.models.llm.anthropic_model import AnthropicModel
|
||||
from core.model_providers.models.llm.base import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class AnthropicProvider(BaseModelProvider):
|
||||
if 'anthropic_api_url' in credentials:
|
||||
credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
chat_llm = AnthropicLLM(
|
||||
model='claude-instant-1',
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
|
||||
164
api/core/model_providers/providers/localai_provider.py
Normal file
164
api/core/model_providers/providers/localai_provider.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from langchain.embeddings import LocalAIEmbeddings
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
|
||||
from core.model_providers.models.llm.localai_model import LocalAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class LocalAIProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'localai'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = LocalAIModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = LocalAIEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=0.7),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
max_tokens=KwargRule[int](min=10, max=4097, default=16),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
|
||||
|
||||
try:
|
||||
if model_type == ModelType.EMBEDDINGS:
|
||||
model = LocalAIEmbeddings(
|
||||
model=model_name,
|
||||
openai_api_key='1',
|
||||
openai_api_base=credentials['server_url']
|
||||
)
|
||||
|
||||
model.embed_query("ping")
|
||||
else:
|
||||
if ('completion_type' not in credentials
|
||||
or credentials['completion_type'] not in ['completion', 'chat_completion']):
|
||||
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
|
||||
|
||||
if credentials['completion_type'] == 'chat_completion':
|
||||
model = EnhanceChatOpenAI(
|
||||
model_name=model_name,
|
||||
openai_api_key='1',
|
||||
openai_api_base=credentials['server_url'] + '/v1',
|
||||
max_tokens=10,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
model([HumanMessage(content='ping')])
|
||||
else:
|
||||
model = EnhanceOpenAI(
|
||||
model_name=model_name,
|
||||
openai_api_key='1',
|
||||
openai_api_base=credentials['server_url'] + '/v1',
|
||||
max_tokens=10,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
model('ping')
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None,
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
@@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider):
|
||||
if 'api_secret' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'app_id': credentials['app_id'],
|
||||
'api_key': credentials['api_key'],
|
||||
'api_secret': credentials['api_secret'],
|
||||
}
|
||||
credential_kwargs = {
|
||||
'app_id': credentials['app_id'],
|
||||
'api_key': credentials['api_key'],
|
||||
'api_secret': credentials['api_secret'],
|
||||
}
|
||||
|
||||
try:
|
||||
chat_llm = ChatSpark(
|
||||
model_name='spark-v2',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
@@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider):
|
||||
|
||||
chat_llm(messages)
|
||||
except SparkError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
# try spark v1.5 if v2.1 failed
|
||||
try:
|
||||
chat_llm = ChatSpark(
|
||||
model_name='spark',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except SparkError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('Spark config validation failed')
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logging.exception('Spark config validation failed')
|
||||
raise ex
|
||||
|
||||
@@ -10,5 +10,6 @@
|
||||
"replicate",
|
||||
"huggingface_hub",
|
||||
"xinference",
|
||||
"openllm"
|
||||
"openllm",
|
||||
"localai"
|
||||
]
|
||||
7
api/core/model_providers/rules/localai.json
Normal file
7
api/core/model_providers/rules/localai.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
@@ -283,6 +283,7 @@ class OrchestratorRuleParser:
|
||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||
DEFAULT_K = 2
|
||||
CONTEXT_TOKENS_PERCENT = 0.3
|
||||
MAX_K = 10
|
||||
|
||||
if rest_tokens == -1:
|
||||
return DEFAULT_K
|
||||
@@ -311,5 +312,5 @@ class OrchestratorRuleParser:
|
||||
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||
return DEFAULT_K
|
||||
|
||||
# Expand the k value when there's still some room left in the 30% rest tokens space
|
||||
return context_limit_tokens // segment_max_tokens
|
||||
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
|
||||
return min(context_limit_tokens // segment_max_tokens, MAX_K)
|
||||
|
||||
@@ -1,10 +1,65 @@
|
||||
CONVERSATION_TITLE_PROMPT = (
|
||||
"Human:{query}\n-----\n"
|
||||
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
|
||||
"If what the human said is conducted in English, you should only return an English title.\n"
|
||||
"If what the human said is conducted in Chinese, you should only return a Chinese title.\n"
|
||||
"title:"
|
||||
)
|
||||
# Written by YORKI MINAKO🤡
|
||||
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
|
||||
Notice: the language type user use could be diverse, which can be English, Chinese, Español, Arabic, Japanese, French, and etc.
|
||||
MAKE SURE your output is the SAME language as the user's input!
|
||||
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
|
||||
Your output MUST be a valid JSON.
|
||||
|
||||
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
|
||||
|
||||
|
||||
example 1:
|
||||
User Input: hi, yesterday i had some burgers.
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "sharing yesterday's food"
|
||||
}
|
||||
|
||||
example 2:
|
||||
User Input: hello
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Greeting myself☺️"
|
||||
}
|
||||
|
||||
|
||||
example 3:
|
||||
User Input: why mmap file: oom
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Asking about the reason for mmap file: oom"
|
||||
}
|
||||
|
||||
|
||||
example 4:
|
||||
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么?
|
||||
{
|
||||
"Language Type": "The user's input English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
|
||||
}
|
||||
|
||||
example 5:
|
||||
User Input: why小红的年龄is老than小明?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English parts are subjective particles, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问小红和小明的年龄"
|
||||
}
|
||||
|
||||
example 6:
|
||||
User Input: yo, 你今天咋样?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "查询今日我的状态☺️"
|
||||
}
|
||||
|
||||
User Input:
|
||||
"""
|
||||
|
||||
CONVERSATION_SUMMARY_PROMPT = (
|
||||
"Please generate a short summary of the following conversation.\n"
|
||||
|
||||
48
api/core/third_party/langchain/llms/anthropic_llm.py
vendored
Normal file
48
api/core/third_party/langchain/llms/anthropic_llm.py
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Dict
|
||||
|
||||
from httpx import Limits
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.utils import get_from_dict_or_env, check_package_version
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class AnthropicLLM(ChatAnthropic):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anthropic_api_key"] = get_from_dict_or_env(
|
||||
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["anthropic_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anthropic_api_url",
|
||||
"ANTHROPIC_API_URL",
|
||||
default="https://api.anthropic.com",
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
check_package_version("anthropic", gte_version="0.3")
|
||||
values["client"] = anthropic.Anthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"],
|
||||
timeout=values["default_request_timeout"],
|
||||
max_retries=0,
|
||||
connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
|
||||
)
|
||||
values["async_client"] = anthropic.AsyncAnthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"],
|
||||
timeout=values["default_request_timeout"],
|
||||
)
|
||||
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
||||
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
||||
values["count_tokens"] = values["client"].count_tokens
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import anthropic python package. "
|
||||
"Please it install it with `pip install anthropic`."
|
||||
)
|
||||
return values
|
||||
@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
|
||||
return {
|
||||
**super()._default_params,
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_base": self.openai_api_base if self.openai_api_base
|
||||
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Dict
|
||||
from typing import Dict, Any, Optional, List, Iterable, Iterator
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.embeddings.huggingface_hub import VALID_TASKS
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from pydantic import Extra, root_validator
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
"""
|
||||
client: Any
|
||||
streaming: bool = False
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
values['client'] = InferenceClient(values['endpoint_url'], token=huggingfacehub_api_token)
|
||||
|
||||
values["huggingfacehub_api_token"] = huggingfacehub_api_token
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = hf("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# payload samples
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
|
||||
# generation parameter
|
||||
gen_kwargs = {
|
||||
**params,
|
||||
'stop_sequences': stop
|
||||
}
|
||||
|
||||
response = self.client.text_generation(prompt, stream=self.streaming, details=True, **gen_kwargs)
|
||||
|
||||
if self.streaming and isinstance(response, Iterable):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_response(response, run_manager):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
else:
|
||||
completion = response.generated_text
|
||||
|
||||
if self.task == "text-generation":
|
||||
text = completion
|
||||
# Remove prompt if included in generated text.
|
||||
if text.startswith(prompt):
|
||||
text = text[len(prompt) :]
|
||||
elif self.task == "text2text-generation":
|
||||
text = completion
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
|
||||
def _stream_response(
|
||||
self,
|
||||
response: Iterable,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> Iterator[str]:
|
||||
for r in response:
|
||||
# skip special tokens
|
||||
if r.token.special:
|
||||
continue
|
||||
|
||||
token = r.token.text
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=None
|
||||
)
|
||||
|
||||
# yield the generated token
|
||||
yield token
|
||||
|
||||
35
api/core/third_party/langchain/llms/open_ai.py
vendored
35
api/core/third_party/langchain/llms/open_ai.py
vendored
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
|
||||
from typing import Dict, Any, Mapping, Optional, Union, Tuple
|
||||
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
|
||||
from langchain import OpenAI
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_base": self.openai_api_base if self.openai_api_base
|
||||
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_base": self.openai_api_base if self.openai_api_base
|
||||
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutates params
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=prompt, run_manager=run_manager, **params
|
||||
):
|
||||
if 'text' in stream_resp["choices"][0]:
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None,
|
||||
)
|
||||
|
||||
@@ -3,17 +3,20 @@ from typing import Optional, List, Any, Union, Generator
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Xinference
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
||||
RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
||||
from xinference.client import (
|
||||
RESTfulChatglmCppChatModelHandle,
|
||||
RESTfulChatModelHandle,
|
||||
RESTfulGenerateModelHandle,
|
||||
)
|
||||
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
@@ -29,7 +32,9 @@ class XinferenceLLM(Xinference):
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
if isinstance(model, RESTfulChatModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
|
||||
"generate_config", {}
|
||||
)
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
@@ -37,10 +42,10 @@ class XinferenceLLM(Xinference):
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
@@ -48,7 +53,9 @@ class XinferenceLLM(Xinference):
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["message"]["content"]
|
||||
elif isinstance(model, RESTfulGenerateModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
|
||||
"generate_config", {}
|
||||
)
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
@@ -56,27 +63,31 @@ class XinferenceLLM(Xinference):
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
|
||||
else:
|
||||
completion = model.generate(prompt=prompt, generate_config=generate_config)
|
||||
completion = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
return completion["choices"][0]["text"]
|
||||
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
|
||||
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
generate_config: "ChatglmCppGenerateConfig" = kwargs.get(
|
||||
"generate_config", {}
|
||||
)
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
@@ -90,12 +101,21 @@ class XinferenceLLM(Xinference):
|
||||
return completion
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[
|
||||
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
self,
|
||||
model: Union[
|
||||
"RESTfulGenerateModelHandle",
|
||||
"RESTfulChatModelHandle",
|
||||
"RESTfulChatglmCppChatModelHandle",
|
||||
],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[
|
||||
Union[
|
||||
"LlamaCppGenerateConfig",
|
||||
"PytorchGenerateConfig",
|
||||
"ChatglmCppGenerateConfig",
|
||||
]
|
||||
] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
@@ -108,7 +128,9 @@ class XinferenceLLM(Xinference):
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)):
|
||||
if isinstance(
|
||||
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
|
||||
):
|
||||
streaming_response = model.chat(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
@@ -123,14 +145,10 @@ class XinferenceLLM(Xinference):
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
if 'finish_reason' in choice and choice['finish_reason'] \
|
||||
and choice['finish_reason'] in ['stop', 'length']:
|
||||
break
|
||||
|
||||
if 'text' in choice:
|
||||
if "text" in choice:
|
||||
token = choice.get("text", "")
|
||||
elif 'delta' in choice and 'content' in choice['delta']:
|
||||
token = choice.get('delta').get('content')
|
||||
elif "delta" in choice and "content" in choice["delta"]:
|
||||
token = choice.get("delta").get("content")
|
||||
else:
|
||||
continue
|
||||
log_probs = choice.get("logprobs")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from flask import current_app
|
||||
@@ -16,7 +15,6 @@ from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
@@ -37,27 +35,22 @@ class DatasetRetrieverTool(BaseTool):
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
description += '\nID of dataset MUST be ' + dataset.id
|
||||
return cls(
|
||||
name=f'dataset-{dataset.id}',
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(self, dataset_id: str, query: str) -> str:
|
||||
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
||||
match = re.search(pattern, dataset_id, re.IGNORECASE)
|
||||
if match:
|
||||
dataset_id = match.group()
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
Dataset.id == self.dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
|
||||
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
@@ -105,7 +98,8 @@ class DatasetRetrieverTool(BaseTool):
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
|
||||
@@ -88,11 +88,9 @@ class WebReaderTool(BaseTool):
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
if len(docs) == 0:
|
||||
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
|
||||
return "No content found."
|
||||
|
||||
docs = docs[1:]
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
38
api/core/vector_store/milvus_vector_store.py
Normal file
38
api/core/vector_store/milvus_vector_store.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from langchain.vectorstores import Milvus
|
||||
|
||||
|
||||
class MilvusVectorStore(Milvus):
|
||||
def del_texts(self, where_filter: dict):
|
||||
if not where_filter:
|
||||
raise ValueError('where_filter must not be empty')
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._index_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def del_text(self, uuid: str) -> None:
|
||||
self._client.data_object.delete(
|
||||
uuid,
|
||||
class_name=self._index_name
|
||||
)
|
||||
|
||||
def text_exists(self, uuid: str) -> bool:
|
||||
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": uuid,
|
||||
}).with_limit(1).do()
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
entries = result["data"]["Get"][self._index_name]
|
||||
if len(entries) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete(self):
|
||||
self._client.schema.delete_class(self._index_name)
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import cast, Any
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Qdrant
|
||||
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.index.vector_index.qdrant import Qdrant
|
||||
|
||||
|
||||
class QdrantVectorStore(Qdrant):
|
||||
def del_texts(self, filter: Filter):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
from events.event_handlers.document_index_event import document_index_created
|
||||
from tasks.clean_dataset_task import clean_dataset_task
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
@@ -26,7 +26,7 @@ def handle(sender, **kwargs):
|
||||
|
||||
conversation.name = name
|
||||
except:
|
||||
conversation.name = 'New Chat'
|
||||
conversation.name = 'New conversation'
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""update_dataset_model_field_null_available
|
||||
|
||||
Revision ID: 4bcffcd64aa4
|
||||
Revises: 853f9b9cd3b6
|
||||
Create Date: 2023-08-28 20:58:50.077056
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4bcffcd64aa4'
|
||||
down_revision = '853f9b9cd3b6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.alter_column('embedding_model',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
|
||||
batch_op.alter_column('embedding_model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("'openai'::character varying"))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.alter_column('embedding_model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("'openai'::character varying"))
|
||||
batch_op.alter_column('embedding_model',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -36,10 +36,8 @@ class Dataset(db.Model):
|
||||
updated_by = db.Column(UUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
embedding_model = db.Column(db.String(
|
||||
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying"))
|
||||
embedding_model_provider = db.Column(db.String(
|
||||
255), nullable=False, server_default=db.text("'openai'::character varying"))
|
||||
embedding_model = db.Column(db.String(255), nullable=True)
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
|
||||
@@ -10,6 +10,7 @@ from flask import current_app
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
@@ -91,16 +92,18 @@ class DatasetService:
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(
|
||||
f'Dataset with name {name} already exists.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
embedding_model = None
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.name if embedding_model else None
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
@@ -115,17 +118,50 @@ class DatasetService:
|
||||
else:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(f"The dataset in unavailable, due to: "
|
||||
f"{ex.description}")
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
action = None
|
||||
if dataset.indexing_technique != data['indexing_technique']:
|
||||
# if update indexing_technique
|
||||
if data['indexing_technique'] == 'economy':
|
||||
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
|
||||
action = 'remove'
|
||||
filtered_data['embedding_model'] = None
|
||||
filtered_data['embedding_model_provider'] = None
|
||||
elif data['indexing_technique'] == 'high_quality':
|
||||
deal_dataset_vector_index_task.delay(dataset_id, 'add')
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
||||
action = 'add'
|
||||
# get embedding model setting
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
filtered_data['embedding_model'] = embedding_model.name
|
||||
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
filtered_data['updated_by'] = user.id
|
||||
filtered_data['updated_at'] = datetime.datetime.now()
|
||||
@@ -133,7 +169,8 @@ class DatasetService:
|
||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset_id, action)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
@@ -394,16 +431,26 @@ class DocumentService:
|
||||
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = 'web'):
|
||||
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
if 'original_document_id' not in document_data or not document_data['original_document_id']:
|
||||
count = 0
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
count = len(upload_file_list)
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + count
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = document_data["data_source"]["type"]
|
||||
db.session.commit()
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if 'indexing_technique' not in document_data \
|
||||
@@ -411,6 +458,13 @@ class DocumentService:
|
||||
raise ValueError("Indexing technique is required")
|
||||
|
||||
dataset.indexing_technique = document_data["indexing_technique"]
|
||||
if document_data["indexing_technique"] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
@@ -455,12 +509,12 @@ class DocumentService:
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
@@ -501,12 +555,12 @@ class DocumentService:
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
@@ -525,10 +579,10 @@ class DocumentService:
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@@ -557,6 +611,7 @@ class DocumentService:
|
||||
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = 'web'):
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
||||
if document.display_status != 'available':
|
||||
raise ValueError("Document is not available")
|
||||
@@ -649,15 +704,26 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
|
||||
count = 0
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
count = len(upload_file_list)
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + count
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
embedding_model = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@@ -665,8 +731,8 @@ class DocumentService:
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -874,21 +940,25 @@ class SegmentService:
|
||||
if document.doc_form == 'qa_model':
|
||||
if 'answer' not in args or not args['answer']:
|
||||
raise ValueError("Answer is required")
|
||||
if not args['answer'].strip():
|
||||
raise ValueError("Answer is empty")
|
||||
if 'content' not in args or not args['content'] or not args['content'].strip():
|
||||
raise ValueError("Content is empty")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
content = args['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
).scalar()
|
||||
@@ -950,15 +1020,16 @@ class SegmentService:
|
||||
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
@@ -990,10 +1061,11 @@ class SegmentService:
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Segment is deleting.")
|
||||
# send delete segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
# enabled segment need to delete index
|
||||
if segment.enabled:
|
||||
# send delete segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
||||
raise ValueError('Document is not available.')
|
||||
document_segments = []
|
||||
for segment in content:
|
||||
content = segment['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
for segment in content:
|
||||
content = segment['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == dataset_document.id
|
||||
).scalar()
|
||||
|
||||
@@ -3,8 +3,10 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
|
||||
AppDatasetJoin, Document
|
||||
@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||
try:
|
||||
vector_index.delete()
|
||||
except Exception:
|
||||
|
||||
@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
raise Exception('Dataset not found')
|
||||
|
||||
if action == "remove":
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
|
||||
index.delete()
|
||||
elif action == "add":
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
if dataset_documents:
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
# delete from vector index
|
||||
@@ -65,7 +65,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
documents.append(document)
|
||||
|
||||
# save vector index
|
||||
index.add_texts(documents)
|
||||
index.create(documents)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
||||
@@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL=
|
||||
XINFERENCE_MODEL_UID=
|
||||
|
||||
# OpenLLM Credentials
|
||||
OPENLLM_SERVER_URL=
|
||||
OPENLLM_SERVER_URL=
|
||||
|
||||
# LocalAI Credentials
|
||||
LOCALAI_SERVER_URL=
|
||||
@@ -0,0 +1,61 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
def get_mock_provider():
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='localai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config='',
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_embedding_model(mocker):
|
||||
model_name = 'text-embedding-ada-002'
|
||||
server_url = os.environ['LOCALAI_SERVER_URL']
|
||||
model_provider = LocalAIProvider(provider=get_mock_provider())
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
provider_name='localai',
|
||||
model_name=model_name,
|
||||
model_type=ModelType.EMBEDDINGS.value,
|
||||
encrypted_config=json.dumps({
|
||||
'server_url': server_url,
|
||||
}),
|
||||
is_valid=True,
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
return LocalAIEmbedding(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
return encrypted_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_embed_documents(mock_decrypt, mocker):
|
||||
embedding_model = get_mock_embedding_model(mocker)
|
||||
rst = embedding_model.client.embed_documents(['test', 'test1'])
|
||||
assert isinstance(rst, list)
|
||||
assert len(rst) == 2
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_embed_query(mock_decrypt, mocker):
|
||||
embedding_model = get_mock_embedding_model(mocker)
|
||||
rst = embedding_model.client.embed_query('test')
|
||||
assert isinstance(rst, list)
|
||||
68
api/tests/integration_tests/models/llm/test_localai_model.py
Normal file
68
api/tests/integration_tests/models/llm/test_localai_model.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from core.model_providers.models.llm.localai_model import LocalAIModel
|
||||
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
def get_mock_provider(server_url):
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='localai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({}),
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_model(model_name, mocker):
|
||||
model_kwargs = ModelKwargs(
|
||||
max_tokens=10,
|
||||
temperature=0
|
||||
)
|
||||
server_url = os.environ['LOCALAI_SERVER_URL']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
provider_name='localai',
|
||||
model_name=model_name,
|
||||
model_type=ModelType.TEXT_GENERATION.value,
|
||||
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
|
||||
is_valid=True,
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
|
||||
return LocalAIModel(
|
||||
model_provider=openai_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
|
||||
return encrypted_openai_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
|
||||
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
|
||||
assert rst > 0
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
|
||||
rst = openai_model.run(
|
||||
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
@@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i
|
||||
|
||||
def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
|
||||
mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc")
|
||||
mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
@@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
|
||||
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
|
||||
)
|
||||
|
||||
|
||||
def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
|
||||
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
|
||||
mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
|
||||
116
api/tests/unit_tests/model_providers/test_localai_provider.py
Normal file
116
api/tests/unit_tests/model_providers/test_localai_provider.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
PROVIDER_NAME = 'localai'
|
||||
MODEL_PROVIDER_CLASS = LocalAIProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'server_url': 'http://127.0.0.1:8080/'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
|
||||
return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='username/test_model_name',
|
||||
model_type=ModelType.EMBEDDINGS,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if server_url is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.EMBEDDINGS,
|
||||
credentials={}
|
||||
)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt, mocker):
|
||||
server_url = 'http://127.0.0.1:8080/'
|
||||
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.EMBEDDINGS,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
mock_encrypt.assert_called_with('tenant_id', server_url)
|
||||
assert result['server_url'] == f'encrypted_{server_url}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.EMBEDDINGS
|
||||
)
|
||||
assert result['server_url'] == 'http://127.0.0.1:8080/'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.EMBEDDINGS,
|
||||
obfuscated=True
|
||||
)
|
||||
middle_token = result['server_url'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.18
|
||||
image: langgenius/dify-api:0.3.19
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -124,7 +124,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.18
|
||||
image: langgenius/dify-api:0.3.19
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -176,7 +176,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.18
|
||||
image: langgenius/dify-web:0.3.19
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@@ -33,20 +33,18 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
||||
if (!response)
|
||||
return <Loading />
|
||||
|
||||
const handleError = (err: Error | null) => {
|
||||
if (!err) {
|
||||
notify({
|
||||
type: 'success',
|
||||
message: t('common.actionMsg.modifiedSuccessfully'),
|
||||
})
|
||||
const handleCallbackResult = (err: Error | null, message?: string) => {
|
||||
const type = err ? 'error' : 'success'
|
||||
|
||||
message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully')
|
||||
|
||||
if (type === 'success') {
|
||||
mutate(detailParams)
|
||||
}
|
||||
else {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: t('common.actionMsg.modificationFailed'),
|
||||
})
|
||||
}
|
||||
notify({
|
||||
type,
|
||||
message: t(`common.actionMsg.${message}`),
|
||||
})
|
||||
}
|
||||
|
||||
const onChangeSiteStatus = async (value: boolean) => {
|
||||
@@ -56,7 +54,8 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
||||
body: { enable_site: value },
|
||||
}) as Promise<App>,
|
||||
)
|
||||
handleError(err)
|
||||
|
||||
handleCallbackResult(err)
|
||||
}
|
||||
|
||||
const onChangeApiStatus = async (value: boolean) => {
|
||||
@@ -66,7 +65,8 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
||||
body: { enable_api: value },
|
||||
}) as Promise<App>,
|
||||
)
|
||||
handleError(err)
|
||||
|
||||
handleCallbackResult(err)
|
||||
}
|
||||
|
||||
const onSaveSiteConfig: IAppCardProps['onSaveSiteConfig'] = async (params) => {
|
||||
@@ -79,7 +79,7 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
||||
if (!err)
|
||||
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
|
||||
|
||||
handleError(err)
|
||||
handleCallbackResult(err)
|
||||
}
|
||||
|
||||
const onGenerateCode = async () => {
|
||||
@@ -88,7 +88,8 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
||||
url: `/apps/${appId}/site/access-token-reset`,
|
||||
}) as Promise<UpdateAppSiteCodeResponse>,
|
||||
)
|
||||
handleError(err)
|
||||
|
||||
handleCallbackResult(err, err ? 'generatedUnsuccessfully' : 'generatedSuccessfully')
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -2,7 +2,7 @@ import React from 'react'
|
||||
import ChartView from './chartView'
|
||||
import CardView from './cardView'
|
||||
import { getLocaleOnServer } from '@/i18n/server'
|
||||
import { useTranslation } from '@/i18n/i18next-serverside-config'
|
||||
import { useTranslation as translate } from '@/i18n/i18next-serverside-config'
|
||||
import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel'
|
||||
|
||||
export type IDevelopProps = {
|
||||
@@ -13,7 +13,11 @@ const Overview = async ({
|
||||
params: { appId },
|
||||
}: IDevelopProps) => {
|
||||
const locale = getLocaleOnServer()
|
||||
const { t } = await useTranslation(locale, 'app-overview')
|
||||
/*
|
||||
rename useTranslation to avoid lint error
|
||||
please check: https://github.com/i18next/next-13-app-dir-i18next-example/issues/24
|
||||
*/
|
||||
const { t } = await translate(locale, 'app-overview')
|
||||
return (
|
||||
<div className="h-full px-16 py-6 overflow-scroll">
|
||||
<ApikeyInfoPanel />
|
||||
|
||||
@@ -9,6 +9,7 @@ import style from '../list.module.css'
|
||||
import AppModeLabel from './AppModeLabel'
|
||||
import s from './style.module.css'
|
||||
import SettingsModal from '@/app/components/app/overview/settings'
|
||||
import type { ConfigParams } from '@/app/components/app/overview/settings'
|
||||
import type { App } from '@/types/app'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
@@ -73,7 +74,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
}
|
||||
|
||||
const onSaveSiteConfig = useCallback(
|
||||
async (params: any) => {
|
||||
async (params: ConfigParams) => {
|
||||
const [err] = await asyncRunSafe<App>(
|
||||
updateAppSiteConfig({
|
||||
url: `/apps/${app.id}/site`,
|
||||
@@ -92,7 +93,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
else {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: t('common.actionMsg.modificationFailed'),
|
||||
message: t('common.actionMsg.modifiedUnsuccessfully'),
|
||||
})
|
||||
}
|
||||
},
|
||||
@@ -100,12 +101,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
)
|
||||
|
||||
const Operations = (props: any) => {
|
||||
const onClickSettings = async (e: any) => {
|
||||
const onClickSettings = async (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
props?.onClose()
|
||||
e.preventDefault()
|
||||
await getAppDetail()
|
||||
}
|
||||
const onClickDelete = async (e: any) => {
|
||||
const onClickDelete = async (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
props?.onClose()
|
||||
e.preventDefault()
|
||||
setShowConfirmDelete(true)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { FC, SVGProps } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import { usePathname } from 'next/navigation'
|
||||
import useSWR from 'swr'
|
||||
@@ -57,7 +57,7 @@ const LikedItem: FC<{ type?: 'plugin' | 'app'; appStatus?: boolean; detail: Rela
|
||||
)
|
||||
}
|
||||
|
||||
const TargetIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
const TargetIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<g clip-path="url(#clip0_4610_6951)">
|
||||
<path d="M10.6666 5.33325V3.33325L12.6666 1.33325L13.3332 2.66659L14.6666 3.33325L12.6666 5.33325H10.6666ZM10.6666 5.33325L7.9999 7.99988M14.6666 7.99992C14.6666 11.6818 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6818 1.33325 7.99992C1.33325 4.31802 4.31802 1.33325 7.99992 1.33325M11.3333 7.99992C11.3333 9.84087 9.84087 11.3333 7.99992 11.3333C6.15897 11.3333 4.66659 9.84087 4.66659 7.99992C4.66659 6.15897 6.15897 4.66659 7.99992 4.66659" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
|
||||
@@ -70,7 +70,7 @@ const TargetIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
</svg>
|
||||
}
|
||||
|
||||
const TargetSolidIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
const TargetSolidIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M12.7733 0.67512C12.9848 0.709447 13.1669 0.843364 13.2627 1.03504L13.83 2.16961L14.9646 2.73689C15.1563 2.83273 15.2902 3.01486 15.3245 3.22639C15.3588 3.43792 15.2894 3.65305 15.1379 3.80458L13.1379 5.80458C13.0128 5.92961 12.8433 5.99985 12.6665 5.99985H10.9426L8.47124 8.47124C8.21089 8.73159 7.78878 8.73159 7.52843 8.47124C7.26808 8.21089 7.26808 7.78878 7.52843 7.52843L9.9998 5.05707V3.33318C9.9998 3.15637 10.07 2.9868 10.1951 2.86177L12.1951 0.861774C12.3466 0.710244 12.5617 0.640794 12.7733 0.67512Z" fill="#155EEF" />
|
||||
<path d="M1.99984 7.99984C1.99984 4.68613 4.68613 1.99984 7.99984 1.99984C8.36803 1.99984 8.6665 1.70136 8.6665 1.33317C8.6665 0.964981 8.36803 0.666504 7.99984 0.666504C3.94975 0.666504 0.666504 3.94975 0.666504 7.99984C0.666504 12.0499 3.94975 15.3332 7.99984 15.3332C12.0499 15.3332 15.3332 12.0499 15.3332 7.99984C15.3332 7.63165 15.0347 7.33317 14.6665 7.33317C14.2983 7.33317 13.9998 7.63165 13.9998 7.99984C13.9998 11.3135 11.3135 13.9998 7.99984 13.9998C4.68613 13.9998 1.99984 11.3135 1.99984 7.99984Z" fill="#155EEF" />
|
||||
@@ -78,7 +78,7 @@ const TargetSolidIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
</svg>
|
||||
}
|
||||
|
||||
const BookOpenIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
const BookOpenIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<path opacity="0.12" d="M1 3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7V10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1Z" fill="#155EEF" />
|
||||
<path d="M6 10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7M6 10.5V4.7M6 10.5L6.05003 10.425C6.39735 9.90398 6.57101 9.64349 6.80045 9.45491C7.00357 9.28796 7.23762 9.1627 7.4892 9.0863C7.77337 9 8.08645 9 8.71259 9H9.4C9.96005 9 10.2401 9 10.454 8.89101C10.6422 8.79513 10.7951 8.64215 10.891 8.45399C11 8.24008 11 7.96005 11 7.4V3.1C11 2.53995 11 2.25992 10.891 2.04601C10.7951 1.85785 10.6422 1.70487 10.454 1.60899C10.2401 1.5 9.96005 1.5 9.4 1.5H9.2C8.07989 1.5 7.51984 1.5 7.09202 1.71799C6.71569 1.90973 6.40973 2.21569 6.21799 2.59202C6 3.01984 6 3.5799 6 4.7" stroke="#155EEF" strokeLinecap="round" strokeLinejoin="round" />
|
||||
|
||||
@@ -6,9 +6,11 @@ const Layout: FC<{
|
||||
children: React.ReactNode
|
||||
}> = ({ children }) => {
|
||||
return (
|
||||
<div className="min-w-[300px]">
|
||||
<GA gaType={GaType.webapp} />
|
||||
{children}
|
||||
<div className=''>
|
||||
<div className="min-w-[300px]">
|
||||
<GA gaType={GaType.webapp} />
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import React from 'react'
|
||||
import type { FC } from 'react'
|
||||
import NavLink from './navLink'
|
||||
import AppBasic from './basic'
|
||||
|
||||
import type { NavIcon } from './navLink'
|
||||
|
||||
export type IAppDetailNavProps = {
|
||||
iconType?: 'app' | 'dataset' | 'notion'
|
||||
title: string
|
||||
@@ -12,13 +13,13 @@ export type IAppDetailNavProps = {
|
||||
navigation: Array<{
|
||||
name: string
|
||||
href: string
|
||||
icon: any
|
||||
selectedIcon: any
|
||||
icon: NavIcon
|
||||
selectedIcon: NavIcon
|
||||
}>
|
||||
extraInfo?: React.ReactNode
|
||||
}
|
||||
|
||||
const AppDetailNav: FC<IAppDetailNavProps> = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }) => {
|
||||
const AppDetailNav = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }: IAppDetailNavProps) => {
|
||||
return (
|
||||
<div className="flex flex-col w-56 overflow-y-auto bg-white border-r border-gray-200 shrink-0">
|
||||
<div className="flex flex-shrink-0 p-4">
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
'use client'
|
||||
|
||||
import { useSelectedLayoutSegment } from 'next/navigation'
|
||||
import classNames from 'classnames'
|
||||
import Link from 'next/link'
|
||||
|
||||
export type NavIcon = React.ComponentType<
|
||||
React.PropsWithoutRef<React.ComponentProps<'svg'>> & {
|
||||
title?: string | undefined
|
||||
titleId?: string | undefined
|
||||
}
|
||||
>
|
||||
|
||||
export type NavLinkProps = {
|
||||
name: string
|
||||
href: string
|
||||
iconMap: {
|
||||
selected: NavIcon
|
||||
normal: NavIcon
|
||||
}
|
||||
}
|
||||
|
||||
export default function NavLink({
|
||||
name,
|
||||
href,
|
||||
iconMap,
|
||||
}: {
|
||||
name: string
|
||||
href: string
|
||||
iconMap: { selected: any; normal: any }
|
||||
}) {
|
||||
}: NavLinkProps) {
|
||||
const segment = useSelectedLayoutSegment()
|
||||
const isActive = href.toLowerCase().split('/')?.pop() === segment?.toLowerCase()
|
||||
const NavIcon = isActive ? iconMap.selected : iconMap.normal
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { UserCircleIcon } from '@heroicons/react/24/solid'
|
||||
import cn from 'classnames'
|
||||
import type { DisplayScene, FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc, ThoughtItem } from '../type'
|
||||
import type { CitationItem, DisplayScene, FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc, ThoughtItem } from '../type'
|
||||
import OperationBtn from '../operation'
|
||||
import LoadingAnim from '../loading-anim'
|
||||
import { EditIcon, EditIconSolid, OpeningStatementIcon, RatingIcon } from '../icon-component'
|
||||
@@ -13,6 +13,7 @@ import s from '../style.module.css'
|
||||
import MoreInfo from '../more-info'
|
||||
import CopyBtn from '../copy-btn'
|
||||
import Thought from '../thought'
|
||||
import Citation from '../citation'
|
||||
import { randomString } from '@/utils'
|
||||
import type { Annotation, MessageRating } from '@/models/log'
|
||||
import AppContext from '@/context/app-context'
|
||||
@@ -45,11 +46,14 @@ export type IAnswerProps = {
|
||||
isResponsing?: boolean
|
||||
answerIconClassName?: string
|
||||
thoughts?: ThoughtItem[]
|
||||
citation?: CitationItem[]
|
||||
isThinking?: boolean
|
||||
dataSets?: DataSet[]
|
||||
isShowCitation?: boolean
|
||||
isShowCitationHitInfo?: boolean
|
||||
}
|
||||
// The component needs to maintain its own state to control whether to display input component
|
||||
const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing, answerIconClassName, thoughts, isThinking, dataSets }) => {
|
||||
const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing, answerIconClassName, thoughts, citation, isThinking, dataSets, isShowCitation, isShowCitationHitInfo = false }) => {
|
||||
const { id, content, more, feedback, adminFeedback, annotation: initAnnotation } = item
|
||||
const [showEdit, setShowEdit] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
@@ -237,6 +241,11 @@ const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedba
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
{
|
||||
!!citation?.length && !isThinking && isShowCitation && !isResponsing && (
|
||||
<Citation data={citation} showHitInfo={isShowCitationHitInfo} />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className='absolute top-[-14px] right-[-14px] flex flex-row justify-end gap-1'>
|
||||
{!item.isOpeningStatement && (
|
||||
|
||||
65
web/app/components/app/chat/citation/index.tsx
Normal file
65
web/app/components/app/chat/citation/index.tsx
Normal file
@@ -0,0 +1,65 @@
|
||||
import { useMemo } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { CitationItem } from '../type'
|
||||
import Popup from './popup'
|
||||
|
||||
export type Resources = {
|
||||
documentId: string
|
||||
documentName: string
|
||||
dataSourceType: string
|
||||
sources: CitationItem[]
|
||||
}
|
||||
|
||||
type CitationProps = {
|
||||
data: CitationItem[]
|
||||
showHitInfo?: boolean
|
||||
}
|
||||
const Citation: FC<CitationProps> = ({
|
||||
data,
|
||||
showHitInfo,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const resources = useMemo(() => data.reduce((prev: Resources[], next) => {
|
||||
const documentId = next.document_id
|
||||
const documentName = next.document_name
|
||||
const dataSourceType = next.data_source_type
|
||||
const documentIndex = prev.findIndex(i => i.documentId === documentId)
|
||||
|
||||
if (documentIndex > -1) {
|
||||
prev[documentIndex].sources.push(next)
|
||||
}
|
||||
else {
|
||||
prev.push({
|
||||
documentId,
|
||||
documentName,
|
||||
dataSourceType,
|
||||
sources: [next],
|
||||
})
|
||||
}
|
||||
|
||||
return prev
|
||||
}, []), [data])
|
||||
|
||||
return (
|
||||
<div className='mt-3'>
|
||||
<div className='flex items-center mb-2 text-xs font-medium text-gray-500'>
|
||||
{t('common.chat.citation.title')}
|
||||
<div className='grow ml-2 h-[1px] bg-black/5' />
|
||||
</div>
|
||||
<div className='flex'>
|
||||
{
|
||||
resources.map((res, index) => (
|
||||
<Popup
|
||||
key={index}
|
||||
data={res}
|
||||
showHitInfo={showHitInfo}
|
||||
/>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Citation
|
||||
113
web/app/components/app/chat/citation/popup.tsx
Normal file
113
web/app/components/app/chat/citation/popup.tsx
Normal file
@@ -0,0 +1,113 @@
|
||||
import { Fragment, useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import Link from 'next/link'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Tooltip from './tooltip'
|
||||
import ProgressTooltip from './progress-tooltip'
|
||||
import type { Resources } from './index'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import FileIcon from '@/app/components/base/file-icon'
|
||||
import { Hash02 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import { ArrowUpRight } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
|
||||
type PopupProps = {
|
||||
data: Resources
|
||||
showHitInfo?: boolean
|
||||
}
|
||||
|
||||
const Popup: FC<PopupProps> = ({
|
||||
data,
|
||||
showHitInfo = false,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const fileType = data.dataSourceType === 'upload_file'
|
||||
? (/\.([^.]*)$/g.exec(data.documentName)?.[1] || '')
|
||||
: 'notion'
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='top-start'
|
||||
offset={{
|
||||
mainAxis: 8,
|
||||
crossAxis: -2,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
|
||||
<div className='flex items-center mr-1 px-2 max-w-[240px] h-7 bg-white rounded-lg'>
|
||||
<FileIcon type={fileType} className='mr-1 w-4 h-4' />
|
||||
<div className='text-xs text-gray-600 truncate'>{data.documentName}</div>
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent>
|
||||
<div className='w-[360px] bg-gray-50 rounded-xl shadow-lg'>
|
||||
<div className='px-4 pt-3 pb-2'>
|
||||
<div className='flex items-center h-[18px]'>
|
||||
<FileIcon type={fileType} className='mr-1 w-4 h-4' />
|
||||
<div className='text-xs font-medium text-gray-600 truncate'>{data.documentName}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className='px-4 py-0.5 max-h-[450px] bg-white rounded-lg overflow-auto'>
|
||||
{
|
||||
data.sources.map((source, index) => (
|
||||
<Fragment key={index}>
|
||||
<div className='group py-3'>
|
||||
{
|
||||
showHitInfo && (
|
||||
<div className='flex items-center justify-between mb-2'>
|
||||
<div className='flex items-center px-1.5 h-5 border border-gray-200 rounded-md'>
|
||||
<Hash02 className='mr-0.5 w-3 h-3 text-gray-400' />
|
||||
<div className='text-[11px] font-medium text-gray-500'>{source.segment_position}</div>
|
||||
</div>
|
||||
<Link
|
||||
href={`/datasets/${source.dataset_id}/documents/${source.document_id}`}
|
||||
className='hidden items-center h-[18px] text-xs text-primary-600 group-hover:flex'>
|
||||
Link to dataset
|
||||
<ArrowUpRight className='ml-1 w-3 h-3' />
|
||||
</Link>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<div className='text-[13px] text-gray-800'>{source.content}</div>
|
||||
{
|
||||
showHitInfo && (
|
||||
<div className='flex items-center mt-2 text-xs font-medium text-gray-500'>
|
||||
<Tooltip
|
||||
text={t('common.chat.citation.characters')}
|
||||
data={source.word_count}
|
||||
/>
|
||||
<Tooltip
|
||||
text={t('common.chat.citation.hitCount')}
|
||||
data={source.hit_count}
|
||||
/>
|
||||
<Tooltip
|
||||
text={t('common.chat.citation.vectorHash')}
|
||||
data={source.index_node_hash.substring(0, 7)}
|
||||
/>
|
||||
<ProgressTooltip data={Number(source.score.toFixed(2))} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{
|
||||
index !== data.sources.length - 1 && (
|
||||
<div className='my-1 h-[1px] bg-black/5' />
|
||||
)
|
||||
}
|
||||
</Fragment>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default Popup
|
||||
46
web/app/components/app/chat/citation/progress-tooltip.tsx
Normal file
46
web/app/components/app/chat/citation/progress-tooltip.tsx
Normal file
@@ -0,0 +1,46 @@
|
||||
import { useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
|
||||
type ProgressTooltipProps = {
|
||||
data: number
|
||||
}
|
||||
|
||||
const ProgressTooltip: FC<ProgressTooltipProps> = ({
|
||||
data,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='top-start'
|
||||
>
|
||||
<PortalToFollowElemTrigger
|
||||
onMouseEnter={() => setOpen(true)}
|
||||
onMouseLeave={() => setOpen(false)}
|
||||
>
|
||||
<div className='grow flex items-center'>
|
||||
<div className='mr-1 w-16 h-1.5 rounded-[3px] border border-gray-400 overflow-hidden'>
|
||||
<div className='bg-gray-400 h-full' style={{ width: `${data * 100}%` }}></div>
|
||||
</div>
|
||||
{data}
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent>
|
||||
<div className='p-3 bg-white text-xs font-medium text-gray-500 rounded-lg shadow-lg'>
|
||||
{t('common.chat.citation.hitScore')} {data}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default ProgressTooltip
|
||||
45
web/app/components/app/chat/citation/tooltip.tsx
Normal file
45
web/app/components/app/chat/citation/tooltip.tsx
Normal file
@@ -0,0 +1,45 @@
|
||||
import { useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import { TypeSquare } from '@/app/components/base/icons/src/vender/line/editor'
|
||||
|
||||
type TooltipProps = {
|
||||
data: number | string
|
||||
text: string
|
||||
}
|
||||
|
||||
const Tooltip: FC<TooltipProps> = ({
|
||||
data,
|
||||
text,
|
||||
}) => {
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='top-start'
|
||||
>
|
||||
<PortalToFollowElemTrigger
|
||||
onMouseEnter={() => setOpen(true)}
|
||||
onMouseLeave={() => setOpen(false)}
|
||||
>
|
||||
<div className='flex items-center mr-6'>
|
||||
<TypeSquare className='mr-1 w-3 h-3' />
|
||||
{data}
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent>
|
||||
<div className='p-3 bg-white text-xs font-medium text-gray-500 rounded-lg shadow-lg'>
|
||||
{text} {data}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default Tooltip
|
||||
@@ -8,32 +8,36 @@ import Tooltip from '@/app/components/base/tooltip'
|
||||
type ICopyBtnProps = {
|
||||
value: string
|
||||
className?: string
|
||||
isPlain?: boolean
|
||||
}
|
||||
|
||||
const CopyBtn = ({
|
||||
value,
|
||||
className,
|
||||
isPlain,
|
||||
}: ICopyBtnProps) => {
|
||||
const [isCopied, setIsCopied] = React.useState(false)
|
||||
|
||||
return (
|
||||
<div className={`${className}`}>
|
||||
<Tooltip
|
||||
selector="copy-btn-tooltip"
|
||||
selector={`copy-btn-tooltip-${value}`}
|
||||
content={(isCopied ? t('appApi.copied') : t('appApi.copy')) as string}
|
||||
className='z-10'
|
||||
>
|
||||
<div
|
||||
className={'box-border p-0.5 flex items-center justify-center rounded-md bg-white cursor-pointer'}
|
||||
style={{
|
||||
boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)',
|
||||
}}
|
||||
style={!isPlain
|
||||
? {
|
||||
boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)',
|
||||
}
|
||||
: {}}
|
||||
onClick={() => {
|
||||
copy(value)
|
||||
setIsCopied(true)
|
||||
}}
|
||||
>
|
||||
<div className={`w-6 h-6 hover:bg-gray-50 ${s.copyIcon} ${isCopied ? s.copied : ''}`}></div>
|
||||
<div className={`w-6 h-6 rounded-md hover:bg-gray-50 ${s.copyIcon} ${isCopied ? s.copied : ''}`}></div>
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { FC } from 'react'
|
||||
import type { FC, SVGProps } from 'react'
|
||||
import { HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline'
|
||||
|
||||
export const stopIcon = (
|
||||
@@ -7,7 +7,7 @@ export const stopIcon = (
|
||||
</svg>
|
||||
)
|
||||
|
||||
export const OpeningStatementIcon: FC<{ className?: string }> = ({ className }) => (
|
||||
export const OpeningStatementIcon = ({ className }: SVGProps<SVGElement>) => (
|
||||
<svg className={className} width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M6.25002 1C3.62667 1 1.50002 3.12665 1.50002 5.75C1.50002 6.28 1.58702 6.79071 1.7479 7.26801C1.7762 7.35196 1.79285 7.40164 1.80368 7.43828L1.80722 7.45061L1.80535 7.45452C1.79249 7.48102 1.77339 7.51661 1.73766 7.58274L0.911727 9.11152C0.860537 9.20622 0.807123 9.30503 0.770392 9.39095C0.733879 9.47635 0.674738 9.63304 0.703838 9.81878C0.737949 10.0365 0.866092 10.2282 1.05423 10.343C1.21474 10.4409 1.38213 10.4461 1.475 10.4451C1.56844 10.444 1.68015 10.4324 1.78723 10.4213L4.36472 10.1549C4.406 10.1506 4.42758 10.1484 4.44339 10.1472L4.44542 10.147L4.45161 10.1492C4.47103 10.1562 4.49738 10.1663 4.54285 10.1838C5.07332 10.3882 5.64921 10.5 6.25002 10.5C8.87338 10.5 11 8.37335 11 5.75C11 3.12665 8.87338 1 6.25002 1ZM4.48481 4.29111C5.04844 3.81548 5.7986 3.9552 6.24846 4.47463C6.69831 3.9552 7.43879 3.82048 8.01211 4.29111C8.58544 4.76175 8.6551 5.562 8.21247 6.12453C7.93825 6.47305 7.24997 7.10957 6.76594 7.54348C6.58814 7.70286 6.49924 7.78255 6.39255 7.81466C6.30103 7.84221 6.19589 7.84221 6.10436 7.81466C5.99767 7.78255 5.90878 7.70286 5.73098 7.54348C5.24694 7.10957 4.55867 6.47305 4.28444 6.12453C3.84182 5.562 3.92117 4.76675 4.48481 4.29111Z" fill="#667085" />
|
||||
</svg>
|
||||
@@ -17,13 +17,13 @@ export const RatingIcon: FC<{ isLike: boolean }> = ({ isLike }) => {
|
||||
return isLike ? <HandThumbUpIcon className='w-4 h-4' /> : <HandThumbDownIcon className='w-4 h-4' />
|
||||
}
|
||||
|
||||
export const EditIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
export const EditIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className}>
|
||||
<path d="M14 11.9998L13.3332 12.7292C12.9796 13.1159 12.5001 13.3332 12.0001 13.3332C11.5001 13.3332 11.0205 13.1159 10.6669 12.7292C10.3128 12.3432 9.83332 12.1265 9.33345 12.1265C8.83359 12.1265 8.35409 12.3432 7.99998 12.7292M2 13.3332H3.11636C3.44248 13.3332 3.60554 13.3332 3.75899 13.2963C3.89504 13.2637 4.0251 13.2098 4.1444 13.1367C4.27895 13.0542 4.39425 12.9389 4.62486 12.7083L13 4.33316C13.5523 3.78087 13.5523 2.88544 13 2.33316C12.4477 1.78087 11.5523 1.78087 11 2.33316L2.62484 10.7083C2.39424 10.9389 2.27894 11.0542 2.19648 11.1888C2.12338 11.3081 2.0695 11.4381 2.03684 11.5742C2 11.7276 2 11.8907 2 12.2168V13.3332Z" stroke="#6B7280" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</svg>
|
||||
}
|
||||
|
||||
export const EditIconSolid: FC<{ className?: string }> = ({ className }) => {
|
||||
export const EditIconSolid = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className}>
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M10.8374 8.63108C11.0412 8.81739 11.0554 9.13366 10.8691 9.33747L10.369 9.88449C10.0142 10.2725 9.52293 10.5001 9.00011 10.5001C8.47746 10.5001 7.98634 10.2727 7.63157 9.8849C7.45561 9.69325 7.22747 9.59515 7.00014 9.59515C6.77271 9.59515 6.54446 9.69335 6.36846 9.88517C6.18177 10.0886 5.86548 10.1023 5.66201 9.91556C5.45853 9.72888 5.44493 9.41259 5.63161 9.20911C5.98678 8.82201 6.47777 8.59515 7.00014 8.59515C7.52251 8.59515 8.0135 8.82201 8.36867 9.20911L8.36924 9.20974C8.54486 9.4018 8.77291 9.50012 9.00011 9.50012C9.2273 9.50012 9.45533 9.40182 9.63095 9.20979L10.131 8.66276C10.3173 8.45895 10.6336 8.44476 10.8374 8.63108Z" fill="#6B7280" />
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M7.89651 1.39656C8.50599 0.787085 9.49414 0.787084 10.1036 1.39656C10.7131 2.00604 10.7131 2.99419 10.1036 3.60367L3.82225 9.88504C3.81235 9.89494 3.80254 9.90476 3.79281 9.91451C3.64909 10.0585 3.52237 10.1855 3.3696 10.2791C3.23539 10.3613 3.08907 10.4219 2.93602 10.4587C2.7618 10.5005 2.58242 10.5003 2.37897 10.5001C2.3652 10.5001 2.35132 10.5001 2.33732 10.5001H1.50005C1.22391 10.5001 1.00005 10.2763 1.00005 10.0001V9.16286C1.00005 9.14886 1.00004 9.13497 1.00003 9.1212C0.999836 8.91776 0.999669 8.73838 1.0415 8.56416C1.07824 8.4111 1.13885 8.26479 1.22109 8.13058C1.31471 7.97781 1.44166 7.85109 1.58566 7.70736C1.5954 7.69764 1.60523 7.68783 1.61513 7.67793L7.89651 1.39656Z" fill="#6B7280" />
|
||||
|
||||
@@ -24,6 +24,7 @@ import type { DataSet } from '@/models/datasets'
|
||||
export type IChatProps = {
|
||||
configElem?: React.ReactNode
|
||||
chatList: IChatItem[]
|
||||
controlChatUpdateAllConversation?: number
|
||||
/**
|
||||
* Whether to display the editing area and rating status
|
||||
*/
|
||||
@@ -47,14 +48,17 @@ export type IChatProps = {
|
||||
isShowSuggestion?: boolean
|
||||
suggestionList?: string[]
|
||||
isShowSpeechToText?: boolean
|
||||
isShowCitation?: boolean
|
||||
answerIconClassName?: string
|
||||
isShowConfigElem?: boolean
|
||||
dataSets?: DataSet[]
|
||||
isShowCitationHitInfo?: boolean
|
||||
}
|
||||
|
||||
const Chat: FC<IChatProps> = ({
|
||||
configElem,
|
||||
chatList,
|
||||
|
||||
feedbackDisabled = false,
|
||||
isHideFeedbackEdit = false,
|
||||
isHideSendInput = false,
|
||||
@@ -72,9 +76,11 @@ const Chat: FC<IChatProps> = ({
|
||||
isShowSuggestion,
|
||||
suggestionList,
|
||||
isShowSpeechToText,
|
||||
isShowCitation,
|
||||
answerIconClassName,
|
||||
isShowConfigElem,
|
||||
dataSets,
|
||||
isShowCitationHitInfo,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
@@ -160,6 +166,7 @@ const Chat: FC<IChatProps> = ({
|
||||
if (item.isAnswer) {
|
||||
const isLast = item.id === chatList[chatList.length - 1].id
|
||||
const thoughts = item.agent_thoughts?.filter(item => item.thought !== '[DONE]')
|
||||
const citation = item.citation
|
||||
const isThinking = !item.content && item.agent_thoughts && item.agent_thoughts?.length > 0 && !item.agent_thoughts.some(item => item.thought === '[DONE]')
|
||||
return <Answer
|
||||
key={item.id}
|
||||
@@ -172,8 +179,11 @@ const Chat: FC<IChatProps> = ({
|
||||
isResponsing={isResponsing && isLast}
|
||||
answerIconClassName={answerIconClassName}
|
||||
thoughts={thoughts}
|
||||
citation={citation}
|
||||
isThinking={isThinking}
|
||||
dataSets={dataSets}
|
||||
isShowCitation={isShowCitation}
|
||||
isShowCitationHitInfo={isShowCitationHitInfo}
|
||||
/>
|
||||
}
|
||||
return <Question key={item.id} id={item.id} content={item.content} more={item.more} useCurrentUserAvatar={useCurrentUserAvatar} />
|
||||
|
||||
@@ -37,12 +37,13 @@ const Thought: FC<IThoughtProps> = ({
|
||||
const getThoughtText = (item: ThoughtItem) => {
|
||||
try {
|
||||
const input = JSON.parse(item.tool_input)
|
||||
|
||||
// dataset
|
||||
if (item.tool.startsWith('dataset-')) {
|
||||
const dataSetId = item.tool.replace('dataset-', '')
|
||||
const datasetName = dataSets?.find(item => item.id === dataSetId)?.name || 'unknown dataset'
|
||||
return t('explore.universalChat.thought.res.dataset').replace('{datasetName}', `<span class="text-gray-700">${datasetName}</span>`)
|
||||
}
|
||||
switch (item.tool) {
|
||||
case 'dataset':
|
||||
// eslint-disable-next-line no-case-declarations
|
||||
const datasetName = dataSets?.find(item => item.id === input.dataset_id)?.name || 'unknown dataset'
|
||||
return t('explore.universalChat.thought.res.dataset').replace('{datasetName}', `<span class="text-gray-700">${datasetName}</span>`)
|
||||
case 'web_reader':
|
||||
return t(`explore.universalChat.thought.res.webReader.${!input.cursor ? 'normal' : 'hasPageInfo'}`).replace('{url}', `<a href="${input.url}" class="text-[#155EEF]">${input.url}</a>`)
|
||||
case 'google_search':
|
||||
|
||||
@@ -23,10 +23,26 @@ export type ThoughtItem = {
|
||||
tool_input: string
|
||||
message_id: string
|
||||
}
|
||||
export type CitationItem = {
|
||||
content: string
|
||||
data_source_type: string
|
||||
dataset_name: string
|
||||
dataset_id: string
|
||||
document_id: string
|
||||
document_name: string
|
||||
hit_count: number
|
||||
index_node_hash: string
|
||||
segment_id: string
|
||||
segment_position: number
|
||||
score: number
|
||||
word_count: number
|
||||
}
|
||||
|
||||
export type IChatItem = {
|
||||
id: string
|
||||
content: string
|
||||
agent_thoughts?: ThoughtItem[]
|
||||
citation?: CitationItem[]
|
||||
/**
|
||||
* Specific message type
|
||||
*/
|
||||
@@ -51,3 +67,8 @@ export type IChatItem = {
|
||||
useCurrentUserAvatar?: boolean
|
||||
isOpeningStatement?: boolean
|
||||
}
|
||||
|
||||
export type MessageEnd = {
|
||||
id: string
|
||||
retriever_resources?: CitationItem[]
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import OperationBtn from '../base/operation-btn'
|
||||
import VarIcon from '../base/icons/var-icon'
|
||||
import EditModel from './config-model'
|
||||
import IconTypeIcon from './input-type-icon'
|
||||
import type { IInputTypeIconProps } from './input-type-icon'
|
||||
import s from './style.module.css'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import type { PromptVariable } from '@/models/debug'
|
||||
@@ -37,8 +38,8 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
|
||||
return obj
|
||||
})()
|
||||
|
||||
const updatePromptVariable = (key: string, updateKey: string, newValue: any) => {
|
||||
const newPromptVariables = promptVariables.map((item, i) => {
|
||||
const updatePromptVariable = (key: string, updateKey: string, newValue: string | boolean) => {
|
||||
const newPromptVariables = promptVariables.map((item) => {
|
||||
if (item.key === key) {
|
||||
return {
|
||||
...item,
|
||||
@@ -179,7 +180,7 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
|
||||
<tr key={index} className="h-9 leading-9">
|
||||
<td className="w-[160px] border-b border-gray-100 pl-3">
|
||||
<div className='flex items-center space-x-1'>
|
||||
<IconTypeIcon type={type} />
|
||||
<IconTypeIcon type={type as IInputTypeIconProps['type']} />
|
||||
{!readonly
|
||||
? (
|
||||
<input
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import React from 'react'
|
||||
import type { FC } from 'react'
|
||||
|
||||
type IInputTypeIconProps = {
|
||||
export type IInputTypeIconProps = {
|
||||
type: 'string' | 'select'
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 175 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user