Compare commits

...

65 Commits

Author SHA1 Message Date
John Wang
765c2b92a1 feat: add bash before entrypoint.sh in Dockerfile 2023-07-18 16:13:39 +08:00
John Wang
eff115267f fix: anthropic completion error in blocking mode (#591) 2023-07-18 15:12:52 +08:00
John Wang
07cde4f8fe feat: bump 0.3.10 (#589) 2023-07-18 15:04:49 +08:00
Jyong
9f28a48a92 index add to db when dataset updated (#588) 2023-07-18 15:02:33 +08:00
John Wang
0d3cd3b16a fix: azure provider select error when use custom azure provider (#587) 2023-07-18 14:34:09 +08:00
John Wang
3dc82fb044 feat: remove davinci required model from azure provider (#586) 2023-07-18 14:14:56 +08:00
crazywoola
cb6e73347e Feat/add ruby sdk (#583) 2023-07-18 10:18:58 +08:00
zxhlyh
ecd6cbaee6 Fix/use embedded chatbot with no track mode (#582) 2023-07-18 09:45:17 +08:00
KVOJJJin
d54e942264 Feat: hide password setting and invitation link in cloud version (#581) 2023-07-18 08:54:14 +08:00
Panmuse
28ba721455 Update README_CN.md (#575) 2023-07-17 11:08:26 +08:00
Panmuse
784dd7848e Update README.md (#576) 2023-07-17 11:08:03 +08:00
John Wang
e2a5f8ba1a feat: bump version to 0.3.9 (#574) 2023-07-17 09:47:23 +08:00
Joel
8e11200306 feat: frontend support claude (#573)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-07-17 00:14:32 +08:00
John Wang
7599f79a17 feat: claude api support (#572) 2023-07-17 00:14:19 +08:00
Joel
510389909c fix: change chatbot avart to dify icon (#571) 2023-07-16 16:30:55 +08:00
Jyong
2c6e00174b add document limit check (#570) 2023-07-16 13:21:56 +08:00
John Wang
24f3456990 fix: account check in runtime (#569) 2023-07-15 23:58:15 +08:00
Joel
20514ff288 fix: table too wide fix text generation ui (#566) 2023-07-14 18:15:56 +08:00
zxhlyh
381d255290 fix setting-modal provider encrypted tip style (#565) 2023-07-14 17:10:02 +08:00
John Wang
7f320f9146 feat: bump version to 0.3.8 (#559) 2023-07-14 11:53:15 +08:00
KVOJJJin
cd51d3323b feat: member invitation and activation (#535)
Co-authored-by: John Wang <takatost@gmail.com>
2023-07-14 11:19:26 +08:00
crazywoola
004b3caa43 Feature/add delete to service (#555) 2023-07-14 10:37:33 +08:00
Joel
dbe10799e3 fix: user cancel conversation show error (#558) 2023-07-13 10:32:45 +08:00
Joel
054ba88434 fix: regeneration not clear like status and sub more items (#557) 2023-07-13 10:31:07 +08:00
Joel
da82a11b26 feat: batch run support export as csv file (#556) 2023-07-13 09:30:16 +08:00
zxhlyh
fec607db81 Feat/embedding (#553)
Co-authored-by: Gillian97 <jinling.sunshine@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-07-12 17:27:50 +08:00
zxhlyh
397a92f2ee convert audio wav to mp3 (#552) 2023-07-12 17:18:56 +08:00
Joel
b91e226063 fix: api doc update conversation list api to real response (#548) 2023-07-12 13:53:06 +08:00
Joel
da5782df92 fix: mobile not auto show generation res (#544) 2023-07-11 17:16:28 +08:00
zxhlyh
9af0da4450 fix jwt in web (#545) 2023-07-11 17:07:52 +08:00
crazywoola
d49ac1e4ac Feature/use jwt in web (#533)
Co-authored-by: crazywoola <li.zheng@dentsplysirona.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-07-11 15:21:20 +08:00
John Wang
57de19a5ca feat: bump version to 0.3.7 (#540) 2023-07-10 15:23:38 +08:00
zxhlyh
7c00a0b6a3 fix voice input in safari (#537) 2023-07-10 10:16:38 +08:00
Jyong
a93506df18 Fix/dataset clean task (#534) 2023-07-08 17:29:56 +08:00
zxhlyh
a03a92e9db Feat/chat support voice input (#532) 2023-07-07 17:50:42 +08:00
John Wang
feebb5dd1f feat: dataset list add order by created at (#531) 2023-07-07 11:51:48 +08:00
John Wang
6eee7cb42c feat: fix azure embedding Too many inputs problem (#530) 2023-07-07 11:17:36 +08:00
Joel
11baff6740 feat: text generation application support run batch (#529) 2023-07-07 10:35:05 +08:00
zxhlyh
cde1797cc0 feat: max token add tip (#525) 2023-07-06 15:57:04 +08:00
KVOJJJin
d143284d99 Fix: stop embedding status display (#523) 2023-07-06 10:51:30 +08:00
zxhlyh
2b94545190 fix check version api (#520) 2023-07-05 11:11:38 +08:00
John Wang
ed6648a41e feat: dataset list add order by created at (#487) 2023-07-05 11:00:21 +08:00
Joel
5e2c3eeac3 fix: chat app added new var old conversation not work (#511) 2023-07-04 14:33:41 +08:00
Joel
b23d8a912b fix: add missing like i18n (#512) 2023-07-04 14:21:51 +08:00
Joel
4f13f8fd0a fix: change langenius text to dify (#498) 2023-07-02 14:01:11 +08:00
Joel
561c9cabd5 fix: input text repeat (#492) 2023-06-29 17:27:48 +08:00
zxhlyh
39ea967b30 refact common layout (#490) 2023-06-29 15:30:12 +08:00
John Wang
da04ff040b fix: remove document from dataset error when vector index npe (#489) 2023-06-29 13:09:22 +08:00
John Wang
b9b0866a46 fix: generate summary error when tokens=4097 (#488) 2023-06-29 12:54:50 +08:00
Joel
c6ab7eebd9 fix: delete operation style error (#485) 2023-06-29 09:24:31 +08:00
Joel
db4e6d81c5 fix: choose dataset not selected after one page (#481) 2023-06-29 09:22:42 +08:00
John Wang
df68a7c82b feat: Optimize the quality of the title generate (#484) 2023-06-28 19:59:20 +08:00
Joel
838825d747 feat: optimize conversation operation (#479) 2023-06-28 17:53:23 +08:00
crazywoola
a87f6f2837 fix: modal disappear (#478) 2023-06-28 16:44:17 +08:00
John Wang
9d98669e7d fix: dataset destination error (#477) 2023-06-28 15:51:07 +08:00
John Wang
408fbb0c70 fix: title, summary, suggested questions generate (#476) 2023-06-28 15:43:33 +08:00
crazywoola
998f819b04 use sub to operate all (#475) 2023-06-28 14:58:40 +08:00
John Wang
6194b82752 feat: bump to 0.3.6 (#474) 2023-06-28 14:23:20 +08:00
Jyong
334f46d0b6 Fix/json format (#466) 2023-06-28 13:58:50 +08:00
Jyong
2eea114ac0 fix special code (#473) 2023-06-28 13:58:36 +08:00
crazywoola
97e9ebd29a Feature/add is deleted to conversations (#470) 2023-06-28 13:31:51 +08:00
Joel
ec261aea54 feat: conversation app support pin and delete conversation (#467) 2023-06-28 11:16:54 +08:00
Joel
accc5faae3 fix: delete dataset not trigger show start new conversation message (#471) 2023-06-28 10:39:40 +08:00
Joel
0462f09ecc fix: app nav call detail match explore app detail page (#469) 2023-06-27 18:40:24 +08:00
zxhlyh
1226d73159 Feat/refact header (#468) 2023-06-27 18:02:01 +08:00
326 changed files with 9688 additions and 1554 deletions

3
.gitignore vendored
View File

@@ -109,6 +109,7 @@ venv/
ENV/
env.bak/
venv.bak/
.conda/
# Spyder project settings
.spyderproject
@@ -147,3 +148,5 @@ docker/volumes/weaviate/*
sdks/python-client/build
sdks/python-client/dist
sdks/python-client/dify_client.egg-info
.vscode/

View File

@@ -17,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa
Visual data analysis, log review, and annotation for applications
Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported:
- GPT 3 (text-davinci-003)
- GPT 3.5 Turbo(ChatGPT)
- GPT-4
* **OpenAI** GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
* **Azure OpenAI**
* **Antropic**Claude2、Claude-instant
> We've got 1000 free trial credits available for all cloud service users to try out the Claude model.Visit [Dify.ai](https://dify.ai) and
try it now.
* **hugging face Hub**Coming soon.
## Use Cloud Services

View File

@@ -17,11 +17,16 @@
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
- 可视化的对应用进行数据分析,查阅日志或进行标注
Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持:
Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商
- GPT 3 (text-davinci-003)
- GPT 3.5 Turbo(ChatGPT)
- GPT-4
* **OpenAI**GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
* **Azure OpenAI Service**
* **Anthropic**Claude2、Claude-instant
> 我们为所有注册云端版的用户免费提供了 1000 次 Claude 模型的消息调用额度,登录 [dify.ai](https://cloud.dify.ai) 即可使用。
* **Hugging Face Hub**(即将推出)
## 使用云服务

View File

@@ -8,13 +8,19 @@ EDITION=SELF_HOSTED
SECRET_KEY=
# Console API base URL
CONSOLE_URL=http://127.0.0.1:5001
CONSOLE_API_URL=http://127.0.0.1:5001
# Console frontend web base URL
CONSOLE_WEB_URL=http://127.0.0.1:3000
# Service API base URL
API_URL=http://127.0.0.1:5001
SERVICE_API_URL=http://127.0.0.1:5001
# Web APP base URL
APP_URL=http://127.0.0.1:3000
# Web APP API base URL
APP_API_URL=http://127.0.0.1:5001
# Web APP frontend web base URL
APP_WEB_URL=http://127.0.0.1:3000
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
@@ -79,6 +85,11 @@ WEAVIATE_BATCH_SIZE=100
QDRANT_URL=path:storage/qdrant
QDRANT_API_KEY=your-qdrant-api-key
# Mail configuration, support: resend
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
RESEND_API_KEY=
# Sentry configuration
SENTRY_DSN=

View File

@@ -5,9 +5,11 @@ LABEL maintainer="takatost@gmail.com"
ENV FLASK_APP app.py
ENV EDITION SELF_HOSTED
ENV DEPLOY_ENV PRODUCTION
ENV CONSOLE_URL http://127.0.0.1:5001
ENV API_URL http://127.0.0.1:5001
ENV APP_URL http://127.0.0.1:5001
ENV CONSOLE_API_URL http://127.0.0.1:5001
ENV CONSOLE_WEB_URL http://127.0.0.1:3000
ENV SERVICE_API_URL http://127.0.0.1:5001
ENV APP_API_URL http://127.0.0.1:5001
ENV APP_WEB_URL http://127.0.0.1:3000
EXPOSE 5001
@@ -25,4 +27,4 @@ RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA ${COMMIT_SHA}
ENTRYPOINT ["/entrypoint.sh"]
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

View File

@@ -2,6 +2,8 @@
import os
from datetime import datetime
from werkzeug.exceptions import Forbidden
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
@@ -15,7 +17,7 @@ import flask_login
from flask_cors import CORS
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage
ext_database, ext_storage, ext_mail
from extensions.ext_database import db
from extensions.ext_login import login_manager
@@ -27,7 +29,7 @@ from events import event_handlers
import core
from config import Config, CloudEditionConfig
from commands import register_commands
from models.account import TenantAccountJoin
from models.account import TenantAccountJoin, AccountStatus
from models.model import Account, EndUser, App
import warnings
@@ -83,6 +85,7 @@ def initialize_extensions(app):
ext_celery.init_app(app)
ext_session.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
@@ -100,6 +103,9 @@ def load_user(user_id):
account = db.session.query(Account).filter(Account.id == account_id).first()
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
@@ -149,13 +155,17 @@ def register_blueprints(app):
from controllers.web import bp as web_bp
from controllers.console import bp as console_app_bp
CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
app.register_blueprint(service_api_bp)
CORS(web_bp,
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'],
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)

View File

@@ -18,7 +18,8 @@ from models.model import Account
import secrets
import base64
from models.provider import Provider
from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
@click.command('reset-password', help='Reset the account password.')
@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for tenant in tenants:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
count += 1
except Exception as e:
click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)

View File

@@ -28,9 +28,11 @@ DEFAULTS = {
'SESSION_REDIS_USE_SSL': 'False',
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
'OAUTH_REDIRECT_INDEX_PATH': '/',
'CONSOLE_URL': 'https://cloud.dify.ai',
'API_URL': 'https://api.dify.ai',
'APP_URL': 'https://udify.app',
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
'CONSOLE_API_URL': 'https://cloud.dify.ai',
'SERVICE_API_URL': 'https://api.dify.ai',
'APP_WEB_URL': 'https://udify.app',
'APP_API_URL': 'https://udify.app',
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
@@ -48,7 +50,10 @@ DEFAULTS = {
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai'
'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'TENANT_DOCUMENT_COUNT': 100
}
@@ -76,10 +81,15 @@ class Config:
def __init__(self):
# app settings
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.5"
self.CURRENT_VERSION = "0.3.10"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -147,10 +157,15 @@ class Config:
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'WEB_API_CORS_ALLOW_ORIGINS', '*')
# mail settings
self.MAIL_TYPE = get_env('MAIL_TYPE')
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
# sentry settings
self.SENTRY_DSN = get_env('SENTRY_DSN')
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
@@ -179,6 +194,10 @@ class Config:
# hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
@@ -195,6 +214,8 @@ class Config:
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
class CloudEditionConfig(Config):

View File

@@ -9,10 +9,10 @@ api = ExternalApi(bp)
from . import setup, version, apikey, admin
# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth
from .auth import login, oauth, data_source_oauth, activate
# Import datasets controllers
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
@@ -21,4 +21,4 @@ from .datasets import datasets, datasets_document, datasets_segments, file, hit_
from .workspace import workspace, members, providers, account
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio

View File

@@ -22,6 +22,7 @@ model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
@@ -144,6 +145,7 @@ class AppListApi(Resource):
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
@@ -434,6 +436,7 @@ class AppCopy(Resource):
opening_statement=app_config.opening_statement,
suggested_questions=app_config.suggested_questions,
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
model=app_config.model,
user_input_form=app_config.user_input_form,

View File

@@ -0,0 +1,69 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from flask_login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import AppUnavailableError, \
ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
class ChatMessageAudioApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
app_id = str(app_id)
app_model = _get_app(app_id, 'chat')
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')

View File

@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -209,6 +209,26 @@ class CompletionConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'completion')
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id)
app = _get_app(app_id, 'chat')
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
if not conversation:
raise NotFound("Conversation Not Exists.")
conversation.is_deleted = True
db.session.commit()
return {'result': 'success'}, 204
class ChatConversationApi(Resource):
@@ -356,6 +376,27 @@ class ChatConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'chat')
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id)
# get app info
app = _get_app(app_id, 'chat')
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
if not conversation:
raise NotFound("Conversation Not Exists.")
conversation.is_deleted = True
db.session.commit()
return {'result': 'success'}, 204

View File

@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
class ProviderQuotaExceededError(BaseHTTPException):
error_code = 'provider_quota_exceeded'
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Please go to Settings -> Model Provider to complete your own provider credentials."
code = 400
@@ -49,3 +49,27 @@ class AppMoreLikeThisDisabledError(BaseHTTPException):
error_code = 'app_more_like_this_disabled'
description = "The 'More like this' feature is disabled. Please refresh your page."
code = 403
class NoAudioUploadedError(BaseHTTPException):
error_code = 'no_audio_uploaded'
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = 'audio_too_large'
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = 'unsupported_audio_type'
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400

View File

@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
args['audiences'],
args['hoping_to_solve']
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -41,6 +41,7 @@ class ModelConfigResource(Resource):
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),

View File

@@ -0,0 +1,75 @@
import base64
import secrets
from datetime import datetime
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import email, str_len, supported_language, timezone
from libs.password import valid_password, hash_password
from models.account import AccountStatus, Tenant
from services.account_service import RegisterService
class ActivateCheckApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args')
parser.add_argument('email', type=email, required=True, nullable=False, location='args')
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
tenant = db.session.query(Tenant).filter(
Tenant.id == args['workspace_id'],
Tenant.status == 'normal'
).first()
return {'is_valid': account is not None, 'workspace_name': tenant.name}
class ActivateApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json')
parser.add_argument('email', type=email, required=True, nullable=False, location='json')
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
location='json')
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
if account is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
account.name = args['name']
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(args['password'], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = args['interface_language']
account.timezone = args['timezone']
account.interface_theme = 'light'
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.utcnow()
db.session.commit()
return {'result': 'success'}
api.add_resource(ActivateCheckApi, '/activate/check')
api.add_resource(ActivateApi, '/activate')

View File

@@ -20,7 +20,7 @@ def get_oauth_providers():
client_secret=current_app.config.get(
'NOTION_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_URL') + '/console/api/oauth/data-source/callback/notion')
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
OAUTH_PROVIDERS = {
'notion': notion_oauth
@@ -42,7 +42,7 @@ class OAuthDataSource(Resource):
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
oauth_provider.save_internal_access_token(internal_secret)
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
else:
auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url)
@@ -66,12 +66,12 @@ class OAuthDataSourceCallback(Resource):
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source={error}')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
else:
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=access_denied')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
class OAuthDataSourceSync(Resource):

View File

@@ -20,13 +20,13 @@ def get_oauth_providers():
client_secret=current_app.config.get(
'GITHUB_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_URL') + '/console/api/oauth/authorize/github')
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
client_secret=current_app.config.get(
'GOOGLE_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_URL') + '/console/api/oauth/authorize/google')
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
OAUTH_PROVIDERS = {
'github': github_oauth,
@@ -80,7 +80,7 @@ class OAuthCallback(Resource):
flask_login.login_user(account, remember=True)
AccountService.update_last_login(account, request)
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

View File

@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
document_data=args,
account=current_user
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -18,3 +18,9 @@ class AccountNotLinkTenantError(BaseHTTPException):
error_code = 'account_not_link_tenant'
description = "Account not link tenant."
code = 403
class AlreadyActivateError(BaseHTTPException):
error_code = 'already_activate'
description = "Auth Token is invalid or account already activated, please check again."
code = 403

View File

@@ -0,0 +1,66 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
from models.model import AppModelConfig
class ChatAudioApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')

View File

@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -21,6 +21,7 @@ class AppParameterApi(InstalledAppResource):
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -35,6 +36,7 @@ class AppParameterApi(InstalledAppResource):
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -6,22 +6,23 @@ from flask import current_app, request
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal_with
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \
RepeatPasswordNotMatchError
RepeatPasswordNotMatchError, CurrentPasswordIncorrectError
from controllers.console.wraps import account_initialization_required
from libs.helper import TimestampField, supported_language, timezone
from extensions.ext_database import db
from models.account import InvitationCode, AccountIntegrate
from services.account_service import AccountService
account_fields = {
'id': fields.String,
'name': fields.String,
'avatar': fields.String,
'email': fields.String,
'is_password_set': fields.Boolean,
'interface_language': fields.String,
'interface_theme': fields.String,
'timezone': fields.String,
@@ -194,8 +195,11 @@ class AccountPasswordApi(Resource):
if args['new_password'] != args['repeat_new_password']:
raise RepeatPasswordNotMatchError()
AccountService.update_account_password(
current_user, args['password'], args['new_password'])
try:
AccountService.update_account_password(
current_user, args['password'], args['new_password'])
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
return {"result": "success"}

View File

@@ -7,6 +7,12 @@ class RepeatPasswordNotMatchError(BaseHTTPException):
code = 400
class CurrentPasswordIncorrectError(BaseHTTPException):
error_code = 'current_password_incorrect'
description = "Current password is incorrect."
code = 400
class ProviderRequestFailedError(BaseHTTPException):
error_code = 'provider_request_failed'
description = None

View File

@@ -1,5 +1,5 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
@@ -60,7 +60,8 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
try:
RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, inviter=inviter)
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
inviter=inviter)
account = db.session.query(Account, TenantAccountJoin.role).join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
).filter(Account.email == args['email']).first()
@@ -78,7 +79,16 @@ class MemberInviteEmailApi(Resource):
# todo:413
return {'result': 'success', 'account': account}, 201
return {
'result': 'success',
'account': account,
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
current_app.config.get("CONSOLE_WEB_URL"),
str(current_user.current_tenant_id),
invitee_email,
token
)
}, 201
class MemberCancelInviteApi(Resource):
@@ -88,7 +98,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
member = Account.query.get(str(member_id))
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
if not member:
abort(404)

View File

@@ -3,6 +3,7 @@ import base64
import json
import logging
from flask import current_app
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort
from werkzeug.exceptions import Forbidden
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
ProviderService.init_supported_provider(current_user.current_tenant)
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name))
ProviderName(p.provider_name), only_custom=True)
if p.provider_type == ProviderType.CUSTOM.value else None
}
for p in providers
]
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
is_valid=token_is_valid)
db.session.add(provider_model)
if provider_model.is_valid:
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
db.session.commit()
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
args = parser.parse_args()
# todo: remove this when the provider is supported
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
if provider in [ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
provider_model.is_valid = args['is_enabled']
db.session.commit()
elif not provider_model:
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
if provider == ProviderName.OPENAI.value:
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
elif provider == ProviderName.ANTHROPIC.value:
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
else:
quota_limit = 0
ProviderService.create_system_provider(
tenant,
provider,
quota_limit,
args['is_enabled']
)
else:
abort(403)

View File

@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
api = ExternalApi(bp)
from .app import completion, app, conversation, message
from .app import completion, app, conversation, message, audio
from .dataset import document

View File

@@ -22,6 +22,7 @@ class AppParameterApi(AppApiResource):
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -35,6 +36,7 @@ class AppParameterApi(AppApiResource):
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -0,0 +1,61 @@
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from models.model import App, AppModelConfig
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
class AudioApi(AppApiResource):
def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(AudioApi, '/audio-to-text')

View File

@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -1,4 +1,5 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_restful import fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
@@ -48,6 +49,24 @@ class ConversationApi(AppApiResource):
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource):
@marshal_with(conversation_fields)
def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try:
ConversationService.delete(app_model, conversation_id, end_user)
return {"result": "success"}
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class ConversationRenameApi(AppApiResource):
@@ -74,3 +93,5 @@ class ConversationRenameApi(AppApiResource):
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
api.add_resource(ConversationApi, '/conversations')
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>', endpoint='conversation')
api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail')

View File

@@ -51,3 +51,27 @@ class CompletionRequestError(BaseHTTPException):
description = "Completion request failed."
code = 400
class NoAudioUploadedError(BaseHTTPException):
error_code = 'no_audio_uploaded'
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = 'audio_too_large'
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = 'unsupported_audio_type'
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400

View File

@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
dataset_process_rule=dataset.latest_process_rule,
created_from='api'
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]

View File

@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
api = ExternalApi(bp)
from . import completion, app, conversation, message, site, saved_message
from . import completion, app, conversation, message, site, saved_message, audio, passport

View File

@@ -21,6 +21,7 @@ class AppParameterApi(WebApiResource):
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -34,6 +35,7 @@ class AppParameterApi(WebApiResource):
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -0,0 +1,63 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.web import api
from controllers.web.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
from models.model import App, AppModelConfig
class AudioApi(WebApiResource):
def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(AudioApi, '/audio-to-text')

View File

@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -62,3 +62,27 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = 'app_suggested_questions_after_answer_disabled'
description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page."
code = 403
class NoAudioUploadedError(BaseHTTPException):
error_code = 'no_audio_uploaded'
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = 'audio_too_large'
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = 'unsupported_audio_type'
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400

View File

@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -0,0 +1,64 @@
# -*- coding:utf-8 -*-
import uuid
from controllers.web import api
from flask_restful import Resource
from flask import request
from werkzeug.exceptions import Unauthorized, NotFound
from models.model import Site, EndUser, App
from extensions.ext_database import db
from libs.passport import PassportService
class PassportResource(Resource):
"""Base resource for passport."""
def get(self):
app_id = request.headers.get('X-App-Code')
if app_id is None:
raise Unauthorized('X-App-Code header is missing.')
# get site from db and check if it is normal
site = db.session.query(Site).filter(
Site.code == app_id,
Site.status == 'normal'
).first()
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
raise NotFound()
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='browser',
is_anonymous=True,
session_id=generate_session_id(),
)
db.session.add(end_user)
db.session.commit()
payload = {
"iss": site.app_id,
'sub': 'Web API Passport',
'app_id': site.app_id,
'end_user_id': end_user.id,
}
tk = PassportService().issue(payload)
return {
'access_token': tk,
}
api.add_resource(PassportResource, '/passport')
def generate_session_id():
"""
Generate a unique session ID.
"""
while True:
session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser) \
.filter(EndUser.session_id == session_id).count()
if existing_count == 0:
return session_id

View File

@@ -1,110 +1,48 @@
# -*- coding:utf-8 -*-
import uuid
from functools import wraps
from flask import request, session
from flask import request
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from models.model import App, Site, EndUser
from models.model import App, EndUser
from libs.passport import PassportService
def validate_token(view=None):
def validate_jwt_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
site = validate_and_get_site()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model:
raise NotFound()
if app_model.status != 'normal':
raise NotFound()
if not app_model.enable_site:
raise NotFound()
end_user = create_or_update_end_user_for_session(app_model)
app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def validate_and_get_site():
"""
Validate and get API token.
"""
def decode_jwt_token():
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized('Authorization header is missing.')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
site = db.session.query(Site).filter(
Site.code == auth_token,
Site.status == 'normal'
).first()
if not site:
decoded = PassportService().verify(tk)
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
if not app_model:
raise NotFound()
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
if not end_user:
raise NotFound()
return site
def create_or_update_end_user_for_session(app_model):
"""
Create or update session terminal based on session ID.
"""
if 'session_id' not in session:
session['session_id'] = generate_session_id()
session_id = session.get('session_id')
end_user = db.session.query(EndUser) \
.filter(
EndUser.session_id == session_id,
EndUser.type == 'browser'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='browser',
is_anonymous=True,
session_id=session_id
)
db.session.add(end_user)
db.session.commit()
return end_user
def generate_session_id():
"""
Generate a unique session ID.
"""
count = 1
session_id = ''
while count != 0:
session_id = str(uuid.uuid4())
count = db.session.query(EndUser) \
.filter(EndUser.session_id == session_id).count()
return session_id
return app_model, end_user
class WebApiResource(Resource):
method_decorators = [validate_token]
method_decorators = [validate_jwt_token]

View File

@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials()
@@ -26,3 +31,6 @@ def init_app(app: Flask):
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

View File

@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -69,9 +69,8 @@ class LLMCallbackHandler(BaseCallbackHandler):
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.conversation_message_task.save_message(self.llm_message)

View File

@@ -1,4 +1,5 @@
import math
import re
from typing import Mapping, List, Dict, Any, Optional
from langchain import PromptTemplate
@@ -178,13 +179,20 @@ class MultiDatasetRouterChain(Chain):
route = self.router_chain.route(inputs)
if not route.destination:
destination = ''
if route.destination:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, route.destination, re.IGNORECASE)
if match:
destination = match.group()
if not destination:
return {"text": ''}
elif route.destination in self.dataset_tools:
return {"text": self.dataset_tools[route.destination].run(
elif destination in self.dataset_tools:
return {"text": self.dataset_tools[destination].run(
route.next_inputs['input']
)}
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
f"Received invalid destination chain name '{destination}'"
)

View File

@@ -118,6 +118,7 @@ class Completion:
prompt, stop_words = cls.get_main_llm_prompt(
mode=mode,
llm=final_llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
@@ -129,6 +130,7 @@ class Completion:
cls.recale_llm_max_tokens(
final_llm=final_llm,
model=app_model_config.model_dict,
prompt=prompt,
mode=mode
)
@@ -138,7 +140,8 @@ class Completion:
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
@@ -151,10 +154,11 @@ class Completion:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge:
[CONTEXT]
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
if chain_output:
human_inputs['context'] = chain_output
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT]
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\nHuman: {{query}}\nAI: "
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
inputs=human_inputs
)
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- memory.llm.max_tokens - curr_message_tokens
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
model_name = model['name']
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>"
human_message_prompt += histories + "</histories>"
human_message_prompt += query_prompt
@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
model=app_model_config.model_dict
)
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
model_name = app_model_config.model_dict.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
max_tokens = final_llm.max_tokens
model_name = model.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = model.get("completion_params").get('max_tokens')
if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt)
else:
prompt_tokens = final_llm.get_messages_tokens(prompt)
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
llm: StreamableOpenAI = LLMBuilder.to_llm(
llm = LLMBuilder.to_llm_from_model(
tenant_id=app.tenant_id,
model_name='gpt-3.5-turbo',
model=app_model_config.model_dict,
streaming=streaming
)
@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
original_prompt, _ = cls.get_main_llm_prompt(
mode="completion",
llm=llm,
model=app_model_config.model_dict,
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
cls.recale_llm_max_tokens(
final_llm=llm,
model=app_model_config.model_dict,
prompt=prompt,
mode='completion'
)

View File

@@ -1,6 +1,8 @@
from _decimal import Decimal
models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
@@ -10,10 +12,13 @@ models = {
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
}
max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
@@ -23,17 +28,21 @@ max_context_token_length = {
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
'text-embedding-ada-002': 8191
'text-embedding-ada-002': 8191,
}
models_by_mode = {
'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
],
'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
@@ -52,6 +61,14 @@ models_by_mode = {
model_currency = 'USD'
model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),

View File

@@ -56,7 +56,7 @@ class ConversationMessageTask:
)
def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None
@@ -89,7 +89,7 @@ class ConversationMessageTask:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_messages_tokens([system_message])
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
if not self.conversation:
self.is_new_conversation = True
@@ -185,6 +185,7 @@ class ConversationMessageTask:
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})

View File

@@ -39,6 +39,7 @@ class ExcelLoader(BaseLoader):
row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
data.append(item)
document = Document(page_content=item)
data.append(document)
return [Document(page_content='\n\n'.join(data))]
return data

View File

@@ -81,8 +81,8 @@ class NotionLoader(BaseLoader):
docs = []
if notion_page_type == 'database':
# get all the pages in the database
page_text = self._get_notion_database_data(notion_obj_id)
docs.append(Document(page_content=page_text))
page_text_documents = self._get_notion_database_data(notion_obj_id)
docs.extend(page_text_documents)
elif notion_page_type == 'page':
page_text_list = self._get_notion_block_data(notion_obj_id)
for page_text in page_text_list:
@@ -94,7 +94,7 @@ class NotionLoader(BaseLoader):
def _get_notion_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> str:
) -> List[Document]:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
@@ -110,7 +110,7 @@ class NotionLoader(BaseLoader):
database_content_list = []
if 'results' not in data or data["results"] is None:
return ""
return []
for result in data["results"]:
properties = result['properties']
data = {}
@@ -143,10 +143,10 @@ class NotionLoader(BaseLoader):
row_content = row_content + f'{key}:{value_content}\n'
else:
row_content = row_content + f'{key}:{value}\n'
database_content_list.append(row_content)
database_content_list.append(json.dumps(data, ensure_ascii=False))
document = Document(page_content=row_content)
database_content_list.append(document)
return "\n\n".join(database_content_list)
return database_content_list
def _get_notion_block_data(self, page_id: str) -> List[str]:
result_lines_arr = []

View File

@@ -4,6 +4,7 @@ from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
text_embeddings.extend(embedding_results)
return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists

View File

@@ -2,7 +2,7 @@ import logging
from langchain import PromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException
from langchain.schema import HumanMessage, OutputParserException, BaseMessage
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
@@ -23,10 +23,14 @@ class LLMGenerator:
@classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer):
prompt = CONVERSATION_TITLE_PROMPT
prompt = prompt.format(query=query, answer=answer)
if len(query) > 2000:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_name='gpt-3.5-turbo',
max_tokens=50
)
@@ -40,26 +44,40 @@ class LLMGenerator:
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model = 'gpt-3.5-turbo'
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
if not message.answer:
continue
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_name=model,
max_tokens=max_tokens
)
@@ -102,7 +120,7 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=256
)
@@ -114,6 +132,8 @@ class LLMGenerator:
try:
output = llm(query)
if isinstance(output, BaseMessage):
output = output.content
questions = output_parser.parse(output)
except Exception:
logging.exception("Error generating suggested questions after answer")

View File

@@ -17,7 +17,7 @@ class IndexBuilder:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)

View File

@@ -235,7 +235,8 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
self.filter_string(document.page_content))
return {
"total_segments": total_segments,
@@ -345,8 +346,10 @@ class IndexingRunner:
return text_docs
def filter_string(self, text):
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
return pattern.sub('', text)
text = re.sub(r'<\|', '<', text)
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
return text
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
"""
@@ -425,7 +428,7 @@ class IndexingRunner:
return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
processing_rule: DatasetProcessRule) -> List[Document]:
"""
Split the text documents into nodes.
"""

View File

@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
"""

View File

@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType
from models.provider import ProviderType, ProviderName
class LLMBuilder:
@@ -32,43 +33,43 @@ class LLMBuilder:
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id)
provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == 'openai':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
else:
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion':
if provider == 'openai':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
else:
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
else:
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
}
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls(
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
)
return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
@@ -118,14 +119,30 @@ class LLMBuilder:
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id)
if not provider:
raise ProviderTokenNotInitError()
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider_name = llm_constant.models[model_name]
if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = azure_openai_provider
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name
return provider_name

View File

@@ -1,23 +1,138 @@
from typing import Optional
import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
"""
return {
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
}
return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
return ProviderName.ANTHROPIC
return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'anthropic_api_key': ''
}
if obfuscated:
if not config.get('anthropic_api_key'):
config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

View File

@@ -44,6 +44,7 @@ class AzureProvider(BaseProvider):
config['openai_api_type'] = 'azure'
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 1
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
@@ -51,12 +52,12 @@ class AzureProvider(BaseProvider):
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key()
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'openai_api_type': 'azure',
@@ -80,7 +81,6 @@ class AzureProvider(BaseProvider):
return config
def get_token_type(self):
# TODO: change to dict when implemented
return dict
def config_validate(self, config: Union[dict | str]):
@@ -97,12 +97,11 @@ class AzureProvider(BaseProvider):
models = self.get_models(credentials=config)
if not models:
raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
raise ValidateFailedError("Please add deployments for "
"'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
"and 'gpt-4', 'gpt-35-turbo-16k' (optional).")
"and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).")
fixed_model_ids = [
'text-davinci-003',
'gpt-35-turbo',
'text-embedding-ada-002'
]

View File

@@ -2,7 +2,7 @@ import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
from core import hosted_llm_credentials
from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
@@ -14,15 +14,18 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider = self.get_provider(prefer_custom)
provider = self.get_provider(only_custom)
if not provider:
raise ProviderTokenNotInitError()
raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
@@ -38,18 +41,19 @@ class BaseProvider(ABC):
else:
return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
If both CUSTOM and System providers exist.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
@@ -58,39 +62,31 @@ class BaseProvider(ABC):
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
custom_provider = None
system_provider = None
providers = query.order_by(Provider.provider_type.asc()).all()
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider
return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider
return provider
if custom_provider:
return custom_provider
elif system_provider:
return system_provider
else:
return None
return None
def get_hosted_credentials(self) -> str:
if self.get_provider_name() != ProviderName.OPENAI:
raise ProviderTokenNotInitError()
def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError()
return hosted_llm_credentials.openai.api_key
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key()
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = ''

View File

@@ -31,11 +31,11 @@ class LLMProviderService:
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
return self.provider.get_provider(prefer_custom)
def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]):
"""

View File

@@ -4,6 +4,8 @@ from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key

View File

@@ -1,11 +1,11 @@
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI):
@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
model_kwargs = {
'top_p': params.get('top_p', 1),
'frequency_penalty': params.get('frequency_penalty', 0),
'presence_penalty': params.get('presence_penalty', 0),
}
del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params

View File

@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI):
@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
prompts: List[str],
@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
return params

View File

@@ -0,0 +1,39 @@
from typing import List, Optional, Any, Dict
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params

View File

@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableChatOpenAI(ChatOpenAI):
@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
model_kwargs = {
'top_p': params.get('top_p', 1),
'frequency_penalty': params.get('frequency_penalty', 0),
'presence_penalty': params.get('presence_penalty', 0),
}
del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params

View File

@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableOpenAI(OpenAI):
@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
prompts: List[str],
@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
return params

26
api/core/llm/whisper.py Normal file
View File

@@ -0,0 +1,26 @@
import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName
from core.llm.provider.base import BaseProvider
class Whisper:
def __init__(self, provider: BaseProvider):
self.provider = provider
if self.provider.get_provider_name() == ProviderName.OPENAI:
self.client = openai.Audio
self.credentials = provider.get_credentials()
@handle_openai_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',
file=file,
api_key=self.credentials.get('openai_api_key'),
api_base=self.credentials.get('openai_api_base'),
api_type=self.credentials.get('openai_api_type'),
api_version=self.credentials.get('openai_api_version'),
)

View File

@@ -0,0 +1,27 @@
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper

View File

@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
LLMBadRequestError
def handle_llm_exceptions(func):
def handle_openai_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper
def handle_llm_exceptions_async(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Union
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
return chat_messages
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
return chat_messages

View File

@@ -1,15 +1,15 @@
CONVERSATION_TITLE_PROMPT = (
"Human:{{query}}\n-----\n"
"Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
"If the human said is conducted in Chinese, you should return a Chinese title.\n"
"If the human said is conducted in English, you should return an English title.\n"
"If what the human said is conducted in English, you should only return an English title.\n"
"If what the human said is conducted in Chinese, you should only return a Chinese title.\n"
"title:"
)
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"
"If the conversation communicating in Chinese, you should return a Chinese summary.\n"
"If the conversation communicating in English, you should return an English summary.\n"
"If the following conversation communicating in English, you should only return an English summary.\n"
"If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
"[Conversation Start]\n"
"{context}\n"
"[Conversation End]\n\n"
@@ -19,7 +19,7 @@ CONVERSATION_SUMMARY_PROMPT = (
INTRODUCTION_GENERATE_PROMPT = (
"I am designing a product for users to interact with an AI through dialogue. "
"The Prompt given to the AI before the conversation is:\n\n"
"```\n{{prompt}}\n```\n\n"
"```\n{prompt}\n```\n\n"
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
"but focus on building a relationship with the user:\n"
@@ -27,13 +27,13 @@ INTRODUCTION_GENERATE_PROMPT = (
MORE_LIKE_THIS_GENERATE_PROMPT = (
"-----\n"
"{{original_completion}}\n"
"{original_completion}\n"
"-----\n\n"
"Please use the above content as a sample for generating the result, "
"and include key information points related to the original sample in the result. "
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
"-----\n"
"{{prompt}}\n"
"{prompt}\n"
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (

View File

@@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
@@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)

View File

@@ -1,4 +1,7 @@
from flask import current_app
from events.tenant_event import tenant_was_updated
from models.provider import ProviderName
from services.provider_service import ProviderService
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs):
tenant = sender
if tenant.status == 'normal':
ProviderService.create_system_provider(tenant)
ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)

View File

@@ -1,4 +1,7 @@
from flask import current_app
from events.tenant_event import tenant_was_created
from models.provider import ProviderName
from services.provider_service import ProviderService
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs):
tenant = sender
if tenant.status == 'normal':
ProviderService.create_system_provider(tenant)
ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)

View File

@@ -0,0 +1,61 @@
from typing import Optional
import resend
from flask import Flask
class Mail:
def __init__(self):
self._client = None
self._default_send_from = None
def is_inited(self) -> bool:
return self._client is not None
def init_app(self, app: Flask):
if app.config.get('MAIL_TYPE'):
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
if app.config.get('MAIL_TYPE') == 'resend':
api_key = app.config.get('RESEND_API_KEY')
if not api_key:
raise ValueError('RESEND_API_KEY is not set')
resend.api_key = api_key
self._client = resend.Emails
else:
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
if not self._client:
raise ValueError('Mail client is not initialized')
if not from_ and self._default_send_from:
from_ = self._default_send_from
if not from_:
raise ValueError('mail from is not set')
if not to:
raise ValueError('mail to is not set')
if not subject:
raise ValueError('mail subject is not set')
if not html:
raise ValueError('mail html is not set')
self._client.send({
"from": from_,
"to": to,
"subject": subject,
"html": html
})
def init_app(app: Flask):
mail.init_app(app)
mail = Mail()

20
api/libs/passport.py Normal file
View File

@@ -0,0 +1,20 @@
# -*- coding:utf-8 -*-
import jwt
from werkzeug.exceptions import Unauthorized
from flask import current_app
class PassportService:
def __init__(self):
self.sk = current_app.config.get('SECRET_KEY')
def issue(self, payload):
return jwt.encode(payload, self.sk, algorithm='HS256')
def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=['HS256'])
except jwt.exceptions.InvalidSignatureError:
raise Unauthorized('Invalid token signature.')
except jwt.exceptions.DecodeError:
raise Unauthorized('Invalid token.')
except jwt.exceptions.ExpiredSignatureError:
raise Unauthorized('Token has expired.')

View File

@@ -0,0 +1,32 @@
"""app config add speech_to_text
Revision ID: a5b56fb053ef
Revises: d3d503a3471c
Create Date: 2023-07-06 17:55:20.894149
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a5b56fb053ef'
down_revision = 'd3d503a3471c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('speech_to_text')
# ### end Alembic commands ###

View File

@@ -0,0 +1,32 @@
"""add is_deleted to conversations
Revision ID: d3d503a3471c
Revises: e32f6ccb87c6
Create Date: 2023-06-27 19:13:30.897981
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd3d503a3471c'
down_revision = 'e32f6ccb87c6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.drop_column('is_deleted')
# ### end Alembic commands ###

View File

@@ -38,6 +38,10 @@ class Account(UserMixin, db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def is_password_set(self):
return self.password is not None
@property
def current_tenant(self):
return self._current_tenant

View File

@@ -56,7 +56,8 @@ class App(db.Model):
@property
def api_base_url(self):
return (current_app.config['API_URL'] if current_app.config['API_URL'] else request.host_url.rstrip('/')) + '/v1'
return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
else request.host_url.rstrip('/')) + '/v1'
@property
def tenant(self):
@@ -81,6 +82,7 @@ class AppModelConfig(db.Model):
opening_statement = db.Column(db.Text)
suggested_questions = db.Column(db.Text)
suggested_questions_after_answer = db.Column(db.Text)
speech_to_text = db.Column(db.Text)
more_like_this = db.Column(db.Text)
model = db.Column(db.Text)
user_input_form = db.Column(db.Text)
@@ -104,6 +106,11 @@ class AppModelConfig(db.Model):
def suggested_questions_after_answer_dict(self) -> dict:
return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \
else {"enabled": False}
@property
def speech_to_text_dict(self) -> dict:
return json.loads(self.speech_to_text) if self.speech_to_text \
else {"enabled": False}
@property
def more_like_this_dict(self) -> dict:
@@ -206,6 +213,8 @@ class Conversation(db.Model):
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@property
def model_config(self):
model_config = {}
@@ -221,6 +230,9 @@ class Conversation(db.Model):
model_config['suggested_questions_after_answer'] = override_model_configs[
'suggested_questions_after_answer'] \
if 'suggested_questions_after_answer' in override_model_configs else {"enabled": False}
model_config['speech_to_text'] = override_model_configs[
'speech_to_text'] \
if 'speech_to_text' in override_model_configs else {"enabled": False}
model_config['more_like_this'] = override_model_configs['more_like_this'] \
if 'more_like_this' in override_model_configs else {"enabled": False}
model_config['user_input_form'] = override_model_configs['user_input_form']
@@ -237,6 +249,7 @@ class Conversation(db.Model):
model_config['opening_statement'] = app_model_config.opening_statement
model_config['suggested_questions'] = app_model_config.suggested_questions_list
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
model_config['speech_to_text'] = app_model_config.speech_to_text_dict
model_config['more_like_this'] = app_model_config.more_like_this_dict
model_config['user_input_form'] = app_model_config.user_input_form_list
@@ -503,7 +516,7 @@ class Site(db.Model):
@property
def app_base_url(self):
return (current_app.config['APP_URL'] if current_app.config['APP_URL'] else request.host_url.rstrip('/'))
return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
class ApiToken(db.Model):

View File

@@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
gunicorn~=20.1.0
gevent~=22.10.2
langchain==0.0.209
langchain==0.0.230
openai~=0.27.5
psycopg2-binary~=2.9.6
pycryptodome==3.17
@@ -21,7 +21,7 @@ Authlib==1.2.0
boto3~=1.26.123
tenacity==8.2.2
cachetools~=5.3.0
weaviate-client~=3.16.2
weaviate-client~=3.21.0
qdrant_client~=1.1.6
mailchimp-transactional~=1.0.50
scikit-learn==1.2.2
@@ -32,4 +32,7 @@ redis~=4.5.4
openpyxl==3.1.2
chardet~=5.1.0
docx2txt==0.8
pypdfium2==4.16.0
pypdfium2==4.16.0
resend~=0.5.1
pyjwt~=2.6.0
anthropic~=0.3.4

View File

@@ -2,13 +2,16 @@
import base64
import logging
import secrets
import uuid
from datetime import datetime
from hashlib import sha256
from typing import Optional
from flask import session
from sqlalchemy import func
from events.tenant_event import tenant_was_created
from extensions.ext_redis import redis_client
from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \
TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \
RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError
@@ -16,6 +19,7 @@ from libs.helper import get_remote_ip
from libs.password import compare_password, hash_password
from libs.rsa import generate_key_pair
from models.account import *
from tasks.mail_invite_member_task import send_invite_member_mail_task
class AccountService:
@@ -48,12 +52,18 @@ class AccountService:
@staticmethod
def update_account_password(account, password, new_password):
"""update account password"""
# todo: split validation and update
if account.password and not compare_password(password, account.password, account.password_salt):
raise CurrentPasswordIncorrectError("Current password is incorrect.")
password_hashed = hash_password(new_password, account.password_salt)
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
return account
@@ -283,8 +293,6 @@ class TenantService:
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
"""Remove member from tenant"""
# todo: check permission
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
raise CannotOperateSelfError("Cannot operate self.")
@@ -293,6 +301,12 @@ class TenantService:
raise MemberNotInTenantError("Member not in tenant.")
db.session.delete(ta)
account.initialized_at = None
account.status = AccountStatus.PENDING.value
account.password = None
account.password_salt = None
db.session.commit()
@staticmethod
@@ -332,8 +346,8 @@ class TenantService:
class RegisterService:
@staticmethod
def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
@classmethod
def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
db.session.begin_nested()
"""Register account"""
try:
@@ -359,9 +373,9 @@ class RegisterService:
return account
@staticmethod
def invite_new_member(tenant: Tenant, email: str, role: str = 'normal',
inviter: Account = None) -> TenantAccountJoin:
@classmethod
def invite_new_member(cls, tenant: Tenant, email: str, role: str = 'normal',
inviter: Account = None) -> str:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()
@@ -380,5 +394,71 @@ class RegisterService:
if ta:
raise AccountAlreadyInTenantError("Account already in tenant.")
ta = TenantService.create_tenant_member(tenant, account, role)
return ta
TenantService.create_tenant_member(tenant, account, role)
token = cls.generate_invite_token(tenant, account)
# send email
send_invite_member_mail_task.delay(
to=email,
token=cls.generate_invite_token(tenant, account),
inviter_name=inviter.name if inviter else 'Dify',
workspace_id=tenant.id,
workspace_name=tenant.name,
)
return token
@classmethod
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
token = str(uuid.uuid4())
email_hash = sha256(account.email.encode()).hexdigest()
cache_key = 'member_invite_token:{}, {}:{}'.format(str(tenant.id), email_hash, token)
redis_client.setex(cache_key, 3600, str(account.id))
return token
@classmethod
def revoke_token(cls, workspace_id: str, email: str, token: str):
email_hash = sha256(email.encode()).hexdigest()
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
redis_client.delete(cache_key)
@classmethod
def get_account_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]:
tenant = db.session.query(Tenant).filter(
Tenant.id == workspace_id,
Tenant.status == 'normal'
).first()
if not tenant:
return None
tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
).filter(Account.email == email, TenantAccountJoin.tenant_id == tenant.id).first()
if not tenant_account:
return None
account_id = cls._get_account_id_by_invite_token(workspace_id, email, token)
if not account_id:
return None
account = tenant_account[0]
if not account:
return None
if account_id != str(account.id):
return None
return account
@classmethod
def _get_account_id_by_invite_token(cls, workspace_id: str, email: str, token: str) -> Optional[str]:
email_hash = sha256(email.encode()).hexdigest()
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
account_id = redis_client.get(cache_key)
if not account_id:
return None
return account_id.decode('utf-8')

View File

@@ -4,7 +4,32 @@ import uuid
from core.constant import llm_constant
from models.account import Account
from services.dataset_service import DatasetService
from core.llm.llm_builder import LLMBuilder
MODEL_PROVIDERS = [
'openai',
'anthropic',
]
MODELS_BY_APP_MODE = {
'chat': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
],
'completion': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'text-davinci-003',
]
}
class AppModelConfigService:
@staticmethod
@@ -109,6 +134,26 @@ class AppModelConfigService:
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
# speech_to_text
if 'speech_to_text' not in config or not config["speech_to_text"]:
config["speech_to_text"] = {
"enabled": False
}
if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type")
if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]:
config["speech_to_text"]["enabled"] = False
if not isinstance(config["speech_to_text"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type")
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
raise ValueError("provider not support speech to text")
# more_like_this
if 'more_like_this' not in config or not config["more_like_this"]:
config["more_like_this"] = {
@@ -132,14 +177,14 @@ class AppModelConfigService:
raise ValueError("model must be of object type")
# model.provider
if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
raise ValueError("model.provider must be 'openai'")
if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
# model.name
if 'name' not in config["model"]:
raise ValueError("model.name is required")
if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
raise ValueError("model.name must be in the specified model list")
# model.completion_params
@@ -277,6 +322,7 @@ class AppModelConfigService:
"opening_statement": config["opening_statement"],
"suggested_questions": config["suggested_questions"],
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"],
"more_like_this": config["more_like_this"],
"model": {
"provider": config["model"]["provider"],

View File

@@ -0,0 +1,39 @@
import io
from werkzeug.datastructures import FileStorage
from core.llm.llm_builder import LLMBuilder
from core.llm.provider.llm_provider_service import LLMProviderService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
from core.llm.whisper import Whisper
from models.provider import ProviderName
FILE_SIZE = 15
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']
class AudioService:
@classmethod
def transcript(cls, tenant_id: str, file: FileStorage):
if file is None:
raise NoAudioUploadedServiceError()
extension = file.mimetype
if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]:
raise UnsupportedAudioTypeServiceError()
file_content = file.read()
file_size = len(file_content)
if file_size > FILE_SIZE_LIMIT:
message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message)
provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
if provider_name != ProviderName.OPENAI.value:
raise ProviderNotSupportSpeechToTextServiceError()
provider_service = LLMProviderService(tenant_id, provider_name)
buffer = io.BytesIO(file_content)
buffer.name = 'temp.mp3'
return Whisper(provider_service.provider).transcribe(buffer)

View File

@@ -16,6 +16,7 @@ class ConversationService:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Conversation).filter(
Conversation.is_deleted == False,
Conversation.app_id == app_model.id,
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
@@ -79,6 +80,7 @@ class ConversationService:
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
Conversation.is_deleted == False
).first()
if not conversation:
@@ -90,5 +92,5 @@ class ConversationService:
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
conversation = cls.get_conversation(app_model, conversation_id, user)
db.session.delete(conversation)
conversation.is_deleted = True
db.session.commit()

View File

@@ -4,6 +4,9 @@ import datetime
import time
import random
from typing import Optional, List
from flask import current_app
from extensions.ext_redis import redis_client
from flask_login import current_user
@@ -35,6 +38,7 @@ class DatasetService:
permission_filter = Dataset.permission == 'all_team_members'
datasets = Dataset.query.filter(
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
.order_by(Dataset.created_at.desc()) \
.paginate(
page=page,
per_page=per_page,
@@ -373,6 +377,12 @@ class DocumentService:
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = document_data["data_source"]["type"]
@@ -520,6 +530,14 @@ class DocumentService:
)
return document
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id).count()
return documents_count
@staticmethod
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
@@ -615,6 +633,12 @@ class DocumentService:
@staticmethod
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
# save dataset
dataset = Dataset(
tenant_id=tenant_id,

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
__all__ = [
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
'app', 'completion'
'app', 'completion', 'audio'
]
from . import *

View File

@@ -0,0 +1,13 @@
class NoAudioUploadedServiceError(Exception):
pass
class AudioTooLargeServiceError(Exception):
pass
class UnsupportedAudioTypeServiceError(Exception):
pass
class ProviderNotSupportSpeechToTextServiceError(Exception):
pass

View File

@@ -31,7 +31,7 @@ class HitTestingService:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)

View File

@@ -10,50 +10,40 @@ from models.provider import *
class ProviderService:
@staticmethod
def init_supported_provider(tenant, edition):
def init_supported_provider(tenant):
"""Initialize the model provider, check whether the supported provider has a record"""
providers = Provider.query.filter_by(tenant_id=tenant.id).all()
need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
openai_provider_exists = False
azure_openai_provider_exists = False
# TODO: The cloud version needs to construct the data of the SYSTEM type
providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(need_init_provider_names)
).all()
exists_provider_names = []
for provider in providers:
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
openai_provider_exists = True
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
azure_openai_provider_exists = True
exists_provider_names.append(provider.provider_name)
# Initialize the model provider, check whether the supported provider has a record
not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
# Create default providers if they don't exist
if not openai_provider_exists:
openai_provider = Provider(
tenant_id=tenant.id,
provider_name=ProviderName.OPENAI.value,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(openai_provider)
if not_exists_provider_names:
# Initialize the model provider, check whether the supported provider has a record
for provider_name in not_exists_provider_names:
provider = Provider(
tenant_id=tenant.id,
provider_name=provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
if not azure_openai_provider_exists:
azure_openai_provider = Provider(
tenant_id=tenant.id,
provider_name=ProviderName.AZURE_OPENAI.value,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(azure_openai_provider)
if not openai_provider_exists or not azure_openai_provider_exists:
db.session.commit()
@staticmethod
def get_obfuscated_api_key(tenant, provider_name: ProviderName):
def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_provider_configs(obfuscated=True)
return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
@staticmethod
def get_token_type(tenant, provider_name: ProviderName):
@@ -73,7 +63,7 @@ class ProviderService:
return llm_provider_service.get_encrypted_token(configs)
@staticmethod
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
is_valid: bool = True):
if current_app.config['EDITION'] != 'CLOUD':
return
@@ -90,7 +80,7 @@ class ProviderService:
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=200,
quota_limit=quota_limit,
encrypted_config='',
is_valid=is_valid,
)

View File

@@ -1,6 +1,6 @@
from extensions.ext_database import db
from models.account import Tenant
from models.provider import Provider, ProviderType
from models.provider import Provider, ProviderType, ProviderName
class WorkspaceService:
@@ -33,7 +33,7 @@ class WorkspaceService:
if provider.is_valid and provider.encrypted_config:
custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value:
if provider.is_valid:
if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
system_provider = provider
if system_provider and not custom_provider:

View File

@@ -7,7 +7,7 @@ from celery import shared_task
from core.index.index import IndexBuilder
from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin
AppDatasetJoin, Document
@shared_task
@@ -32,7 +32,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct=index_struct
)
documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
vector_index = IndexBuilder.get_index(dataset, 'high_quality')

View File

@@ -44,14 +44,13 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
documents = []
for dataset_document in dataset_documents:
# delete from vector index
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
) .order_by(DocumentSegment.position.asc()).all()
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -65,8 +64,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index.add_texts(documents)
# save vector index
index.add_texts(documents)
end_at = time.perf_counter()
logging.info(

View File

@@ -28,7 +28,7 @@ def generate_conversation_summary_task(conversation_id: str):
try:
# get conversation messages count
history_message_count = conversation.message_count
if history_message_count >= 5:
if history_message_count >= 5 and not conversation.summary:
app_model = conversation.app
if not app_model:
return

View File

@@ -0,0 +1,52 @@
import logging
import time
import click
from celery import shared_task
from flask import current_app
from extensions.ext_mail import mail
@shared_task
def send_invite_member_mail_task(to: str, token: str, inviter_name: str, workspace_id: str, workspace_name: str):
"""
Async Send invite member mail
:param to
:param token
:param inviter_name
:param workspace_id
:param workspace_name
Usage: send_invite_member_mail_task.delay(to, token, inviter_name, workspace_id, workspace_name)
"""
if not mail.is_inited():
return
logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name),
fg='green'))
start_at = time.perf_counter()
try:
mail.send(
to=to,
subject="{} invited you to join {}".format(inviter_name, workspace_name),
html="""<p>Hi there,</p>
<p>{inviter_name} invited you to join {workspace_name}.</p>
<p>Click <a href="{url}">here</a> to join.</p>
<p>Thanks,</p>
<p>Dify Team</p>""".format(inviter_name=inviter_name, workspace_name=workspace_name,
url='{}/activate?workspace_id={}&email={}&token={}'.format(
current_app.config.get("CONSOLE_WEB_URL"),
workspace_id,
to,
token)
)
)
end_at = time.perf_counter()
logging.info(
click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Send invite member mail to {} failed".format(to))

View File

@@ -41,7 +41,8 @@ def remove_document_from_index_task(document_id: str):
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
vector_index.delete_by_document_id(document.id)
if vector_index:
vector_index.delete_by_document_id(document.id)
# delete from keyword index
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()

Some files were not shown because too many files have changed in this diff Show More