mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
65 Commits
fix/json-f
...
feat/add-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
765c2b92a1 | ||
|
|
eff115267f | ||
|
|
07cde4f8fe | ||
|
|
9f28a48a92 | ||
|
|
0d3cd3b16a | ||
|
|
3dc82fb044 | ||
|
|
cb6e73347e | ||
|
|
ecd6cbaee6 | ||
|
|
d54e942264 | ||
|
|
28ba721455 | ||
|
|
784dd7848e | ||
|
|
e2a5f8ba1a | ||
|
|
8e11200306 | ||
|
|
7599f79a17 | ||
|
|
510389909c | ||
|
|
2c6e00174b | ||
|
|
24f3456990 | ||
|
|
20514ff288 | ||
|
|
381d255290 | ||
|
|
7f320f9146 | ||
|
|
cd51d3323b | ||
|
|
004b3caa43 | ||
|
|
dbe10799e3 | ||
|
|
054ba88434 | ||
|
|
da82a11b26 | ||
|
|
fec607db81 | ||
|
|
397a92f2ee | ||
|
|
b91e226063 | ||
|
|
da5782df92 | ||
|
|
9af0da4450 | ||
|
|
d49ac1e4ac | ||
|
|
57de19a5ca | ||
|
|
7c00a0b6a3 | ||
|
|
a93506df18 | ||
|
|
a03a92e9db | ||
|
|
feebb5dd1f | ||
|
|
6eee7cb42c | ||
|
|
11baff6740 | ||
|
|
cde1797cc0 | ||
|
|
d143284d99 | ||
|
|
2b94545190 | ||
|
|
ed6648a41e | ||
|
|
5e2c3eeac3 | ||
|
|
b23d8a912b | ||
|
|
4f13f8fd0a | ||
|
|
561c9cabd5 | ||
|
|
39ea967b30 | ||
|
|
da04ff040b | ||
|
|
b9b0866a46 | ||
|
|
c6ab7eebd9 | ||
|
|
db4e6d81c5 | ||
|
|
df68a7c82b | ||
|
|
838825d747 | ||
|
|
a87f6f2837 | ||
|
|
9d98669e7d | ||
|
|
408fbb0c70 | ||
|
|
998f819b04 | ||
|
|
6194b82752 | ||
|
|
334f46d0b6 | ||
|
|
2eea114ac0 | ||
|
|
97e9ebd29a | ||
|
|
ec261aea54 | ||
|
|
accc5faae3 | ||
|
|
0462f09ecc | ||
|
|
1226d73159 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -109,6 +109,7 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
@@ -147,3 +148,5 @@ docker/volumes/weaviate/*
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/
|
||||
12
README.md
12
README.md
@@ -17,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa
|
||||
Visual data analysis, log review, and annotation for applications
|
||||
Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported:
|
||||
|
||||
- GPT 3 (text-davinci-003)
|
||||
- GPT 3.5 Turbo(ChatGPT)
|
||||
- GPT-4
|
||||
* **OpenAI** :GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
|
||||
|
||||
* **Azure OpenAI**
|
||||
|
||||
* **Antropic**:Claude2、Claude-instant
|
||||
> We've got 1000 free trial credits available for all cloud service users to try out the Claude model.Visit [Dify.ai](https://dify.ai) and
|
||||
try it now.
|
||||
|
||||
* **hugging face Hub**:Coming soon.
|
||||
|
||||
## Use Cloud Services
|
||||
|
||||
|
||||
13
README_CN.md
13
README_CN.md
@@ -17,11 +17,16 @@
|
||||
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
|
||||
- 可视化的对应用进行数据分析,查阅日志或进行标注
|
||||
|
||||
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前已支持:
|
||||
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:
|
||||
|
||||
- GPT 3 (text-davinci-003)
|
||||
- GPT 3.5 Turbo(ChatGPT)
|
||||
- GPT-4
|
||||
* **OpenAI**:GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
|
||||
|
||||
* **Azure OpenAI Service**
|
||||
* **Anthropic**:Claude2、Claude-instant
|
||||
|
||||
> 我们为所有注册云端版的用户免费提供了 1000 次 Claude 模型的消息调用额度,登录 [dify.ai](https://cloud.dify.ai) 即可使用。
|
||||
|
||||
* **Hugging Face Hub**(即将推出)
|
||||
|
||||
## 使用云服务
|
||||
|
||||
|
||||
@@ -8,13 +8,19 @@ EDITION=SELF_HOSTED
|
||||
SECRET_KEY=
|
||||
|
||||
# Console API base URL
|
||||
CONSOLE_URL=http://127.0.0.1:5001
|
||||
CONSOLE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Console frontend web base URL
|
||||
CONSOLE_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Service API base URL
|
||||
API_URL=http://127.0.0.1:5001
|
||||
SERVICE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP base URL
|
||||
APP_URL=http://127.0.0.1:3000
|
||||
# Web APP API base URL
|
||||
APP_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP frontend web base URL
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
@@ -79,6 +85,11 @@ WEAVIATE_BATCH_SIZE=100
|
||||
QDRANT_URL=path:storage/qdrant
|
||||
QDRANT_API_KEY=your-qdrant-api-key
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
|
||||
@@ -5,9 +5,11 @@ LABEL maintainer="takatost@gmail.com"
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
ENV CONSOLE_URL http://127.0.0.1:5001
|
||||
ENV API_URL http://127.0.0.1:5001
|
||||
ENV APP_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_API_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_WEB_URL http://127.0.0.1:3000
|
||||
ENV SERVICE_API_URL http://127.0.0.1:5001
|
||||
ENV APP_API_URL http://127.0.0.1:5001
|
||||
ENV APP_WEB_URL http://127.0.0.1:3000
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
@@ -25,4 +27,4 @@ RUN chmod +x /entrypoint.sh
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA ${COMMIT_SHA}
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
|
||||
16
api/app.py
16
api/app.py
@@ -2,6 +2,8 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
@@ -15,7 +17,7 @@ import flask_login
|
||||
from flask_cors import CORS
|
||||
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage
|
||||
ext_database, ext_storage, ext_mail
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
@@ -27,7 +29,7 @@ from events import event_handlers
|
||||
import core
|
||||
from config import Config, CloudEditionConfig
|
||||
from commands import register_commands
|
||||
from models.account import TenantAccountJoin
|
||||
from models.account import TenantAccountJoin, AccountStatus
|
||||
from models.model import Account, EndUser, App
|
||||
|
||||
import warnings
|
||||
@@ -83,6 +85,7 @@ def initialize_extensions(app):
|
||||
ext_celery.init_app(app)
|
||||
ext_session.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
|
||||
|
||||
@@ -100,6 +103,9 @@ def load_user(user_id):
|
||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if account:
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise Forbidden('Account is banned or closed.')
|
||||
|
||||
workspace_id = session.get('workspace_id')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
@@ -149,13 +155,17 @@ def register_blueprints(app):
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.console import bp as console_app_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(web_bp,
|
||||
resources={
|
||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization'],
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
|
||||
@@ -18,7 +18,8 @@ from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider
|
||||
from models.provider import Provider, ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
|
||||
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||
def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for tenant in tenants:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(generate_invitation_codes)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
|
||||
@@ -28,9 +28,11 @@ DEFAULTS = {
|
||||
'SESSION_REDIS_USE_SSL': 'False',
|
||||
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
||||
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
||||
'CONSOLE_URL': 'https://cloud.dify.ai',
|
||||
'API_URL': 'https://api.dify.ai',
|
||||
'APP_URL': 'https://udify.app',
|
||||
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
||||
'CONSOLE_API_URL': 'https://cloud.dify.ai',
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
@@ -48,7 +50,10 @@ DEFAULTS = {
|
||||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'DEFAULT_LLM_PROVIDER': 'openai'
|
||||
'DEFAULT_LLM_PROVIDER': 'openai',
|
||||
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
|
||||
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
|
||||
'TENANT_DOCUMENT_COUNT': 100
|
||||
}
|
||||
|
||||
|
||||
@@ -76,10 +81,15 @@ class Config:
|
||||
|
||||
def __init__(self):
|
||||
# app settings
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.5"
|
||||
self.CURRENT_VERSION = "0.3.10"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -147,10 +157,15 @@ class Config:
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'WEB_API_CORS_ALLOW_ORIGINS', '*')
|
||||
|
||||
# mail settings
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# sentry settings
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
@@ -179,6 +194,10 @@ class Config:
|
||||
|
||||
# hosted provider credentials
|
||||
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
|
||||
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
|
||||
|
||||
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
|
||||
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
@@ -195,6 +214,8 @@ class Config:
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ api = ExternalApi(bp)
|
||||
from . import setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
@@ -21,4 +21,4 @@ from .datasets import datasets, datasets_document, datasets_segments, file, hit_
|
||||
from .workspace import workspace, members, providers, account
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||
|
||||
@@ -22,6 +22,7 @@ model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
@@ -144,6 +145,7 @@ class AppListApi(Resource):
|
||||
opening_statement=model_configuration['opening_statement'],
|
||||
suggested_questions=json.dumps(model_configuration['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
@@ -434,6 +436,7 @@ class AppCopy(Resource):
|
||||
opening_statement=app_config.opening_statement,
|
||||
suggested_questions=app_config.suggested_questions,
|
||||
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
|
||||
speech_to_text=app_config.speech_to_text,
|
||||
more_like_this=app_config.more_like_this,
|
||||
model=app_config.model,
|
||||
user_input_form=app_config.user_input_form,
|
||||
|
||||
69
api/controllers/console/app/audio.py
Normal file
69
api/controllers/console/app/audio.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import AppUnavailableError, \
|
||||
ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
|
||||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from flask_restful import Resource
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
||||
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, 'chat')
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -209,6 +209,26 @@ class CompletionConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'completion')
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, conversation_id):
|
||||
app_id = str(app_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
app = _get_app(app_id, 'chat')
|
||||
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ChatConversationApi(Resource):
|
||||
@@ -356,6 +376,27 @@ class ChatConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'chat')
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, conversation_id):
|
||||
app_id = str(app_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
# get app info
|
||||
app = _get_app(app_id, 'chat')
|
||||
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = 'provider_quota_exceeded'
|
||||
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
|
||||
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
code = 400
|
||||
|
||||
@@ -49,3 +49,27 @@ class AppMoreLikeThisDisabledError(BaseHTTPException):
|
||||
error_code = 'app_more_like_this_disabled'
|
||||
description = "The 'More like this' feature is disabled. Please refresh your page."
|
||||
code = 403
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
|
||||
args['audiences'],
|
||||
args['hoping_to_solve']
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -41,6 +41,7 @@ class ModelConfigResource(Resource):
|
||||
opening_statement=model_configuration['opening_statement'],
|
||||
suggested_questions=json.dumps(model_configuration['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
|
||||
75
api/controllers/console/auth/activate.py
Normal file
75
api/controllers/console/auth/activate.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, supported_language, timezone
|
||||
from libs.password import valid_password, hash_password
|
||||
from models.account import AccountStatus, Tenant
|
||||
from services.account_service import RegisterService
|
||||
|
||||
|
||||
class ActivateCheckApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='args')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == args['workspace_id'],
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
|
||||
return {'is_valid': account is not None, 'workspace_name': tenant.name}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='json')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
if account is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
account.name = args['name']
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(args['password'], salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
account.interface_language = args['interface_language']
|
||||
account.timezone = args['timezone']
|
||||
account.interface_theme = 'light'
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, '/activate/check')
|
||||
api.add_resource(ActivateApi, '/activate')
|
||||
@@ -20,7 +20,7 @@ def get_oauth_providers():
|
||||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
@@ -42,7 +42,7 @@ class OAuthDataSource(Resource):
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
@@ -66,12 +66,12 @@ class OAuthDataSourceCallback(Resource):
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source={error}')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=access_denied')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
|
||||
@@ -20,13 +20,13 @@ def get_oauth_providers():
|
||||
client_secret=current_app.config.get(
|
||||
'GITHUB_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/github')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
|
||||
|
||||
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GOOGLE_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/google')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'github': github_oauth,
|
||||
@@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
||||
flask_login.login_user(account, remember=True)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
|
||||
document_data=args,
|
||||
account=current_user
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -18,3 +18,9 @@ class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = 'account_not_link_tenant'
|
||||
description = "Account not link tenant."
|
||||
code = 403
|
||||
|
||||
|
||||
class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = 'already_activate'
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
||||
66
api/controllers/console/explore/audio.py
Normal file
66
api/controllers/console/explore/audio.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
|
||||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -21,6 +21,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
@@ -35,6 +36,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
@@ -6,22 +6,23 @@ from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \
|
||||
RepeatPasswordNotMatchError
|
||||
RepeatPasswordNotMatchError, CurrentPasswordIncorrectError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import TimestampField, supported_language, timezone
|
||||
from extensions.ext_database import db
|
||||
from models.account import InvitationCode, AccountIntegrate
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'is_password_set': fields.Boolean,
|
||||
'interface_language': fields.String,
|
||||
'interface_theme': fields.String,
|
||||
'timezone': fields.String,
|
||||
@@ -194,8 +195,11 @@ class AccountPasswordApi(Resource):
|
||||
if args['new_password'] != args['repeat_new_password']:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
try:
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -7,6 +7,12 @@ class RepeatPasswordNotMatchError(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class CurrentPasswordIncorrectError(BaseHTTPException):
|
||||
error_code = 'current_password_incorrect'
|
||||
description = "Current password is incorrect."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderRequestFailedError(BaseHTTPException):
|
||||
error_code = 'provider_request_failed'
|
||||
description = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
@@ -60,7 +60,8 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
|
||||
try:
|
||||
RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, inviter=inviter)
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == args['email']).first()
|
||||
@@ -78,7 +79,16 @@ class MemberInviteEmailApi(Resource):
|
||||
|
||||
# todo:413
|
||||
|
||||
return {'result': 'success', 'account': account}, 201
|
||||
return {
|
||||
'result': 'success',
|
||||
'account': account,
|
||||
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
str(current_user.current_tenant_id),
|
||||
invitee_email,
|
||||
token
|
||||
)
|
||||
}, 201
|
||||
|
||||
|
||||
class MemberCancelInviteApi(Resource):
|
||||
@@ -88,7 +98,7 @@ class MemberCancelInviteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id):
|
||||
member = Account.query.get(str(member_id))
|
||||
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, abort
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
|
||||
ProviderService.init_supported_provider(current_user.current_tenant)
|
||||
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
|
||||
|
||||
provider_list = [
|
||||
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
|
||||
'quota_used': p.quota_used
|
||||
} if p.provider_type == ProviderType.SYSTEM.value else {}),
|
||||
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
|
||||
ProviderName(p.provider_name))
|
||||
ProviderName(p.provider_name), only_custom=True)
|
||||
if p.provider_type == ProviderType.CUSTOM.value else None
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
|
||||
is_valid=token_is_valid)
|
||||
db.session.add(provider_model)
|
||||
|
||||
if provider_model.is_valid:
|
||||
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
|
||||
other_providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
|
||||
Provider.provider_name != provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
|
||||
|
||||
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: remove this when the provider is supported
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
||||
|
||||
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
|
||||
provider_model.is_valid = args['is_enabled']
|
||||
db.session.commit()
|
||||
elif not provider_model:
|
||||
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
|
||||
else:
|
||||
quota_limit = 0
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
provider,
|
||||
quota_limit,
|
||||
args['is_enabled']
|
||||
)
|
||||
else:
|
||||
abort(403)
|
||||
|
||||
|
||||
@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from .app import completion, app, conversation, message
|
||||
from .app import completion, app, conversation, message, audio
|
||||
|
||||
from .dataset import document
|
||||
|
||||
@@ -22,6 +22,7 @@ class AppParameterApi(AppApiResource):
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
@@ -35,6 +36,7 @@ class AppParameterApi(AppApiResource):
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
61
api/controllers/service_api/app/audio.py
Normal file
61
api/controllers/service_api/app/audio.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
|
||||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
|
||||
ProviderNotSupportSpeechToTextError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
||||
class AudioApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -48,6 +49,24 @@ class ConversationApi(AppApiResource):
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
@marshal_with(conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
return {"result": "success"}
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
|
||||
@@ -74,3 +93,5 @@ class ConversationRenameApi(AppApiResource):
|
||||
|
||||
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
|
||||
api.add_resource(ConversationApi, '/conversations')
|
||||
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>', endpoint='conversation')
|
||||
api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail')
|
||||
|
||||
@@ -51,3 +51,27 @@ class CompletionRequestError(BaseHTTPException):
|
||||
description = "Completion request failed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
|
||||
dataset_process_rule=dataset.latest_process_rule,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
if doc_type and doc_metadata:
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import completion, app, conversation, message, site, saved_message
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||
|
||||
@@ -21,6 +21,7 @@ class AppParameterApi(WebApiResource):
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
@@ -34,6 +35,7 @@ class AppParameterApi(WebApiResource):
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
63
api/controllers/web/audio.py
Normal file
63
api/controllers/web/audio.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
class AudioApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -62,3 +62,27 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
||||
error_code = 'app_suggested_questions_after_answer_disabled'
|
||||
description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page."
|
||||
code = 403
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
64
api/controllers/web/passport.py
Normal file
64
api/controllers/web/passport.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from controllers.web import api
|
||||
from flask_restful import Resource
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Unauthorized, NotFound
|
||||
from models.model import Site, EndUser, App
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
|
||||
class PassportResource(Resource):
|
||||
"""Base resource for passport."""
|
||||
def get(self):
|
||||
app_id = request.headers.get('X-App-Code')
|
||||
if app_id is None:
|
||||
raise Unauthorized('X-App-Code header is missing.')
|
||||
|
||||
# get site from db and check if it is normal
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == app_id,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
if not site:
|
||||
raise NotFound()
|
||||
# get app from db and check if it is normal and enable_site
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
'sub': 'Web API Passport',
|
||||
'app_id': site.app_id,
|
||||
'end_user_id': end_user.id,
|
||||
}
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
'access_token': tk,
|
||||
}
|
||||
|
||||
api.add_resource(PassportResource, '/passport')
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
while True:
|
||||
session_id = str(uuid.uuid4())
|
||||
existing_count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
if existing_count == 0:
|
||||
return session_id
|
||||
@@ -1,110 +1,48 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from flask import request, session
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Site, EndUser
|
||||
from models.model import App, EndUser
|
||||
from libs.passport import PassportService
|
||||
|
||||
|
||||
def validate_token(view=None):
|
||||
def validate_jwt_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
site = validate_and_get_site()
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
|
||||
if app_model.status != 'normal':
|
||||
raise NotFound()
|
||||
|
||||
if not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = create_or_update_end_user_for_session(app_model)
|
||||
app_model, end_user = decode_jwt_token()
|
||||
|
||||
return view(app_model, end_user, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_site():
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
def decode_jwt_token():
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized('Authorization header is missing.')
|
||||
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == auth_token,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not site:
|
||||
decoded = PassportService().verify(tk)
|
||||
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
def create_or_update_end_user_for_session(app_model):
|
||||
"""
|
||||
Create or update session terminal based on session ID.
|
||||
"""
|
||||
if 'session_id' not in session:
|
||||
session['session_id'] = generate_session_id()
|
||||
|
||||
session_id = session.get('session_id')
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.type == 'browser'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=session_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
count = 1
|
||||
session_id = ''
|
||||
while count != 0:
|
||||
session_id = str(uuid.uuid4())
|
||||
count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
|
||||
return session_id
|
||||
|
||||
return app_model, end_user
|
||||
|
||||
class WebApiResource(Resource):
|
||||
method_decorators = [validate_token]
|
||||
method_decorators = [validate_jwt_token]
|
||||
|
||||
@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedAnthropicCredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedLLMCredentials(BaseModel):
|
||||
openai: Optional[HostedOpenAICredential] = None
|
||||
anthropic: Optional[HostedAnthropicCredential] = None
|
||||
|
||||
|
||||
hosted_llm_credentials = HostedLLMCredentials()
|
||||
@@ -26,3 +31,6 @@ def init_app(app: Flask):
|
||||
|
||||
if app.config.get("OPENAI_API_KEY"):
|
||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||
|
||||
if app.config.get("ANTHROPIC_API_KEY"):
|
||||
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
@@ -69,9 +69,8 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
import re
|
||||
from typing import Mapping, List, Dict, Any, Optional
|
||||
|
||||
from langchain import PromptTemplate
|
||||
@@ -178,13 +179,20 @@ class MultiDatasetRouterChain(Chain):
|
||||
|
||||
route = self.router_chain.route(inputs)
|
||||
|
||||
if not route.destination:
|
||||
destination = ''
|
||||
if route.destination:
|
||||
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
||||
match = re.search(pattern, route.destination, re.IGNORECASE)
|
||||
if match:
|
||||
destination = match.group()
|
||||
|
||||
if not destination:
|
||||
return {"text": ''}
|
||||
elif route.destination in self.dataset_tools:
|
||||
return {"text": self.dataset_tools[route.destination].run(
|
||||
elif destination in self.dataset_tools:
|
||||
return {"text": self.dataset_tools[destination].run(
|
||||
route.next_inputs['input']
|
||||
)}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Received invalid destination chain name '{route.destination}'"
|
||||
f"Received invalid destination chain name '{destination}'"
|
||||
)
|
||||
|
||||
@@ -118,6 +118,7 @@ class Completion:
|
||||
prompt, stop_words = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
@@ -129,6 +130,7 @@ class Completion:
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode=mode
|
||||
)
|
||||
@@ -138,7 +140,8 @@ class Completion:
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
chain_output: Optional[str],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
@@ -151,10 +154,11 @@ class Completion:
|
||||
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following CONTEXT as your learned knowledge:
|
||||
[CONTEXT]
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
|
||||
|
||||
if chain_output:
|
||||
human_inputs['context'] = chain_output
|
||||
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
||||
[CONTEXT]
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
|
||||
if pre_prompt:
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\nHuman: {{query}}\nAI: "
|
||||
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
|
||||
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
|
||||
- memory.llm.max_tokens - curr_message_tokens
|
||||
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
|
||||
model_name = model['name']
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = llm_constant.max_context_token_length[model_name] \
|
||||
- max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
||||
@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
|
||||
# if histories_param not in human_inputs:
|
||||
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
||||
|
||||
human_message_prompt += "\n\n" + histories
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>"
|
||||
human_message_prompt += histories + "</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
|
||||
model=app_model_config.model_dict
|
||||
)
|
||||
|
||||
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
max_tokens = llm.max_tokens
|
||||
model_name = app_model_config.model_dict.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
|
||||
|
||||
# get prompt without memory and context
|
||||
prompt, _ = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
|
||||
prompt: Union[str, List[BaseMessage]], mode: str):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
|
||||
max_tokens = final_llm.max_tokens
|
||||
model_name = model.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
|
||||
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
||||
prompt_tokens = final_llm.get_num_tokens(prompt)
|
||||
else:
|
||||
prompt_tokens = final_llm.get_messages_tokens(prompt)
|
||||
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
model=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
|
||||
original_prompt, _ = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode='completion'
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from _decimal import Decimal
|
||||
|
||||
models = {
|
||||
'claude-instant-1': 'anthropic', # 100,000 tokens
|
||||
'claude-2': 'anthropic', # 100,000 tokens
|
||||
'gpt-4': 'openai', # 8,192 tokens
|
||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||
@@ -10,10 +12,13 @@ models = {
|
||||
'text-curie-001': 'openai', # 2,049 tokens
|
||||
'text-babbage-001': 'openai', # 2,049 tokens
|
||||
'text-ada-001': 'openai', # 2,049 tokens
|
||||
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
|
||||
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
|
||||
'whisper-1': 'openai'
|
||||
}
|
||||
|
||||
max_context_token_length = {
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
@@ -23,17 +28,21 @@ max_context_token_length = {
|
||||
'text-curie-001': 2049,
|
||||
'text-babbage-001': 2049,
|
||||
'text-ada-001': 2049,
|
||||
'text-embedding-ada-002': 8191
|
||||
'text-embedding-ada-002': 8191,
|
||||
}
|
||||
|
||||
models_by_mode = {
|
||||
'chat': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
@@ -52,6 +61,14 @@ models_by_mode = {
|
||||
model_currency = 'USD'
|
||||
|
||||
model_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': Decimal('0.00163'),
|
||||
'completion': Decimal('0.00551'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': Decimal('0.01102'),
|
||||
'completion': Decimal('0.03268'),
|
||||
},
|
||||
'gpt-4': {
|
||||
'prompt': Decimal('0.03'),
|
||||
'completion': Decimal('0.06'),
|
||||
|
||||
@@ -56,7 +56,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
def init(self):
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
|
||||
self.model_dict['provider'] = provider_name
|
||||
|
||||
override_model_configs = None
|
||||
@@ -89,7 +89,7 @@ class ConversationMessageTask:
|
||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
||||
system_instruction_tokens = llm.get_messages_tokens([system_message])
|
||||
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
@@ -185,6 +185,7 @@ class ConversationMessageTask:
|
||||
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.app.tenant_id,
|
||||
Provider.provider_name == provider.provider_name,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + 1})
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ class ExcelLoader(BaseLoader):
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
|
||||
data.append(item)
|
||||
document = Document(page_content=item)
|
||||
data.append(document)
|
||||
|
||||
return [Document(page_content='\n\n'.join(data))]
|
||||
return data
|
||||
|
||||
@@ -81,8 +81,8 @@ class NotionLoader(BaseLoader):
|
||||
docs = []
|
||||
if notion_page_type == 'database':
|
||||
# get all the pages in the database
|
||||
page_text = self._get_notion_database_data(notion_obj_id)
|
||||
docs.append(Document(page_content=page_text))
|
||||
page_text_documents = self._get_notion_database_data(notion_obj_id)
|
||||
docs.extend(page_text_documents)
|
||||
elif notion_page_type == 'page':
|
||||
page_text_list = self._get_notion_block_data(notion_obj_id)
|
||||
for page_text in page_text_list:
|
||||
@@ -94,7 +94,7 @@ class NotionLoader(BaseLoader):
|
||||
|
||||
def _get_notion_database_data(
|
||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
||||
) -> str:
|
||||
) -> List[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
@@ -110,7 +110,7 @@ class NotionLoader(BaseLoader):
|
||||
|
||||
database_content_list = []
|
||||
if 'results' not in data or data["results"] is None:
|
||||
return ""
|
||||
return []
|
||||
for result in data["results"]:
|
||||
properties = result['properties']
|
||||
data = {}
|
||||
@@ -143,10 +143,10 @@ class NotionLoader(BaseLoader):
|
||||
row_content = row_content + f'{key}:{value_content}\n'
|
||||
else:
|
||||
row_content = row_content + f'{key}:{value}\n'
|
||||
database_content_list.append(row_content)
|
||||
database_content_list.append(json.dumps(data, ensure_ascii=False))
|
||||
document = Document(page_content=row_content)
|
||||
database_content_list.append(document)
|
||||
|
||||
return "\n\n".join(database_content_list)
|
||||
return database_content_list
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> List[str]:
|
||||
result_lines_arr = []
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
|
||||
text_embeddings.extend(embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
@handle_openai_exceptions
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage, OutputParserException
|
||||
from langchain.schema import HumanMessage, OutputParserException, BaseMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
@@ -23,10 +23,14 @@ class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
prompt = prompt.format(query=query, answer=answer)
|
||||
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
prompt = prompt.format(query=query)
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=50
|
||||
)
|
||||
|
||||
@@ -40,26 +44,40 @@ class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_summary(cls, tenant_id: str, messages):
|
||||
max_tokens = 200
|
||||
model = 'gpt-3.5-turbo'
|
||||
|
||||
prompt = CONVERSATION_SUMMARY_PROMPT
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
|
||||
rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
|
||||
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
|
||||
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
|
||||
|
||||
context = ''
|
||||
for message in messages:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
|
||||
if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
|
||||
if len(message.query) > 2000:
|
||||
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
|
||||
else:
|
||||
query = message.query
|
||||
|
||||
if len(message.answer) > 2000:
|
||||
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
|
||||
else:
|
||||
answer = message.answer
|
||||
|
||||
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
if not context:
|
||||
return '[message too long, no summary]'
|
||||
|
||||
prompt = prompt.format(context=context)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
model_name=model,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
@@ -102,7 +120,7 @@ class LLMGenerator:
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
model_name='gpt-3.5-turbo',
|
||||
temperature=0,
|
||||
max_tokens=256
|
||||
)
|
||||
@@ -114,6 +132,8 @@ class LLMGenerator:
|
||||
|
||||
try:
|
||||
output = llm(query)
|
||||
if isinstance(output, BaseMessage):
|
||||
output = output.content
|
||||
questions = output_parser.parse(output)
|
||||
except Exception:
|
||||
logging.exception("Error generating suggested questions after answer")
|
||||
|
||||
@@ -17,7 +17,7 @@ class IndexBuilder:
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
@@ -235,7 +235,8 @@ class IndexingRunner:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
|
||||
self.filter_string(document.page_content))
|
||||
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
@@ -345,8 +346,10 @@ class IndexingRunner:
|
||||
return text_docs
|
||||
|
||||
def filter_string(self, text):
|
||||
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
|
||||
return pattern.sub('', text)
|
||||
text = re.sub(r'<\|', '<', text)
|
||||
text = re.sub(r'\|>', '>', text)
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
|
||||
return text
|
||||
|
||||
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
|
||||
"""
|
||||
@@ -425,7 +428,7 @@ class IndexingRunner:
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
|
||||
@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
|
||||
@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from models.provider import ProviderType
|
||||
from models.provider import ProviderType, ProviderName
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
@@ -32,43 +33,43 @@ class LLMBuilder:
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
provider = cls.get_default_provider(tenant_id)
|
||||
provider = cls.get_default_provider(tenant_id, model_name)
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
llm_cls = None
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
if provider == 'openai':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableChatOpenAI
|
||||
else:
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureChatOpenAI
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
llm_cls = StreamableChatAnthropic
|
||||
elif mode == 'completion':
|
||||
if provider == 'openai':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableOpenAI
|
||||
else:
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureOpenAI
|
||||
else:
|
||||
|
||||
if not llm_cls:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
|
||||
model_kwargs = {
|
||||
'model_name': model_name,
|
||||
'temperature': kwargs.get('temperature', 0),
|
||||
'max_tokens': kwargs.get('max_tokens', 256),
|
||||
'top_p': kwargs.get('top_p', 1),
|
||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||
'callbacks': kwargs.get('callbacks', None),
|
||||
'streaming': kwargs.get('streaming', False),
|
||||
}
|
||||
|
||||
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
|
||||
model_kwargs.update(model_credentials)
|
||||
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
temperature=kwargs.get('temperature', 0),
|
||||
max_tokens=kwargs.get('max_tokens', 256),
|
||||
**model_extras_kwargs,
|
||||
callbacks=kwargs.get('callbacks', None),
|
||||
streaming=kwargs.get('streaming', False),
|
||||
# request_timeout=None
|
||||
**model_credentials
|
||||
)
|
||||
return llm_cls(**model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
@@ -118,14 +119,30 @@ class LLMBuilder:
|
||||
return provider_service.get_credentials(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_default_provider(cls, tenant_id: str) -> str:
|
||||
provider = BaseProvider.get_valid_provider(tenant_id)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
|
||||
provider_name = llm_constant.models[model_name]
|
||||
|
||||
if provider_name == 'openai':
|
||||
# get the default provider (openai / azure_openai) for the tenant
|
||||
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = azure_openai_provider
|
||||
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {provider_name} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider_name = 'openai'
|
||||
else:
|
||||
provider_name = provider.provider_name
|
||||
|
||||
return provider_name
|
||||
|
||||
@@ -1,23 +1,138 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import anthropic
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName, ProviderType
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
},
|
||||
]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
|
||||
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
|
||||
"""
|
||||
return {
|
||||
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
return self.get_provider_api_key(model_id=model_id)
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
return ProviderName.ANTHROPIC
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('anthropic_api_key'):
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# check OpenAI / Azure OpenAI credential is valid
|
||||
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
if quota_used >= quota_limit:
|
||||
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
|
||||
f"please configure OpenAI or Azure OpenAI provider first.")
|
||||
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'anthropic_api_key' not in config:
|
||||
raise ValueError('anthropic_api_key must be provided.')
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
model='claude-instant-1',
|
||||
anthropic_api_key=config['anthropic_api_key'],
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
default_request_timeout=60
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except anthropic.APIConnectionError as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
|
||||
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
|
||||
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
|
||||
except Exception as ex:
|
||||
logging.exception('Anthropic config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
|
||||
|
||||
@@ -44,6 +44,7 @@ class AzureProvider(BaseProvider):
|
||||
config['openai_api_type'] = 'azure'
|
||||
if model_id == 'text-embedding-ada-002':
|
||||
config['deployment'] = model_id.replace('.', '') if model_id else None
|
||||
config['chunk_size'] = 1
|
||||
else:
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
return config
|
||||
@@ -51,12 +52,12 @@ class AzureProvider(BaseProvider):
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
@@ -80,7 +81,6 @@ class AzureProvider(BaseProvider):
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
# TODO: change to dict when implemented
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
@@ -97,12 +97,11 @@ class AzureProvider(BaseProvider):
|
||||
models = self.get_models(credentials=config)
|
||||
|
||||
if not models:
|
||||
raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
|
||||
raise ValidateFailedError("Please add deployments for "
|
||||
"'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
|
||||
"and 'gpt-4', 'gpt-35-turbo-16k' (optional).")
|
||||
"and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).")
|
||||
|
||||
fixed_model_ids = [
|
||||
'text-davinci-003',
|
||||
'gpt-35-turbo',
|
||||
'text-embedding-ada-002'
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
@@ -14,15 +14,18 @@ class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(prefer_custom)
|
||||
provider = self.get_provider(only_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
@@ -38,18 +41,19 @@ class BaseProvider(ABC):
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
|
||||
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
|
||||
|
||||
@classmethod
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
|
||||
Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
If both CUSTOM and System providers exist.
|
||||
"""
|
||||
query = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id
|
||||
@@ -58,39 +62,31 @@ class BaseProvider(ABC):
|
||||
if provider_name:
|
||||
query = query.filter(Provider.provider_name == provider_name)
|
||||
|
||||
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
if only_custom:
|
||||
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
providers = query.order_by(Provider.provider_type.asc()).all()
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
return provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
||||
system_provider = provider
|
||||
return provider
|
||||
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
elif system_provider:
|
||||
return system_provider
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> str:
|
||||
if self.get_provider_name() != ProviderName.OPENAI:
|
||||
raise ProviderTokenNotInitError()
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = ''
|
||||
|
||||
|
||||
@@ -31,11 +31,11 @@ class LLMProviderService:
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated)
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
|
||||
|
||||
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
return self.provider.get_provider(prefer_custom)
|
||||
def get_provider_db_record(self) -> Optional[Provider]:
|
||||
return self.provider.get_provider()
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Optional, Union
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
model_kwargs = {
|
||||
'top_p': params.get('top_p', 1),
|
||||
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||
'presence_penalty': params.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
del params['top_p']
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
params['model_kwargs'] = model_kwargs
|
||||
|
||||
return params
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableAzureOpenAI(AzureOpenAI):
|
||||
@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
return params
|
||||
|
||||
39
api/core/llm/streamable_chat_anthropic.py
Normal file
39
api/core/llm/streamable_chat_anthropic.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
|
||||
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
|
||||
|
||||
|
||||
class StreamableChatAnthropic(ChatAnthropic):
|
||||
"""
|
||||
Wrapper around Anthropic's large language model.
|
||||
"""
|
||||
|
||||
@handle_anthropic_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
params['model'] = params.get('model_name')
|
||||
del params['model_name']
|
||||
|
||||
params['max_tokens_to_sample'] = params.get('max_tokens')
|
||||
del params['max_tokens']
|
||||
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
return params
|
||||
@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
model_kwargs = {
|
||||
'top_p': params.get('top_p', 1),
|
||||
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||
'presence_penalty': params.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
del params['top_p']
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
params['model_kwargs'] = model_kwargs
|
||||
|
||||
return params
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
|
||||
from langchain import OpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
return params
|
||||
|
||||
26
api/core/llm/whisper.py
Normal file
26
api/core/llm/whisper.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import openai
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from models.provider import ProviderName
|
||||
from core.llm.provider.base import BaseProvider
|
||||
|
||||
|
||||
class Whisper:
|
||||
|
||||
def __init__(self, provider: BaseProvider):
|
||||
self.provider = provider
|
||||
|
||||
if self.provider.get_provider_name() == ProviderName.OPENAI:
|
||||
self.client = openai.Audio
|
||||
self.credentials = provider.get_credentials()
|
||||
|
||||
@handle_openai_exceptions
|
||||
def transcribe(self, file):
|
||||
return self.client.transcribe(
|
||||
model='whisper-1',
|
||||
file=file,
|
||||
api_key=self.credentials.get('openai_api_key'),
|
||||
api_base=self.credentials.get('openai_api_base'),
|
||||
api_type=self.credentials.get('openai_api_type'),
|
||||
api_version=self.credentials.get('openai_api_version'),
|
||||
)
|
||||
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import anthropic
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_anthropic_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except anthropic.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to Anthropic API.")
|
||||
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
|
||||
except anthropic.RateLimitError:
|
||||
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(f"Anthropic: {e.message}")
|
||||
except anthropic.BadRequestError as e:
|
||||
raise LLMBadRequestError(f"Anthropic: {e.message}")
|
||||
except anthropic.APIStatusError as e:
|
||||
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
|
||||
|
||||
return wrapper
|
||||
@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_llm_exceptions(func):
|
||||
def handle_openai_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def handle_llm_exceptions_async(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
|
||||
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
|
||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
|
||||
ai_prefix: str = "Assistant"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
return chat_messages
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
|
||||
return chat_messages
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
CONVERSATION_TITLE_PROMPT = (
|
||||
"Human:{{query}}\n-----\n"
|
||||
"Human:{query}\n-----\n"
|
||||
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
|
||||
"If the human said is conducted in Chinese, you should return a Chinese title.\n"
|
||||
"If the human said is conducted in English, you should return an English title.\n"
|
||||
"If what the human said is conducted in English, you should only return an English title.\n"
|
||||
"If what the human said is conducted in Chinese, you should only return a Chinese title.\n"
|
||||
"title:"
|
||||
)
|
||||
|
||||
CONVERSATION_SUMMARY_PROMPT = (
|
||||
"Please generate a short summary of the following conversation.\n"
|
||||
"If the conversation communicating in Chinese, you should return a Chinese summary.\n"
|
||||
"If the conversation communicating in English, you should return an English summary.\n"
|
||||
"If the following conversation communicating in English, you should only return an English summary.\n"
|
||||
"If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
|
||||
"[Conversation Start]\n"
|
||||
"{context}\n"
|
||||
"[Conversation End]\n\n"
|
||||
@@ -19,7 +19,7 @@ CONVERSATION_SUMMARY_PROMPT = (
|
||||
INTRODUCTION_GENERATE_PROMPT = (
|
||||
"I am designing a product for users to interact with an AI through dialogue. "
|
||||
"The Prompt given to the AI before the conversation is:\n\n"
|
||||
"```\n{{prompt}}\n```\n\n"
|
||||
"```\n{prompt}\n```\n\n"
|
||||
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
|
||||
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
|
||||
"but focus on building a relationship with the user:\n"
|
||||
@@ -27,13 +27,13 @@ INTRODUCTION_GENERATE_PROMPT = (
|
||||
|
||||
MORE_LIKE_THIS_GENERATE_PROMPT = (
|
||||
"-----\n"
|
||||
"{{original_completion}}\n"
|
||||
"{original_completion}\n"
|
||||
"-----\n\n"
|
||||
"Please use the above content as a sample for generating the result, "
|
||||
"and include key information points related to the original sample in the result. "
|
||||
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
|
||||
"-----\n"
|
||||
"{{prompt}}\n"
|
||||
"{prompt}\n"
|
||||
)
|
||||
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
|
||||
@@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from flask import current_app
|
||||
|
||||
from events.tenant_event import tenant_was_updated
|
||||
from models.provider import ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
|
||||
def handle(sender, **kwargs):
|
||||
tenant = sender
|
||||
if tenant.status == 'normal':
|
||||
ProviderService.create_system_provider(tenant)
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.OPENAI.value,
|
||||
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from flask import current_app
|
||||
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.provider import ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
|
||||
def handle(sender, **kwargs):
|
||||
tenant = sender
|
||||
if tenant.status == 'normal':
|
||||
ProviderService.create_system_provider(tenant)
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.OPENAI.value,
|
||||
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
61
api/extensions/ext_mail.py
Normal file
61
api/extensions/ext_mail.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Optional
|
||||
|
||||
import resend
|
||||
from flask import Flask
|
||||
|
||||
|
||||
class Mail:
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
self._default_send_from = None
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
if app.config.get('MAIL_TYPE'):
|
||||
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
|
||||
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
|
||||
|
||||
if app.config.get('MAIL_TYPE') == 'resend':
|
||||
api_key = app.config.get('RESEND_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError('RESEND_API_KEY is not set')
|
||||
|
||||
resend.api_key = api_key
|
||||
self._client = resend.Emails
|
||||
else:
|
||||
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
|
||||
|
||||
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
|
||||
if not self._client:
|
||||
raise ValueError('Mail client is not initialized')
|
||||
|
||||
if not from_ and self._default_send_from:
|
||||
from_ = self._default_send_from
|
||||
|
||||
if not from_:
|
||||
raise ValueError('mail from is not set')
|
||||
|
||||
if not to:
|
||||
raise ValueError('mail to is not set')
|
||||
|
||||
if not subject:
|
||||
raise ValueError('mail subject is not set')
|
||||
|
||||
if not html:
|
||||
raise ValueError('mail html is not set')
|
||||
|
||||
self._client.send({
|
||||
"from": from_,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html
|
||||
})
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
mail.init_app(app)
|
||||
|
||||
|
||||
mail = Mail()
|
||||
20
api/libs/passport.py
Normal file
20
api/libs/passport.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import jwt
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from flask import current_app
|
||||
class PassportService:
|
||||
def __init__(self):
|
||||
self.sk = current_app.config.get('SECRET_KEY')
|
||||
|
||||
def issue(self, payload):
|
||||
return jwt.encode(payload, self.sk, algorithm='HS256')
|
||||
|
||||
def verify(self, token):
|
||||
try:
|
||||
return jwt.decode(token, self.sk, algorithms=['HS256'])
|
||||
except jwt.exceptions.InvalidSignatureError:
|
||||
raise Unauthorized('Invalid token signature.')
|
||||
except jwt.exceptions.DecodeError:
|
||||
raise Unauthorized('Invalid token.')
|
||||
except jwt.exceptions.ExpiredSignatureError:
|
||||
raise Unauthorized('Token has expired.')
|
||||
@@ -0,0 +1,32 @@
|
||||
"""app config add speech_to_text
|
||||
|
||||
Revision ID: a5b56fb053ef
|
||||
Revises: d3d503a3471c
|
||||
Create Date: 2023-07-06 17:55:20.894149
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a5b56fb053ef'
|
||||
down_revision = 'd3d503a3471c'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('speech_to_text')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add is_deleted to conversations
|
||||
|
||||
Revision ID: d3d503a3471c
|
||||
Revises: e32f6ccb87c6
|
||||
Create Date: 2023-06-27 19:13:30.897981
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd3d503a3471c'
|
||||
down_revision = 'e32f6ccb87c6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.drop_column('is_deleted')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -38,6 +38,10 @@ class Account(UserMixin, db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
||||
@property
|
||||
def current_tenant(self):
|
||||
return self._current_tenant
|
||||
|
||||
@@ -56,7 +56,8 @@ class App(db.Model):
|
||||
|
||||
@property
|
||||
def api_base_url(self):
|
||||
return (current_app.config['API_URL'] if current_app.config['API_URL'] else request.host_url.rstrip('/')) + '/v1'
|
||||
return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
|
||||
@property
|
||||
def tenant(self):
|
||||
@@ -81,6 +82,7 @@ class AppModelConfig(db.Model):
|
||||
opening_statement = db.Column(db.Text)
|
||||
suggested_questions = db.Column(db.Text)
|
||||
suggested_questions_after_answer = db.Column(db.Text)
|
||||
speech_to_text = db.Column(db.Text)
|
||||
more_like_this = db.Column(db.Text)
|
||||
model = db.Column(db.Text)
|
||||
user_input_form = db.Column(db.Text)
|
||||
@@ -104,6 +106,11 @@ class AppModelConfig(db.Model):
|
||||
def suggested_questions_after_answer_dict(self) -> dict:
|
||||
return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \
|
||||
else {"enabled": False}
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> dict:
|
||||
return json.loads(self.speech_to_text) if self.speech_to_text \
|
||||
else {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> dict:
|
||||
@@ -206,6 +213,8 @@ class Conversation(db.Model):
|
||||
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
|
||||
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
|
||||
|
||||
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
|
||||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
@@ -221,6 +230,9 @@ class Conversation(db.Model):
|
||||
model_config['suggested_questions_after_answer'] = override_model_configs[
|
||||
'suggested_questions_after_answer'] \
|
||||
if 'suggested_questions_after_answer' in override_model_configs else {"enabled": False}
|
||||
model_config['speech_to_text'] = override_model_configs[
|
||||
'speech_to_text'] \
|
||||
if 'speech_to_text' in override_model_configs else {"enabled": False}
|
||||
model_config['more_like_this'] = override_model_configs['more_like_this'] \
|
||||
if 'more_like_this' in override_model_configs else {"enabled": False}
|
||||
model_config['user_input_form'] = override_model_configs['user_input_form']
|
||||
@@ -237,6 +249,7 @@ class Conversation(db.Model):
|
||||
model_config['opening_statement'] = app_model_config.opening_statement
|
||||
model_config['suggested_questions'] = app_model_config.suggested_questions_list
|
||||
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
|
||||
model_config['speech_to_text'] = app_model_config.speech_to_text_dict
|
||||
model_config['more_like_this'] = app_model_config.more_like_this_dict
|
||||
model_config['user_input_form'] = app_model_config.user_input_form_list
|
||||
|
||||
@@ -503,7 +516,7 @@ class Site(db.Model):
|
||||
|
||||
@property
|
||||
def app_base_url(self):
|
||||
return (current_app.config['APP_URL'] if current_app.config['APP_URL'] else request.host_url.rstrip('/'))
|
||||
return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
|
||||
|
||||
|
||||
class ApiToken(db.Model):
|
||||
|
||||
@@ -10,7 +10,7 @@ flask-session2==1.3.1
|
||||
flask-cors==3.0.10
|
||||
gunicorn~=20.1.0
|
||||
gevent~=22.10.2
|
||||
langchain==0.0.209
|
||||
langchain==0.0.230
|
||||
openai~=0.27.5
|
||||
psycopg2-binary~=2.9.6
|
||||
pycryptodome==3.17
|
||||
@@ -21,7 +21,7 @@ Authlib==1.2.0
|
||||
boto3~=1.26.123
|
||||
tenacity==8.2.2
|
||||
cachetools~=5.3.0
|
||||
weaviate-client~=3.16.2
|
||||
weaviate-client~=3.21.0
|
||||
qdrant_client~=1.1.6
|
||||
mailchimp-transactional~=1.0.50
|
||||
scikit-learn==1.2.2
|
||||
@@ -32,4 +32,7 @@ redis~=4.5.4
|
||||
openpyxl==3.1.2
|
||||
chardet~=5.1.0
|
||||
docx2txt==0.8
|
||||
pypdfium2==4.16.0
|
||||
pypdfium2==4.16.0
|
||||
resend~=0.5.1
|
||||
pyjwt~=2.6.0
|
||||
anthropic~=0.3.4
|
||||
|
||||
@@ -2,13 +2,16 @@
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Optional
|
||||
|
||||
from flask import session
|
||||
from sqlalchemy import func
|
||||
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \
|
||||
TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \
|
||||
RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError
|
||||
@@ -16,6 +19,7 @@ from libs.helper import get_remote_ip
|
||||
from libs.password import compare_password, hash_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import *
|
||||
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||
|
||||
|
||||
class AccountService:
|
||||
@@ -48,12 +52,18 @@ class AccountService:
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
"""update account password"""
|
||||
# todo: split validation and update
|
||||
if account.password and not compare_password(password, account.password, account.password_salt):
|
||||
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
||||
password_hashed = hash_password(new_password, account.password_salt)
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@@ -283,8 +293,6 @@ class TenantService:
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
"""Remove member from tenant"""
|
||||
# todo: check permission
|
||||
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
@@ -293,6 +301,12 @@ class TenantService:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
db.session.delete(ta)
|
||||
|
||||
account.initialized_at = None
|
||||
account.status = AccountStatus.PENDING.value
|
||||
account.password = None
|
||||
account.password_salt = None
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
@@ -332,8 +346,8 @@ class TenantService:
|
||||
|
||||
class RegisterService:
|
||||
|
||||
@staticmethod
|
||||
def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
||||
@classmethod
|
||||
def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
try:
|
||||
@@ -359,9 +373,9 @@ class RegisterService:
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def invite_new_member(tenant: Tenant, email: str, role: str = 'normal',
|
||||
inviter: Account = None) -> TenantAccountJoin:
|
||||
@classmethod
|
||||
def invite_new_member(cls, tenant: Tenant, email: str, role: str = 'normal',
|
||||
inviter: Account = None) -> str:
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
@@ -380,5 +394,71 @@ class RegisterService:
|
||||
if ta:
|
||||
raise AccountAlreadyInTenantError("Account already in tenant.")
|
||||
|
||||
ta = TenantService.create_tenant_member(tenant, account, role)
|
||||
return ta
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
||||
token = cls.generate_invite_token(tenant, account)
|
||||
|
||||
# send email
|
||||
send_invite_member_mail_task.delay(
|
||||
to=email,
|
||||
token=cls.generate_invite_token(tenant, account),
|
||||
inviter_name=inviter.name if inviter else 'Dify',
|
||||
workspace_id=tenant.id,
|
||||
workspace_name=tenant.name,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
email_hash = sha256(account.email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(str(tenant.id), email_hash, token)
|
||||
redis_client.setex(cache_key, 3600, str(account.id))
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def get_account_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]:
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == workspace_id,
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == email, TenantAccountJoin.tenant_id == tenant.id).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
|
||||
account_id = cls._get_account_id_by_invite_token(workspace_id, email, token)
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
account = tenant_account[0]
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account_id != str(account.id):
|
||||
return None
|
||||
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def _get_account_id_by_invite_token(cls, workspace_id: str, email: str, token: str) -> Optional[str]:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
account_id = redis_client.get(cache_key)
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
return account_id.decode('utf-8')
|
||||
|
||||
@@ -4,7 +4,32 @@ import uuid
|
||||
from core.constant import llm_constant
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
MODEL_PROVIDERS = [
|
||||
'openai',
|
||||
'anthropic',
|
||||
]
|
||||
|
||||
MODELS_BY_APP_MODE = {
|
||||
'chat': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
'text-davinci-003',
|
||||
]
|
||||
}
|
||||
|
||||
class AppModelConfigService:
|
||||
@staticmethod
|
||||
@@ -109,6 +134,26 @@ class AppModelConfigService:
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
# speech_to_text
|
||||
if 'speech_to_text' not in config or not config["speech_to_text"]:
|
||||
config["speech_to_text"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["speech_to_text"], dict):
|
||||
raise ValueError("speech_to_text must be of dict type")
|
||||
|
||||
if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]:
|
||||
config["speech_to_text"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||
|
||||
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
|
||||
|
||||
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
|
||||
raise ValueError("provider not support speech to text")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
config["more_like_this"] = {
|
||||
@@ -132,14 +177,14 @@ class AppModelConfigService:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
|
||||
raise ValueError("model.provider must be 'openai'")
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
|
||||
raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
|
||||
if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.completion_params
|
||||
@@ -277,6 +322,7 @@ class AppModelConfigService:
|
||||
"opening_statement": config["opening_statement"],
|
||||
"suggested_questions": config["suggested_questions"],
|
||||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"speech_to_text": config["speech_to_text"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
|
||||
39
api/services/audio_service.py
Normal file
39
api/services/audio_service.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import io
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from core.llm.whisper import Whisper
|
||||
from models.provider import ProviderName
|
||||
|
||||
FILE_SIZE = 15
|
||||
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
|
||||
ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']
|
||||
|
||||
class AudioService:
|
||||
@classmethod
|
||||
def transcript(cls, tenant_id: str, file: FileStorage):
|
||||
if file is None:
|
||||
raise NoAudioUploadedServiceError()
|
||||
|
||||
extension = file.mimetype
|
||||
if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]:
|
||||
raise UnsupportedAudioTypeServiceError()
|
||||
|
||||
file_content = file.read()
|
||||
file_size = len(file_content)
|
||||
|
||||
if file_size > FILE_SIZE_LIMIT:
|
||||
message = f"Audio size larger than {FILE_SIZE} mb"
|
||||
raise AudioTooLargeServiceError(message)
|
||||
|
||||
provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
|
||||
if provider_name != ProviderName.OPENAI.value:
|
||||
raise ProviderNotSupportSpeechToTextServiceError()
|
||||
|
||||
provider_service = LLMProviderService(tenant_id, provider_name)
|
||||
|
||||
buffer = io.BytesIO(file_content)
|
||||
buffer.name = 'temp.mp3'
|
||||
|
||||
return Whisper(provider_service.provider).transcribe(buffer)
|
||||
@@ -16,6 +16,7 @@ class ConversationService:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
base_query = db.session.query(Conversation).filter(
|
||||
Conversation.is_deleted == False,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
@@ -79,6 +80,7 @@ class ConversationService:
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
@@ -90,5 +92,5 @@ class ConversationService:
|
||||
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
db.session.delete(conversation)
|
||||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
@@ -4,6 +4,9 @@ import datetime
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, List
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -35,6 +38,7 @@ class DatasetService:
|
||||
permission_filter = Dataset.permission == 'all_team_members'
|
||||
datasets = Dataset.query.filter(
|
||||
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
|
||||
.order_by(Dataset.created_at.desc()) \
|
||||
.paginate(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
@@ -373,6 +377,12 @@ class DocumentService:
|
||||
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = 'web'):
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = document_data["data_source"]["type"]
|
||||
@@ -520,6 +530,14 @@ class DocumentService:
|
||||
)
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_documents_count():
|
||||
documents_count = Document.query.filter(Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id).count()
|
||||
return documents_count
|
||||
|
||||
@staticmethod
|
||||
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
@@ -615,6 +633,12 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
__all__ = [
|
||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||
'app', 'completion'
|
||||
'app', 'completion', 'audio'
|
||||
]
|
||||
|
||||
from . import *
|
||||
|
||||
13
api/services/errors/audio.py
Normal file
13
api/services/errors/audio.py
Normal file
@@ -0,0 +1,13 @@
|
||||
class NoAudioUploadedServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AudioTooLargeServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedAudioTypeServiceError(Exception):
|
||||
pass
|
||||
|
||||
class ProviderNotSupportSpeechToTextServiceError(Exception):
|
||||
pass
|
||||
@@ -31,7 +31,7 @@ class HitTestingService:
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
@@ -10,50 +10,40 @@ from models.provider import *
|
||||
class ProviderService:
|
||||
|
||||
@staticmethod
|
||||
def init_supported_provider(tenant, edition):
|
||||
def init_supported_provider(tenant):
|
||||
"""Initialize the model provider, check whether the supported provider has a record"""
|
||||
|
||||
providers = Provider.query.filter_by(tenant_id=tenant.id).all()
|
||||
need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
|
||||
|
||||
openai_provider_exists = False
|
||||
azure_openai_provider_exists = False
|
||||
|
||||
# TODO: The cloud version needs to construct the data of the SYSTEM type
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(need_init_provider_names)
|
||||
).all()
|
||||
|
||||
exists_provider_names = []
|
||||
for provider in providers:
|
||||
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
||||
openai_provider_exists = True
|
||||
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
||||
azure_openai_provider_exists = True
|
||||
exists_provider_names.append(provider.provider_name)
|
||||
|
||||
# Initialize the model provider, check whether the supported provider has a record
|
||||
not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
|
||||
|
||||
# Create default providers if they don't exist
|
||||
if not openai_provider_exists:
|
||||
openai_provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=ProviderName.OPENAI.value,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(openai_provider)
|
||||
if not_exists_provider_names:
|
||||
# Initialize the model provider, check whether the supported provider has a record
|
||||
for provider_name in not_exists_provider_names:
|
||||
provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(provider)
|
||||
|
||||
if not azure_openai_provider_exists:
|
||||
azure_openai_provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=ProviderName.AZURE_OPENAI.value,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(azure_openai_provider)
|
||||
|
||||
if not openai_provider_exists or not azure_openai_provider_exists:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_obfuscated_api_key(tenant, provider_name: ProviderName):
|
||||
def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_provider_configs(obfuscated=True)
|
||||
return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
|
||||
|
||||
@staticmethod
|
||||
def get_token_type(tenant, provider_name: ProviderName):
|
||||
@@ -73,7 +63,7 @@ class ProviderService:
|
||||
return llm_provider_service.get_encrypted_token(configs)
|
||||
|
||||
@staticmethod
|
||||
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
|
||||
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
|
||||
is_valid: bool = True):
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return
|
||||
@@ -90,7 +80,7 @@ class ProviderService:
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=200,
|
||||
quota_limit=quota_limit,
|
||||
encrypted_config='',
|
||||
is_valid=is_valid,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
|
||||
|
||||
class WorkspaceService:
|
||||
@@ -33,7 +33,7 @@ class WorkspaceService:
|
||||
if provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
if provider.is_valid:
|
||||
if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
|
||||
system_provider = provider
|
||||
|
||||
if system_provider and not custom_provider:
|
||||
|
||||
@@ -7,7 +7,7 @@ from celery import shared_task
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
|
||||
AppDatasetJoin
|
||||
AppDatasetJoin, Document
|
||||
|
||||
|
||||
@shared_task
|
||||
@@ -32,7 +32,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
index_struct=index_struct
|
||||
)
|
||||
|
||||
documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
|
||||
@@ -44,14 +44,13 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
if dataset_documents:
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
# delete from vector index
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True
|
||||
) .order_by(DocumentSegment.position.asc()).all()
|
||||
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
@@ -65,8 +64,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
documents.append(document)
|
||||
|
||||
# save vector index
|
||||
index.add_texts(documents)
|
||||
# save vector index
|
||||
index.add_texts(documents)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
||||
@@ -28,7 +28,7 @@ def generate_conversation_summary_task(conversation_id: str):
|
||||
try:
|
||||
# get conversation messages count
|
||||
history_message_count = conversation.message_count
|
||||
if history_message_count >= 5:
|
||||
if history_message_count >= 5 and not conversation.summary:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
|
||||
52
api/tasks/mail_invite_member_task.py
Normal file
52
api/tasks/mail_invite_member_task.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
|
||||
|
||||
@shared_task
|
||||
def send_invite_member_mail_task(to: str, token: str, inviter_name: str, workspace_id: str, workspace_name: str):
|
||||
"""
|
||||
Async Send invite member mail
|
||||
:param to
|
||||
:param token
|
||||
:param inviter_name
|
||||
:param workspace_id
|
||||
:param workspace_name
|
||||
|
||||
Usage: send_invite_member_mail_task.delay(to, token, inviter_name, workspace_id, workspace_name)
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name),
|
||||
fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
mail.send(
|
||||
to=to,
|
||||
subject="{} invited you to join {}".format(inviter_name, workspace_name),
|
||||
html="""<p>Hi there,</p>
|
||||
<p>{inviter_name} invited you to join {workspace_name}.</p>
|
||||
<p>Click <a href="{url}">here</a> to join.</p>
|
||||
<p>Thanks,</p>
|
||||
<p>Dify Team</p>""".format(inviter_name=inviter_name, workspace_name=workspace_name,
|
||||
url='{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
workspace_id,
|
||||
to,
|
||||
token)
|
||||
)
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Send invite member mail to {} failed".format(to))
|
||||
@@ -41,7 +41,8 @@ def remove_document_from_index_task(document_id: str):
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
# delete from vector index
|
||||
vector_index.delete_by_document_id(document.id)
|
||||
if vector_index:
|
||||
vector_index.delete_by_document_id(document.id)
|
||||
|
||||
# delete from keyword index
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user