mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 11:55:44 +08:00
Compare commits
47 Commits
feat/optim
...
0.3.15
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
916d8be0ae | ||
|
|
a38412de7b | ||
|
|
9c9f0ddb93 | ||
|
|
f8fbe96da4 | ||
|
|
215a27fd95 | ||
|
|
5cba2e7087 | ||
|
|
5623839c71 | ||
|
|
78d3aa5fcd | ||
|
|
a7c78d2cd2 | ||
|
|
4db35fa375 | ||
|
|
e67a1413b6 | ||
|
|
4f3053a8cc | ||
|
|
b3c2bf125f | ||
|
|
9d5299e9ec | ||
|
|
aee15adf1b | ||
|
|
b185a70c21 | ||
|
|
a3aba7a9aa | ||
|
|
866ee5da91 | ||
|
|
e8039a7da8 | ||
|
|
5e0540077a | ||
|
|
b346bd9b83 | ||
|
|
062e2e915b | ||
|
|
e0a48c4972 | ||
|
|
f53242c081 | ||
|
|
4b53bb1a32 | ||
|
|
4c49ecedb5 | ||
|
|
4ff1870a4b | ||
|
|
6c832ee328 | ||
|
|
25264e7852 | ||
|
|
18dd0d569d | ||
|
|
3ea8d7a019 | ||
|
|
da3f10a55e | ||
|
|
8c991b5b26 | ||
|
|
22c1aafb9b | ||
|
|
8d6d1c442b | ||
|
|
95b179fb39 | ||
|
|
3a0a9e2d8f | ||
|
|
0a0d63457d | ||
|
|
920fb6d0e1 | ||
|
|
fd0fc8f4fe | ||
|
|
1c552ff23a | ||
|
|
5163dd38e5 | ||
|
|
2a27dad2fb | ||
|
|
930f74c610 | ||
|
|
3f250c9e12 | ||
|
|
fa408d264c | ||
|
|
09ea27f1ee |
@@ -30,10 +30,10 @@ Visual data analysis, log review, and annotation for applications
|
||||
- [x] **Spark**
|
||||
- [x] **Wenxin**
|
||||
- [x] **Tongyi**
|
||||
- [x] **ChatGLM**
|
||||
|
||||
|
||||
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
|
||||
* 1000 free Claude model queries to build Claude-powered apps
|
||||
* 600,000 free Claude model tokens to build Claude-powered apps
|
||||
* 200 free OpenAI queries to build OpenAI-based apps
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
|
||||
|
||||
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
|
||||
* 1000 次 Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
|
||||
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
|
||||
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
|
||||
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
|
||||
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用
|
||||
|
||||
@@ -16,7 +16,7 @@ EXPOSE 5001
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev
|
||||
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev nodejs
|
||||
|
||||
COPY requirements.txt /app/api/requirements.txt
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -102,6 +102,7 @@ def reset_encrypt_key_pair():
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
|
||||
db.session.query(ProviderModel).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
|
||||
@@ -48,9 +48,7 @@ DEFAULTS = {
|
||||
'WEAVIATE_GRPC_ENABLED': 'True',
|
||||
'WEAVIATE_BATCH_SIZE': 100,
|
||||
'CELERY_BACKEND': 'database',
|
||||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
@@ -102,13 +100,12 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.14"
|
||||
self.CURRENT_VERSION = "0.3.15"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
self.PDF_PREVIEW = get_bool_env('PDF_PREVIEW')
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
@@ -236,10 +233,6 @@ class Config:
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
import flask_restful
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -3,7 +3,9 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
import flask
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -316,7 +318,7 @@ class AppApi(Resource):
|
||||
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
db.session.delete(app)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy import or_, func
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Union, Generator
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
@@ -16,6 +16,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.login.login import login_required
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.login.login import login_required
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppModelConfig
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -5,9 +5,12 @@ from typing import Optional
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.login.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
import services
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@@ -764,11 +765,13 @@ class DocumentMetadataApi(DocumentResource):
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
document.doc_metadata = {}
|
||||
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
if doc_type == 'others':
|
||||
document.doc_metadata = doc_metadata
|
||||
else:
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.utcnow()
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
@@ -15,6 +14,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.login.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
@@ -8,7 +8,8 @@ from pathlib import Path
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, fields
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
|
||||
@@ -3,7 +3,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@@ -43,6 +46,13 @@ tenants_fields = {
|
||||
'current': fields.Boolean
|
||||
}
|
||||
|
||||
workspace_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'status': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@@ -57,6 +67,38 @@ class TenantListApi(Resource):
|
||||
return {'workspaces': marshal(tenants, tenants_fields)}, 200
|
||||
|
||||
|
||||
class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
|
||||
.paginate(page=args['page'], per_page=args['limit'])
|
||||
|
||||
has_more = False
|
||||
if len(tenants.items) == args['limit']:
|
||||
current_page_first_tenant = tenants[-1]
|
||||
rest_count = db.session.query(Tenant).filter(
|
||||
Tenant.created_at < current_page_first_tenant.created_at,
|
||||
Tenant.id != current_page_first_tenant.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
total = db.session.query(Tenant).count()
|
||||
return {
|
||||
'data': marshal(tenants.items, workspace_fields),
|
||||
'has_more': has_more,
|
||||
'limit': args['limit'],
|
||||
'page': args['page'],
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -92,6 +134,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
||||
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
|
||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
||||
|
||||
@@ -17,7 +17,7 @@ def validate_app_token(view=None):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).get(api_token.app_id)
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
|
||||
@@ -44,7 +44,7 @@ def validate_dataset_token(view=None):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('dataset')
|
||||
|
||||
dataset = db.session.query(Dataset).get(api_token.dataset_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound()
|
||||
|
||||
@@ -64,14 +64,14 @@ def validate_and_get_api_token(scope=None):
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized()
|
||||
if auth_header is None or ' ' not in auth_header:
|
||||
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized()
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
api_token = db.session.query(ApiToken).filter(
|
||||
ApiToken.token == auth_token,
|
||||
@@ -79,7 +79,7 @@ def validate_and_get_api_token(scope=None):
|
||||
).first()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized()
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.status = 'llm_end'
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
|
||||
@@ -119,9 +119,11 @@ class ConversationMessageTask:
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=self.model_instance.get_currency(),
|
||||
@@ -140,17 +142,24 @@ class ConversationMessageTask:
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
|
||||
|
||||
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.total_price = total_price
|
||||
|
||||
@@ -192,7 +201,9 @@ class ConversationMessageTask:
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=agent_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
@@ -206,25 +217,26 @@ class ConversationMessageTask:
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
|
||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_total_price = self.calc_total_price(
|
||||
loop_message_tokens,
|
||||
agent_message_unit_price,
|
||||
loop_answer_tokens,
|
||||
agent_answer_unit_price
|
||||
)
|
||||
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.message_price_unit = agent_message_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.answer_price_unit = agent_answer_price_unit
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
@@ -243,15 +255,6 @@ class ConversationMessageTask:
|
||||
|
||||
db.session.add(dataset_query)
|
||||
|
||||
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
|
||||
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def end(self):
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ class DatesetDocumentStore:
|
||||
content=doc.page_content,
|
||||
word_count=len(doc.page_content),
|
||||
tokens=tokens,
|
||||
enabled=False,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
if 'answer' in doc.metadata and doc.metadata['answer']:
|
||||
|
||||
@@ -278,7 +278,7 @@ class IndexingRunner:
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -286,7 +286,7 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
@@ -371,7 +371,7 @@ class IndexingRunner:
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -379,7 +379,7 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
@@ -691,6 +691,7 @@ class IndexingRunner:
|
||||
DocumentSegment.status == "indexing"
|
||||
).update({
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.utcnow()
|
||||
})
|
||||
|
||||
|
||||
108
api/core/login/login.py
Normal file
108
api/core/login/login.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
import flask_login
|
||||
from flask import current_app
|
||||
from flask import g
|
||||
from flask import has_request_context
|
||||
from flask import request
|
||||
from flask_login import user_logged_in
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
|
||||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user = LocalProxy(lambda: _get_user())
|
||||
|
||||
|
||||
def login_required(func):
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
|
||||
example::
|
||||
|
||||
@app.route('/post')
|
||||
@login_required
|
||||
def post():
|
||||
pass
|
||||
|
||||
If there are only certain times you need to require that your user is
|
||||
logged in, you can do so with::
|
||||
|
||||
if not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
...which is essentially the code that this function adds to your views.
|
||||
|
||||
It can be convenient to globally turn off authentication when unit testing.
|
||||
To enable this, if the application configuration variable `LOGIN_DISABLED`
|
||||
is set to `True`, this decorator will be ignored.
|
||||
|
||||
.. Note ::
|
||||
|
||||
Per `W3 guidelines for CORS preflight requests
|
||||
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
|
||||
HTTP ``OPTIONS`` requests are exempt from login checks.
|
||||
|
||||
:param func: The view function to decorate.
|
||||
:type func: function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False')
|
||||
if admin_api_key_enable:
|
||||
if auth_header:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
admin_api_key = os.getenv('ADMIN_API_KEY')
|
||||
|
||||
if admin_api_key:
|
||||
if os.getenv('ADMIN_API_KEY') == auth_token:
|
||||
workspace_id = request.headers.get('X-WORKSPACE-ID')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == workspace_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
|
||||
pass
|
||||
elif not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
# flask 1.x compatibility
|
||||
# current_app.ensure_sync is only available in Flask >= 2.0
|
||||
if callable(getattr(current_app, "ensure_sync", None)):
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def _get_user():
|
||||
if has_request_context():
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user()
|
||||
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
@@ -57,6 +57,12 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'huggingface_hub':
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
return HuggingfaceHubProvider
|
||||
elif provider_name == 'xinference':
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
return XinferenceProvider
|
||||
elif provider_name == 'openllm':
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
return OpenLLMProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -26,10 +26,20 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
**self.credentials
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base')
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
@@ -48,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
@@ -71,7 +71,7 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
return LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
import decimal
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
@@ -17,6 +19,63 @@ class BaseEmbedding(BaseProviderModel):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def calc_tokens_price(self, tokens: int) -> decimal.Decimal:
|
||||
"""
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:return: decimal.Decimal('0.0000001')
|
||||
"""
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.price_config['unit']
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:return: decimal.Decimal('0.0001')
|
||||
|
||||
"""
|
||||
unit_price = self.price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logger.debug(f'unit_price:{unit_price}')
|
||||
return unit_price
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
@@ -29,11 +88,14 @@ class BaseEmbedding(BaseProviderModel):
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
currency = self.price_config['currency']
|
||||
return currency
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from langchain.embeddings import MiniMaxEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
@@ -22,12 +19,6 @@ class MinimaxEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
|
||||
@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
@@ -65,7 +55,7 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = XinferenceEmbeddings(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid'],
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': decimal.Decimal('1.63'),
|
||||
'completion': decimal.Decimal('5.51'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': decimal.Decimal('11.02'),
|
||||
'completion': decimal.Decimal('32.68'),
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1m * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
|
||||
@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-35-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-35-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
base_model_name = self.credentials.get("base_model_name")
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[base_model_name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
@@ -166,7 +135,7 @@ class AzureOpenAIModel(BaseLLM):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
return LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
import decimal
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(BaseProviderModel):
|
||||
@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
|
||||
def _init_client(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get llm base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'prompt': decimal.Decimal(price_config['prompt']),
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
@@ -161,25 +197,64 @@ class BaseLLM(BaseProviderModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.get_price_unit(message_type)
|
||||
|
||||
@abstractmethod
|
||||
def get_currency(self):
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.0001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"unit_price={unit_price}")
|
||||
return unit_price
|
||||
|
||||
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get price unit.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
|
||||
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"price_unit={price_unit}")
|
||||
return price_unit
|
||||
|
||||
def get_currency(self) -> str:
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return:
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
raise NotImplementedError
|
||||
currency = self.price_config['currency']
|
||||
return currency
|
||||
|
||||
def get_model_kwargs(self):
|
||||
return self.model_kwargs
|
||||
|
||||
@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# not support calc price
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
||||
@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
|
||||
# TODO load price config from configs(db)
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
@@ -185,7 +148,7 @@ class OpenAIModel(BaseLLM):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
60
api/core/model_providers/models/llm/openllm_model.py
Normal file
60
api/core/model_providers/models/llm/openllm_model.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
|
||||
|
||||
class OpenLLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = OpenLLM(
|
||||
server_url=self.credentials.get('server_url'),
|
||||
callbacks=self.callbacks,
|
||||
llm_kwargs=self.provider_model_kwargs
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
|
||||
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.input = provider_model_kwargs
|
||||
|
||||
@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
|
||||
contents = [message.content for message in messages]
|
||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
# TODO load price_config from configs(db)
|
||||
return Wenxin(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'ernie-bot': {
|
||||
'prompt': decimal.Decimal('0.012'),
|
||||
'completion': decimal.Decimal('0.012'),
|
||||
},
|
||||
'ernie-bot-turbo': {
|
||||
'prompt': decimal.Decimal('0.008'),
|
||||
'completion': decimal.Decimal('0.008')
|
||||
},
|
||||
'bloomz-7b': {
|
||||
'prompt': decimal.Decimal('0.006'),
|
||||
'completion': decimal.Decimal('0.006')
|
||||
}
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
|
||||
70
api/core/model_providers/models/llm/xinference_model.py
Normal file
70
api/core/model_providers/models/llm/xinference_model.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
|
||||
|
||||
class XinferenceModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = XinferenceLLM(
|
||||
server_url=self.credentials['server_url'],
|
||||
model_uid=self.credentials['model_uid'],
|
||||
)
|
||||
|
||||
client.callbacks = self.callbacks
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate(
|
||||
[prompts],
|
||||
stop,
|
||||
callbacks,
|
||||
generate_config={
|
||||
"stop": stop,
|
||||
"stream": self.streaming,
|
||||
**self.provider_model_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Xinference: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@@ -41,7 +41,7 @@ class OpenAIModeration(BaseProviderModel):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
@@ -40,7 +40,7 @@ class OpenAIWhisper(BaseSpeech2Text):
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
|
||||
138
api/core/model_providers/providers/openllm_provider.py
Normal file
138
api/core/model_providers/providers/openllm_provider.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class OpenLLMProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'openllm'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenLLMModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'server_url': credentials['server_url']
|
||||
}
|
||||
|
||||
llm = OpenLLM(
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
@@ -116,7 +116,8 @@ class ReplicateProvider(BaseModelProvider):
|
||||
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
|
||||
elif model_type == ModelType.TEXT_GENERATION \
|
||||
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
and ('items' not in rst.openapi_schema['components']['schemas']['Output']
|
||||
or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
|
||||
except ReplicateError as e:
|
||||
|
||||
193
api/core/model_providers/providers/xinference_provider.py
Normal file
193
api/core/model_providers/providers/xinference_provider.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class XinferenceProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'xinference'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = XinferenceModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = XinferenceEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
credentials = self.get_model_credentials(model_name, model_type)
|
||||
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
elif credentials['model_format'] == "ggmlv3":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
|
||||
|
||||
if 'model_uid' not in credentials:
|
||||
raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'server_url': credentials['server_url'],
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None,
|
||||
'model_uid': None,
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _get_extra_credentials(self, credentials: dict) -> dict:
|
||||
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get the model description, detail: {response.json()['detail']}"
|
||||
)
|
||||
desc = response.json()
|
||||
|
||||
extra_credentials = {
|
||||
'model_format': desc['model_format'],
|
||||
}
|
||||
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
|
||||
extra_credentials['model_handle_type'] = 'chatglm'
|
||||
elif "generate" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'generate'
|
||||
elif "chat" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'chat'
|
||||
else:
|
||||
raise NotImplementedError(f"Model handle type not supported.")
|
||||
|
||||
return extra_credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
@@ -8,5 +8,7 @@
|
||||
"wenxin",
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub"
|
||||
"huggingface_hub",
|
||||
"xinference",
|
||||
"openllm"
|
||||
]
|
||||
@@ -11,5 +11,19 @@
|
||||
"quota_unit": "tokens",
|
||||
"quota_limit": 600000
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"claude-instant-1": {
|
||||
"prompt": "1.63",
|
||||
"completion": "5.51",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"claude-2": {
|
||||
"prompt": "11.02",
|
||||
"completion": "32.68",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,5 +3,48 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
"model_flexibility": "configurable",
|
||||
"price_config":{
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo": {
|
||||
"prompt": "0.002",
|
||||
"completion": "0.0015",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-002": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,5 +9,24 @@
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"abab5.5-chat": {
|
||||
"prompt": "0.015",
|
||||
"completion": "0.015",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"abab5-chat": {
|
||||
"prompt": "0.015",
|
||||
"completion": "0.015",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"embo-01": {
|
||||
"completion": "0",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,5 +10,42 @@
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 200
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
||||
7
api/core/model_providers/rules/openllm.json
Normal file
7
api/core/model_providers/rules/openllm.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
@@ -9,5 +9,19 @@
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"spark": {
|
||||
"prompt": "0.18",
|
||||
"completion": "0.18",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"spark-v2": {
|
||||
"prompt": "0.36",
|
||||
"completion": "0.36",
|
||||
"unit": "0.0001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,5 +3,25 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"ernie-bot": {
|
||||
"prompt": "0.012",
|
||||
"completion": "0.012",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"ernie-bot-turbo": {
|
||||
"prompt": "0.008",
|
||||
"completion": "0.008",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"bloomz-7b": {
|
||||
"prompt": "0.006",
|
||||
"completion": "0.006",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
7
api/core/model_providers/rules/xinference.json
Normal file
7
api/core/model_providers/rules/xinference.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
|
||||
class XinferenceEmbedding(XinferenceEmbeddings):
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
vectors = super().embed_documents(texts)
|
||||
|
||||
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
||||
|
||||
return normalized_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
vector = super().embed_query(text)
|
||||
|
||||
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
return normalized_vector
|
||||
84
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
84
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenLLM(LLM):
|
||||
"""OpenLLM, supporting both in-process model
|
||||
instance and remote OpenLLM servers.
|
||||
|
||||
If you have a OpenLLM server running, you can also use it remotely:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenLLM
|
||||
llm = OpenLLM(server_url='http://localhost:3000')
|
||||
llm("What is the difference between a duck and a goose?")
|
||||
"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to be passed to openllm.LLM"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openllm"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"llm_config": self.llm_kwargs
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(
|
||||
f'{self.server_url}/v1/generate',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
completion = json_response["responses"][0]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
raise NotImplementedError(
|
||||
"Async call is not supported for OpenLLM at the moment."
|
||||
)
|
||||
132
api/core/third_party/langchain/llms/xinference_llm.py
vendored
Normal file
132
api/core/third_party/langchain/llms/xinference_llm.py
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Optional, List, Any, Union, Generator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Xinference
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
||||
RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
||||
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Returns:
|
||||
The generated string by the model.
|
||||
"""
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
if isinstance(model, RESTfulChatModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["message"]["content"]
|
||||
elif isinstance(model, RESTfulGenerateModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
|
||||
else:
|
||||
completion = model.generate(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["text"]
|
||||
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
|
||||
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
completion = completion["choices"][0]["message"]["content"]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
model: The model used for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
if isinstance(model, RESTfulGenerateModelHandle):
|
||||
streaming_response = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
else:
|
||||
streaming_response = model.chat(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
|
||||
for chunk in streaming_response:
|
||||
if isinstance(chunk, dict):
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
token = choice.get("text", "")
|
||||
log_probs = choice.get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield token
|
||||
@@ -88,6 +88,11 @@ class WebReaderTool(BaseTool):
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
if len(docs) == 0:
|
||||
return "No content found."
|
||||
|
||||
docs = docs[1:]
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
@@ -20,6 +20,10 @@ def handle(sender, **kwargs):
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer)
|
||||
|
||||
if len(name) > 75:
|
||||
name = name[:75] + '...'
|
||||
|
||||
conversation.name = name
|
||||
except:
|
||||
conversation.name = 'New Chat'
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""add message price unit
|
||||
|
||||
Revision ID: 853f9b9cd3b6
|
||||
Revises: e8883b0148c9
|
||||
Create Date: 2023-08-19 17:01:57.471562
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '853f9b9cd3b6'
|
||||
down_revision = 'e8883b0148c9'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.drop_column('answer_price_unit')
|
||||
batch_op.drop_column('message_price_unit')
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.drop_column('answer_price_unit')
|
||||
batch_op.drop_column('message_price_unit')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -421,9 +421,11 @@ class Message(db.Model):
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
answer = db.Column(db.Text, nullable=False)
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0'))
|
||||
total_price = db.Column(db.Numeric(10, 7))
|
||||
currency = db.Column(db.String(255), nullable=False)
|
||||
@@ -705,9 +707,11 @@ class MessageAgentThought(db.Model):
|
||||
message = db.Column(db.Text, nullable=True)
|
||||
message_token = db.Column(db.Integer, nullable=True)
|
||||
message_unit_price = db.Column(db.Numeric, nullable=True)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
answer_token = db.Column(db.Integer, nullable=True)
|
||||
answer_unit_price = db.Column(db.Numeric, nullable=True)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
tokens = db.Column(db.Integer, nullable=True)
|
||||
total_price = db.Column(db.Numeric, nullable=True)
|
||||
currency = db.Column(db.String, nullable=True)
|
||||
|
||||
@@ -48,4 +48,5 @@ dashscope~=1.5.0
|
||||
huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
pandas==1.5.3
|
||||
xinference==0.2.0
|
||||
@@ -284,8 +284,9 @@ class DocumentService:
|
||||
"github_link": str,
|
||||
"open_source_license": str,
|
||||
"commit_date": str,
|
||||
"commit_author": str
|
||||
}
|
||||
"commit_author": str,
|
||||
},
|
||||
"others": dict
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -370,8 +371,8 @@ class DocumentService:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be recover
|
||||
document.is_paused = False
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = time.time()
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@@ -972,7 +973,7 @@ class SegmentService:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment vector index
|
||||
VectorService.create_segment_vector(args['keywords'], segment, dataset)
|
||||
VectorService.update_segment_vector(args['keywords'], segment, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
|
||||
@@ -32,4 +32,11 @@ WENXIN_API_KEY=
|
||||
WENXIN_SECRET_KEY=
|
||||
|
||||
# ChatGLM Credentials
|
||||
CHATGLM_API_BASE=
|
||||
CHATGLM_API_BASE=
|
||||
|
||||
# Xinference Credentials
|
||||
XINFERENCE_SERVER_URL=
|
||||
XINFERENCE_MODEL_UID=
|
||||
|
||||
# OpenLLM Credentials
|
||||
OPENLLM_SERVER_URL=
|
||||
@@ -0,0 +1,65 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
def get_mock_provider():
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='xinference',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config='',
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_embedding_model(mocker):
|
||||
model_name = 'vicuna-v1.3'
|
||||
server_url = os.environ['XINFERENCE_SERVER_URL']
|
||||
model_uid = os.environ['XINFERENCE_MODEL_UID']
|
||||
model_provider = XinferenceProvider(provider=get_mock_provider())
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
provider_name='xinference',
|
||||
model_name=model_name,
|
||||
model_type=ModelType.EMBEDDINGS.value,
|
||||
encrypted_config=json.dumps({
|
||||
'server_url': server_url,
|
||||
'model_uid': model_uid
|
||||
}),
|
||||
is_valid=True,
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
return XinferenceEmbedding(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
return encrypted_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_embed_documents(mock_decrypt, mocker):
|
||||
embedding_model = get_mock_embedding_model(mocker)
|
||||
rst = embedding_model.client.embed_documents(['test', 'test1'])
|
||||
assert isinstance(rst, list)
|
||||
assert len(rst) == 2
|
||||
assert len(rst[0]) == 4096
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_embed_query(mock_decrypt, mocker):
|
||||
embedding_model = get_mock_embedding_model(mocker)
|
||||
rst = embedding_model.client.embed_query('test')
|
||||
assert isinstance(rst, list)
|
||||
assert len(rst) == 4096
|
||||
@@ -50,7 +50,9 @@ def test_get_num_tokens(mock_decrypt):
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt):
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('claude-2')
|
||||
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
|
||||
rst = model.run(
|
||||
@@ -58,4 +60,3 @@ def test_run(mock_decrypt):
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
assert rst.content.strip() == '2'
|
||||
|
||||
@@ -76,6 +76,8 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
|
||||
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
|
||||
rst = openai_model.run(
|
||||
@@ -83,4 +85,3 @@ def test_run(mock_decrypt, mocker):
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
assert rst.content.strip() == 'n'
|
||||
|
||||
@@ -95,6 +95,8 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_hosted_inference_api_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model(
|
||||
'google/flan-t5-base',
|
||||
'hosted_inference_api',
|
||||
@@ -111,6 +113,8 @@ def test_hosted_inference_api_run(mock_decrypt, mocker):
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_inference_endpoints_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model(
|
||||
'',
|
||||
'inference_endpoints',
|
||||
@@ -121,4 +125,3 @@ def test_inference_endpoints_run(mock_decrypt, mocker):
|
||||
[PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
assert rst.content.strip() == 'no'
|
||||
|
||||
@@ -54,11 +54,12 @@ def test_get_num_tokens(mock_decrypt):
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt):
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('abab5.5-chat')
|
||||
rst = model.run(
|
||||
[PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
assert rst.content.strip() == 'n'
|
||||
|
||||
@@ -58,7 +58,9 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt):
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
openai_model = get_mock_openai_model('text-davinci-003')
|
||||
rst = openai_model.run(
|
||||
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||
@@ -69,7 +71,9 @@ def test_run(mock_decrypt):
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_chat_run(mock_decrypt):
|
||||
def test_chat_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
openai_model = get_mock_openai_model('gpt-3.5-turbo')
|
||||
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
|
||||
rst = openai_model.run(
|
||||
@@ -77,4 +81,3 @@ def test_chat_run(mock_decrypt):
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
assert rst.content.strip() == 'n'
|
||||
|
||||
72
api/tests/integration_tests/models/llm/test_openllm_model.py
Normal file
72
api/tests/integration_tests/models/llm/test_openllm_model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
def get_mock_provider():
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openllm',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config='',
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_model(model_name, mocker):
|
||||
model_kwargs = ModelKwargs(
|
||||
max_tokens=10,
|
||||
temperature=0.01
|
||||
)
|
||||
server_url = os.environ['OPENLLM_SERVER_URL']
|
||||
model_provider = OpenLLMProvider(provider=get_mock_provider())
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
provider_name='openllm',
|
||||
model_name=model_name,
|
||||
model_type=ModelType.TEXT_GENERATION.value,
|
||||
encrypted_config=json.dumps({
|
||||
'server_url': server_url
|
||||
}),
|
||||
is_valid=True,
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
return OpenLLMModel(
|
||||
model_provider=model_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
return encrypted_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('facebook/opt-125m', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('facebook/opt-125m', mocker)
|
||||
messages = [PromptMessage(content='Human: who are you? \nAnswer: ')]
|
||||
rst = model.run(
|
||||
messages
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
@@ -65,6 +65,8 @@ def test_get_num_tokens(mock_decrypt, mocker):
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
|
||||
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
|
||||
rst = model.run(
|
||||
|
||||
@@ -58,7 +58,9 @@ def test_get_num_tokens(mock_decrypt):
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt):
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('spark')
|
||||
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
||||
rst = model.run(
|
||||
@@ -66,4 +68,3 @@ def test_run(mock_decrypt):
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
assert rst.content.strip() == '2'
|
||||
|
||||
@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt):
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('qwen-v1')
|
||||
rst = model.run(
|
||||
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||
|
||||
@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt):
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('ernie-bot')
|
||||
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
||||
rst = model.run(
|
||||
@@ -60,4 +62,3 @@ def test_run(mock_decrypt):
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
assert rst.content.strip() == '2'
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
def get_mock_provider():
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='xinference',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config='',
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_model(model_name, mocker):
|
||||
model_kwargs = ModelKwargs(
|
||||
max_tokens=10,
|
||||
temperature=0.01
|
||||
)
|
||||
server_url = os.environ['XINFERENCE_SERVER_URL']
|
||||
model_uid = os.environ['XINFERENCE_MODEL_UID']
|
||||
model_provider = XinferenceProvider(provider=get_mock_provider())
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
provider_name='xinference',
|
||||
model_name=model_name,
|
||||
model_type=ModelType.TEXT_GENERATION.value,
|
||||
encrypted_config=json.dumps({
|
||||
'server_url': server_url,
|
||||
'model_uid': model_uid
|
||||
}),
|
||||
is_valid=True,
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
return XinferenceModel(
|
||||
model_provider=model_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
return encrypted_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('llama-2-chat', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('llama-2-chat', mocker)
|
||||
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
|
||||
rst = model.run(
|
||||
messages
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
122
api/tests/unit_tests/model_providers/test_openllm_provider.py
Normal file
122
api/tests/unit_tests/model_providers/test_openllm_provider.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
PROVIDER_NAME = 'openllm'
|
||||
MODEL_PROVIDER_CLASS = OpenLLMProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'server_url': 'http://127.0.0.1:3333/'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call',
|
||||
return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='username/test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_invalid(mocker):
|
||||
# raise CredentialsValidateFailedError if credential is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={}
|
||||
)
|
||||
|
||||
# raise CredentialsValidateFailedError if credential is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={'server_url': 'invalid'})
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt):
|
||||
api_key = 'http://127.0.0.1:3333/'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result['server_url'] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
)
|
||||
assert result['server_url'] == 'http://127.0.0.1:3333/'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
obfuscated=True
|
||||
)
|
||||
middle_token = result['server_url'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
130
api/tests/unit_tests/model_providers/test_xinference_provider.py
Normal file
130
api/tests/unit_tests/model_providers/test_xinference_provider.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
PROVIDER_NAME = 'xinference'
|
||||
MODEL_PROVIDER_CLASS = XinferenceProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'model_uid': 'fake-model-uid',
|
||||
'server_url': 'http://127.0.0.1:9997/'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
|
||||
return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='username/test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if replicate_api_token is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={}
|
||||
)
|
||||
|
||||
# raise CredentialsValidateFailedError if replicate_api_token is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={'server_url': 'invalid'})
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt, mocker):
|
||||
api_key = 'http://127.0.0.1:9997/'
|
||||
|
||||
mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
|
||||
return_value={
|
||||
'model_handle_type': 'generate',
|
||||
'model_format': 'ggmlv3'
|
||||
})
|
||||
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result['server_url'] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
)
|
||||
assert result['server_url'] == 'http://127.0.0.1:9997/'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
obfuscated=True
|
||||
)
|
||||
middle_token = result['server_url'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.14
|
||||
image: langgenius/dify-api:0.3.15
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -124,7 +124,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.14
|
||||
image: langgenius/dify-api:0.3.15
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -176,7 +176,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.14
|
||||
image: langgenius/dify-web:0.3.15
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@@ -17,6 +17,7 @@ import type { App } from '@/types/app'
|
||||
import type { UpdateAppSiteCodeResponse } from '@/models/app'
|
||||
import { asyncRunSafe } from '@/utils'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import type { IAppCardProps } from '@/app/components/app/overview/appCard'
|
||||
|
||||
export type ICardViewProps = {
|
||||
appId: string
|
||||
@@ -68,7 +69,7 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
||||
handleError(err)
|
||||
}
|
||||
|
||||
const onSaveSiteConfig = async (params: any) => {
|
||||
const onSaveSiteConfig: IAppCardProps['onSaveSiteConfig'] = async (params) => {
|
||||
const [err] = await asyncRunSafe<App>(
|
||||
updateAppSiteConfig({
|
||||
url: `/apps/${appId}/site`,
|
||||
|
||||
@@ -16,7 +16,6 @@ const Overview = async ({
|
||||
const { t } = await useTranslation(locale, 'app-overview')
|
||||
return (
|
||||
<div className="h-full px-16 py-6 overflow-scroll">
|
||||
{/* <WelcomeBanner /> */}
|
||||
<ApikeyInfoPanel />
|
||||
<div className='flex flex-row items-center justify-between mb-4 text-xl text-gray-900'>
|
||||
{t('overview.title')}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user