mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
69 Commits
feat/optim
...
feat/huggi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8db580833 | ||
|
|
9d3ba98d07 | ||
|
|
8b2573efab | ||
|
|
757ee4d39f | ||
|
|
435f804c6f | ||
|
|
ae3f1ac0a9 | ||
|
|
269a465fc4 | ||
|
|
60e0bbd713 | ||
|
|
827c97f0d3 | ||
|
|
c8bd76cd66 | ||
|
|
ec5f585df4 | ||
|
|
1de48f33ca | ||
|
|
6b41a9593e | ||
|
|
82267083e8 | ||
|
|
c385961d33 | ||
|
|
20bab6edec | ||
|
|
67bed54f32 | ||
|
|
562a571281 | ||
|
|
fc68c81791 | ||
|
|
5d9070bc60 | ||
|
|
b11fb0dfd1 | ||
|
|
d1c5c5f160 | ||
|
|
0b1d1440aa | ||
|
|
0c420d64b3 | ||
|
|
f9082104ed | ||
|
|
983834cd52 | ||
|
|
96d10c8b39 | ||
|
|
24cb992843 | ||
|
|
7907c0bf58 | ||
|
|
ebf4fd9a09 | ||
|
|
38b9901274 | ||
|
|
642842d61b | ||
|
|
e161c511af | ||
|
|
f29e82685e | ||
|
|
3a5ae96e7b | ||
|
|
b63a685386 | ||
|
|
877da82b06 | ||
|
|
6637629045 | ||
|
|
e925b6c572 | ||
|
|
5412f4aba5 | ||
|
|
2d5ad0d208 | ||
|
|
1ade70aa1e | ||
|
|
2658c4d57b | ||
|
|
84c76bc04a | ||
|
|
6effcd3755 | ||
|
|
d9866489f0 | ||
|
|
c4d8bdc3db | ||
|
|
681eb1cfcc | ||
|
|
a5d21f3b09 | ||
|
|
7ba068c3e4 | ||
|
|
b201eeedbd | ||
|
|
f28cb84977 | ||
|
|
714872cd58 | ||
|
|
0708bd60ee | ||
|
|
23a6c85b80 | ||
|
|
4a28599fbd | ||
|
|
7c66d3c793 | ||
|
|
cc9edfffd8 | ||
|
|
6fa2454c9a | ||
|
|
487e699021 | ||
|
|
a7cdb745c1 | ||
|
|
73c86ee6a0 | ||
|
|
48eb590065 | ||
|
|
33562a9d8d | ||
|
|
c9194ba382 | ||
|
|
a199fa6388 | ||
|
|
4c8608dc61 | ||
|
|
a6b0f788e7 | ||
|
|
df6604a734 |
49
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
49
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: "🕷️ Bug report"
|
||||
description: Report errors or unexpected behavior
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
placeholder: 0.3.21
|
||||
description: See about section in Dify console
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Cloud or Self Hosted
|
||||
description: How / Where was Dify installed from?
|
||||
multiple: true
|
||||
options:
|
||||
- Cloud
|
||||
- Self Hosted
|
||||
- Other (please specify in "Steps to Reproduce")
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: We highly suggest including screenshots and a bug report log.
|
||||
placeholder: Having detailed steps helps us reproduce the bug.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: What were you expecting?
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: What happened instead?
|
||||
validations:
|
||||
required: false
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: "\U0001F4DA Dify user documentation"
|
||||
url: https://docs.dify.ai/getting-started/readme
|
||||
about: Documentation for users of Dify
|
||||
- name: "\U0001F4DA Dify dev documentation"
|
||||
url: https://docs.dify.ai/getting-started/install-self-hosted
|
||||
about: Documentation for people interested in developing and contributing for Dify
|
||||
11
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
Normal file
11
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
name: "📚 Documentation Issue"
|
||||
description: Report issues in our documentation
|
||||
labels:
|
||||
- ducumentation
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of requested docs changes
|
||||
placeholder: Briefly describe which document needs to be corrected and why.
|
||||
validations:
|
||||
required: true
|
||||
26
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
26
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: "⭐ Feature or enhancement request"
|
||||
description: Propose something new.
|
||||
labels:
|
||||
- enhancement
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Description of the new feature / enhancement
|
||||
placeholder: What is the expected behavior of the proposed feature?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Scenario when this would be used?
|
||||
placeholder: What is the scenario this would be used? Why is this important to your workflow as a dify user?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Supporting information
|
||||
placeholder: "Having additional evidence, data, tweets, blog posts, research, ... anything is extremely helpful. This information provides context to the scenario that may otherwise be lost."
|
||||
validations:
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: Please limit one request per issue.
|
||||
46
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
Normal file
46
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: "🌐 Localization/Translation issue"
|
||||
description: Report incorrect translations.
|
||||
labels:
|
||||
- translation
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
placeholder: 0.3.21
|
||||
description: Hover over system tray icon or look at Settings
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: Utility with translation issue
|
||||
placeholder: Some area
|
||||
description: Please input here the utility with the translation issue
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: 🌐 Language affected
|
||||
placeholder: "German"
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual phrase(s)
|
||||
placeholder: What is there? Please include a screenshot as that is extremely helpful.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected phrase(s)
|
||||
placeholder: What was expected?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ℹ Why is the current translation wrong
|
||||
placeholder: Why do you feel this is incorrect?
|
||||
validations:
|
||||
required: true
|
||||
32
.github/ISSUE_TEMPLATE/🐛-bug-report.md
vendored
32
.github/ISSUE_TEMPLATE/🐛-bug-report.md
vendored
@@ -1,32 +0,0 @@
|
||||
---
|
||||
name: "\U0001F41B Bug report"
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
<!--
|
||||
Please provide a clear and concise description of what the bug is. Include
|
||||
screenshots if needed. Please test using the latest version of the relevant
|
||||
Dify packages to make sure your issue has not already been fixed.
|
||||
-->
|
||||
|
||||
Dify version: Cloud | Self Host
|
||||
|
||||
## Steps To Reproduce
|
||||
<!--
|
||||
Your bug will get fixed much faster if we can run your code and it doesn't
|
||||
have dependencies other than Dify. Issues without reproduction steps or
|
||||
code examples may be immediately closed as not actionable.
|
||||
-->
|
||||
|
||||
1.
|
||||
2.
|
||||
|
||||
|
||||
## The current behavior
|
||||
|
||||
|
||||
## The expected behavior
|
||||
20
.github/ISSUE_TEMPLATE/🚀-feature-request.md
vendored
20
.github/ISSUE_TEMPLATE/🚀-feature-request.md
vendored
@@ -1,20 +0,0 @@
|
||||
---
|
||||
name: "\U0001F680 Feature request"
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
10
.github/ISSUE_TEMPLATE/🤔-questions-and-help.md
vendored
10
.github/ISSUE_TEMPLATE/🤔-questions-and-help.md
vendored
@@ -1,10 +0,0 @@
|
||||
---
|
||||
name: "\U0001F914 Questions and Help"
|
||||
about: Ask a usage or consultation question
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -20,7 +20,8 @@ def check_file_for_chinese_comments(file_path):
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
|
||||
'prompts.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -149,4 +149,5 @@ sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
27
.vscode/launch.json
vendored
Normal file
27
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "api/app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--debug"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -16,6 +16,10 @@ Out-of-the-box web sites supporting form mode and chat conversation mode
|
||||
A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort
|
||||
Visual data analysis, log review, and annotation for applications
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
|
||||
|
||||
|
||||
## Highlighted Features
|
||||
**1. LLMs support:** Choose capabilities based on different models when building your Dify AI apps. Dify is compatible with Langchain, meaning it will support various LLMs. Currently supported:
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
|
||||
- 可视化的对应用进行数据分析,查阅日志或进行标注
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
|
||||
|
||||
## 核心能力
|
||||
1. **模型支持:** 你可以在 Dify 上选择基于不同模型的能力来开发你的 AI 应用。Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
FROM python:3.10-slim
|
||||
# packages install stage
|
||||
FROM python:3.10-slim AS base
|
||||
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ python3-dev libc-dev libffi-dev
|
||||
|
||||
COPY requirements.txt /requirements.txt
|
||||
|
||||
RUN pip install --prefix=/pkg -r requirements.txt
|
||||
|
||||
# build stage
|
||||
FROM python:3.10-slim AS builder
|
||||
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
@@ -15,15 +26,17 @@ EXPOSE 5001
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
|
||||
|
||||
COPY requirements.txt /app/api/requirements.txt
|
||||
|
||||
RUN pip install -r requirements.txt
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs \
|
||||
&& apt-get autoremove \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=base /pkg /usr/local
|
||||
COPY . /app/api/
|
||||
|
||||
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
|
||||
ENV TRANSFORMERS_OFFLINE true
|
||||
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
|
||||
8. If you need to debug local async processing, you can run `celery -A app.celery worker -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
|
||||
8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
|
||||
|
||||
8. Start frontend
|
||||
|
||||
|
||||
10
api/app.py
10
api/app.py
@@ -1,6 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -145,8 +145,12 @@ def load_user(user_id):
|
||||
_create_tenant_for_account(account)
|
||||
session['workspace_id'] = account.current_tenant_id
|
||||
|
||||
account.last_active_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
# update last_active_at when last_active_at is more than 10 minutes ago
|
||||
if current_time - account.last_active_at > timedelta(minutes=10):
|
||||
account.last_active_at = current_time
|
||||
db.session.commit()
|
||||
|
||||
# Log in the user with the updated user_id
|
||||
flask_login.login_user(account, remember=True)
|
||||
|
||||
230
api/commands.py
230
api/commands.py
@@ -4,8 +4,10 @@ import math
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import click
|
||||
from tqdm import tqdm
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -21,9 +23,9 @@ from libs.password import password_pattern, valid_password, hash_password
|
||||
from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant
|
||||
from models.dataset import Dataset, DatasetQuery, Document
|
||||
from models.model import Account
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
|
||||
from models.model import Account, AppModelConfig, App
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
@@ -238,7 +240,13 @@ def clean_unused_dataset_indexes():
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete()
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
vector_index.delete()
|
||||
kw_index.delete()
|
||||
# update document
|
||||
update_params = {
|
||||
@@ -345,7 +353,8 @@ def create_qdrant_indexes():
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
@@ -363,7 +372,8 @@ def create_qdrant_indexes():
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
"vector_store": {
|
||||
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
@@ -372,7 +382,8 @@ def create_qdrant_indexes():
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
@@ -413,7 +424,8 @@ def update_qdrant_indexes():
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
@@ -434,11 +446,207 @@ def update_qdrant_indexes():
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('normalization-collections', help='restore all collections in one')
|
||||
def normalization_collections():
|
||||
click.echo(click.style('Start normalization collections.', fg='green'))
|
||||
normalization_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if not dataset.collection_binding_id:
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
|
||||
original_index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if original_index:
|
||||
original_index.delete_original_collection(dataset, dataset_collection_binding)
|
||||
normalization_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def update_app_model_configs(batch_size):
|
||||
pre_prompt_template = '{{default_input}}'
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"zh-Hans": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "查询内容",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.count()
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
|
||||
try:
|
||||
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
|
||||
for data in data_batch:
|
||||
# click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
|
||||
|
||||
if data.pre_prompt is None:
|
||||
data.pre_prompt = pre_prompt_template
|
||||
else:
|
||||
if pre_prompt_template in data.pre_prompt:
|
||||
continue
|
||||
data.pre_prompt += pre_prompt_template
|
||||
|
||||
app_data = db.session.query(App) \
|
||||
.filter(App.id == data.app_id) \
|
||||
.one()
|
||||
|
||||
account_data = db.session.query(Account) \
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
|
||||
.one_or_none()
|
||||
|
||||
if not account_data:
|
||||
continue
|
||||
|
||||
if data.user_input_form is None or data.user_input_form == 'null':
|
||||
data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
|
||||
else:
|
||||
raw_json_data = json.loads(data.user_input_form)
|
||||
raw_json_data.append(user_input_form_template[account_data.interface_language][0])
|
||||
data.user_input_form = json.dumps(raw_json_data)
|
||||
|
||||
# click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@@ -448,4 +656,6 @@ def register_commands(app):
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(update_qdrant_indexes)
|
||||
app.cli.add_command(update_qdrant_indexes)
|
||||
app.cli.add_command(update_app_model_configs)
|
||||
app.cli.add_command(normalization_collections)
|
||||
|
||||
@@ -61,6 +61,8 @@ DEFAULTS = {
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
@@ -100,7 +102,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.19"
|
||||
self.CURRENT_VERSION = "0.3.22"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -230,6 +232,9 @@ class Config:
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ model_templates = {
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'text-davinci-003',
|
||||
'model_id': 'gpt-3.5-turbo-instruct',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
@@ -30,7 +30,7 @@ model_templates = {
|
||||
},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "text-davinci-003",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@@ -38,7 +38,18 @@ model_templates = {
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
})
|
||||
}),
|
||||
'user_input_form': json.dumps([
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
]),
|
||||
'pre_prompt': '{{query}}'
|
||||
}
|
||||
},
|
||||
|
||||
@@ -93,7 +104,7 @@ demo_model_templates = {
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='text-davinci-003',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "Please translate the following text into {{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
@@ -129,7 +140,7 @@ demo_model_templates = {
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "text-davinci-003",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@@ -211,7 +222,7 @@ demo_model_templates = {
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='text-davinci-003',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
@@ -247,7 +258,7 @@ demo_model_templates = {
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "text-davinci-003",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
|
||||
@@ -29,6 +29,7 @@ model_config_fields = {
|
||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
|
||||
@@ -39,9 +39,10 @@ class CompletionMessageApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
@@ -115,6 +116,7 @@ class ChatMessageApi(Resource):
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
|
||||
@@ -16,26 +16,25 @@ from services.account_service import RegisterService
|
||||
class ActivateCheckApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='args')
|
||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
|
||||
parser.add_argument('email', type=email, required=False, nullable=True, location='args')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
workspaceId = args['workspace_id']
|
||||
reg_email = args['email']
|
||||
token = args['token']
|
||||
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == args['workspace_id'],
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
|
||||
return {'is_valid': account is not None, 'workspace_name': tenant.name}
|
||||
return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='json')
|
||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('email', type=email, required=False, nullable=True, location='json')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
||||
@@ -44,12 +43,13 @@ class ActivateApi(Resource):
|
||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
if account is None:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
account = invitation['account']
|
||||
account.name = args['name']
|
||||
|
||||
# generate password salt
|
||||
|
||||
@@ -26,7 +26,7 @@ from models.model import UploadFile
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx']
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
|
||||
@@ -31,8 +31,9 @@ class CompletionApi(InstalledAppResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
@@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource):
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
@@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource):
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
@@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource):
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
@@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource):
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
@@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource):
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask_restful import marshal_with, fields
|
||||
|
||||
from controllers.console import api
|
||||
@@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
app_model_config.retriever_resource = json.dumps({'enabled': True})
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
retriever_resource=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
model=json.dumps({
|
||||
|
||||
@@ -72,7 +72,7 @@ class MemberInviteEmailApi(Resource):
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?workspace_id={current_user.current_tenant_id}&email={invitee_email}&token={token}'
|
||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
||||
})
|
||||
account = marshal(account, account_fields)
|
||||
account['role'] = role
|
||||
|
||||
@@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
'enabled': v.enabled,
|
||||
'min': v.min,
|
||||
'max': v.max,
|
||||
'default': v.default
|
||||
'default': v.default,
|
||||
'precision': v.precision
|
||||
}
|
||||
for k, v in vars(parameter_rules).items()
|
||||
}
|
||||
@@ -285,6 +286,25 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
token=args['token']
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
|
||||
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
|
||||
@@ -300,3 +320,5 @@ api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
|
||||
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
|
||||
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
|
||||
|
||||
@@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
@@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
@@ -27,9 +27,11 @@ class CompletionApi(AppApiResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
@@ -91,6 +93,8 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
@@ -16,6 +16,24 @@ class MessageListApi(AppApiResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
@@ -24,6 +42,7 @@ class MessageListApi(AppApiResource):
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
@@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
@@ -29,8 +29,10 @@ class CompletionApi(WebApiResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
@@ -88,6 +90,8 @@ class ChatApi(WebApiResource):
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
@@ -29,6 +29,25 @@ class MessageListApi(WebApiResource):
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
@@ -36,6 +55,7 @@ class MessageListApi(WebApiResource):
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
@@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
# output = ''
|
||||
# rst_json = json.loads(rst)
|
||||
# for item in rst_json:
|
||||
# output += f'{item["content"]}\n'
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
|
||||
@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.helper import moderation
|
||||
from core.model_providers.error import LLMError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
@@ -116,6 +118,18 @@ class AgentExecutor:
|
||||
return self.agent.should_use_agent(query)
|
||||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
moderation_result = moderation.check_moderation(
|
||||
self.configuration.model_instance.model_provider,
|
||||
query
|
||||
)
|
||||
|
||||
if not moderation_result:
|
||||
return AgentExecuteResult(
|
||||
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
strategy=self.configuration.strategy,
|
||||
configuration=self.configuration
|
||||
)
|
||||
|
||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
@@ -128,7 +142,9 @@ class AgentExecutor:
|
||||
|
||||
try:
|
||||
output = agent_executor.run(query)
|
||||
except Exception:
|
||||
except LLMError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logging.exception("agent_executor run failed")
|
||||
output = None
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_instant = model_instant
|
||||
self.model_instance = model_instance
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([message.content for message in messages[0]]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
|
||||
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
|
||||
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instant, self._current_loop
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
)
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instant, self._current_loop
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
|
||||
@@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# kwargs={'name': 'Search'}
|
||||
# llm_prefix='Thought:'
|
||||
# observation_prefix='Observation: '
|
||||
# output='53 years'
|
||||
pass
|
||||
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
|
||||
@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion: str = ''
|
||||
completion_tokens: int = 0
|
||||
latency: float = 0.0
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
@@ -9,8 +10,9 @@ from models.dataset import DocumentSegment
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, dataset_id: str) -> None:
|
||||
def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
|
||||
self.dataset_id = dataset_id
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
@@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self.conversation_message_task.on_dataset_query_finish(resource)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
self.start_at = time.perf_counter()
|
||||
real_prompts = []
|
||||
for message in messages[0]:
|
||||
if message.type == 'human':
|
||||
@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
@@ -63,14 +59,22 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
|
||||
if response.llm_output and 'token_usage' in response.llm_output:
|
||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
|
||||
if 'completion_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
@@ -89,8 +93,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing."""
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
|
||||
@@ -1,15 +1,33 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation import openai_moderation
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceRule(BaseModel):
|
||||
class Type(enum.Enum):
|
||||
MODERATION = "moderation"
|
||||
KEYWORDS = "keywords"
|
||||
|
||||
type: Type
|
||||
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
|
||||
extra_params: dict = {}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
sensitive_words: List[str] = []
|
||||
canned_response: str = None
|
||||
model_instance: BaseLLM
|
||||
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@@ -31,11 +49,24 @@ class SensitiveWordAvoidanceChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _check_sensitive_word(self, text: str) -> str:
|
||||
for word in self.sensitive_words:
|
||||
def _check_sensitive_word(self, text: str) -> bool:
|
||||
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
|
||||
if word in text:
|
||||
return self.canned_response
|
||||
return text
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_moderation(self, text: str) -> bool:
|
||||
moderation_model_instance = ModelFactory.get_moderation_model(
|
||||
tenant_id=self.model_instance.model_provider.provider.tenant_id,
|
||||
model_provider_name='openai',
|
||||
model_name=openai_moderation.DEFAULT_MODEL
|
||||
)
|
||||
|
||||
try:
|
||||
return moderation_model_instance.run(text=text)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@@ -43,5 +74,19 @@ class SensitiveWordAvoidanceChain(Chain):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
text = inputs[self.input_key]
|
||||
output = self._check_sensitive_word(text)
|
||||
return {self.output_key: output}
|
||||
|
||||
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
|
||||
result = self._check_sensitive_word(text)
|
||||
else:
|
||||
result = self._check_moderation(text)
|
||||
|
||||
if not result:
|
||||
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
|
||||
|
||||
return {self.output_key: text}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
@@ -1,31 +1,32 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Union, Tuple
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||
|
||||
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
|
||||
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
|
||||
is_override: bool = False, retriever_from: str = 'dev'):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
@@ -76,29 +77,53 @@ class Completion:
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
# parse sensitive_word_avoidance_chain
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback
|
||||
)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
|
||||
# run the final llm
|
||||
try:
|
||||
# parse sensitive_word_avoidance_chain
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
|
||||
final_model_instance, [chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
try:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
except SensitiveWordAvoidanceError as ex:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=ex.message
|
||||
)
|
||||
return
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
|
||||
PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# run the final llm
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
@@ -107,7 +132,8 @@ class Completion:
|
||||
inputs=inputs,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
fake_response=fake_response
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
return
|
||||
@@ -118,17 +144,12 @@ class Completion:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
fake_response: Optional[str]):
|
||||
# get llm prompt
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
@@ -150,7 +171,6 @@ class Completion:
|
||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
||||
fake_response=fake_response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import decimal
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
import time
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
@@ -15,13 +15,16 @@ from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
|
||||
MessageChain, DatasetRetrieverResource
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.task_id = task_id
|
||||
|
||||
self.app = app
|
||||
@@ -41,6 +44,8 @@ class ConversationMessageTask:
|
||||
|
||||
self.message = None
|
||||
|
||||
self.retriever_resource = None
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.provider_name = self.model_dict.get('provider')
|
||||
self.model_name = self.model_dict.get('name')
|
||||
@@ -58,6 +63,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
def init(self):
|
||||
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = self.app_model_config.to_dict()
|
||||
@@ -157,11 +163,12 @@ class ConversationMessageTask:
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer = PromptBuilder.process_template(
|
||||
llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.provider_response_latency = time.perf_counter() - self.start_at
|
||||
self.message.total_price = total_price
|
||||
|
||||
db.session.commit()
|
||||
@@ -216,18 +223,18 @@ class ConversationMessageTask:
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
@@ -241,7 +248,7 @@ class ConversationMessageTask:
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = agent_model_instant.get_currency()
|
||||
message_agent_thought.currency = agent_model_instance.get_currency()
|
||||
db.session.flush()
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
@@ -256,7 +263,36 @@ class ConversationMessageTask:
|
||||
|
||||
db.session.add(dataset_query)
|
||||
|
||||
def on_dataset_query_finish(self, resource: List):
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self.message.id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.flush()
|
||||
self.retriever_resource = resource
|
||||
|
||||
def message_end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
|
||||
def end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
@@ -350,6 +386,23 @@ class PubHandler:
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_message_end(self, retriever_resource: List):
|
||||
content = {
|
||||
'event': 'message_end',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
if retriever_resource:
|
||||
content['data']['retriever_resources'] = retriever_resource
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_end(self):
|
||||
content = {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.schema import OutputParserException
|
||||
@@ -22,18 +23,25 @@ class LLMGenerator:
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
prompt = prompt.format(query=query)
|
||||
query = query.replace("\n", " ")
|
||||
|
||||
prompt += query + "\n"
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=50
|
||||
temperature=1,
|
||||
max_tokens=100
|
||||
)
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
|
||||
result_dict = json.loads(answer)
|
||||
answer = result_dict['Your Output']
|
||||
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
|
||||
34
api/core/helper/moderation.py
Normal file
34
api/core/helper/moderation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
|
||||
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and model_provider.provider_name in hosted_config.moderation.providers:
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = 32
|
||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
|
||||
for text_chunk in chunks:
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -16,6 +16,10 @@ class BaseIndex(ABC):
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
raise NotImplementedError
|
||||
@@ -28,6 +32,10 @@ class BaseIndex(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -25,7 +25,33 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = {}
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = {}
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
@@ -52,7 +78,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
@@ -74,7 +100,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
DocumentSegment.document_id == document_id
|
||||
).all()
|
||||
|
||||
ids = [segment.id for segment in segments]
|
||||
ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
@@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
@@ -199,15 +231,18 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
|
||||
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id == node_id
|
||||
).first()
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(node_id, keywords)
|
||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
@@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"delete original collection: {dataset.id}")
|
||||
|
||||
self.delete()
|
||||
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
@@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=collection_name,
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
|
||||
@@ -28,6 +28,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
@@ -84,6 +85,7 @@ class Qdrant(VectorStore):
|
||||
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR_NAME = None
|
||||
|
||||
def __init__(
|
||||
@@ -93,9 +95,12 @@ class Qdrant(VectorStore):
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
distance_strategy: str = "COSINE",
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
embedding_function: Optional[Callable] = None, # deprecated
|
||||
is_new_collection: bool = False
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
try:
|
||||
@@ -129,7 +134,10 @@ class Qdrant(VectorStore):
|
||||
self.collection_name = collection_name
|
||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
||||
self.group_payload_key = group_payload_key or self.GROUP_KEY
|
||||
self.vector_name = vector_name or self.VECTOR_NAME
|
||||
self.group_id = group_id
|
||||
self.is_new_collection= is_new_collection
|
||||
|
||||
if embedding_function is not None:
|
||||
warnings.warn(
|
||||
@@ -170,6 +178,8 @@ class Qdrant(VectorStore):
|
||||
batch_size:
|
||||
How many vectors upload per-request.
|
||||
Default: 64
|
||||
group_id:
|
||||
collection group
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
@@ -182,7 +192,11 @@ class Qdrant(VectorStore):
|
||||
collection_name=self.collection_name, points=points, **kwargs
|
||||
)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
# if is new collection, create payload index on group_id
|
||||
if self.is_new_collection:
|
||||
self.client.create_payload_index(self.collection_name, self.group_payload_key,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
field_type=PayloadSchemaType.KEYWORD)
|
||||
return added_ids
|
||||
|
||||
@sync_call_fallback
|
||||
@@ -970,6 +984,8 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
batch_size: int = 64,
|
||||
shard_number: Optional[int] = None,
|
||||
@@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key:
|
||||
A payload key used to store the metadata of the document.
|
||||
Default: "metadata"
|
||||
group_payload_key:
|
||||
A payload key used to store the content of the document.
|
||||
Default: "group_id"
|
||||
group_id:
|
||||
collection group id
|
||||
vector_name:
|
||||
Name of the vector to be used internally in Qdrant.
|
||||
Default: None
|
||||
@@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
|
||||
distance_func,
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
group_payload_key,
|
||||
group_id,
|
||||
vector_name,
|
||||
shard_number,
|
||||
replication_factor,
|
||||
@@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
shard_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
@@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
|
||||
vector_size = len(partial_embeddings[0])
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
distance_func = distance_func.upper()
|
||||
is_new_collection = False
|
||||
client = qdrant_client.QdrantClient(
|
||||
location=location,
|
||||
url=url,
|
||||
@@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
|
||||
init_from=init_from,
|
||||
timeout=timeout, # type: ignore[arg-type]
|
||||
)
|
||||
is_new_collection = True
|
||||
qdrant = cls(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
@@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
distance_strategy=distance_func,
|
||||
vector_name=vector_name,
|
||||
group_id=group_id,
|
||||
group_payload_key=group_payload_key,
|
||||
is_new_collection=is_new_collection
|
||||
)
|
||||
return qdrant
|
||||
|
||||
@@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
|
||||
metadatas: Optional[List[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str
|
||||
) -> List[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
@@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
|
||||
{
|
||||
content_payload_key: text,
|
||||
metadata_payload_key: metadata,
|
||||
group_payload_key: group_id
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
|
||||
else:
|
||||
out.append(
|
||||
rest.FieldCondition(
|
||||
key=f"{self.metadata_payload_key}.{key}",
|
||||
key=key,
|
||||
match=rest.MatchValue(value=value),
|
||||
)
|
||||
)
|
||||
@@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
self.group_id,
|
||||
self.group_payload_key
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http.models import HnswConfigDiff
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||
from models.dataset import Dataset
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
root_path: Optional[str]
|
||||
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
@@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
return 'qdrant'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
return dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
return class_prefix
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
@@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
@@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
@@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='page_content'
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return QdrantVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@@ -113,6 +138,38 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
for node_id in ids:
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=group_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
|
||||
@@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation.base import BaseModeration
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantDefaultModel
|
||||
@@ -180,7 +181,7 @@ class ModelFactory:
|
||||
def get_moderation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
model_name: str) -> Optional[BaseProviderModel]:
|
||||
model_name: str) -> Optional[BaseModeration]:
|
||||
"""
|
||||
get moderation model.
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'wenxin':
|
||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||
return WenxinProvider
|
||||
elif provider_name == 'zhipuai':
|
||||
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
|
||||
return ZhipuAIProvider
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
from replicate.exceptions import ModelError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class HuggingfaceEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = HuggingfaceHubEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface embedding: {str(ex)}")
|
||||
@@ -0,0 +1,22 @@
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
|
||||
|
||||
|
||||
class ZhipuAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = ZhipuAIEmbeddings(
|
||||
model=name,
|
||||
**credentials,
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")
|
||||
@@ -8,6 +8,7 @@ class LLMRunResult(BaseModel):
|
||||
content: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
source: list = None
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
|
||||
@@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
|
||||
max: Optional[T] = None
|
||||
default: Optional[T] = None
|
||||
alias: Optional[str] = None
|
||||
precision: Optional[int] = None
|
||||
|
||||
|
||||
class ModelKwargsRules(BaseModel):
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.helper import moderation
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
moderation_result = moderation.check_moderation(
|
||||
self.model_provider,
|
||||
"\n".join([message.content for message in messages])
|
||||
)
|
||||
|
||||
if not moderation_result:
|
||||
kwargs['fake_response'] = "I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest."
|
||||
|
||||
if self.deduct_quota:
|
||||
self.model_provider.check_quota_over_limit()
|
||||
|
||||
@@ -342,7 +352,7 @@ class BaseLLM(BaseProviderModel):
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
|
||||
prompt += pre_prompt_content
|
||||
|
||||
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
streaming=streaming
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
client = HuggingFaceHubLLM(
|
||||
repo_id=self.name,
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
if 'baichuan' in self.name.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -17,6 +17,7 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
COMPLETION_MODELS = [
|
||||
'gpt-3.5-turbo-instruct', # 4,096 tokens
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
]
|
||||
|
||||
@@ -31,6 +32,7 @@ MODEL_MAX_TOKENS = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 8192,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
61
api/core/model_providers/models/llm/zhipuai_model.py
Normal file
61
api/core/model_providers/models/llm/zhipuai_model.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
|
||||
|
||||
|
||||
class ZhipuAIModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ZhipuAIChatLLM(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
29
api/core/model_providers/models/moderation/base.py
Normal file
29
api/core/model_providers/models/moderation/base.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseModeration(BaseProviderModel):
|
||||
name: str
|
||||
type: ModelType = ModelType.MODERATION
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
def run(self, text: str) -> bool:
|
||||
try:
|
||||
return self._run(text)
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, text: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
raise NotImplementedError
|
||||
@@ -4,29 +4,39 @@ import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.models.moderation.base import BaseModeration
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
DEFAULT_AUDIO_MODEL = 'whisper-1'
|
||||
DEFAULT_MODEL = 'whisper-1'
|
||||
|
||||
|
||||
class OpenAIModeration(BaseProviderModel):
|
||||
type: ModelType = ModelType.MODERATION
|
||||
class OpenAIModeration(BaseModeration):
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
super().__init__(model_provider, openai.Moderation)
|
||||
super().__init__(model_provider, openai.Moderation, name)
|
||||
|
||||
def run(self, text):
|
||||
def _run(self, text: str) -> bool:
|
||||
credentials = self.model_provider.get_model_credentials(
|
||||
model_name=DEFAULT_AUDIO_MODEL,
|
||||
model_name=self.name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
try:
|
||||
return self._client.create(input=text, api_key=credentials['openai_api_key'])
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = 32
|
||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
|
||||
for text_chunk in chunks:
|
||||
moderation_result = self._client.create(input=text_chunk,
|
||||
api_key=credentials['openai_api_key'])
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
|
||||
@@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
model_credentials = self.get_model_credentials(model_name, model_type)
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
|
||||
model_credentials['base_model_name'],
|
||||
4097
|
||||
), default=16),
|
||||
), default=16, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
|
||||
hosted_model_providers = HostedModelProviders()
|
||||
|
||||
|
||||
class HostedModerationConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
providers: list[str] = []
|
||||
|
||||
|
||||
class HostedConfig(BaseModel):
|
||||
moderation = HostedModerationConfig()
|
||||
|
||||
|
||||
hosted_config = HostedConfig()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
@@ -78,3 +90,9 @@ def init_app(app: Flask):
|
||||
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
|
||||
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
hosted_config.moderation = HostedModerationConfig(
|
||||
enabled=app.config.get("HOSTED_MODERATION_ENABLED"),
|
||||
providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',')
|
||||
)
|
||||
|
||||
@@ -10,6 +10,8 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
|
||||
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@@ -33,6 +35,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = HuggingfaceHubModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = HuggingfaceEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -47,11 +51,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -63,7 +67,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if model_type != ModelType.TEXT_GENERATION:
|
||||
if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'huggingfacehub_api_type' not in credentials \
|
||||
@@ -88,18 +92,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'feature-extraction'):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
|
||||
'text-generation, summarization, feature-extraction.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
if credentials['task_type'] == 'feature-extraction':
|
||||
cls.check_embedding_valid(credentials, model_name)
|
||||
else:
|
||||
cls.check_llm_valid(credentials)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
else:
|
||||
@@ -111,13 +112,33 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "feature-extraction")
|
||||
if model_info.pipeline_tag not in VALID_TASKS:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
||||
f"must be one of {VALID_TASKS}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
|
||||
@classmethod
|
||||
def check_llm_valid(cls, credentials: dict):
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
|
||||
@classmethod
|
||||
def check_embedding_valid(cls, credentials: dict, model_name: str):
|
||||
embedding_model = HuggingfaceHubEmbeddings(
|
||||
model=model_name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
embedding_model.embed_query("ping")
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
|
||||
@@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=0.7),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
max_tokens=KwargRule[int](min=10, max=4097, default=16),
|
||||
temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.95),
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -40,6 +40,10 @@ class OpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-instruct',
|
||||
'name': 'GPT-3.5-Turbo-Instruct',
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
@@ -128,16 +132,17 @@ class OpenAIProvider(BaseModelProvider):
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 8192,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
|
||||
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
|
||||
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
|
||||
default=float(value.get('default')) if value.get('default') is not None else None,
|
||||
precision = 2
|
||||
)
|
||||
if key == 'temperature':
|
||||
model_kwargs_rules.temperature = kwarg_rule
|
||||
@@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
|
||||
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
|
||||
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
|
||||
default=int(value.get('default')) if value.get('default') is not None else 500,
|
||||
precision = 0
|
||||
)
|
||||
|
||||
return model_kwargs_rules
|
||||
|
||||
@@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.5),
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2),
|
||||
top_p=KwargRule[float](enabled=False),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
|
||||
max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider):
|
||||
if 'api_secret' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'app_id': credentials['app_id'],
|
||||
'api_key': credentials['api_key'],
|
||||
'api_secret': credentials['api_secret'],
|
||||
}
|
||||
credential_kwargs = {
|
||||
'app_id': credentials['app_id'],
|
||||
'api_key': credentials['api_key'],
|
||||
'api_secret': credentials['api_secret'],
|
||||
}
|
||||
|
||||
try:
|
||||
chat_llm = ChatSpark(
|
||||
model_name='spark-v2',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
@@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider):
|
||||
|
||||
chat_llm(messages)
|
||||
except SparkError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
# try spark v1.5 if v2.1 failed
|
||||
try:
|
||||
chat_llm = ChatSpark(
|
||||
model_name='spark',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except SparkError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('Spark config validation failed')
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logging.exception('Spark config validation failed')
|
||||
raise ex
|
||||
|
||||
@@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](enabled=False),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.8),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
|
||||
"""
|
||||
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
|
||||
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
|
||||
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
@@ -52,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
|
||||
credentials = self.get_model_credentials(model_name, model_type)
|
||||
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
|
||||
)
|
||||
elif credentials['model_format'] == "ggmlv3":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
|
||||
)
|
||||
|
||||
|
||||
@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
llm("ping")
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
embedding = XinferenceEmbeddings(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
|
||||
|
||||
176
api/core/model_providers/providers/zhipuai_provider.py
Normal file
176
api/core/model_providers/providers/zhipuai_provider.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class ZhipuAIProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'zhipuai'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'chatglm_pro',
|
||||
'name': 'chatglm_pro',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_std',
|
||||
'name': 'chatglm_std',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_lite',
|
||||
'name': 'chatglm_lite',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_lite_32k',
|
||||
'name': 'chatglm_lite_32k',
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'text_embedding',
|
||||
'name': 'text_embedding',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = ZhipuAIModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = ZhipuAIEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
|
||||
top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key']
|
||||
}
|
||||
|
||||
llm = ZhipuAIChatLLM(
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm([HumanMessage(content='ping')])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
or (self.provider.provider_type == ProviderType.SYSTEM.value
|
||||
and self.provider.quota_type == ProviderQuotaType.FREE.value):
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
@@ -6,6 +6,7 @@
|
||||
"tongyi",
|
||||
"spark",
|
||||
"wenxin",
|
||||
"zhipuai",
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub",
|
||||
|
||||
@@ -30,6 +30,12 @@
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-instruct": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
|
||||
44
api/core/model_providers/rules/zhipuai.json
Normal file
44
api/core/model_providers/rules/zhipuai.json
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"system",
|
||||
"custom"
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"free"
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"chatglm_pro": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.01",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"chatglm_std": {
|
||||
"prompt": "0.005",
|
||||
"completion": "0.005",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"chatglm_lite": {
|
||||
"prompt": "0.002",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"chatglm_lite_32k": {
|
||||
"prompt": "0.0004",
|
||||
"completion": "0.0004",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"text_embedding": {
|
||||
"completion": "0",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
@@ -36,8 +38,8 @@ class OrchestratorRuleParser:
|
||||
self.app_model_config = app_model_config
|
||||
|
||||
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
|
||||
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
|
||||
-> Optional[AgentExecutor]:
|
||||
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
|
||||
return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
|
||||
if not self.app_model_config.agent_mode_dict:
|
||||
return None
|
||||
|
||||
@@ -63,7 +65,7 @@ class OrchestratorRuleParser:
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_instant=agent_model_instance,
|
||||
model_instance=agent_model_instance,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
|
||||
@@ -74,7 +76,7 @@ class OrchestratorRuleParser:
|
||||
|
||||
# only OpenAI chat model (include Azure) support function call, use ReACT instead
|
||||
if agent_model_instance.model_mode != ModelMode.CHAT \
|
||||
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
|
||||
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
|
||||
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
@@ -99,7 +101,9 @@ class OrchestratorRuleParser:
|
||||
tool_configs=tool_configs,
|
||||
conversation_message_task=conversation_message_task,
|
||||
rest_tokens=rest_tokens,
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
|
||||
return_resource=return_resource,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
@@ -121,23 +125,45 @@ class OrchestratorRuleParser:
|
||||
|
||||
return chain
|
||||
|
||||
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
|
||||
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
|
||||
-> Optional[SensitiveWordAvoidanceChain]:
|
||||
"""
|
||||
Convert app sensitive word avoidance config to chain
|
||||
|
||||
:param model_instance: model instance
|
||||
:param callbacks: callbacks for the chain
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if not self.app_model_config.sensitive_word_avoidance_dict:
|
||||
return None
|
||||
sensitive_word_avoidance_rule = None
|
||||
|
||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
|
||||
if self.app_model_config.sensitive_word_avoidance_dict:
|
||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||
if sensitive_word_avoidance_config.get("enabled", False):
|
||||
if sensitive_word_avoidance_config.get('type') == 'moderation':
|
||||
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||
type=SensitiveWordAvoidanceRule.Type.MODERATION,
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||
if sensitive_word_avoidance_config.get("canned_response")
|
||||
else 'Your content violates our usage policy. Please revise and try again.',
|
||||
)
|
||||
else:
|
||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||
if sensitive_words:
|
||||
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||
if sensitive_word_avoidance_config.get("canned_response")
|
||||
else 'Your content violates our usage policy. Please revise and try again.',
|
||||
extra_params={
|
||||
'sensitive_words': sensitive_words.split(','),
|
||||
}
|
||||
)
|
||||
|
||||
if sensitive_word_avoidance_rule:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
|
||||
model_instance=model_instance,
|
||||
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callbacks=callbacks,
|
||||
**kwargs
|
||||
@@ -145,8 +171,10 @@ class OrchestratorRuleParser:
|
||||
|
||||
return None
|
||||
|
||||
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
|
||||
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
|
||||
retriever_from: str = 'dev') -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
@@ -155,6 +183,8 @@ class OrchestratorRuleParser:
|
||||
:param tool_configs: app agent tool configs
|
||||
:param conversation_message_task:
|
||||
:param callbacks:
|
||||
:param return_resource:
|
||||
:param retriever_from:
|
||||
:return:
|
||||
"""
|
||||
tools = []
|
||||
@@ -166,7 +196,7 @@ class OrchestratorRuleParser:
|
||||
|
||||
tool = None
|
||||
if tool_type == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
|
||||
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
|
||||
elif tool_type == "web_reader":
|
||||
tool = self.to_web_reader_tool(agent_model_instance)
|
||||
elif tool_type == "google_search":
|
||||
@@ -183,13 +213,15 @@ class OrchestratorRuleParser:
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int) \
|
||||
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param rest_tokens:
|
||||
:param tool_config:
|
||||
:param conversation_message_task:
|
||||
:param return_resource:
|
||||
:param retriever_from:
|
||||
:return:
|
||||
"""
|
||||
# get dataset from dataset id
|
||||
@@ -208,7 +240,10 @@ class OrchestratorRuleParser:
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
k=k,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
||||
conversation_message_task=conversation_message_task,
|
||||
return_resource=return_resource,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
@@ -8,6 +8,6 @@
|
||||
"pre_prompt",
|
||||
"histories_prompt"
|
||||
],
|
||||
"query_prompt": "用户:{{query}}",
|
||||
"query_prompt": "\n\n用户:{{query}}",
|
||||
"stops": ["用户:"]
|
||||
}
|
||||
@@ -8,6 +8,6 @@
|
||||
"pre_prompt",
|
||||
"histories_prompt"
|
||||
],
|
||||
"query_prompt": "Human: {{query}}\n\nAssistant: ",
|
||||
"query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ",
|
||||
"stops": ["\nHuman:", "</histories>"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,65 @@
|
||||
CONVERSATION_TITLE_PROMPT = (
|
||||
"Human:{query}\n-----\n"
|
||||
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
|
||||
"If what the human said is conducted in English, you should only return an English title.\n"
|
||||
"If what the human said is conducted in Chinese, you should only return a Chinese title.\n"
|
||||
"title:"
|
||||
)
|
||||
# Written by YORKI MINAKO🤡
|
||||
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
|
||||
Notice: the language type user use could be diverse, which can be English, Chinese, Español, Arabic, Japanese, French, and etc.
|
||||
MAKE SURE your output is the SAME language as the user's input!
|
||||
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
|
||||
Your output MUST be a valid JSON.
|
||||
|
||||
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
|
||||
|
||||
|
||||
example 1:
|
||||
User Input: hi, yesterday i had some burgers.
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "sharing yesterday's food"
|
||||
}
|
||||
|
||||
example 2:
|
||||
User Input: hello
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Greeting myself☺️"
|
||||
}
|
||||
|
||||
|
||||
example 3:
|
||||
User Input: why mmap file: oom
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Asking about the reason for mmap file: oom"
|
||||
}
|
||||
|
||||
|
||||
example 4:
|
||||
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么?
|
||||
{
|
||||
"Language Type": "The user's input English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
|
||||
}
|
||||
|
||||
example 5:
|
||||
User Input: why小红的年龄is老than小明?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English parts are subjective particles, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问小红和小明的年龄"
|
||||
}
|
||||
|
||||
example 6:
|
||||
User Input: yo, 你今天咋样?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "查询今日我的状态☺️"
|
||||
}
|
||||
|
||||
User Input:
|
||||
"""
|
||||
|
||||
CONVERSATION_SUMMARY_PROMPT = (
|
||||
"Please generate a short summary of the following conversation.\n"
|
||||
@@ -50,7 +105,7 @@ GENERATOR_QA_PROMPT = (
|
||||
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
|
||||
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
|
||||
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
|
||||
"Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
"Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
)
|
||||
|
||||
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||
|
||||
68
api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py
vendored
Normal file
68
api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
HOSTED_INFERENCE_API = 'hosted_inference_api'
|
||||
INFERENCE_ENDPOINTS = 'inference_endpoints'
|
||||
|
||||
|
||||
class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
|
||||
client: Any
|
||||
model: str
|
||||
|
||||
task_type: Optional[str] = None
|
||||
huggingfacehub_api_type: Optional[str] = None
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
huggingfacehub_endpoint_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values['huggingfacehub_api_token'] = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
values['client'] = InferenceClient(values['huggingfacehub_api_token'])
|
||||
|
||||
return values
|
||||
|
||||
def embeddings(self, inputs: Union[str, List[str]]) -> str:
|
||||
model = ''
|
||||
|
||||
if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
|
||||
model = self.model
|
||||
else:
|
||||
model = self.huggingfacehub_endpoint_url
|
||||
|
||||
output = self.client.post(
|
||||
json={
|
||||
"inputs": inputs,
|
||||
"options": {
|
||||
"wait_for_model": False
|
||||
}
|
||||
}, model=model)
|
||||
|
||||
return json.loads(output.decode())
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
output = self.embeddings(texts)
|
||||
|
||||
if isinstance(output, list):
|
||||
return output
|
||||
|
||||
return [list(map(float, e)) for e in output]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
output = self.embeddings(text)
|
||||
|
||||
if isinstance(output, list):
|
||||
return output
|
||||
|
||||
return list(map(float, output))
|
||||
64
api/core/third_party/langchain/embeddings/zhipuai_embedding.py
vendored
Normal file
64
api/core/third_party/langchain/embeddings/zhipuai_embedding.py
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Wrapper around ZhipuAI embedding models."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
|
||||
|
||||
|
||||
class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around ZhipuAI embedding models.
|
||||
1024 dimensions.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "ZHIPUAI_API_KEY"
|
||||
)
|
||||
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = self.client.invoke(model=self.model, prompt=text)
|
||||
data = response["data"]
|
||||
embeddings.append(data.get('embedding'))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
62
api/core/third_party/langchain/llms/huggingface_hub_llm.py
vendored
Normal file
62
api/core/third_party/langchain/llms/huggingface_hub_llm.py
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Dict, Optional, List, Any
|
||||
|
||||
from huggingface_hub import HfApi, InferenceApi
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.huggingface_hub import VALID_TASKS
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class HuggingFaceHubLLM(HuggingFaceHub):
|
||||
"""HuggingFaceHub models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import HuggingFaceHub
|
||||
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
|
||||
"""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
client = InferenceApi(
|
||||
repo_id=values["repo_id"],
|
||||
token=huggingfacehub_api_token,
|
||||
task=values.get("task"),
|
||||
)
|
||||
client.options = {"wait_for_model": False, "use_gpu": False}
|
||||
values["client"] = client
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
hfapi = HfApi(token=self.huggingfacehub_api_token)
|
||||
model_info = hfapi.model_info(repo_id=self.repo_id)
|
||||
if not model_info:
|
||||
raise ValueError(f"Model {self.repo_id} not found.")
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")
|
||||
|
||||
if model_info.pipeline_tag not in VALID_TASKS:
|
||||
raise ValueError(f"Model {self.repo_id} is not a valid task, "
|
||||
f"must be one of {VALID_TASKS}.")
|
||||
|
||||
return super()._call(prompt, stop, run_manager, **kwargs)
|
||||
@@ -14,6 +14,9 @@ class EnhanceOpenAI(OpenAI):
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
def __new__(cls, **data: Any): # type: ignore
|
||||
return super(EnhanceOpenAI, cls).__new__(cls)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
315
api/core/third_party/langchain/llms/zhipuai_llm.py
vendored
Normal file
315
api/core/third_party/langchain/llms/zhipuai_llm.py
vendored
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Wrapper around ZhipuAI APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import posixpath
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional, Iterator, Sequence,
|
||||
)
|
||||
|
||||
import zhipuai
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
|
||||
from pydantic import Extra, root_validator, BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils import jwt_token
|
||||
from zhipuai.utils.http_client import post, stream
|
||||
from zhipuai.utils.sse_client import SSEClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZhipuModelAPI(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
api_timeout_seconds = 60
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SYNC)
|
||||
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
if not response['success']:
|
||||
raise ValueError(
|
||||
f"Error Code: {response['code']}, Message: {response['msg']} "
|
||||
)
|
||||
return response
|
||||
|
||||
def sse_invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SSE)
|
||||
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
return SSEClient(data)
|
||||
|
||||
def _build_api_url(self, kwargs, *path):
|
||||
if kwargs:
|
||||
if "model" not in kwargs:
|
||||
raise Exception("model param missed")
|
||||
model = kwargs.pop("model")
|
||||
else:
|
||||
model = "-"
|
||||
|
||||
return posixpath.join(self.base_url, model, *path)
|
||||
|
||||
def _generate_token(self):
|
||||
if not self.api_key:
|
||||
raise Exception(
|
||||
"api_key not provided, you could provide it."
|
||||
)
|
||||
|
||||
try:
|
||||
return jwt_token.generate_token(self.api_key)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Your api_key is invalid, please check it."
|
||||
)
|
||||
|
||||
|
||||
class ZhipuAIChatLLM(BaseChatModel):
|
||||
"""Wrapper around ZhipuAI large language models.
|
||||
To use, you should pass the api_key as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from core.third_party.langchain.llms.zhipuai import ZhipuAI
|
||||
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "chatglm_lite"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.95
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.7
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
api_key: Optional[str] = None
|
||||
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "ZHIPUAI_API_KEY"
|
||||
)
|
||||
|
||||
if 'test' in values['base_url']:
|
||||
values['model'] = 'chatglm_130b_test'
|
||||
|
||||
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "zhipuai"
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
dict_messages = []
|
||||
for m in messages:
|
||||
message = self._convert_message_to_dict(m)
|
||||
if dict_messages:
|
||||
previous_message = dict_messages[-1]
|
||||
if previous_message['role'] == message['role']:
|
||||
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
|
||||
return dict_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
continue
|
||||
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["prompt"] = message_dicts
|
||||
request.update(kwargs)
|
||||
response = self.client.invoke(**request)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["prompt"] = message_dicts
|
||||
request.update(kwargs)
|
||||
|
||||
for event in self.client.sse_invoke(incremental=True, **request).events():
|
||||
if event.event == "add":
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(event.data)
|
||||
elif event.event == "error" or event.event == "interrupted":
|
||||
raise ValueError(
|
||||
f"{event.data}"
|
||||
)
|
||||
elif event.event == "finish":
|
||||
meta = json.loads(event.meta)
|
||||
token_usage = meta['usage']
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=event.data),
|
||||
generation_info=dict({'token_usage': token_usage})
|
||||
)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||
data = response["data"]
|
||||
generations = []
|
||||
for res in data["choices"]:
|
||||
message = self._convert_dict_to_message(res)
|
||||
gen = ChatGeneration(
|
||||
message=message
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = data.get("usage")
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
# def get_token_ids(self, text: str) -> List[int]:
|
||||
# """Return the ordered ids of the tokens in a text.
|
||||
#
|
||||
# Args:
|
||||
# text: The string input to tokenize.
|
||||
#
|
||||
# Returns:
|
||||
# A list of ids corresponding to the tokens in the text, in order they occur
|
||||
# in the text.
|
||||
# """
|
||||
# from core.third_party.transformers.Token import ChatGLMTokenizer
|
||||
#
|
||||
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
|
||||
# return tokenizer.encode(text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from flask import current_app
|
||||
@@ -5,13 +6,14 @@ from langchain.tools import BaseTool
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset, DocumentSegment, Document
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
@@ -27,6 +29,9 @@ class DatasetRetrieverTool(BaseTool):
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
k: int = 3
|
||||
conversation_message_task: ConversationMessageTask
|
||||
return_resource: str
|
||||
retriever_from: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
@@ -86,16 +91,23 @@ class DatasetRetrieverTool(BaseTool):
|
||||
if self.k > 0:
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity',
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
'k': self.k,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
|
||||
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
|
||||
@@ -112,9 +124,43 @@ class DatasetRetrieverTool(BaseTool):
|
||||
float('inf')))
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
|
||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
context = {}
|
||||
document = Document.query.filter(Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': self.retriever_from
|
||||
}
|
||||
if dataset.indexing_technique != "economy":
|
||||
source['score'] = document_score_list.get(segment.index_node_id)
|
||||
if self.retriever_from == 'dev':
|
||||
source['hit_count'] = segment.hit_count
|
||||
source['word_count'] = segment.word_count
|
||||
source['segment_position'] = segment.position
|
||||
source['index_node_hash'] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
|
||||
@@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
def delete_group(self):
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
|
||||
@@ -26,7 +26,7 @@ def handle(sender, **kwargs):
|
||||
|
||||
conversation.name = name
|
||||
except:
|
||||
conversation.name = 'New Chat'
|
||||
conversation.name = 'New conversation'
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
"""add_dataset_retriever_resource
|
||||
|
||||
Revision ID: 6dcb43972bdc
|
||||
Revises: 4bcffcd64aa4
|
||||
Create Date: 2023-09-06 16:51:27.385844
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '6dcb43972bdc'
|
||||
down_revision = '4bcffcd64aa4'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('dataset_retriever_resources',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('message_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.Column('dataset_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('dataset_name', sa.Text(), nullable=False),
|
||||
sa.Column('document_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('document_name', sa.Text(), nullable=False),
|
||||
sa.Column('data_source_type', sa.Text(), nullable=False),
|
||||
sa.Column('segment_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('score', sa.Float(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('hit_count', sa.Integer(), nullable=True),
|
||||
sa.Column('word_count', sa.Integer(), nullable=True),
|
||||
sa.Column('segment_position', sa.Integer(), nullable=True),
|
||||
sa.Column('index_node_hash', sa.Text(), nullable=True),
|
||||
sa.Column('retriever_from', sa.Text(), nullable=False),
|
||||
sa.Column('created_by', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey')
|
||||
)
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.drop_index('dataset_retriever_resource_message_id_idx')
|
||||
|
||||
op.drop_table('dataset_retriever_resources')
|
||||
# ### end Alembic commands ###
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user