Compare commits

...

69 Commits

Author SHA1 Message Date
Garfield Dai
b8db580833 feat: hugging face supports embeddings. 2023-09-19 21:08:37 +08:00
StyleZhang
9d3ba98d07 fix: frontend huggingface embedding hide task type 2023-09-19 11:02:54 +08:00
StyleZhang
8b2573efab feat: frontend add huggingface embedding support 2023-09-19 10:46:32 +08:00
Garfield Dai
757ee4d39f feat: hugging face supports embeddings. 2023-09-19 10:34:19 +08:00
takatost
435f804c6f fix: gpt-3.5-turbo-instruct context size to 8192 (#1196) 2023-09-19 02:10:22 +08:00
takatost
ae3f1ac0a9 feat: support gpt-3.5-turbo-instruct model (#1195) 2023-09-19 02:05:04 +08:00
Jyong
269a465fc4 Feat/improve vector database logic (#1193)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-18 18:15:41 +08:00
zxhlyh
60e0bbd713 Feat/provider add zhipuai (#1192)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-09-18 18:02:05 +08:00
takatost
827c97f0d3 feat: add zhipuai (#1188) 2023-09-18 17:32:31 +08:00
takatost
c8bd76cd66 fix: inference embedding validate (#1187) 2023-09-16 03:09:36 +08:00
crazywoola
ec5f585df4 1111 wrong embedding model displayed in datasets (#1186) 2023-09-15 07:54:45 -05:00
Rhon Joe
1de48f33ca feat(web): service request return generics type (#1157) 2023-09-15 07:54:20 -05:00
Joel
6b41a9593e fix: text error (#1184) 2023-09-15 14:15:28 +08:00
Joel
82267083e8 fix: model param description error (#1183) 2023-09-15 11:36:01 +08:00
Joel
c385961d33 chore: Optimization model parameter description (#1181) 2023-09-15 11:14:14 +08:00
charli117
20bab6edec Restore the application template (#1174)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 08:28:32 -05:00
charli117
67bed54f32 Mermaid front end rendering (#1166)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 14:09:23 +08:00
leo
562a571281 fix: Improved fallback solution for avatar image loading failure (#1172) 2023-09-14 13:31:35 +08:00
Matri
fc68c81791 fix: correct invite url (#1173) 2023-09-14 12:07:34 +08:00
Jyong
5d9070bc60 Feat/add blocking mode resource return (#1171)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-13 18:53:35 +08:00
crazywoola
b11fb0dfd1 fix LocalAI is missing in lang/en (#1169) 2023-09-13 10:08:33 +08:00
crazywoola
d1c5c5f160 add video to cn readme (#1165) 2023-09-12 08:30:12 -05:00
crazywoola
0b1d1440aa Update README.md (#1164) 2023-09-12 07:48:35 -05:00
Joel
0c420d64b3 chore: hover conversation show option button (#1160) 2023-09-12 16:35:13 +08:00
takatost
f9082104ed feat: add hosted moderation (#1158) 2023-09-12 10:26:12 +08:00
takatost
983834cd52 feat: spark check (#1134) 2023-09-11 17:31:03 +08:00
zxhlyh
96d10c8b39 feat: spark free quota verify (#1152) 2023-09-11 17:30:54 +08:00
takatost
24cb992843 feat: bump version to 0.3.22 (#1153) 2023-09-11 12:04:06 +08:00
crazywoola
7907c0bf58 Update bug_report.yml (#1151) 2023-09-11 10:48:37 +08:00
crazywoola
ebf4fd9a09 Update issue template (#1150) 2023-09-11 10:45:10 +08:00
Rhon Joe
38b9901274 fix(web): complete some ts type (#1148) 2023-09-11 09:30:17 +08:00
Jyong
642842d61b Feat:dataset retiever resource (#1123)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-09-10 15:17:43 +08:00
KVOJJJin
e161c511af Feat:csv & docx support (#1139)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-10 15:17:22 +08:00
takatost
f29e82685e feat: bump version to 0.3.21 (#1145) 2023-09-10 12:34:54 +08:00
takatost
3a5ae96e7b fix: TRANSFORMERS_OFFLINE orders in Dockerfile (#1144) 2023-09-10 12:26:13 +08:00
takatost
b63a685386 feat: set transformers offline default true (#1143) 2023-09-10 12:20:58 +08:00
takatost
877da82b06 feat: cache huggingface gpt2 tokenizer files (#1138) 2023-09-10 12:16:21 +08:00
takatost
6637629045 fix: remove the deprecated depends_on.condition format (#1142) 2023-09-10 12:07:20 +08:00
Joel
e925b6c572 fix: log page compatible old query (#1141) 2023-09-10 11:29:25 +08:00
Joel
5412f4aba5 fix: in log page not show user query (#1140) 2023-09-10 09:30:30 +08:00
Joel
2d5ad0d208 feat: support optional query content (#1097)
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
2023-09-10 00:12:34 +08:00
takatost
1ade70aa1e feat: bump version to 0.3.20 (#1135) 2023-09-09 23:47:14 +08:00
takatost
2658c4d57b fix: answer returned null when response_mode was blocking (#1133) 2023-09-09 23:22:21 +08:00
zxhlyh
84c76bc04a Feat/chat add origin (#1130) 2023-09-09 19:17:12 +08:00
takatost
6effcd3755 feat: optimize celery start cmd (#1129) 2023-09-09 13:48:29 +08:00
李锐东
d9866489f0 feat: add health check and depend condition in docker compose (#1113) 2023-09-09 13:47:08 +08:00
takatost
c4d8bdc3db fix: hf hosted inference check (#1128) 2023-09-09 00:29:48 +08:00
Joel
681eb1cfcc fix: click inner link no jump (#1118) 2023-09-08 10:21:42 +08:00
Matri
a5d21f3b09 fix: shortening invite url (#1100)
Co-authored-by: MatriQi <matri@aifi.io>
2023-09-07 17:15:57 +08:00
Joel
7ba068c3e4 fix: self host embedding missing base url config (#1116) 2023-09-07 14:56:38 +08:00
bowen
b201eeedbd fix: optimize styles (#1112) 2023-09-07 14:24:09 +08:00
Rhon Joe
f28cb84977 fix(web): fix AppCard Menu popover open bug (#1107) 2023-09-07 09:47:31 +08:00
Joel
714872cd58 chore: enchancment frontend readme (#1110) 2023-09-07 09:43:24 +08:00
Joel
0708bd60ee fix: try to fix chunk load error (#1109) 2023-09-06 15:47:53 +08:00
Joel
23a6c85b80 chore: handle workspace apps scrollbar (#1101) 2023-09-05 15:56:21 +08:00
bowen
4a28599fbd fix: optimize feedback and app icon (#1099) 2023-09-05 09:13:59 +08:00
seewhy
7c66d3c793 feat: Optimize the description for Azure deployment name (#1091) 2023-09-04 14:26:22 +08:00
Joel
cc9edfffd8 fix: markdown code lang capitalization and line number color (#1098) 2023-09-04 11:31:25 +08:00
Joel
6fa2454c9a fix: change frontend start script (#1096) 2023-09-04 11:10:32 +08:00
crazywoola
487e699021 fix: ui in chat openning statement (#1094) 2023-09-04 10:26:46 +08:00
takatost
a7cdb745c1 feat: support spark v2 validate (#1086) 2023-09-01 20:53:32 +08:00
takatost
73c86ee6a0 fix: prompt of title generation (#1084) 2023-09-01 14:55:58 +08:00
takatost
48eb590065 feat: optimize last_active_at update (#1083) 2023-09-01 13:58:26 +08:00
takatost
33562a9d8d feat: optimize prompt (#1080) 2023-09-01 11:46:06 +08:00
Rhon Joe
c9194ba382 chore(api): api image multistage build (#1069) 2023-09-01 11:13:22 +08:00
takatost
a199fa6388 feat: optimize high load sql query of document segment (#1078) 2023-09-01 10:52:39 +08:00
takatost
4c8608dc61 feat: optimize conversation title generation output must be a valid JSON (#1077) 2023-09-01 10:31:42 +08:00
Garfield Dai
a6b0f788e7 feat: add visual studio code debug config. (#1068)
Co-authored-by: Keruberosu <631677014@qq.com>
2023-09-01 09:15:06 +08:00
takatost
df6604a734 feat: optimize generation of conversation title (#1075) 2023-09-01 02:28:37 +08:00
291 changed files with 7939 additions and 1792 deletions

49
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
name: "🕷️ Bug report"
description: Report errors or unexpected behavior
labels:
- bug
body:
- type: markdown
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
- type: input
attributes:
label: Dify version
placeholder: 0.3.21
description: See about section in Dify console
validations:
required: true
- type: dropdown
attributes:
label: Cloud or Self Hosted
description: How / Where was Dify installed from?
multiple: true
options:
- Cloud
- Self Hosted
- Other (please specify in "Steps to Reproduce")
validations:
required: true
- type: textarea
attributes:
label: Steps to reproduce
description: We highly suggest including screenshots and a bug report log.
placeholder: Having detailed steps helps us reproduce the bug.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: What were you expecting?
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: What happened instead?
validations:
required: false

8
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
blank_issues_enabled: false
contact_links:
- name: "\U0001F4DA Dify user documentation"
url: https://docs.dify.ai/getting-started/readme
about: Documentation for users of Dify
- name: "\U0001F4DA Dify dev documentation"
url: https://docs.dify.ai/getting-started/install-self-hosted
about: Documentation for people interested in developing and contributing for Dify

View File

@@ -0,0 +1,11 @@
name: "📚 Documentation Issue"
description: Report issues in our documentation
labels:
- ducumentation
body:
- type: textarea
attributes:
label: Provide a description of requested docs changes
placeholder: Briefly describe which document needs to be corrected and why.
validations:
required: true

View File

@@ -0,0 +1,26 @@
name: "⭐ Feature or enhancement request"
description: Propose something new.
labels:
- enhancement
body:
- type: textarea
attributes:
label: Description of the new feature / enhancement
placeholder: What is the expected behavior of the proposed feature?
validations:
required: true
- type: textarea
attributes:
label: Scenario when this would be used?
placeholder: What is the scenario this would be used? Why is this important to your workflow as a dify user?
validations:
required: true
- type: textarea
attributes:
label: Supporting information
placeholder: "Having additional evidence, data, tweets, blog posts, research, ... anything is extremely helpful. This information provides context to the scenario that may otherwise be lost."
validations:
required: false
- type: markdown
attributes:
value: Please limit one request per issue.

View File

@@ -0,0 +1,46 @@
name: "🌐 Localization/Translation issue"
description: Report incorrect translations.
labels:
- translation
body:
- type: markdown
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
- type: input
attributes:
label: Dify version
placeholder: 0.3.21
description: Hover over system tray icon or look at Settings
validations:
required: true
- type: input
attributes:
label: Utility with translation issue
placeholder: Some area
description: Please input here the utility with the translation issue
validations:
required: true
- type: input
attributes:
label: 🌐 Language affected
placeholder: "German"
validations:
required: true
- type: textarea
attributes:
label: ❌ Actual phrase(s)
placeholder: What is there? Please include a screenshot as that is extremely helpful.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected phrase(s)
placeholder: What was expected?
validations:
required: true
- type: textarea
attributes:
label: Why is the current translation wrong
placeholder: Why do you feel this is incorrect?
validations:
required: true

View File

@@ -1,32 +0,0 @@
---
name: "\U0001F41B Bug report"
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
<!--
Please provide a clear and concise description of what the bug is. Include
screenshots if needed. Please test using the latest version of the relevant
Dify packages to make sure your issue has not already been fixed.
-->
Dify version: Cloud | Self Host
## Steps To Reproduce
<!--
Your bug will get fixed much faster if we can run your code and it doesn't
have dependencies other than Dify. Issues without reproduction steps or
code examples may be immediately closed as not actionable.
-->
1.
2.
## The current behavior
## The expected behavior

View File

@@ -1,20 +0,0 @@
---
name: "\U0001F680 Feature request"
about: Suggest an idea for this project
title: ''
labels: enhancement
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -1,10 +0,0 @@
---
name: "\U0001F914 Questions and Help"
about: Ask a usage or consultation question
title: ''
labels: ''
assignees: ''
---

View File

@@ -20,7 +20,8 @@ def check_file_for_chinese_comments(file_path):
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
'prompts.py']
for root, _, files in os.walk("."):
for file in files:

3
.gitignore vendored
View File

@@ -149,4 +149,5 @@ sdks/python-client/build
sdks/python-client/dist
sdks/python-client/dify_client.egg-info
.vscode/
.vscode/*
!.vscode/launch.json

27
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,27 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Flask",
"type": "python",
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "api/app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"args": [
"run",
"--host=0.0.0.0",
"--port=5001",
"--debug"
],
"jinja": true,
"justMyCode": true
}
]
}

View File

@@ -16,6 +16,10 @@ Out-of-the-box web sites supporting form mode and chat conversation mode
A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort
Visual data analysis, log review, and annotation for applications
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## Highlighted Features
**1. LLMs support:** Choose capabilities based on different models when building your Dify AI apps. Dify is compatible with Langchain, meaning it will support various LLMs. Currently supported:

View File

@@ -17,7 +17,7 @@
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
- 可视化的对应用进行数据分析,查阅日志或进行标注
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## 核心能力
1. **模型支持:** 你可以在 Dify 上选择基于不同模型的能力来开发你的 AI 应用。Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:

View File

@@ -1,7 +1,18 @@
FROM python:3.10-slim
# packages install stage
FROM python:3.10-slim AS base
LABEL maintainer="takatost@gmail.com"
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ python3-dev libc-dev libffi-dev
COPY requirements.txt /requirements.txt
RUN pip install --prefix=/pkg -r requirements.txt
# build stage
FROM python:3.10-slim AS builder
ENV FLASK_APP app.py
ENV EDITION SELF_HOSTED
ENV DEPLOY_ENV PRODUCTION
@@ -15,15 +26,17 @@ EXPOSE 5001
WORKDIR /app/api
RUN apt-get update && \
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
COPY requirements.txt /app/api/requirements.txt
RUN pip install -r requirements.txt
RUN apt-get update \
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs \
&& apt-get autoremove \
&& rm -rf /var/lib/apt/lists/*
COPY --from=base /pkg /usr/local
COPY . /app/api/
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
ENV TRANSFORMERS_OFFLINE true
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh

View File

@@ -52,7 +52,7 @@
flask run --host 0.0.0.0 --port=5001 --debug
```
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
8. If you need to debug local async processing, you can run `celery -A app.celery worker -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
8. Start frontend

View File

@@ -1,6 +1,6 @@
# -*- coding:utf-8 -*-
import os
from datetime import datetime
from datetime import datetime, timedelta
from werkzeug.exceptions import Forbidden
@@ -145,8 +145,12 @@ def load_user(user_id):
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
account.last_active_at = datetime.utcnow()
db.session.commit()
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)

View File

@@ -4,8 +4,10 @@ import math
import random
import string
import time
import uuid
import click
from tqdm import tqdm
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
@@ -21,9 +23,9 @@ from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant
from models.dataset import Dataset, DatasetQuery, Document
from models.model import Account
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
import secrets
import base64
@@ -238,7 +240,13 @@ def clean_unused_dataset_indexes():
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete()
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete()
# update document
update_params = {
@@ -345,7 +353,8 @@ def create_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@@ -363,7 +372,8 @@ def create_qdrant_indexes():
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
@@ -372,7 +382,8 @@ def create_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@@ -413,7 +424,8 @@ def update_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@@ -434,11 +446,207 @@ def update_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size):
pre_prompt_template = '{{default_input}}'
user_input_form_template = {
"en-US": [
{
"paragraph": {
"label": "Query",
"variable": "default_input",
"required": False,
"default": ""
}
}
],
"zh-Hans": [
{
"paragraph": {
"label": "查询内容",
"variable": "default_input",
"required": False,
"default": ""
}
}
]
}
click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
try:
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
for data in data_batch:
# click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
if data.pre_prompt is None:
data.pre_prompt = pre_prompt_template
else:
if pre_prompt_template in data.pre_prompt:
continue
data.pre_prompt += pre_prompt_template
app_data = db.session.query(App) \
.filter(App.id == data.app_id) \
.one()
account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \
.filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
.one_or_none()
if not account_data:
continue
if data.user_input_form is None or data.user_input_form == 'null':
data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
else:
raw_json_data = json.loads(data.user_input_form)
raw_json_data.append(user_input_form_template[account_data.interface_language][0])
data.user_input_form = json.dumps(raw_json_data)
# click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@@ -448,4 +656,6 @@ def register_commands(app):
app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)

View File

@@ -61,6 +61,8 @@ DEFAULTS = {
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
@@ -100,7 +102,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.19"
self.CURRENT_VERSION = "0.3.22"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -230,6 +232,9 @@ class Config:
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@@ -16,7 +16,7 @@ model_templates = {
},
'model_config': {
'provider': 'openai',
'model_id': 'text-davinci-003',
'model_id': 'gpt-3.5-turbo-instruct',
'configs': {
'prompt_template': '',
'prompt_variables': [],
@@ -30,7 +30,7 @@ model_templates = {
},
'model': json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -38,7 +38,18 @@ model_templates = {
"presence_penalty": 0,
"frequency_penalty": 0
}
})
}),
'user_input_form': json.dumps([
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
]),
'pre_prompt': '{{query}}'
}
},
@@ -93,7 +104,7 @@ demo_model_templates = {
'mode': 'completion',
'model_config': AppModelConfig(
provider='openai',
model_id='text-davinci-003',
model_id='gpt-3.5-turbo-instruct',
configs={
'prompt_template': "Please translate the following text into {{target_language}}:\n",
'prompt_variables': [
@@ -129,7 +140,7 @@ demo_model_templates = {
pre_prompt="Please translate the following text into {{target_language}}:\n",
model=json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -211,7 +222,7 @@ demo_model_templates = {
'mode': 'completion',
'model_config': AppModelConfig(
provider='openai',
model_id='text-davinci-003',
model_id='gpt-3.5-turbo-instruct',
configs={
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
'prompt_variables': [
@@ -247,7 +258,7 @@ demo_model_templates = {
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
model=json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,

View File

@@ -29,6 +29,7 @@ model_config_fields = {
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),

View File

@@ -39,9 +39,10 @@ class CompletionMessageApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
@@ -115,6 +116,7 @@ class ChatMessageApi(Resource):
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'

View File

@@ -16,26 +16,25 @@ from services.account_service import RegisterService
class ActivateCheckApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args')
parser.add_argument('email', type=email, required=True, nullable=False, location='args')
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
parser.add_argument('email', type=email, required=False, nullable=True, location='args')
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
workspaceId = args['workspace_id']
reg_email = args['email']
token = args['token']
tenant = db.session.query(Tenant).filter(
Tenant.id == args['workspace_id'],
Tenant.status == 'normal'
).first()
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
return {'is_valid': account is not None, 'workspace_name': tenant.name}
return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
class ActivateApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json')
parser.add_argument('email', type=email, required=True, nullable=False, location='json')
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
parser.add_argument('email', type=email, required=False, nullable=True, location='json')
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
@@ -44,12 +43,13 @@ class ActivateApi(Resource):
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
if account is None:
invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
account = invitation['account']
account.name = args['name']
# generate password salt

View File

@@ -26,7 +26,7 @@ from models.model import UploadFile
cache = TTLCache(maxsize=None, ttl=30)
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx']
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
PREVIEW_WORDS_LIMIT = 3000

View File

@@ -31,8 +31,9 @@ class CompletionApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
@@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource):
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'

View File

@@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource):
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
@@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource):
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}

View File

@@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -29,6 +29,7 @@ class UniversalChatApi(UniversalChatResource):
parser.add_argument('provider', type=str, required=True, location='json')
parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
args = parser.parse_args()
app_model_config = app_model.app_model_config

View File

@@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource):
'created_at': TimestampField
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
@@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource):
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
}

View File

@@ -1,4 +1,6 @@
# -*- coding:utf-8 -*-
import json
from flask_restful import marshal_with, fields
from controllers.console import api
@@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
}
@marshal_with(parameters_fields)
@@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource):
"""Retrieve app parameters."""
app_model = universal_app
app_model_config = app_model.app_model_config
app_model_config.retriever_resource = json.dumps({'enabled': True})
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
}

View File

@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
suggested_questions=json.dumps([]),
suggested_questions_after_answer=json.dumps({'enabled': True}),
speech_to_text=json.dumps({'enabled': True}),
retriever_resource=json.dumps({'enabled': True}),
more_like_this=None,
sensitive_word_avoidance=None,
model=json.dumps({

View File

@@ -72,7 +72,7 @@ class MemberInviteEmailApi(Resource):
invitation_results.append({
'status': 'success',
'email': invitee_email,
'url': f'{console_web_url}/activate?workspace_id={current_user.current_tenant_id}&email={invitee_email}&token={token}'
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
})
account = marshal(account, account_fields)
account['role'] = role

View File

@@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default
'default': v.default,
'precision': v.precision
}
for k, v in vars(parameter_rules).items()
}
@@ -285,6 +286,25 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
return result
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()
provider_service = ProviderService()
result = provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
token=args['token']
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
@@ -300,3 +320,5 @@ api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')

View File

@@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -27,9 +27,11 @@ class CompletionApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
@@ -91,6 +93,8 @@ class ChatApi(AppApiResource):
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'

View File

@@ -16,6 +16,24 @@ class MessageListApi(AppApiResource):
feedback_fields = {
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
@@ -24,6 +42,7 @@ class MessageListApi(AppApiResource):
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}

View File

@@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -29,8 +29,10 @@ class CompletionApi(WebApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
@@ -88,6 +90,8 @@ class ChatApi(WebApiResource):
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'

View File

@@ -29,6 +29,25 @@ class MessageListApi(WebApiResource):
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
@@ -36,6 +55,7 @@ class MessageListApi(WebApiResource):
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}

View File

@@ -1,3 +1,4 @@
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:

View File

@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -116,6 +118,18 @@ class AgentExecutor:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
query
)
if not moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
@@ -128,7 +142,9 @@ class AgentExecutor:
try:
output = agent_executor.run(query)
except Exception:
except LLMError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")
output = None

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_instant = model_instant
self.model_instance = model_instance
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)
@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)

View File

@@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
# kwargs={'name': 'Search'}
# llm_prefix='Thought:'
# observation_prefix='Observation: '
# output='53 years'
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:

View File

@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0
latency: float = 0.0

View File

@@ -2,6 +2,7 @@ from typing import List
from langchain.schema import Document
from core.conversation_message_task import ConversationMessageTask
from extensions.ext_database import db
from models.dataset import DocumentSegment
@@ -9,8 +10,9 @@ from models.dataset import DocumentSegment
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None:
def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
self.dataset_id = dataset_id
self.conversation_message_task = conversation_message_task
def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end."""
@@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler:
)
db.session.commit()
def return_retriever_resource_info(self, resource: List):
"""Handle return_retriever_resource_info."""
self.conversation_message_task.on_dataset_query_finish(resource)

View File

@@ -1,5 +1,4 @@
import logging
import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
@@ -63,14 +59,22 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
@@ -89,8 +93,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)

View File

@@ -1,15 +1,33 @@
import enum
import logging
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
sensitive_words: List[str] = []
canned_response: str = None
model_instance: BaseLLM
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property
def _chain_type(self) -> str:
@@ -31,11 +49,24 @@ class SensitiveWordAvoidanceChain(Chain):
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> str:
for word in self.sensitive_words:
def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text:
return self.canned_response
return text
return False
return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call(
self,
@@ -43,5 +74,19 @@ class SensitiveWordAvoidanceChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@@ -1,31 +1,32 @@
import json
import logging
import re
from typing import Optional, List, Union, Tuple
from typing import Optional, List, Union
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
is_override: bool = False, retriever_from: str = 'dev'):
"""
errors: ProviderTokenNotInitError
"""
@@ -76,29 +77,53 @@ class Completion:
app_model_config=app_model_config
)
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# run the final llm
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
retriever_from=retriever_from
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# run the final llm
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
@@ -107,7 +132,8 @@ class Completion:
inputs=inputs,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
return
@@ -118,17 +144,12 @@ class Completion:
return
@classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
@@ -150,7 +171,6 @@ class Completion:
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
return response
@classmethod

View File

@@ -1,6 +1,6 @@
import decimal
import json
from typing import Optional, Union
import time
from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
@@ -15,13 +15,16 @@ from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.start_at = time.perf_counter()
self.task_id = task_id
self.app = app
@@ -41,6 +44,8 @@ class ConversationMessageTask:
self.message = None
self.retriever_resource = None
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name')
@@ -58,6 +63,7 @@ class ConversationMessageTask:
)
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = self.app_model_config.to_dict()
@@ -157,11 +163,12 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer = PromptBuilder.process_template(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency
self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price
db.session.commit()
@@ -216,18 +223,18 @@ class ConversationMessageTask:
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
@@ -241,7 +248,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instant.get_currency()
message_agent_thought.currency = agent_model_instance.get_currency()
db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
@@ -256,7 +263,36 @@ class ConversationMessageTask:
db.session.add(dataset_query)
def on_dataset_query_finish(self, resource: List):
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self.message.id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self.user.id
)
db.session.add(dataset_retriever_resource)
db.session.flush()
self.retriever_resource = resource
def message_end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
def end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end()
@@ -350,6 +386,23 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_end(self, retriever_resource: List):
content = {
'event': 'message_end',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
if retriever_resource:
content['data']['retriever_resources'] = retriever_resource
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {

View File

@@ -1,3 +1,4 @@
import json
import logging
from langchain.schema import OutputParserException
@@ -22,18 +23,25 @@ class LLMGenerator:
if len(query) > 2000:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
query = query.replace("\n", " ")
prompt += query + "\n"
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=50
temperature=1,
max_tokens=100
)
)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
result_dict = json.loads(answer)
answer = result_dict['Your Output']
return answer.strip()
@classmethod

View File

@@ -0,0 +1,34 @@
import logging
import openai
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
from models.provider import ProviderType
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in hosted_config.moderation.providers:
# 2000 text per chunk
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for text_chunk in chunks:
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@@ -16,6 +16,10 @@ class BaseIndex(ABC):
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@@ -28,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError

View File

@@ -25,7 +25,33 @@ class KeywordTableIndex(BaseIndex):
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
@@ -52,7 +78,7 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
@@ -74,7 +100,7 @@ class KeywordTableIndex(BaseIndex):
DocumentSegment.document_id == document_id
).all()
ids = [segment.id for segment in segments]
ids = [segment.index_node_id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
@@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table)
db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
@@ -199,15 +231,18 @@ class KeywordTableIndex(BaseIndex):
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: List[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(node_id, keywords)
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)

View File

@@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
from models.dataset import Document as DatasetDocument
@@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@@ -28,6 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@@ -84,6 +85,7 @@ class Qdrant(VectorStore):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR_NAME = None
def __init__(
@@ -93,9 +95,12 @@ class Qdrant(VectorStore):
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated
is_new_collection: bool = False
):
"""Initialize with necessary components."""
try:
@@ -129,7 +134,10 @@ class Qdrant(VectorStore):
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.group_payload_key = group_payload_key or self.GROUP_KEY
self.vector_name = vector_name or self.VECTOR_NAME
self.group_id = group_id
self.is_new_collection= is_new_collection
if embedding_function is not None:
warnings.warn(
@@ -170,6 +178,8 @@ class Qdrant(VectorStore):
batch_size:
How many vectors upload per-request.
Default: 64
group_id:
collection group
Returns:
List of ids from adding the texts into the vectorstore.
@@ -182,7 +192,11 @@ class Qdrant(VectorStore):
collection_name=self.collection_name, points=points, **kwargs
)
added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
return added_ids
@sync_call_fallback
@@ -970,6 +984,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
@@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name:
Name of the vector to be used internally in Qdrant.
Default: None
@@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
distance_func,
content_payload_key,
metadata_payload_key,
group_payload_key,
group_id,
vector_name,
shard_number,
replication_factor,
@@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
@@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
is_new_collection = False
client = qdrant_client.QdrantClient(
location=location,
url=url,
@@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
qdrant = cls(
client=client,
collection_name=collection_name,
@@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
group_id=group_id,
group_payload_key=group_payload_key,
is_new_collection=is_new_collection
)
return qdrant
@@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
@@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
@@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
else:
out.append(
rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}",
key=key,
match=rest.MatchValue(value=value),
)
)
@@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
@@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
self.group_id,
self.group_payload_key
),
)
]

View File

@@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
@@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
@@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
@@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
@@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content'
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@@ -113,6 +138,38 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
for node_id in ids:
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

View File

@@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
@@ -180,7 +181,7 @@ class ModelFactory:
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
model_name: str) -> Optional[BaseModeration]:
"""
get moderation model.

View File

@@ -45,6 +45,9 @@ class ModelProviderFactory:
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'zhipuai':
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
return ZhipuAIProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider

View File

@@ -0,0 +1,24 @@
from replicate.exceptions import ModelError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class HuggingfaceEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = HuggingfaceHubEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface embedding: {str(ex)}")

View File

@@ -0,0 +1,22 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
class ZhipuAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ZhipuAIEmbeddings(
model=name,
**credentials,
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")

View File

@@ -8,6 +8,7 @@ class LLMRunResult(BaseModel):
content: str
prompt_tokens: int
completion_tokens: int
source: list = None
class MessageType(enum.Enum):

View File

@@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
precision: Optional[int] = None
class ModelKwargsRules(BaseModel):

View File

@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
:param callbacks:
:return:
"""
moderation_result = moderation.check_moderation(
self.model_provider,
"\n".join([message.content for message in messages])
)
if not moderation_result:
kwargs['fake_response'] = "I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest."
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
@@ -342,7 +352,7 @@ class BaseLLM(BaseProviderModel):
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
prompt += pre_prompt_content
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'

View File

@@ -1,6 +1,5 @@
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
class HuggingfaceHubModel(BaseLLM):
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
streaming=streaming
)
else:
client = HuggingFaceHub(
client = HuggingFaceHubLLM(
repo_id=self.name,
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
if 'baichuan' in self.name.lower():
return False
return True
return True
else:
return False

View File

@@ -17,6 +17,7 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
from models.provider import ProviderType, ProviderQuotaType
COMPLETION_MODELS = [
'gpt-3.5-turbo-instruct', # 4,096 tokens
'text-davinci-003', # 4,097 tokens
]
@@ -31,6 +32,7 @@ MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}

View File

@@ -0,0 +1,61 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
class ZhipuAIModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@@ -0,0 +1,29 @@
from abc import abstractmethod
from typing import Any
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseModeration(BaseProviderModel):
name: str
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def run(self, text: str) -> bool:
try:
return self._run(text)
except Exception as ex:
raise self.handle_exceptions(ex)
@abstractmethod
def _run(self, text: str) -> bool:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@@ -4,29 +4,39 @@ import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.providers.base import BaseModelProvider
DEFAULT_AUDIO_MODEL = 'whisper-1'
DEFAULT_MODEL = 'whisper-1'
class OpenAIModeration(BaseProviderModel):
type: ModelType = ModelType.MODERATION
class OpenAIModeration(BaseModeration):
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Moderation)
super().__init__(model_provider, openai.Moderation, name)
def run(self, text):
def _run(self, text: str) -> bool:
credentials = self.model_provider.get_model_credentials(
model_name=DEFAULT_AUDIO_MODEL,
model_name=self.name,
model_type=self.type
)
try:
return self._client.create(input=text, api_key=credentials['openai_api_key'])
except Exception as ex:
raise self.handle_exceptions(ex)
# 2000 text per chunk
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for text_chunk in chunks:
moderation_result = self._client.create(input=text_chunk,
api_key=credentials['openai_api_key'])
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):

View File

@@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
)
@classmethod

View File

@@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
model_credentials = self.get_model_credentials(model_name, model_type)
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
model_credentials['base_model_name'],
4097
), default=16),
), default=16, precision=0),
)
@classmethod

View File

@@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
)
@classmethod

View File

@@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
hosted_model_providers = HostedModelProviders()
class HostedModerationConfig(BaseModel):
enabled: bool = False
providers: list[str] = []
class HostedConfig(BaseModel):
moderation = HostedModerationConfig()
hosted_config = HostedConfig()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
@@ -78,3 +90,9 @@ def init_app(app: Flask):
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
hosted_config.moderation = HostedModerationConfig(
enabled=app.config.get("HOSTED_MODERATION_ENABLED"),
providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',')
)

View File

@@ -10,6 +10,8 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from models.provider import ProviderType
@@ -33,6 +35,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = HuggingfaceHubModel
elif model_type == ModelType.EMBEDDINGS:
model_class = HuggingfaceEmbedding
else:
raise NotImplementedError
@@ -47,11 +51,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0),
)
@classmethod
@@ -63,7 +67,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
:param model_type:
:param credentials:
"""
if model_type != ModelType.TEXT_GENERATION:
if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
raise NotImplementedError
if 'huggingfacehub_api_type' not in credentials \
@@ -88,18 +92,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'feature-extraction'):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization, feature-extraction.')
try:
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
if credentials['task_type'] == 'feature-extraction':
cls.check_embedding_valid(credentials, model_name)
else:
cls.check_llm_valid(credentials)
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
else:
@@ -111,13 +112,33 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "feature-extraction")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
@classmethod
def check_llm_valid(cls, credentials: dict):
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
@classmethod
def check_embedding_valid(cls, credentials: dict, model_name: str):
embedding_model = HuggingfaceHubEmbeddings(
model=model_name,
**credentials
)
embedding_model.embed_query("ping")
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:

View File

@@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
)
@classmethod

View File

@@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
top_p=KwargRule[float](min=0, max=1, default=0.95),
temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0),
)
@classmethod

View File

@@ -40,6 +40,10 @@ class OpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
@@ -128,16 +132,17 @@ class OpenAIProvider(BaseModelProvider):
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
)
@classmethod

View File

@@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0),
)
@classmethod

View File

@@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
default=float(value.get('default')) if value.get('default') is not None else None,
precision = 2
)
if key == 'temperature':
model_kwargs_rules.temperature = kwarg_rule
@@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
default=int(value.get('default')) if value.get('default') is not None else 500,
precision = 0
)
return model_kwargs_rules

View File

@@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.5),
temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0),
)
@classmethod
@@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider):
if 'api_secret' not in credentials:
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
try:
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
try:
chat_llm = ChatSpark(
model_name='spark-v2',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
# try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex

View File

@@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8),
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
)
@classmethod

View File

@@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
"""
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),

View File

@@ -2,6 +2,7 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -52,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'],
}
llm = XinferenceLLM(
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping")
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = XinferenceEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
if model_type == ModelType.TEXT_GENERATION:
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])

View File

@@ -0,0 +1,176 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
from models.provider import ProviderType, ProviderQuotaType
class ZhipuAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'zhipuai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
}
]
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text_embedding',
'name': 'text_embedding',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ZhipuAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = ZhipuAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key']
}
llm = ZhipuAIChatLLM(
temperature=0.01,
**credential_kwargs
)
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -6,6 +6,7 @@
"tongyi",
"spark",
"wenxin",
"zhipuai",
"chatglm",
"replicate",
"huggingface_hub",

View File

@@ -30,6 +30,12 @@
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-instruct": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",

View File

@@ -0,0 +1,44 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed",
"price_config": {
"chatglm_pro": {
"prompt": "0.01",
"completion": "0.01",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_std": {
"prompt": "0.005",
"completion": "0.005",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite": {
"prompt": "0.002",
"completion": "0.002",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite_32k": {
"prompt": "0.0004",
"completion": "0.0004",
"unit": "0.001",
"currency": "RMB"
},
"text_embedding": {
"completion": "0",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -1,6 +1,7 @@
import math
from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser:
@@ -36,8 +38,8 @@ class OrchestratorRuleParser:
self.app_model_config = app_model_config
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
-> Optional[AgentExecutor]:
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
@@ -63,7 +65,7 @@ class OrchestratorRuleParser:
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_instant=agent_model_instance,
model_instance=agent_model_instance,
conversation_message_task=conversation_message_task
)
@@ -74,7 +76,7 @@ class OrchestratorRuleParser:
# only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER:
@@ -99,7 +101,9 @@ class OrchestratorRuleParser:
tool_configs=tool_configs,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource,
retriever_from=retriever_from
)
if len(tools) == 0:
@@ -121,23 +125,45 @@ class OrchestratorRuleParser:
return chain
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs:
:return:
"""
if not self.app_model_config.sensitive_word_avoidance_dict:
return None
sensitive_word_avoidance_rule = None
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
if self.app_model_config.sensitive_word_avoidance_dict:
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_config.get("enabled", False):
if sensitive_word_avoidance_config.get('type') == 'moderation':
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.MODERATION,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
)
else:
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_words:
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
extra_params={
'sensitive_words': sensitive_words.split(','),
}
)
if sensitive_word_avoidance_rule:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
model_instance=model_instance,
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
output_key="sensitive_word_avoidance_output",
callbacks=callbacks,
**kwargs
@@ -145,8 +171,10 @@ class OrchestratorRuleParser:
return None
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
retriever_from: str = 'dev') -> list[BaseTool]:
"""
Convert app agent tool configs to tools
@@ -155,6 +183,8 @@ class OrchestratorRuleParser:
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks:
:param return_resource:
:param retriever_from:
:return:
"""
tools = []
@@ -166,7 +196,7 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(agent_model_instance)
elif tool_type == "google_search":
@@ -183,13 +213,15 @@ class OrchestratorRuleParser:
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int) \
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
:return:
"""
# get dataset from dataset id
@@ -208,7 +240,10 @@ class OrchestratorRuleParser:
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=k,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
retriever_from=retriever_from
)
return tool

View File

@@ -8,6 +8,6 @@
"pre_prompt",
"histories_prompt"
],
"query_prompt": "用户:{{query}}",
"query_prompt": "\n\n用户:{{query}}",
"stops": ["用户:"]
}

View File

@@ -8,6 +8,6 @@
"pre_prompt",
"histories_prompt"
],
"query_prompt": "Human: {{query}}\n\nAssistant: ",
"query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ",
"stops": ["\nHuman:", "</histories>"]
}
}

View File

@@ -1,10 +1,65 @@
CONVERSATION_TITLE_PROMPT = (
"Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
"If what the human said is conducted in English, you should only return an English title.\n"
"If what the human said is conducted in Chinese, you should only return a Chinese title.\n"
"title:"
)
# Written by YORKI MINAKO🤡
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
Notice: the language type user use could be diverse, which can be English, Chinese, Español, Arabic, Japanese, French, and etc.
MAKE SURE your output is the SAME language as the user's input!
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
Your output MUST be a valid JSON.
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
example 1:
User Input: hi, yesterday i had some burgers.
{
"Language Type": "The user's input is pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "sharing yesterday's food"
}
example 2:
User Input: hello
{
"Language Type": "The user's input is written in pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Greeting myself☺"
}
example 3:
User Input: why mmap file: oom
{
"Language Type": "The user's input is written in pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Asking about the reason for mmap file: oom"
}
example 4:
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么
{
"Language Type": "The user's input English-Chinese mixed",
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
}
example 5:
User Input: why小红的年龄is老than小明
{
"Language Type": "The user's input is English-Chinese mixed",
"Your Reasoning": "The English parts are subjective particles, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
"Your Output": "询问小红和小明的年龄"
}
example 6:
User Input: yo, 你今天咋样?
{
"Language Type": "The user's input is English-Chinese mixed",
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
"Your Output": "查询今日我的状态☺️"
}
User Input:
"""
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"
@@ -50,7 +105,7 @@ GENERATOR_QA_PROMPT = (
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
"Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \

View File

@@ -0,0 +1,68 @@
from typing import Any, Dict, List, Optional, Union
import json
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from huggingface_hub import InferenceClient
HOSTED_INFERENCE_API = 'hosted_inference_api'
INFERENCE_ENDPOINTS = 'inference_endpoints'
class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
client: Any
model: str
task_type: Optional[str] = None
huggingfacehub_api_type: Optional[str] = None
huggingfacehub_api_token: Optional[str] = None
huggingfacehub_endpoint_url: Optional[str] = None
class Config:
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values['huggingfacehub_api_token'] = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values['client'] = InferenceClient(values['huggingfacehub_api_token'])
return values
def embeddings(self, inputs: Union[str, List[str]]) -> str:
model = ''
if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
model = self.model
else:
model = self.huggingfacehub_endpoint_url
output = self.client.post(
json={
"inputs": inputs,
"options": {
"wait_for_model": False
}
}, model=model)
return json.loads(output.decode())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
output = self.embeddings(texts)
if isinstance(output, list):
return output
return [list(map(float, e)) for e in output]
def embed_query(self, text: str) -> List[float]:
output = self.embeddings(text)
if isinstance(output, list):
return output
return list(map(float, output))

View File

@@ -0,0 +1,64 @@
"""Wrapper around ZhipuAI embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""Wrapper around ZhipuAI embedding models.
1024 dimensions.
"""
client: Any #: :meta private:
model: str
"""Model name to use."""
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = self.client.invoke(model=self.model, prompt=text)
data = response["data"]
embeddings.append(data.get('embedding'))
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@@ -0,0 +1,62 @@
from typing import Dict, Optional, List, Any
from huggingface_hub import HfApi, InferenceApi
from langchain import HuggingFaceHub
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.huggingface_hub import VALID_TASKS
from pydantic import root_validator
from langchain.utils import get_from_dict_or_env
class HuggingFaceHubLLM(HuggingFaceHub):
"""HuggingFaceHub models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceHub
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
client = InferenceApi(
repo_id=values["repo_id"],
token=huggingfacehub_api_token,
task=values.get("task"),
)
client.options = {"wait_for_model": False, "use_gpu": False}
values["client"] = client
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
hfapi = HfApi(token=self.huggingfacehub_api_token)
model_info = hfapi.model_info(repo_id=self.repo_id)
if not model_info:
raise ValueError(f"Model {self.repo_id} not found.")
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {self.repo_id} is not a valid task, "
f"must be one of {VALID_TASKS}.")
return super()._call(prompt, stop, run_manager, **kwargs)

View File

@@ -14,6 +14,9 @@ class EnhanceOpenAI(OpenAI):
max_retries: int = 1
"""Maximum number of retries to make when generating."""
def __new__(cls, **data: Any): # type: ignore
return super(EnhanceOpenAI, cls).__new__(cls)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""

View File

@@ -0,0 +1,315 @@
"""Wrapper around ZhipuAI APIs."""
from __future__ import annotations
import json
import logging
import posixpath
from typing import (
Any,
Dict,
List,
Optional, Iterator, Sequence,
)
import zhipuai
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import Extra, root_validator, BaseModel
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.utils import get_from_dict_or_env
from zhipuai.model_api.api import InvokeType
from zhipuai.utils import jwt_token
from zhipuai.utils.http_client import post, stream
from zhipuai.utils.sse_client import SSEClient
logger = logging.getLogger(__name__)
class ZhipuModelAPI(BaseModel):
base_url: str
api_key: str
api_timeout_seconds = 60
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SYNC)
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
if not response['success']:
raise ValueError(
f"Error Code: {response['code']}, Message: {response['msg']} "
)
return response
def sse_invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SSE)
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
return SSEClient(data)
def _build_api_url(self, kwargs, *path):
if kwargs:
if "model" not in kwargs:
raise Exception("model param missed")
model = kwargs.pop("model")
else:
model = "-"
return posixpath.join(self.base_url, model, *path)
def _generate_token(self):
if not self.api_key:
raise Exception(
"api_key not provided, you could provide it."
)
try:
return jwt_token.generate_token(self.api_key)
except Exception:
raise ValueError(
f"Your api_key is invalid, please check it."
)
class ZhipuAIChatLLM(BaseChatModel):
"""Wrapper around ZhipuAI large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.zhipuai import ZhipuAI
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any = None #: :meta private:
model: str = "chatglm_lite"
"""Model name to use."""
temperature: float = 0.95
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.7
"""Total probability mass of tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the response or return it all at once."""
api_key: Optional[str] = None
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
if 'test' in values['base_url']:
values['model'] = 'chatglm_130b_test'
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "zhipuai"
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
dict_messages = []
for m in messages:
message = self._convert_message_to_dict(m)
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
continue
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
response = self.client.invoke(**request)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
for event in self.client.sse_invoke(incremental=True, **request).events():
if event.event == "add":
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
if run_manager:
run_manager.on_llm_new_token(event.data)
elif event.event == "error" or event.event == "interrupted":
raise ValueError(
f"{event.data}"
)
elif event.event == "finish":
meta = json.loads(event.meta)
token_usage = meta['usage']
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
yield ChatGenerationChunk(
message=AIMessageChunk(content=event.data),
generation_info=dict({'token_usage': token_usage})
)
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
data = response["data"]
generations = []
for res in data["choices"]:
message = self._convert_dict_to_message(res)
gen = ChatGeneration(
message=message
)
generations.append(gen)
token_usage = data.get("usage")
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
# def get_token_ids(self, text: str) -> List[int]:
# """Return the ordered ids of the tokens in a text.
#
# Args:
# text: The string input to tokenize.
#
# Returns:
# A list of ids corresponding to the tokens in the text, in order they occur
# in the text.
# """
# from core.third_party.transformers.Token import ChatGLMTokenizer
#
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
# return tokenizer.encode(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}

View File

@@ -1,3 +1,4 @@
import json
from typing import Type
from flask import current_app
@@ -5,13 +6,14 @@ from langchain.tools import BaseTool
from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.conversation_message_task import ConversationMessageTask
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, Document
class DatasetRetrieverToolInput(BaseModel):
@@ -27,6 +29,9 @@ class DatasetRetrieverTool(BaseTool):
tenant_id: str
dataset_id: str
k: int = 3
conversation_message_task: ConversationMessageTask
return_resource: str
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -86,16 +91,23 @@ class DatasetRetrieverTool(BaseTool):
if self.k > 0:
documents = vector_index.search(
query,
search_type='similarity',
search_type='similarity_score_threshold',
search_kwargs={
'k': self.k
'k': self.k,
'filter': {
'group_id': [dataset.id]
}
}
)
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
@@ -112,9 +124,43 @@ class DatasetRetrieverTool(BaseTool):
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from
}
if dataset.indexing_technique != "economy":
source['score'] = document_score_list.get(segment.index_node_id)
if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))

View File

@@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
self.client.delete_collection(collection_name=self.collection_name)
def delete_group(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod
def _document_from_scored_point(
cls,

View File

@@ -26,7 +26,7 @@ def handle(sender, **kwargs):
conversation.name = name
except:
conversation.name = 'New Chat'
conversation.name = 'New conversation'
db.session.add(conversation)
db.session.commit()

View File

@@ -0,0 +1,54 @@
"""add_dataset_retriever_resource
Revision ID: 6dcb43972bdc
Revises: 4bcffcd64aa4
Create Date: 2023-09-06 16:51:27.385844
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6dcb43972bdc'
down_revision = '4bcffcd64aa4'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_retriever_resources',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('message_id', postgresql.UUID(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('dataset_id', postgresql.UUID(), nullable=False),
sa.Column('dataset_name', sa.Text(), nullable=False),
sa.Column('document_id', postgresql.UUID(), nullable=False),
sa.Column('document_name', sa.Text(), nullable=False),
sa.Column('data_source_type', sa.Text(), nullable=False),
sa.Column('segment_id', postgresql.UUID(), nullable=False),
sa.Column('score', sa.Float(), nullable=True),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('hit_count', sa.Integer(), nullable=True),
sa.Column('word_count', sa.Integer(), nullable=True),
sa.Column('segment_position', sa.Integer(), nullable=True),
sa.Column('index_node_hash', sa.Text(), nullable=True),
sa.Column('retriever_from', sa.Text(), nullable=False),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey')
)
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.drop_index('dataset_retriever_resource_message_id_idx')
op.drop_table('dataset_retriever_resources')
# ### end Alembic commands ###

Some files were not shown because too many files have changed in this diff Show More