mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
20 Commits
feat/optim
...
fix/web-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d942d99dc | ||
|
|
78d3aa5fcd | ||
|
|
a7c78d2cd2 | ||
|
|
4db35fa375 | ||
|
|
e67a1413b6 | ||
|
|
4f3053a8cc | ||
|
|
b3c2bf125f | ||
|
|
9d5299e9ec | ||
|
|
aee15adf1b | ||
|
|
b185a70c21 | ||
|
|
a3aba7a9aa | ||
|
|
866ee5da91 | ||
|
|
e8039a7da8 | ||
|
|
5e0540077a | ||
|
|
b346bd9b83 | ||
|
|
062e2e915b | ||
|
|
e0a48c4972 | ||
|
|
f53242c081 | ||
|
|
4b53bb1a32 | ||
|
|
4c49ecedb5 |
@@ -20,7 +20,7 @@ from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -102,6 +102,7 @@ def reset_encrypt_key_pair():
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
|
||||
db.session.query(ProviderModel).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
import flask_restful
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -3,7 +3,9 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
import flask
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -316,7 +318,7 @@ class AppApi(Resource):
|
||||
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
db.session.delete(app)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import login_required
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy import or_, func
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Union, Generator
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
@@ -16,6 +16,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.login.login import login_required
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.login.login import login_required
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppModelConfig
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -5,9 +5,12 @@ from typing import Optional
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask_login import current_user, login_required
|
||||
from flask_login import current_user
|
||||
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.login.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
import services
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@@ -764,11 +765,13 @@ class DocumentMetadataApi(DocumentResource):
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
document.doc_metadata = {}
|
||||
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
if doc_type == 'others':
|
||||
document.doc_metadata = doc_metadata
|
||||
else:
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.utcnow()
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
@@ -15,6 +14,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.login.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
@@ -8,7 +8,8 @@ from pathlib import Path
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, fields
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
|
||||
@@ -3,7 +3,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@@ -43,6 +46,13 @@ tenants_fields = {
|
||||
'current': fields.Boolean
|
||||
}
|
||||
|
||||
workspace_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'status': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@@ -57,6 +67,38 @@ class TenantListApi(Resource):
|
||||
return {'workspaces': marshal(tenants, tenants_fields)}, 200
|
||||
|
||||
|
||||
class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
|
||||
.paginate(page=args['page'], per_page=args['limit'])
|
||||
|
||||
has_more = False
|
||||
if len(tenants.items) == args['limit']:
|
||||
current_page_first_tenant = tenants[-1]
|
||||
rest_count = db.session.query(Tenant).filter(
|
||||
Tenant.created_at < current_page_first_tenant.created_at,
|
||||
Tenant.id != current_page_first_tenant.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
total = db.session.query(Tenant).count()
|
||||
return {
|
||||
'data': marshal(tenants.items, workspace_fields),
|
||||
'has_more': has_more,
|
||||
'limit': args['limit'],
|
||||
'page': args['page'],
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -92,6 +134,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
||||
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
|
||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
||||
|
||||
108
api/core/login/login.py
Normal file
108
api/core/login/login.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
import flask_login
|
||||
from flask import current_app
|
||||
from flask import g
|
||||
from flask import has_request_context
|
||||
from flask import request
|
||||
from flask_login import user_logged_in
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
|
||||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user = LocalProxy(lambda: _get_user())
|
||||
|
||||
|
||||
def login_required(func):
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
|
||||
example::
|
||||
|
||||
@app.route('/post')
|
||||
@login_required
|
||||
def post():
|
||||
pass
|
||||
|
||||
If there are only certain times you need to require that your user is
|
||||
logged in, you can do so with::
|
||||
|
||||
if not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
...which is essentially the code that this function adds to your views.
|
||||
|
||||
It can be convenient to globally turn off authentication when unit testing.
|
||||
To enable this, if the application configuration variable `LOGIN_DISABLED`
|
||||
is set to `True`, this decorator will be ignored.
|
||||
|
||||
.. Note ::
|
||||
|
||||
Per `W3 guidelines for CORS preflight requests
|
||||
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
|
||||
HTTP ``OPTIONS`` requests are exempt from login checks.
|
||||
|
||||
:param func: The view function to decorate.
|
||||
:type func: function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False')
|
||||
if admin_api_key_enable:
|
||||
if auth_header:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
admin_api_key = os.getenv('ADMIN_API_KEY')
|
||||
|
||||
if admin_api_key:
|
||||
if os.getenv('ADMIN_API_KEY') == auth_token:
|
||||
workspace_id = request.headers.get('X-WORKSPACE-ID')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == workspace_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
|
||||
pass
|
||||
elif not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
# flask 1.x compatibility
|
||||
# current_app.ensure_sync is only available in Flask >= 2.0
|
||||
if callable(getattr(current_app, "ensure_sync", None)):
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def _get_user():
|
||||
if has_request_context():
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user()
|
||||
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
@@ -1,4 +1,4 @@
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
@@ -14,7 +14,8 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
)
|
||||
|
||||
client = XinferenceEmbeddings(
|
||||
**credentials,
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid'],
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Xinference
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
|
||||
|
||||
class XinferenceModel(BaseLLM):
|
||||
@@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = Xinference(
|
||||
**self.credentials,
|
||||
client = XinferenceLLM(
|
||||
server_url=self.credentials['server_url'],
|
||||
model_uid=self.credentials['model_uid'],
|
||||
)
|
||||
|
||||
client.callbacks = self.callbacks
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import Xinference
|
||||
import requests
|
||||
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
credentials = self.get_model_credentials(model_name, model_type)
|
||||
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
elif credentials['model_format'] == "ggmlv3":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
llm = Xinference(
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping", generate_config={'max_tokens': 10})
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _get_extra_credentials(self, credentials: dict) -> dict:
|
||||
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get the model description, detail: {response.json()['detail']}"
|
||||
)
|
||||
desc = response.json()
|
||||
|
||||
extra_credentials = {
|
||||
'model_format': desc['model_format'],
|
||||
}
|
||||
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
|
||||
extra_credentials['model_handle_type'] = 'chatglm'
|
||||
elif "generate" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'generate'
|
||||
elif "chat" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'chat'
|
||||
else:
|
||||
raise NotImplementedError(f"Model handle type not supported.")
|
||||
|
||||
return extra_credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
|
||||
class XinferenceEmbedding(XinferenceEmbeddings):
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
vectors = super().embed_documents(texts)
|
||||
|
||||
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
||||
|
||||
return normalized_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
vector = super().embed_query(text)
|
||||
|
||||
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
return normalized_vector
|
||||
@@ -67,9 +67,6 @@ class OpenLLM(LLM):
|
||||
json_response = response.json()
|
||||
completion = json_response["responses"][0]
|
||||
|
||||
if completion:
|
||||
completion = completion[len(prompt):]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
|
||||
132
api/core/third_party/langchain/llms/xinference_llm.py
vendored
Normal file
132
api/core/third_party/langchain/llms/xinference_llm.py
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Optional, List, Any, Union, Generator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Xinference
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
||||
RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
||||
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Returns:
|
||||
The generated string by the model.
|
||||
"""
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
if isinstance(model, RESTfulChatModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["message"]["content"]
|
||||
elif isinstance(model, RESTfulGenerateModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
|
||||
else:
|
||||
completion = model.generate(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["text"]
|
||||
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
|
||||
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
completion = completion["choices"][0]["message"]["content"]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
model: The model used for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
if isinstance(model, RESTfulGenerateModelHandle):
|
||||
streaming_response = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
else:
|
||||
streaming_response = model.chat(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
|
||||
for chunk in streaming_response:
|
||||
if isinstance(chunk, dict):
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
token = choice.get("text", "")
|
||||
log_probs = choice.get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield token
|
||||
@@ -88,6 +88,11 @@ class WebReaderTool(BaseTool):
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
if len(docs) == 0:
|
||||
return "No content found."
|
||||
|
||||
docs = docs[1:]
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
@@ -284,8 +284,9 @@ class DocumentService:
|
||||
"github_link": str,
|
||||
"open_source_license": str,
|
||||
"commit_date": str,
|
||||
"commit_author": str
|
||||
}
|
||||
"commit_author": str,
|
||||
},
|
||||
"others": dict
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -972,7 +973,7 @@ class SegmentService:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment vector index
|
||||
VectorService.create_segment_vector(args['keywords'], segment, dataset)
|
||||
VectorService.update_segment_vector(args['keywords'], segment, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.xinference.Xinference._call',
|
||||
mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
|
||||
return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt):
|
||||
def test_encrypt_model_credentials(mock_encrypt, mocker):
|
||||
api_key = 'http://127.0.0.1:9997/'
|
||||
|
||||
mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
|
||||
return_value={
|
||||
'model_handle_type': 'generate',
|
||||
'model_format': 'ggmlv3'
|
||||
})
|
||||
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
|
||||
@@ -17,6 +17,7 @@ import type { App } from '@/types/app'
|
||||
import type { UpdateAppSiteCodeResponse } from '@/models/app'
|
||||
import { asyncRunSafe } from '@/utils'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import type { IAppCardProps } from '@/app/components/app/overview/appCard'
|
||||
|
||||
export type ICardViewProps = {
|
||||
appId: string
|
||||
@@ -68,7 +69,7 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
||||
handleError(err)
|
||||
}
|
||||
|
||||
const onSaveSiteConfig = async (params: any) => {
|
||||
const onSaveSiteConfig: IAppCardProps['onSaveSiteConfig'] = async (params) => {
|
||||
const [err] = await asyncRunSafe<App>(
|
||||
updateAppSiteConfig({
|
||||
url: `/apps/${appId}/site`,
|
||||
|
||||
@@ -16,7 +16,6 @@ const Overview = async ({
|
||||
const { t } = await useTranslation(locale, 'app-overview')
|
||||
return (
|
||||
<div className="h-full px-16 py-6 overflow-scroll">
|
||||
{/* <WelcomeBanner /> */}
|
||||
<ApikeyInfoPanel />
|
||||
<div className='flex flex-row items-center justify-between mb-4 text-xl text-gray-900'>
|
||||
{t('overview.title')}
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useState } from 'react'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Link from 'next/link'
|
||||
import useSWR, { useSWRConfig } from 'swr'
|
||||
import { ArrowTopRightOnSquareIcon } from '@heroicons/react/24/outline'
|
||||
import { ExclamationCircleIcon } from '@heroicons/react/24/solid'
|
||||
import { debounce } from 'lodash-es'
|
||||
import Popover from '@/app/components/base/popover'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Tag from '@/app/components/base/tag'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { updateOpenAIKey, validateOpenAIKey } from '@/service/apps'
|
||||
import { fetchTenantInfo } from '@/service/common'
|
||||
import I18n from '@/context/i18n'
|
||||
|
||||
type IStatusType = 'normal' | 'verified' | 'error' | 'error-api-key-exceed-bill'
|
||||
|
||||
const STATUS_COLOR_MAP = {
|
||||
'normal': { color: '', bgColor: 'bg-primary-50', borderColor: 'border-primary-100' },
|
||||
'error': { color: 'text-red-600', bgColor: 'bg-red-50', borderColor: 'border-red-100' },
|
||||
'verified': { color: '', bgColor: 'bg-green-50', borderColor: 'border-green-100' },
|
||||
'error-api-key-exceed-bill': { color: 'text-red-600', bgColor: 'bg-red-50', borderColor: 'border-red-100' },
|
||||
}
|
||||
|
||||
const CheckCircleIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
return <svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<rect width="20" height="20" rx="10" fill="#DEF7EC" />
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M14.6947 6.70495C14.8259 6.83622 14.8996 7.01424 14.8996 7.19985C14.8996 7.38547 14.8259 7.56348 14.6947 7.69475L9.0947 13.2948C8.96343 13.426 8.78541 13.4997 8.5998 13.4997C8.41418 13.4997 8.23617 13.426 8.1049 13.2948L5.3049 10.4948C5.17739 10.3627 5.10683 10.1859 5.10842 10.0024C5.11002 9.81883 5.18364 9.64326 5.31342 9.51348C5.44321 9.38369 5.61878 9.31007 5.80232 9.30848C5.98585 9.30688 6.16268 9.37744 6.2947 9.50495L8.5998 11.8101L13.7049 6.70495C13.8362 6.57372 14.0142 6.5 14.1998 6.5C14.3854 6.5 14.5634 6.57372 14.6947 6.70495Z" fill="#046C4E" />
|
||||
</svg>
|
||||
}
|
||||
|
||||
type IEditKeyDiv = {
|
||||
className?: string
|
||||
showInPopover?: boolean
|
||||
onClose?: () => void
|
||||
getTenantInfo?: () => void
|
||||
}
|
||||
|
||||
const EditKeyDiv: FC<IEditKeyDiv> = ({ className = '', showInPopover = false, onClose, getTenantInfo }) => {
|
||||
const [inputValue, setInputValue] = useState<string | undefined>()
|
||||
const [editStatus, setEditStatus] = useState<IStatusType>('normal')
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [validating, setValidating] = useState(false)
|
||||
const { notify } = useContext(ToastContext)
|
||||
const { t } = useTranslation()
|
||||
const { locale } = useContext(I18n)
|
||||
|
||||
// Hide the pop-up window and need to get the latest key again
|
||||
// If the key is valid, the edit button will be hidden later
|
||||
const onClosePanel = () => {
|
||||
getTenantInfo && getTenantInfo()
|
||||
onClose && onClose()
|
||||
}
|
||||
|
||||
const onSaveKey = async () => {
|
||||
if (editStatus === 'verified') {
|
||||
setLoading(true)
|
||||
try {
|
||||
await updateOpenAIKey({ url: '/providers/openai/token', body: { token: inputValue ?? '' } })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
onClosePanel()
|
||||
}
|
||||
catch (err) {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modificationFailed') })
|
||||
}
|
||||
finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validateKey = async (value: string) => {
|
||||
try {
|
||||
setValidating(true)
|
||||
const res = await validateOpenAIKey({ url: '/providers/openai/token-validate', body: { token: value ?? '' } })
|
||||
setEditStatus(res.result === 'success' ? 'verified' : 'error')
|
||||
}
|
||||
catch (err: any) {
|
||||
if (err.status === 400) {
|
||||
err.json().then(({ code }: any) => {
|
||||
if (code === 'provider_request_failed')
|
||||
setEditStatus('error-api-key-exceed-bill')
|
||||
})
|
||||
}
|
||||
else {
|
||||
setEditStatus('error')
|
||||
}
|
||||
}
|
||||
finally {
|
||||
setValidating(false)
|
||||
}
|
||||
}
|
||||
const renderErrorMessage = () => {
|
||||
if (validating) {
|
||||
return (
|
||||
<div className={'text-primary-600 mt-2 text-xs'}>
|
||||
{t('common.provider.validating')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (editStatus === 'error-api-key-exceed-bill') {
|
||||
return (
|
||||
<div className={'text-[#D92D20] mt-2 text-xs'}>
|
||||
{t('common.provider.apiKeyExceedBill')}
|
||||
{locale === 'en' ? ' ' : ''}
|
||||
<Link
|
||||
className='underline'
|
||||
href="https://platform.openai.com/account/api-keys"
|
||||
target={'_blank'}>
|
||||
{locale === 'en' ? 'this link' : '这篇文档'}
|
||||
</Link>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (editStatus === 'error') {
|
||||
return (
|
||||
<div className={'text-[#D92D20] mt-2 text-xs'}>
|
||||
{t('common.provider.invalidKey')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`flex flex-col w-full rounded-lg px-8 py-6 border-solid border-[0.5px] ${className} ${Object.values(STATUS_COLOR_MAP[editStatus]).join(' ')}`}>
|
||||
{!showInPopover && <p className='text-xl font-medium text-gray-800'>{t('appOverview.welcome.firstStepTip')}</p>}
|
||||
<p className={`${showInPopover ? 'text-sm' : 'text-xl'} font-medium text-gray-800`}>{t('appOverview.welcome.enterKeyTip')} {showInPopover ? '' : '👇'}</p>
|
||||
<div className='relative mt-2'>
|
||||
<input type="text"
|
||||
className={`h-9 w-96 max-w-full py-2 pl-2 text-gray-900 rounded-lg bg-white sm:text-xs focus:ring-blue-500 focus:border-blue-500 shadow-sm ${editStatus === 'normal' ? 'pr-2' : 'pr-8'}`}
|
||||
placeholder={t('appOverview.welcome.placeholder') || ''}
|
||||
onChange={debounce((e) => {
|
||||
setInputValue(e.target.value)
|
||||
if (!e.target.value) {
|
||||
setEditStatus('normal')
|
||||
return
|
||||
}
|
||||
validateKey(e.target.value)
|
||||
}, 300)}
|
||||
/>
|
||||
{editStatus === 'verified' && <div className="absolute inset-y-0 right-0 flex flex-row-reverse items-center pr-6 pointer-events-none">
|
||||
<CheckCircleIcon className="rounded-lg" />
|
||||
</div>}
|
||||
{(editStatus === 'error' || editStatus === 'error-api-key-exceed-bill') && <div className="absolute inset-y-0 right-0 flex flex-row-reverse items-center pr-6 pointer-events-none">
|
||||
<ExclamationCircleIcon className="w-5 h-5 text-red-800" />
|
||||
</div>}
|
||||
{showInPopover ? null : <Button type='primary' onClick={onSaveKey} className='!h-9 !inline-block ml-2' loading={loading} disabled={editStatus !== 'verified'}>{t('common.operation.save')}</Button>}
|
||||
</div>
|
||||
{renderErrorMessage()}
|
||||
<Link className="inline-flex items-center mt-2 text-xs font-normal cursor-pointer text-primary-600 w-fit" href="https://platform.openai.com/account/api-keys" target={'_blank'}>
|
||||
{t('appOverview.welcome.getKeyTip')}
|
||||
<ArrowTopRightOnSquareIcon className='w-3 h-3 ml-1 text-primary-600' aria-hidden="true" />
|
||||
</Link>
|
||||
{showInPopover && <div className='flex justify-end mt-6'>
|
||||
<Button className='flex-shrink-0 mr-2' onClick={onClosePanel}>{t('common.operation.cancel')}</Button>
|
||||
<Button type='primary' className='flex-shrink-0' onClick={onSaveKey} loading={loading} disabled={editStatus !== 'verified'}>{t('common.operation.save')}</Button>
|
||||
</div>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const WelcomeBanner: FC = () => {
|
||||
const { data: userInfo } = useSWR({ url: '/info' }, fetchTenantInfo)
|
||||
if (!userInfo)
|
||||
return null
|
||||
return userInfo?.providers?.find(({ token_is_set }) => token_is_set) ? null : <EditKeyDiv className='mb-8' />
|
||||
}
|
||||
|
||||
export const EditKeyPopover: FC = () => {
|
||||
const { data: userInfo } = useSWR({ url: '/info' }, fetchTenantInfo)
|
||||
const { mutate } = useSWRConfig()
|
||||
if (!userInfo)
|
||||
return null
|
||||
|
||||
const getTenantInfo = () => {
|
||||
mutate({ url: '/info' })
|
||||
}
|
||||
// In this case, the edit button is displayed
|
||||
const targetProvider = userInfo?.providers?.some(({ token_is_set, is_valid }) => token_is_set && is_valid)
|
||||
return (
|
||||
!targetProvider
|
||||
? <div className='flex items-center'>
|
||||
<Tag className='mr-2 h-fit' color='red'><ExclamationCircleIcon className='h-3.5 w-3.5 mr-2' />OpenAI API key invalid</Tag>
|
||||
<Popover
|
||||
htmlContent={<EditKeyDiv className='!border-0' showInPopover={true} getTenantInfo={getTenantInfo} />}
|
||||
trigger='click'
|
||||
position='br'
|
||||
btnElement='Edit'
|
||||
btnClassName='text-primary-600 !text-xs px-3 py-1.5'
|
||||
className='!p-0 !w-[464px] h-[200px]'
|
||||
/>
|
||||
</div>
|
||||
: null)
|
||||
}
|
||||
|
||||
export default WelcomeBanner
|
||||
@@ -68,11 +68,11 @@ const Apps = () => {
|
||||
|
||||
return (
|
||||
<nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
|
||||
{ isCurrentWorkspaceManager
|
||||
&& <NewAppCard ref={anchorRef} onSuccess={mutate} />}
|
||||
{data?.map(({ data: apps }) => apps.map(app => (
|
||||
<AppCard key={app.id} app={app} onRefresh={mutate} />
|
||||
)))}
|
||||
{ isCurrentWorkspaceManager
|
||||
&& <NewAppCard ref={anchorRef} onSuccess={mutate} />}
|
||||
{
|
||||
showPayStatusModal && (
|
||||
<Confirm
|
||||
|
||||
@@ -42,10 +42,10 @@ const Datasets = () => {
|
||||
|
||||
return (
|
||||
<nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
|
||||
{ isCurrentWorkspaceManager && <NewDatasetCard ref={anchorRef} /> }
|
||||
{data?.map(({ data: datasets }) => datasets.map(dataset => (
|
||||
<DatasetCard key={dataset.id} dataset={dataset} onDelete={mutate} />),
|
||||
))}
|
||||
{ isCurrentWorkspaceManager && <NewDatasetCard ref={anchorRef} /> }
|
||||
</nav>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
|
||||
|
||||
const ActivateForm = () => {
|
||||
const { t } = useTranslation()
|
||||
const { locale } = useContext(I18n)
|
||||
const { locale, setLocaleOnClient } = useContext(I18n)
|
||||
const searchParams = useSearchParams()
|
||||
const workspaceID = searchParams.get('workspace_id')
|
||||
const email = searchParams.get('email')
|
||||
@@ -45,6 +45,7 @@ const ActivateForm = () => {
|
||||
const [timezone, setTimezone] = useState('Asia/Shanghai')
|
||||
const [language, setLanguage] = useState('en-US')
|
||||
const [showSuccess, setShowSuccess] = useState(false)
|
||||
const defaultLanguage = (navigator.language?.startsWith('zh') ? languageMaps['zh-Hans'] : languageMaps.en) || languageMaps.en
|
||||
|
||||
const showErrorMessage = (message: string) => {
|
||||
Toast.notify({
|
||||
@@ -83,6 +84,7 @@ const ActivateForm = () => {
|
||||
timezone,
|
||||
},
|
||||
})
|
||||
setLocaleOnClient(language.startsWith('en') ? 'en' : 'zh-Hans')
|
||||
setShowSuccess(true)
|
||||
}
|
||||
catch {
|
||||
@@ -93,7 +95,7 @@ const ActivateForm = () => {
|
||||
return (
|
||||
<div className={
|
||||
cn(
|
||||
'flex flex-col items-center w-full grow items-center justify-center',
|
||||
'flex flex-col items-center w-full grow justify-center',
|
||||
'px-6',
|
||||
'md:px-[108px]',
|
||||
)
|
||||
@@ -167,7 +169,7 @@ const ActivateForm = () => {
|
||||
</label>
|
||||
<div className="relative mt-1 rounded-md shadow-sm">
|
||||
<SimpleSelect
|
||||
defaultValue={languageMaps.en}
|
||||
defaultValue={defaultLanguage}
|
||||
items={languages}
|
||||
onSelect={(item) => {
|
||||
setLanguage(item.value as string)
|
||||
|
||||
@@ -225,7 +225,7 @@ const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedba
|
||||
setLoading(true)
|
||||
const res = await onSubmitAnnotation?.(id, inputValue)
|
||||
if (res)
|
||||
setAnnotation({ ...annotation, content: inputValue } as any)
|
||||
setAnnotation({ ...annotation, content: inputValue } as Annotation)
|
||||
setLoading(false)
|
||||
setShowEdit(false)
|
||||
}}>{t('common.operation.confirm')}</Button>
|
||||
|
||||
@@ -81,7 +81,7 @@ const Chat: FC<IChatProps> = ({
|
||||
const isUseInputMethod = useRef(false)
|
||||
|
||||
const [query, setQuery] = React.useState('')
|
||||
const handleContentChange = (e: any) => {
|
||||
const handleContentChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const value = e.target.value
|
||||
setQuery(value)
|
||||
}
|
||||
@@ -111,7 +111,7 @@ const Chat: FC<IChatProps> = ({
|
||||
setQuery('')
|
||||
}
|
||||
|
||||
const handleKeyUp = (e: any) => {
|
||||
const handleKeyUp = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.code === 'Enter') {
|
||||
e.preventDefault()
|
||||
// prevent send message when using input method enter
|
||||
@@ -120,7 +120,7 @@ const Chat: FC<IChatProps> = ({
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: any) => {
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
isUseInputMethod.current = e.nativeEvent.isComposing
|
||||
if (e.code === 'Enter' && !e.shiftKey) {
|
||||
setQuery(query.replace(/\n$/, ''))
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
'use client'
|
||||
import React, { FC, ReactNode } from 'react'
|
||||
import { ReactElement } from 'react-markdown/lib/react-markdown'
|
||||
import React from 'react'
|
||||
import type { FC } from 'react'
|
||||
|
||||
export interface IInputTypeIconProps {
|
||||
type: string
|
||||
type IInputTypeIconProps = {
|
||||
type: 'string' | 'select'
|
||||
}
|
||||
|
||||
const IconMap = (type: string) => {
|
||||
const icons: Record<string, ReactNode> = {
|
||||
'string': (
|
||||
const IconMap = (type: IInputTypeIconProps['type']) => {
|
||||
const icons = {
|
||||
string: (
|
||||
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M3.52593 0.166672H8.47411C8.94367 0.166665 9.33123 0.166659 9.64692 0.192452C9.97481 0.219242 10.2762 0.276738 10.5593 0.420991C10.9984 0.644695 11.3553 1.00165 11.579 1.44069C11.7233 1.72381 11.7808 2.02522 11.8076 2.35311C11.8334 2.6688 11.8334 3.05634 11.8334 3.5259V8.47411C11.8334 8.94367 11.8334 9.33121 11.8076 9.6469C11.7808 9.97479 11.7233 10.2762 11.579 10.5593C11.3553 10.9984 10.9984 11.3553 10.5593 11.579C10.2762 11.7233 9.97481 11.7808 9.64692 11.8076C9.33123 11.8334 8.94369 11.8333 8.47413 11.8333H3.52592C3.05636 11.8333 2.66882 11.8334 2.35312 11.8076C2.02523 11.7808 1.72382 11.7233 1.44071 11.579C1.00167 11.3553 0.644711 10.9984 0.421006 10.5593C0.276753 10.2762 0.219257 9.97479 0.192468 9.6469C0.166674 9.33121 0.16668 8.94366 0.166687 8.4741V3.52591C0.16668 3.05635 0.166674 2.6688 0.192468 2.35311C0.219257 2.02522 0.276753 1.72381 0.421006 1.44069C0.644711 1.00165 1.00167 0.644695 1.44071 0.420991C1.72382 0.276738 2.02523 0.219242 2.35312 0.192452C2.66882 0.166659 3.05637 0.166665 3.52593 0.166672ZM3.08335 3.08334C3.08335 2.76117 3.34452 2.50001 3.66669 2.50001H8.33335C8.65552 2.50001 8.91669 2.76117 8.91669 3.08334C8.91669 3.4055 8.65552 3.66667 8.33335 3.66667H6.58335V8.91667C6.58335 9.23884 6.32219 9.5 6.00002 9.5C5.67785 9.5 5.41669 9.23884 5.41669 8.91667V3.66667H3.66669C3.34452 3.66667 3.08335 3.4055 3.08335 3.08334Z" fill="#98A2B3" />
|
||||
</svg>
|
||||
),
|
||||
'select': (
|
||||
select: (
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M7.48913 4.08334H3.01083C2.70334 4.08333 2.43804 4.08333 2.21955 4.10118C1.98893 4.12002 1.75955 4.16162 1.53883 4.27408C1.20955 4.44186 0.941831 4.70958 0.774053 5.03886C0.66159 5.25958 0.619989 5.48896 0.601147 5.71958C0.583295 5.93807 0.583304 6.20334 0.583313 6.51084V10.9892C0.583304 11.2967 0.583295 11.5619 0.601147 11.7804C0.619989 12.0111 0.66159 12.2404 0.774053 12.4612C0.941831 12.7904 1.20955 13.0582 1.53883 13.2259C1.75955 13.3384 1.98893 13.38 2.21955 13.3988C2.43803 13.4167 2.70329 13.4167 3.01077 13.4167H7.48912C7.7966 13.4167 8.06193 13.4167 8.28041 13.3988C8.51103 13.38 8.74041 13.3384 8.96113 13.2259C9.29041 13.0582 9.55813 12.7904 9.72591 12.4612C9.83837 12.2404 9.87997 12.0111 9.89882 11.7804C9.91667 11.5619 9.91666 11.2967 9.91665 10.9892V6.51087C9.91666 6.20336 9.91667 5.93808 9.89882 5.71958C9.87997 5.48896 9.83837 5.25958 9.72591 5.03886C9.55813 4.70958 9.29041 4.44186 8.96113 4.27408C8.74041 4.16162 8.51103 4.12002 8.28041 4.10118C8.06192 4.08333 7.79663 4.08333 7.48913 4.08334ZM7.70413 7.70416C7.93193 7.47635 7.93193 7.107 7.70413 6.8792C7.47632 6.65139 7.10697 6.65139 6.87917 6.8792L4.66665 9.09172L3.91246 8.33753C3.68465 8.10973 3.31531 8.10973 3.0875 8.33753C2.8597 8.56534 2.8597 8.93468 3.0875 9.16249L4.25417 10.3292C4.48197 10.557 4.85132 10.557 5.07913 10.3292L7.70413 7.70416Z" fill="#98A2B3" />
|
||||
<path d="M10.9891 0.583344H6.51083C6.20334 0.583334 5.93804 0.583326 5.71955 0.601177C5.48893 0.620019 5.25955 0.66162 5.03883 0.774084C4.70955 0.941862 4.44183 1.20958 4.27405 1.53886C4.16159 1.75958 4.11999 1.98896 4.10115 2.21958C4.08514 2.41545 4.08349 2.64892 4.08333 2.91669L7.51382 2.91668C7.79886 2.91662 8.10791 2.91654 8.37541 2.9384C8.67818 2.96314 9.07818 3.02436 9.49078 3.23459C10.0396 3.51422 10.4858 3.96041 10.7654 4.50922C10.9756 4.92182 11.0369 5.32182 11.0616 5.62459C11.0835 5.8921 11.0834 6.20115 11.0833 6.48619L11.0833 9.91666C11.3511 9.9165 11.5845 9.91485 11.7804 9.89885C12.011 9.88 12.2404 9.8384 12.4611 9.72594C12.7904 9.55816 13.0581 9.29045 13.2259 8.96116C13.3384 8.74044 13.38 8.51106 13.3988 8.28044C13.4167 8.06196 13.4167 7.7967 13.4166 7.48922V3.01087C13.4167 2.70339 13.4167 2.43807 13.3988 2.21958C13.38 1.98896 13.3384 1.75958 13.2259 1.53886C13.0581 1.20958 12.7904 0.941862 12.4611 0.774084C12.2404 0.66162 12.011 0.620019 11.7804 0.601177C11.5619 0.583326 11.2966 0.583334 10.9891 0.583344Z" fill="#98A2B3" />
|
||||
@@ -21,11 +21,11 @@ const IconMap = (type: string) => {
|
||||
),
|
||||
}
|
||||
|
||||
return icons[type] as any
|
||||
return icons[type]
|
||||
}
|
||||
|
||||
const InputTypeIcon: FC<IInputTypeIconProps> = ({
|
||||
type
|
||||
type,
|
||||
}) => {
|
||||
const Icon = IconMap(type)
|
||||
return Icon
|
||||
|
||||
@@ -38,7 +38,7 @@ const Config: FC = () => {
|
||||
setSpeechToTextConfig,
|
||||
} = useContext(ConfigContext)
|
||||
const isChatApp = mode === AppType.chat
|
||||
const { currentProvider } = useProviderContext()
|
||||
const { speech2textDefaultModel } = useProviderContext()
|
||||
|
||||
const promptTemplate = modelConfig.configs.prompt_template
|
||||
const promptVariables = modelConfig.configs.prompt_variables
|
||||
@@ -90,7 +90,7 @@ const Config: FC = () => {
|
||||
},
|
||||
})
|
||||
|
||||
const hasChatConfig = isChatApp && (featureConfig.openingStatement || featureConfig.suggestedQuestionsAfterAnswer || (featureConfig.speechToText && currentProvider?.provider_name === 'openai'))
|
||||
const hasChatConfig = isChatApp && (featureConfig.openingStatement || featureConfig.suggestedQuestionsAfterAnswer || (featureConfig.speechToText && !!speech2textDefaultModel))
|
||||
const hasToolbox = false
|
||||
|
||||
const [showAutomatic, { setTrue: showAutomaticTrue, setFalse: showAutomaticFalse }] = useBoolean(false)
|
||||
@@ -120,7 +120,7 @@ const Config: FC = () => {
|
||||
isChatApp={isChatApp}
|
||||
config={featureConfig}
|
||||
onChange={handleFeatureChange}
|
||||
showSpeechToTextItem={currentProvider?.provider_name === 'openai'}
|
||||
showSpeechToTextItem={!!speech2textDefaultModel}
|
||||
/>
|
||||
)}
|
||||
{showAutomatic && (
|
||||
@@ -160,7 +160,7 @@ const Config: FC = () => {
|
||||
}
|
||||
}
|
||||
isShowSuggestedQuestionsAfterAnswer={featureConfig.suggestedQuestionsAfterAnswer}
|
||||
isShowSpeechText={featureConfig.speechToText && currentProvider?.provider_name === 'openai'}
|
||||
isShowSpeechText={featureConfig.speechToText && !!speech2textDefaultModel}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -6,22 +6,17 @@ import { useTranslation } from 'react-i18next'
|
||||
import TypeIcon from '../type-icon'
|
||||
import RemoveIcon from '../../base/icons/remove-icon'
|
||||
import s from './style.module.css'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
export type ICardItemProps = {
|
||||
className?: string
|
||||
config: any
|
||||
config: DataSet
|
||||
onRemove: (id: string) => void
|
||||
readonly?: boolean
|
||||
}
|
||||
|
||||
// const RemoveIcon = ({ className, onClick }: { className: string, onClick: () => void }) => (
|
||||
// <svg className={className} onClick={onClick} width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
// <path d="M10 6H14M6 8H18M16.6667 8L16.1991 15.0129C16.129 16.065 16.0939 16.5911 15.8667 16.99C15.6666 17.3412 15.3648 17.6235 15.0011 17.7998C14.588 18 14.0607 18 13.0062 18H10.9938C9.93927 18 9.41202 18 8.99889 17.7998C8.63517 17.6235 8.33339 17.3412 8.13332 16.99C7.90607 16.5911 7.871 16.065 7.80086 15.0129L7.33333 8M10.6667 11V14.3333M13.3333 11V14.3333" stroke="#667085" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
|
||||
// </svg>
|
||||
// )
|
||||
|
||||
const CardItem: FC<ICardItemProps> = ({
|
||||
className,
|
||||
config,
|
||||
|
||||
@@ -52,7 +52,7 @@ const Debug: FC<IDebug> = ({
|
||||
modelConfig,
|
||||
completionParams,
|
||||
} = useContext(ConfigContext)
|
||||
const { currentProvider } = useProviderContext()
|
||||
const { speech2textDefaultModel } = useProviderContext()
|
||||
const [chatList, setChatList, getChatList] = useGetState<IChatItem[]>([])
|
||||
const chatListDomRef = useRef<HTMLDivElement>(null)
|
||||
useEffect(() => {
|
||||
@@ -390,7 +390,7 @@ const Debug: FC<IDebug> = ({
|
||||
}}
|
||||
isShowSuggestion={doShowSuggestion}
|
||||
suggestionList={suggestQuestions}
|
||||
isShowSpeechToText={speechToTextConfig.enabled && currentProvider?.provider_name === 'openai'}
|
||||
isShowSpeechToText={speechToTextConfig.enabled && !!speech2textDefaultModel}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -16,22 +16,21 @@ import ConfigModel from '@/app/components/app/configuration/config-model'
|
||||
import Config from '@/app/components/app/configuration/config'
|
||||
import Debug from '@/app/components/app/configuration/debug'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { ProviderType } from '@/types/app'
|
||||
import { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import type { AppDetailResponse } from '@/models/app'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchTenantInfo } from '@/service/common'
|
||||
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
|
||||
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import AccountSetting from '@/app/components/header/account-setting'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
|
||||
const Configuration: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
|
||||
const [hasFetchedDetail, setHasFetchedDetail] = useState(false)
|
||||
const [hasFetchedKey, setHasFetchedKey] = useState(false)
|
||||
const isLoading = !hasFetchedDetail || !hasFetchedKey
|
||||
const isLoading = !hasFetchedDetail
|
||||
const pathname = usePathname()
|
||||
const matched = pathname.match(/\/app\/([^/]+)/)
|
||||
const appId = (matched?.length && matched[1]) ? matched[1] : ''
|
||||
@@ -68,7 +67,7 @@ const Configuration: FC = () => {
|
||||
frequency_penalty: 1, // -2-2
|
||||
})
|
||||
const [modelConfig, doSetModelConfig] = useState<ModelConfig>({
|
||||
provider: ProviderType.openai,
|
||||
provider: ProviderEnum.openai,
|
||||
model_id: 'gpt-3.5-turbo',
|
||||
configs: {
|
||||
prompt_template: '',
|
||||
@@ -85,7 +84,7 @@ const Configuration: FC = () => {
|
||||
doSetModelConfig(newModelConfig)
|
||||
}
|
||||
|
||||
const setModelId = (modelId: string, provider: ProviderType) => {
|
||||
const setModelId = (modelId: string, provider: ProviderEnum) => {
|
||||
const newModelConfig = produce(modelConfig, (draft: any) => {
|
||||
draft.provider = provider
|
||||
draft.model_id = modelId
|
||||
@@ -113,25 +112,27 @@ const Configuration: FC = () => {
|
||||
})
|
||||
}
|
||||
|
||||
const [hasSetCustomAPIKEY, setHasSetCustomerAPIKEY] = useState(true)
|
||||
const [isTrailFinished, setIsTrailFinished] = useState(false)
|
||||
const { textGenerationModelList } = useProviderContext()
|
||||
const hasSetCustomAPIKEY = !!textGenerationModelList?.find(({ model_provider: provider }) => {
|
||||
if (provider.provider_type === 'system' && provider.quota_type === 'paid')
|
||||
return true
|
||||
|
||||
if (provider.provider_type === 'custom')
|
||||
return true
|
||||
|
||||
return false
|
||||
})
|
||||
const isTrailFinished = !hasSetCustomAPIKEY && textGenerationModelList
|
||||
.filter(({ model_provider: provider }) => provider.quota_type === 'trial')
|
||||
.every(({ model_provider: provider }) => {
|
||||
const { quota_used, quota_limit } = provider
|
||||
return quota_used === quota_limit
|
||||
})
|
||||
|
||||
const hasSetAPIKEY = hasSetCustomAPIKEY || !isTrailFinished
|
||||
|
||||
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
|
||||
|
||||
const checkAPIKey = async () => {
|
||||
const { in_trail, trial_end_reason } = await fetchTenantInfo({ url: '/info' })
|
||||
const isTrailFinished = in_trail && trial_end_reason === 'trial_exceeded'
|
||||
const hasSetCustomAPIKEY = trial_end_reason === 'using_custom'
|
||||
setHasSetCustomerAPIKEY(hasSetCustomAPIKEY)
|
||||
setIsTrailFinished(isTrailFinished)
|
||||
setHasFetchedKey(true)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
checkAPIKey()
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
(fetchAppDetail({ url: '/apps', id: appId }) as any).then(async (res: AppDetailResponse) => {
|
||||
setMode(res.mode)
|
||||
@@ -284,7 +285,7 @@ const Configuration: FC = () => {
|
||||
{/* Model and Parameters */}
|
||||
<ConfigModel
|
||||
mode={mode}
|
||||
provider={modelConfig.provider as ProviderType}
|
||||
provider={modelConfig.provider as ProviderEnum}
|
||||
completionParams={completionParams}
|
||||
modelId={modelConfig.model_id}
|
||||
setModelId={setModelId}
|
||||
@@ -338,7 +339,6 @@ const Configuration: FC = () => {
|
||||
)
|
||||
}
|
||||
{isShowSetAPIKey && <AccountSetting activeTab="provider" onCancel={async () => {
|
||||
await checkAPIKey()
|
||||
hideSetAPIkey()
|
||||
}} />}
|
||||
</>
|
||||
|
||||
@@ -13,6 +13,7 @@ import SettingsModal from './settings'
|
||||
import EmbeddedModal from './embedded'
|
||||
import CustomizeModal from './customize'
|
||||
import style from './style.module.css'
|
||||
import type { ConfigParams } from './settings'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import AppBasic from '@/app/components/app-sidebar/basic'
|
||||
import { asyncRunSafe, randomString } from '@/utils'
|
||||
@@ -31,9 +32,9 @@ export type IAppCardProps = {
|
||||
appInfo: AppDetailResponse
|
||||
cardType?: 'api' | 'webapp'
|
||||
customBgColor?: string
|
||||
onChangeStatus: (val: boolean) => Promise<any>
|
||||
onSaveSiteConfig?: (params: any) => Promise<any>
|
||||
onGenerateCode?: () => Promise<any>
|
||||
onChangeStatus: (val: boolean) => Promise<void>
|
||||
onSaveSiteConfig?: (params: ConfigParams) => Promise<void>
|
||||
onGenerateCode?: () => Promise<void>
|
||||
}
|
||||
|
||||
const EmbedIcon: FC<{ className?: string }> = ({ className = '' }) => {
|
||||
@@ -193,7 +194,7 @@ function AppCard({
|
||||
</div>
|
||||
<div className={'pt-2 flex flex-row items-center'}>
|
||||
{!isApp && <SecretKeyButton className='flex-shrink-0 !h-8 bg-white mr-2' textCls='!text-gray-700 font-medium' iconCls='stroke-[1.2px]' appId={appInfo.id} />}
|
||||
{OPERATIONS_MAP[cardType].map((op: any) => {
|
||||
{OPERATIONS_MAP[cardType].map((op) => {
|
||||
const disabled
|
||||
= op.opName === t('appOverview.overview.appInfo.settings.entry')
|
||||
? false
|
||||
|
||||
@@ -18,7 +18,7 @@ export type ISettingsModalProps = {
|
||||
isShow: boolean
|
||||
defaultValue?: string
|
||||
onClose: () => void
|
||||
onSave?: (params: ConfigParams) => Promise<any>
|
||||
onSave?: (params: ConfigParams) => Promise<void>
|
||||
}
|
||||
|
||||
export type ConfigParams = {
|
||||
@@ -26,6 +26,10 @@ export type ConfigParams = {
|
||||
description: string
|
||||
default_language: string
|
||||
prompt_public: boolean
|
||||
copyright: string
|
||||
privacy_policy: string
|
||||
icon: string
|
||||
icon_background: string
|
||||
}
|
||||
|
||||
const LANGUAGE_MAP: Record<Language, string> = {
|
||||
@@ -82,7 +86,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
}
|
||||
|
||||
const onChange = (field: string) => {
|
||||
return (e: any) => {
|
||||
return (e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
setInputInfo(item => ({ ...item, [field]: e.target.value }))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import { HashtagIcon } from '@heroicons/react/24/solid'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import type { Feedbacktype } from '@/app/components/app/chat'
|
||||
import type { Feedbacktype } from '@/app/components/app/chat/type'
|
||||
import { fetchMoreLikeThis, updateFeedback } from '@/service/share'
|
||||
|
||||
const MAX_DEPTH = 3
|
||||
@@ -136,7 +136,7 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
}
|
||||
|
||||
const mainStyle = (() => {
|
||||
const res: any = !isTop
|
||||
const res: React.CSSProperties = !isTop
|
||||
? {
|
||||
background: depth % 2 === 0 ? 'linear-gradient(90.07deg, #F9FAFB 0.05%, rgba(249, 250, 251, 0) 99.93%)' : '#fff',
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ export type IButtonProps = {
|
||||
className?: string
|
||||
disabled?: boolean
|
||||
loading?: boolean
|
||||
tabIndex?: number
|
||||
children: React.ReactNode
|
||||
onClick?: MouseEventHandler<HTMLDivElement>
|
||||
}
|
||||
@@ -18,6 +19,7 @@ const Button: FC<IButtonProps> = ({
|
||||
className,
|
||||
onClick,
|
||||
loading = false,
|
||||
tabIndex,
|
||||
}) => {
|
||||
let style = 'cursor-pointer'
|
||||
switch (type) {
|
||||
@@ -35,6 +37,7 @@ const Button: FC<IButtonProps> = ({
|
||||
return (
|
||||
<div
|
||||
className={`inline-flex justify-center items-center content-center h-9 leading-5 rounded-lg px-4 py-2 text-base ${style} ${className && className}`}
|
||||
tabIndex={tabIndex}
|
||||
onClick={disabled ? undefined : onClick}
|
||||
>
|
||||
{children}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
'use client'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import React, { FC, useEffect, useState, useRef } from 'react'
|
||||
import React, { useEffect, useRef, useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
|
||||
export interface IPortalToFollowElementProps {
|
||||
type IPortalToFollowElementProps = {
|
||||
portalElem: React.ReactNode
|
||||
children: React.ReactNode
|
||||
controlShow?: number
|
||||
@@ -14,44 +15,42 @@ const PortalToFollowElement: FC<IPortalToFollowElementProps> = ({
|
||||
portalElem,
|
||||
children,
|
||||
controlShow,
|
||||
controlHide
|
||||
controlHide,
|
||||
}) => {
|
||||
const [isShowContent, { setTrue: showContent, setFalse: hideContent, toggle: toggleContent }] = useBoolean(false)
|
||||
const [wrapElem, setWrapElem] = useState<HTMLDivElement | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (controlShow) {
|
||||
if (controlShow)
|
||||
showContent()
|
||||
}
|
||||
}, [controlShow])
|
||||
|
||||
useEffect(() => {
|
||||
if (controlHide) {
|
||||
if (controlHide)
|
||||
hideContent()
|
||||
}
|
||||
}, [controlHide])
|
||||
|
||||
// todo use click outside hidden
|
||||
const triggerElemRef = useRef<HTMLElement>(null)
|
||||
const triggerElemRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const calLoc = () => {
|
||||
const triggerElem = triggerElemRef.current
|
||||
if (!triggerElem) {
|
||||
return {
|
||||
display: 'none'
|
||||
display: 'none',
|
||||
}
|
||||
}
|
||||
const {
|
||||
left: triggerLeft,
|
||||
top: triggerTop,
|
||||
height
|
||||
} = triggerElem.getBoundingClientRect();
|
||||
height,
|
||||
} = triggerElem.getBoundingClientRect()
|
||||
|
||||
return {
|
||||
position: 'fixed',
|
||||
left: triggerLeft,
|
||||
top: triggerTop + height,
|
||||
zIndex: 999
|
||||
zIndex: 999,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,19 +62,20 @@ const PortalToFollowElement: FC<IPortalToFollowElementProps> = ({
|
||||
root.render(
|
||||
<div style={style as React.CSSProperties}>
|
||||
{portalElem}
|
||||
</div>
|
||||
</div>,
|
||||
)
|
||||
document.body.appendChild(holder)
|
||||
setWrapElem(holder)
|
||||
console.log(holder)
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
wrapElem?.remove?.()
|
||||
setWrapElem(null)
|
||||
}
|
||||
}, [isShowContent])
|
||||
|
||||
return (
|
||||
<div ref={triggerElemRef as any} onClick={toggleContent}>
|
||||
<div ref={triggerElemRef as React.RefObject<HTMLDivElement>} onClick={toggleContent}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ export default function Radio({
|
||||
}: IRadioProps): JSX.Element {
|
||||
const groupContext = useContext(RadioGroupContext)
|
||||
const labelId = useId()
|
||||
const handleChange = (e: any) => {
|
||||
const handleChange = (e: IRadioProps['value']) => {
|
||||
if (disabled)
|
||||
return
|
||||
|
||||
|
||||
@@ -9,9 +9,10 @@ import StepTwo from './step-two'
|
||||
import StepThree from './step-three'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import type { DataSet, FileItem, createDocumentResponse } from '@/models/datasets'
|
||||
import { fetchDataSource, fetchTenantInfo } from '@/service/common'
|
||||
import { fetchDataSource } from '@/service/common'
|
||||
import { fetchDataDetail } from '@/service/datasets'
|
||||
import type { DataSourceNotionPage } from '@/models/common'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
|
||||
import AccountSetting from '@/app/components/header/account-setting'
|
||||
|
||||
@@ -23,7 +24,6 @@ type DatasetUpdateFormProps = {
|
||||
|
||||
const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [hasSetAPIKEY, setHasSetAPIKEY] = useState(true)
|
||||
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
|
||||
const [hasConnection, setHasConnection] = useState(true)
|
||||
const [isShowDataSourceSetting, { setTrue: showDataSourceSetting, setFalse: hideDataSourceSetting }] = useBoolean()
|
||||
@@ -33,6 +33,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
const [fileList, setFiles] = useState<FileItem[]>([])
|
||||
const [result, setResult] = useState<createDocumentResponse | undefined>()
|
||||
const [hasError, setHasError] = useState(false)
|
||||
const { embeddingsDefaultModel } = useProviderContext()
|
||||
|
||||
const [notionPages, setNotionPages] = useState<Page[]>([])
|
||||
const updateNotionPages = (value: Page[]) => {
|
||||
@@ -77,11 +78,6 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
setStep(step + delta)
|
||||
}, [step, setStep])
|
||||
|
||||
const checkAPIKey = async () => {
|
||||
const data = await fetchTenantInfo({ url: '/info' })
|
||||
const hasSetKey = data.providers.some(({ is_valid }) => is_valid)
|
||||
setHasSetAPIKEY(hasSetKey)
|
||||
}
|
||||
const checkNotionConnection = async () => {
|
||||
const { data } = await fetchDataSource({ url: '/data-source/integrates' })
|
||||
const hasConnection = data.filter(item => item.provider === 'notion') || []
|
||||
@@ -89,7 +85,6 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
checkAPIKey()
|
||||
checkNotionConnection()
|
||||
}, [])
|
||||
|
||||
@@ -132,7 +127,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
onStepChange={nextStep}
|
||||
/>}
|
||||
{(step === 2 && (!datasetId || (datasetId && !!detail))) && <StepTwo
|
||||
hasSetAPIKEY={hasSetAPIKEY}
|
||||
hasSetAPIKEY={!!embeddingsDefaultModel}
|
||||
onSetting={showSetAPIKey}
|
||||
indexingType={detail?.indexing_technique || ''}
|
||||
datasetId={datasetId}
|
||||
@@ -151,7 +146,6 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
/>}
|
||||
</div>
|
||||
{isShowSetAPIKey && <AccountSetting activeTab="provider" onCancel={async () => {
|
||||
await checkAPIKey()
|
||||
hideSetAPIkey()
|
||||
}} />}
|
||||
{isShowDataSourceSetting && <AccountSetting activeTab="data-source" onCancel={hideDataSourceSetting}/>}
|
||||
|
||||
@@ -28,6 +28,7 @@ import AutoHeightTextarea from '@/app/components/base/auto-height-textarea/commo
|
||||
import Button from '@/app/components/base/button'
|
||||
import NewSegmentModal from '@/app/components/datasets/documents/detail/new-segment-modal'
|
||||
import TagInput from '@/app/components/base/tag-input'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
|
||||
export const SegmentIndexTag: FC<{ positionId: string | number; className?: string }> = ({ positionId, className }) => {
|
||||
const localPositionId = useMemo(() => {
|
||||
@@ -66,6 +67,15 @@ export const SegmentDetail: FC<ISegmentDetailProps> = memo(({
|
||||
const [question, setQuestion] = useState(segInfo?.content || '')
|
||||
const [answer, setAnswer] = useState(segInfo?.answer || '')
|
||||
const [keywords, setKeywords] = useState<string[]>(segInfo?.keywords || [])
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
eventEmitter?.useSubscription((v) => {
|
||||
if (v === 'update-segment')
|
||||
setLoading(true)
|
||||
else
|
||||
setLoading(false)
|
||||
})
|
||||
|
||||
const handleCancel = () => {
|
||||
setIsEditing(false)
|
||||
@@ -129,7 +139,9 @@ export const SegmentDetail: FC<ISegmentDetailProps> = memo(({
|
||||
<Button
|
||||
type='primary'
|
||||
className='!h-7 !px-3 !py-[5px] text-xs font-medium !rounded-md'
|
||||
onClick={handleSave}>
|
||||
onClick={handleSave}
|
||||
disabled={loading}
|
||||
>
|
||||
{t('common.operation.save')}
|
||||
</Button>
|
||||
</>
|
||||
@@ -225,6 +237,7 @@ const Completed: FC<ICompletedProps> = ({
|
||||
const [allSegments, setAllSegments] = useState<Array<SegmentDetailModel[]>>([]) // all segments data
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [total, setTotal] = useState<number | undefined>()
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
|
||||
const onChangeStatus = ({ value }: Item) => {
|
||||
setSelectedStatus(value === 'all' ? 'all' : !!value)
|
||||
@@ -318,23 +331,29 @@ const Completed: FC<ICompletedProps> = ({
|
||||
if (keywords.length)
|
||||
params.keywords = keywords
|
||||
|
||||
const res = await updateSegment({ datasetId, documentId, segmentId, body: params })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
onCloseModal()
|
||||
for (const item of allSegments) {
|
||||
for (const seg of item) {
|
||||
if (seg.id === segmentId) {
|
||||
seg.answer = res.data.answer
|
||||
seg.content = res.data.content
|
||||
seg.keywords = res.data.keywords
|
||||
seg.word_count = res.data.word_count
|
||||
seg.hit_count = res.data.hit_count
|
||||
seg.index_node_hash = res.data.index_node_hash
|
||||
seg.enabled = res.data.enabled
|
||||
try {
|
||||
eventEmitter?.emit('update-segment')
|
||||
const res = await updateSegment({ datasetId, documentId, segmentId, body: params })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
onCloseModal()
|
||||
for (const item of allSegments) {
|
||||
for (const seg of item) {
|
||||
if (seg.id === segmentId) {
|
||||
seg.answer = res.data.answer
|
||||
seg.content = res.data.content
|
||||
seg.keywords = res.data.keywords
|
||||
seg.word_count = res.data.word_count
|
||||
seg.hit_count = res.data.hit_count
|
||||
seg.index_node_hash = res.data.index_node_hash
|
||||
seg.enabled = res.data.enabled
|
||||
}
|
||||
}
|
||||
}
|
||||
setAllSegments([...allSegments])
|
||||
}
|
||||
finally {
|
||||
eventEmitter?.emit('')
|
||||
}
|
||||
setAllSegments([...allSegments])
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -127,7 +127,7 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => {
|
||||
</div>
|
||||
<Divider className='!h-4' type='vertical' />
|
||||
<DocumentTitle extension={documentDetail?.data_source_info?.upload_file?.extension} name={documentDetail?.name} />
|
||||
<StatusItem status={documentDetail?.display_status || 'available'} scene='detail' />
|
||||
<StatusItem status={documentDetail?.display_status || 'available'} scene='detail' errorMessage={documentDetail?.error || ''} />
|
||||
{documentDetail && !documentDetail.archived && (
|
||||
<SegmentAdd
|
||||
importStatus={importStatus}
|
||||
@@ -170,7 +170,7 @@ const DocumentDetail: FC<Props> = ({ datasetId, documentId }) => {
|
||||
</div>
|
||||
}
|
||||
{showMetadata && <Metadata
|
||||
docDetail={{ ...documentDetail, ...documentMetadata } as any}
|
||||
docDetail={{ ...documentDetail, ...documentMetadata, doc_type: documentDetail?.doc_type === 'others' ? '' : documentDetail?.doc_type } as any}
|
||||
loading={isMetadataLoading}
|
||||
onUpdate={metadataMutate}
|
||||
/>}
|
||||
|
||||
@@ -31,6 +31,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = memo(({
|
||||
const [answer, setAnswer] = useState('')
|
||||
const { datasetId, documentId } = useParams()
|
||||
const [keywords, setKeywords] = useState<string[]>([])
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const handleCancel = () => {
|
||||
setQuestion('')
|
||||
@@ -60,10 +61,16 @@ const NewSegmentModal: FC<NewSegmentModalProps> = memo(({
|
||||
if (keywords?.length)
|
||||
params.keywords = keywords
|
||||
|
||||
await addSegment({ datasetId, documentId, body: params })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
handleCancel()
|
||||
onSave()
|
||||
setLoading(true)
|
||||
try {
|
||||
await addSegment({ datasetId, documentId, body: params })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
handleCancel()
|
||||
onSave()
|
||||
}
|
||||
finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
const renderContent = () => {
|
||||
@@ -136,7 +143,9 @@ const NewSegmentModal: FC<NewSegmentModalProps> = memo(({
|
||||
<Button
|
||||
type='primary'
|
||||
className='!h-9 !px-4 !py-2 text-sm font-medium !rounded-lg'
|
||||
onClick={handleSave}>
|
||||
onClick={handleSave}
|
||||
disabled={loading}
|
||||
>
|
||||
{t('common.operation.save')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
@@ -6,7 +6,6 @@ import { useContext } from 'use-context-selector'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import DatasetDetailContext from '@/context/dataset-detail'
|
||||
import type { FullDocumentDetail } from '@/models/datasets'
|
||||
import { fetchTenantInfo } from '@/service/common'
|
||||
import type { MetadataType } from '@/service/datasets'
|
||||
import { fetchDocumentDetail } from '@/service/datasets'
|
||||
|
||||
@@ -14,6 +13,7 @@ import Loading from '@/app/components/base/loading'
|
||||
import StepTwo from '@/app/components/datasets/create/step-two'
|
||||
import AccountSetting from '@/app/components/header/account-setting'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
|
||||
type DocumentSettingsProps = {
|
||||
datasetId: string
|
||||
@@ -23,25 +23,15 @@ type DocumentSettingsProps = {
|
||||
const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const [hasSetAPIKEY, setHasSetAPIKEY] = useState(true)
|
||||
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
|
||||
const [hasError, setHasError] = useState(false)
|
||||
const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
|
||||
const { embeddingsDefaultModel } = useProviderContext()
|
||||
|
||||
const saveHandler = () => router.push(`/datasets/${datasetId}/documents/${documentId}`)
|
||||
|
||||
const cancelHandler = () => router.back()
|
||||
|
||||
const checkAPIKey = async () => {
|
||||
const data = await fetchTenantInfo({ url: '/info' })
|
||||
const hasSetKey = data.providers.some(({ is_valid }) => is_valid)
|
||||
setHasSetAPIKEY(hasSetKey)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
checkAPIKey()
|
||||
}, [])
|
||||
|
||||
const [documentDetail, setDocumentDetail] = useState<FullDocumentDetail | null>(null)
|
||||
const currentPage = useMemo(() => {
|
||||
return {
|
||||
@@ -77,7 +67,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
{!documentDetail && <Loading type='app' />}
|
||||
{dataset && documentDetail && (
|
||||
<StepTwo
|
||||
hasSetAPIKEY={hasSetAPIKEY}
|
||||
hasSetAPIKEY={!!embeddingsDefaultModel}
|
||||
onSetting={showSetAPIKey}
|
||||
datasetId={datasetId}
|
||||
dataSourceType={documentDetail.data_source_type}
|
||||
@@ -92,7 +82,6 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
)}
|
||||
</div>
|
||||
{isShowSetAPIKey && <AccountSetting activeTab="provider" onCancel={async () => {
|
||||
await checkAPIKey()
|
||||
hideSetAPIkey()
|
||||
}} />}
|
||||
</div>
|
||||
|
||||
@@ -27,7 +27,7 @@ import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import ProgressBar from '@/app/components/base/progress-bar'
|
||||
import { DataSourceType, type DocumentDisplayStatus, type SimpleDocumentDetail } from '@/models/datasets'
|
||||
import type { CommonResponse } from '@/models/common'
|
||||
import { DotsHorizontal } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import { DotsHorizontal, HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
|
||||
|
||||
export const SettingsIcon: FC<{ className?: string }> = ({ className }) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
@@ -73,7 +73,8 @@ export const StatusItem: FC<{
|
||||
reverse?: boolean
|
||||
scene?: 'list' | 'detail'
|
||||
textCls?: string
|
||||
}> = ({ status, reverse = false, scene = 'list', textCls = '' }) => {
|
||||
errorMessage?: string
|
||||
}> = ({ status, reverse = false, scene = 'list', textCls = '', errorMessage }) => {
|
||||
const DOC_INDEX_STATUS_MAP = useIndexStatus()
|
||||
const localStatus = status.toLowerCase() as keyof typeof DOC_INDEX_STATUS_MAP
|
||||
return <div className={
|
||||
@@ -83,6 +84,18 @@ export const StatusItem: FC<{
|
||||
}>
|
||||
<Indicator color={DOC_INDEX_STATUS_MAP[localStatus]?.color as IndicatorProps['color']} className={reverse ? 'ml-2' : 'mr-2'} />
|
||||
<span className={cn('text-gray-700 text-sm', textCls)}>{DOC_INDEX_STATUS_MAP[localStatus]?.text}</span>
|
||||
{
|
||||
errorMessage && (
|
||||
<Tooltip
|
||||
selector='dataset-document-detail-item-status'
|
||||
htmlContent={
|
||||
<div className='max-w-[260px]'>{errorMessage}</div>
|
||||
}
|
||||
>
|
||||
<HelpCircle className='ml-1 w-[14px] h-[14px] text-gray-700' />
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
}
|
||||
|
||||
@@ -140,73 +153,6 @@ export const OperationAction: FC<{
|
||||
onUpdate(operationName)
|
||||
}
|
||||
|
||||
const Operations = (props: any) => <div className='w-full py-1'>
|
||||
{!isListScene && <>
|
||||
<div className='flex justify-between items-center mx-4 pt-2'>
|
||||
<span className={cn(s.actionName, 'font-medium')}>
|
||||
{!archived && enabled ? t('datasetDocuments.list.index.enable') : t('datasetDocuments.list.index.disable')}
|
||||
</span>
|
||||
<Tooltip
|
||||
selector={`detail-switch-${id}`}
|
||||
content={t('datasetDocuments.list.action.enableWarning') as string}
|
||||
className='!font-semibold'
|
||||
disabled={!archived}
|
||||
>
|
||||
<div>
|
||||
<Switch
|
||||
defaultValue={archived ? false : enabled}
|
||||
onChange={v => !archived && onOperate(v ? 'enable' : 'disable')}
|
||||
disabled={archived}
|
||||
size='md'
|
||||
/>
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<div className='mx-4 pb-1 pt-0.5 text-xs text-gray-500'>
|
||||
{!archived && enabled ? t('datasetDocuments.list.index.enableTip') : t('datasetDocuments.list.index.disableTip')}
|
||||
</div>
|
||||
<Divider />
|
||||
</>}
|
||||
{!archived && (
|
||||
<>
|
||||
<div className={s.actionItem} onClick={() => router.push(`/datasets/${datasetId}/documents/${detail.id}/settings`)}>
|
||||
<SettingsIcon />
|
||||
<span className={s.actionName}>{t('datasetDocuments.list.action.settings')}</span>
|
||||
</div>
|
||||
{
|
||||
!isListScene && (
|
||||
<div className={s.actionItem} onClick={showNewSegmentModal}>
|
||||
<FilePlus02 className='w-4 h-4 text-gray-500' />
|
||||
<span className={s.actionName}>{t('datasetDocuments.list.action.add')}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
data_source_type === 'notion_import' && (
|
||||
<div className={s.actionItem} onClick={() => onOperate('sync')}>
|
||||
<SyncIcon />
|
||||
<span className={s.actionName}>{t('datasetDocuments.list.action.sync')}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<Divider className='my-1' />
|
||||
</>
|
||||
)}
|
||||
{!archived && <div className={s.actionItem} onClick={() => onOperate('archive')}>
|
||||
<ArchiveIcon />
|
||||
<span className={s.actionName}>{t('datasetDocuments.list.action.archive')}</span>
|
||||
</div>}
|
||||
<div
|
||||
className={cn(s.actionItem, s.deleteActionItem, 'group')}
|
||||
onClick={() => {
|
||||
setShowModal(true)
|
||||
props?.onClose()
|
||||
}}>
|
||||
<TrashIcon className={'w-4 h-4 stroke-current text-gray-500 stroke-2 group-hover:text-red-500'} />
|
||||
<span className={cn(s.actionName, 'group-hover:text-red-500')}>{t('datasetDocuments.list.action.delete')}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
return <div className='flex items-center' onClick={e => e.stopPropagation()}>
|
||||
{isListScene && <>
|
||||
{archived
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
'use client'
|
||||
import { useEffect, useState } from 'react'
|
||||
import type { Dispatch } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { BookOpenIcon } from '@heroicons/react/24/outline'
|
||||
@@ -24,7 +25,7 @@ const labelClass = `
|
||||
const inputClass = `
|
||||
w-[480px] px-3 bg-gray-100 text-sm text-gray-800 rounded-lg outline-none appearance-none
|
||||
`
|
||||
const useInitialValue = (depend: any, dispatch: any) => {
|
||||
const useInitialValue: <T>(depend: T, dispatch: Dispatch<T>) => void = (depend, dispatch) => {
|
||||
useEffect(() => {
|
||||
dispatch(depend)
|
||||
}, [depend])
|
||||
|
||||
@@ -79,7 +79,7 @@ const SecretKeyModal = ({
|
||||
return `${token.slice(0, 3)}...${token.slice(-20)}`
|
||||
}
|
||||
|
||||
const formatDate = (timestamp: any) => {
|
||||
const formatDate = (timestamp: string) => {
|
||||
if (locale === 'en')
|
||||
return new Intl.DateTimeFormat('en-US', { year: 'numeric', month: 'long', day: 'numeric' }).format((+timestamp) * 1000)
|
||||
else
|
||||
|
||||
@@ -13,8 +13,10 @@ import AppCard from '@/app/components/explore/app-card'
|
||||
import { fetchAppDetail, fetchAppList, installApp } from '@/service/explore'
|
||||
import { createApp } from '@/service/apps'
|
||||
import CreateAppModal from '@/app/components/explore/create-app-modal'
|
||||
import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { type AppMode } from '@/types/app'
|
||||
|
||||
const Apps: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
@@ -50,7 +52,7 @@ const Apps: FC = () => {
|
||||
|
||||
const [currApp, setCurrApp] = React.useState<App | null>(null)
|
||||
const [isShowCreateModal, setIsShowCreateModal] = React.useState(false)
|
||||
const onCreate = async ({ name, icon, icon_background }: any) => {
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => {
|
||||
const { app_model_config: model_config } = await fetchAppDetail(currApp?.app.id as string)
|
||||
|
||||
try {
|
||||
@@ -58,7 +60,7 @@ const Apps: FC = () => {
|
||||
name,
|
||||
icon,
|
||||
icon_background,
|
||||
mode: currApp?.app.mode as any,
|
||||
mode: currApp?.app.mode as AppMode,
|
||||
config: model_config,
|
||||
})
|
||||
setIsShowCreateModal(false)
|
||||
|
||||
@@ -25,7 +25,7 @@ const Category: FC<ICategoryProps> = ({
|
||||
const itemClassName = (isSelected: boolean) => cn(isSelected ? 'bg-white text-primary-600 border-gray-200 font-semibold' : 'border-transparent font-medium', 'flex items-center h-7 px-3 border cursor-pointer rounded-lg')
|
||||
const itemStyle = (isSelected: boolean) => isSelected ? { boxShadow: '0px 1px 2px rgba(16, 24, 40, 0.05)' } : {}
|
||||
return (
|
||||
<div className={cn(className, 'flex space-x-1 text-[13px]')}>
|
||||
<div className={cn(className, 'flex space-x-1 text-[13px] flex-wrap')}>
|
||||
<div
|
||||
className={itemClassName(value === '')}
|
||||
style={itemStyle(value === '')}
|
||||
|
||||
@@ -9,10 +9,14 @@ import Toast from '@/app/components/base/toast'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import EmojiPicker from '@/app/components/base/emoji-picker'
|
||||
|
||||
type IProps = {
|
||||
export type CreateAppModalProps = {
|
||||
appName: string
|
||||
show: boolean
|
||||
onConfirm: (info: any) => void
|
||||
onConfirm: (info: {
|
||||
name: string
|
||||
icon: string
|
||||
icon_background: string
|
||||
}) => Promise<void>
|
||||
onHide: () => void
|
||||
}
|
||||
|
||||
@@ -21,7 +25,7 @@ const CreateAppModal = ({
|
||||
show = false,
|
||||
onConfirm,
|
||||
onHide,
|
||||
}: IProps) => {
|
||||
}: CreateAppModalProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [name, setName] = React.useState('')
|
||||
|
||||
@@ -6,12 +6,13 @@ import { useTranslation } from 'react-i18next'
|
||||
import s from './style.module.css'
|
||||
import Config from '@/app/components/explore/universal-chat/config'
|
||||
import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
type Props = {
|
||||
modelId: string
|
||||
providerName: ProviderEnum
|
||||
plugins: Record<string, boolean>
|
||||
dataSets: any[]
|
||||
dataSets: DataSet[]
|
||||
}
|
||||
const ConfigViewPanel: FC<Props> = ({
|
||||
modelId,
|
||||
|
||||
@@ -10,12 +10,13 @@ import ConfigDetail from '@/app/components/explore/universal-chat/config-view/de
|
||||
import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import ModelName from '@/app/components/app/configuration/config-model/model-name'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
export type ISummaryProps = {
|
||||
modelId: string
|
||||
providerName: ProviderEnum
|
||||
plugins: Record<string, boolean>
|
||||
dataSets: any[]
|
||||
dataSets: DataSet[]
|
||||
}
|
||||
|
||||
const getColorInfo = (modelId: string) => {
|
||||
|
||||
@@ -5,6 +5,7 @@ import ModelConfig from './model-config'
|
||||
import DataConfig from './data-config'
|
||||
import PluginConfig from './plugins-config'
|
||||
import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
export type IConfigProps = {
|
||||
className?: string
|
||||
@@ -14,8 +15,8 @@ export type IConfigProps = {
|
||||
onModelChange?: (modelId: string, providerName: ProviderEnum) => void
|
||||
plugins: Record<string, boolean>
|
||||
onPluginChange?: (key: string, value: boolean) => void
|
||||
dataSets: any[]
|
||||
onDataSetsChange?: (contexts: any[]) => void
|
||||
dataSets: DataSet[]
|
||||
onDataSetsChange?: (contexts: DataSet[]) => void
|
||||
}
|
||||
|
||||
const Config: FC<IConfigProps> = ({
|
||||
|
||||
@@ -19,7 +19,8 @@ const plugins = [
|
||||
{ key: 'google_search', icon: <Google /> },
|
||||
{ key: 'web_reader', icon: <WebReader /> },
|
||||
{ key: 'wikipedia', icon: <Wikipedia /> },
|
||||
]
|
||||
] as const
|
||||
|
||||
const Plugins: FC<IPluginsProps> = ({
|
||||
readonly,
|
||||
config,
|
||||
|
||||
@@ -42,12 +42,12 @@ const config: ProviderConfig = {
|
||||
key: 'app_id',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'API ID',
|
||||
'zh-Hans': 'API ID',
|
||||
'en': 'APPID',
|
||||
'zh-Hans': 'APPID',
|
||||
},
|
||||
placeholder: {
|
||||
'en': 'Enter your API ID here',
|
||||
'zh-Hans': '在此输入您的 API ID',
|
||||
'en': 'Enter your APPID here',
|
||||
'zh-Hans': '在此输入您的 APPID',
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -55,12 +55,12 @@ const config: ProviderConfig = {
|
||||
key: 'api_secret',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'API Secret',
|
||||
'zh-Hans': 'API Secret',
|
||||
'en': 'APISecret',
|
||||
'zh-Hans': 'APISecret',
|
||||
},
|
||||
placeholder: {
|
||||
'en': 'Enter your API Secret here',
|
||||
'zh-Hans': '在此输入您的 API Secret',
|
||||
'en': 'Enter your APISecret here',
|
||||
'zh-Hans': '在此输入您的 APISecret',
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -68,12 +68,12 @@ const config: ProviderConfig = {
|
||||
key: 'api_key',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'API Key',
|
||||
'zh-Hans': 'API Key',
|
||||
'en': 'APIKey',
|
||||
'zh-Hans': 'APIKey',
|
||||
},
|
||||
placeholder: {
|
||||
'en': 'Enter your API key here',
|
||||
'zh-Hans': '在此输入您的 API Key',
|
||||
'en': 'Enter your APIKey here',
|
||||
'zh-Hans': '在此输入您的 APIKey',
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import type {
|
||||
BackendModel,
|
||||
FormValue,
|
||||
@@ -30,23 +31,13 @@ import { ModelType } from '@/app/components/header/account-setting/model-page/de
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import I18n from '@/context/i18n'
|
||||
|
||||
const MODEL_CARD_LIST = [
|
||||
config.openai,
|
||||
config.anthropic,
|
||||
]
|
||||
|
||||
const MODEL_LIST = [
|
||||
config.azure_openai,
|
||||
config.replicate,
|
||||
config.huggingface_hub,
|
||||
config.minimax,
|
||||
config.spark,
|
||||
config.tongyi,
|
||||
config.wenxin,
|
||||
config.chatglm,
|
||||
]
|
||||
|
||||
const titleClassName = `
|
||||
flex items-center h-9 text-sm font-medium text-gray-900
|
||||
`
|
||||
@@ -61,11 +52,16 @@ type DeleteModel = {
|
||||
|
||||
const ModelPage = () => {
|
||||
const { t } = useTranslation()
|
||||
const { updateModelList } = useProviderContext()
|
||||
const { locale } = useContext(I18n)
|
||||
const {
|
||||
updateModelList,
|
||||
embeddingsDefaultModel,
|
||||
mutateEmbeddingsDefaultModel,
|
||||
speech2textDefaultModel,
|
||||
mutateSpeech2textDefaultModel,
|
||||
} = useProviderContext()
|
||||
const { data: providers, mutate: mutateProviders } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
|
||||
const { data: textGenerationDefaultModel, mutate: mutateTextGenerationDefaultModel } = useSWR('/workspaces/current/default-model?model_type=text-generation', fetchDefaultModal)
|
||||
const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
|
||||
const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
|
||||
const [showMoreModel, setShowMoreModel] = useState(false)
|
||||
const [showModal, setShowModal] = useState(false)
|
||||
const { notify } = useToastContext()
|
||||
@@ -75,6 +71,33 @@ const ModelPage = () => {
|
||||
const [deleteModel, setDeleteModel] = useState<DeleteModel & { providerKey: ProviderEnum }>()
|
||||
const [modalMode, setModalMode] = useState('add')
|
||||
|
||||
let modelList = []
|
||||
|
||||
if (locale === 'en') {
|
||||
modelList = [
|
||||
config.azure_openai,
|
||||
config.replicate,
|
||||
config.huggingface_hub,
|
||||
config.minimax,
|
||||
config.spark,
|
||||
config.tongyi,
|
||||
config.wenxin,
|
||||
config.chatglm,
|
||||
]
|
||||
}
|
||||
else {
|
||||
modelList = [
|
||||
config.huggingface_hub,
|
||||
config.minimax,
|
||||
config.spark,
|
||||
config.azure_openai,
|
||||
config.replicate,
|
||||
config.tongyi,
|
||||
config.wenxin,
|
||||
config.chatglm,
|
||||
]
|
||||
}
|
||||
|
||||
const handleOpenModal = (newModelModalConfig: ProviderConfigModal | undefined, editValue?: FormValue) => {
|
||||
if (newModelModalConfig) {
|
||||
setShowModal(true)
|
||||
@@ -280,7 +303,7 @@ const ModelPage = () => {
|
||||
}
|
||||
</div>
|
||||
{
|
||||
MODEL_LIST.slice(0, showMoreModel ? MODEL_LIST.length : 3).map((model, index) => (
|
||||
modelList.slice(0, showMoreModel ? modelList.length : 3).map((model, index) => (
|
||||
<ModelItem
|
||||
key={index}
|
||||
modelItem={model.item}
|
||||
|
||||
@@ -8,7 +8,7 @@ import Button from '@/app/components/base/button'
|
||||
|
||||
type CardProps = {
|
||||
providerType: ProviderEnum
|
||||
models: any[]
|
||||
models: Model[]
|
||||
onOpenModal: (v: any) => void
|
||||
onOperate: (v: Record<string, any>) => void
|
||||
}
|
||||
@@ -33,7 +33,7 @@ const Card: FC<CardProps> = ({
|
||||
return (
|
||||
<div className='px-3 pb-3'>
|
||||
{
|
||||
models.map((model: Model) => (
|
||||
models.map(model => (
|
||||
<div key={`${model.model_name}-${model.model_type}`} className='flex mb-1 px-3 py-2 bg-white rounded-lg shadow-xs last:mb-0'>
|
||||
<div className='grow'>
|
||||
<div className='flex items-center mb-0.5 h-[18px] text-[13px] font-medium text-gray-700'>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
|
||||
type QuotaCardProps = {
|
||||
remainTokens: number
|
||||
@@ -17,7 +18,7 @@ const QuotaCard: FC<QuotaCardProps> = ({
|
||||
{t('common.modelProvider.item.freeQuota')}
|
||||
</div>
|
||||
<div className='flex items-center h-5 text-sm font-medium text-gray-700'>
|
||||
{remainTokens}
|
||||
{formatNumber(remainTokens)}
|
||||
<div className='ml-1 font-normal'>Tokens</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -15,6 +15,7 @@ import ModelIcon from '@/app/components/app/configuration/config-model/model-ico
|
||||
import ModelName, { supportI18nModelName } from '@/app/components/app/configuration/config-model/model-name'
|
||||
import ProviderName from '@/app/components/app/configuration/config-model/provider-name'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
|
||||
type Props = {
|
||||
value: {
|
||||
providerName: ProviderEnum
|
||||
@@ -28,6 +29,16 @@ type Props = {
|
||||
triggerIconSmall?: boolean
|
||||
}
|
||||
|
||||
type ModelOption = {
|
||||
type: 'model'
|
||||
value: string
|
||||
providerName: ProviderEnum
|
||||
modelDisplayName: string
|
||||
} | {
|
||||
type: 'provider'
|
||||
value: ProviderEnum
|
||||
}
|
||||
|
||||
const ModelSelector: FC<Props> = ({
|
||||
value,
|
||||
modelType,
|
||||
@@ -69,9 +80,9 @@ const ModelSelector: FC<Props> = ({
|
||||
|
||||
const hasRemoved = value && !modelList.find(({ model_name }) => model_name === value.modelName)
|
||||
|
||||
const modelOptions: any[] = (() => {
|
||||
const modelOptions: ModelOption[] = (() => {
|
||||
const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name))
|
||||
const res: any[] = []
|
||||
const res: ModelOption[] = []
|
||||
providers.forEach((providerName) => {
|
||||
res.push({
|
||||
type: 'provider',
|
||||
@@ -162,7 +173,7 @@ const ModelSelector: FC<Props> = ({
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
modelOptions.map((model: any) => {
|
||||
modelOptions.map((model) => {
|
||||
if (model.type === 'provider') {
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -10,7 +10,7 @@ const PluginPage = () => {
|
||||
const { t } = useTranslation()
|
||||
const { data: plugins, mutate } = useSWR('/workspaces/current/tool-providers', fetchPluginProviders)
|
||||
|
||||
const Plugin_MAP: Record<string, any> = {
|
||||
const Plugin_MAP: Record<string, (plugin: PluginProvider) => JSX.Element> = {
|
||||
serpapi: (plugin: PluginProvider) => <SerpapiPlugin key='serpapi' plugin={plugin} onUpdate={() => mutate()} />,
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ import {
|
||||
} from '@/service/share'
|
||||
import type { ConversationItem, SiteInfo } from '@/models/share'
|
||||
import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
|
||||
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
|
||||
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat/type'
|
||||
import Chat from '@/app/components/app/chat'
|
||||
import { changeLanguage } from '@/i18n/i18next-config'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import Header from './header'
|
||||
import type { Feedbacktype } from '@/app/components/app/chat'
|
||||
import type { Feedbacktype } from '@/app/components/app/chat/type'
|
||||
import { format } from '@/service/base'
|
||||
|
||||
export type IResultProps = {
|
||||
|
||||
@@ -8,7 +8,7 @@ import TextGenerationRes from '@/app/components/app/text-generate/item'
|
||||
import NoData from '@/app/components/share/text-generation/no-data'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { sendCompletionMessage, updateFeedback } from '@/service/share'
|
||||
import type { Feedbacktype } from '@/app/components/app/chat'
|
||||
import type { Feedbacktype } from '@/app/components/app/chat/type'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { InstalledApp } from '@/models/explore'
|
||||
|
||||
@@ -22,7 +22,11 @@ type IState = {
|
||||
google: boolean
|
||||
}
|
||||
|
||||
function reducer(state: IState, action: { type: string; payload: any }) {
|
||||
type IAction = {
|
||||
type: 'login' | 'login_failed' | 'github_login' | 'github_login_failed' | 'google_login' | 'google_login_failed'
|
||||
}
|
||||
|
||||
function reducer(state: IState, action: IAction) {
|
||||
switch (action.type) {
|
||||
case 'login':
|
||||
return {
|
||||
@@ -120,14 +124,14 @@ const NormalForm = () => {
|
||||
|
||||
useEffect(() => {
|
||||
if (github_error !== undefined)
|
||||
dispatch({ type: 'github_login_failed', payload: null })
|
||||
dispatch({ type: 'github_login_failed' })
|
||||
if (github)
|
||||
window.location.href = github.redirect_url
|
||||
}, [github, github_error])
|
||||
|
||||
useEffect(() => {
|
||||
if (google_error !== undefined)
|
||||
dispatch({ type: 'google_login_failed', payload: null })
|
||||
dispatch({ type: 'google_login_failed' })
|
||||
if (google)
|
||||
window.location.href = google.redirect_url
|
||||
}, [google, google])
|
||||
@@ -237,6 +241,10 @@ const NormalForm = () => {
|
||||
id="password"
|
||||
value={password}
|
||||
onChange={e => setPassword(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter')
|
||||
handleEmailPasswordLogin()
|
||||
}}
|
||||
type={showPassword ? 'text' : 'password'}
|
||||
autoComplete="current-password"
|
||||
placeholder={t('login.passwordPlaceholder') || ''}
|
||||
@@ -256,6 +264,7 @@ const NormalForm = () => {
|
||||
|
||||
<div className='mb-2'>
|
||||
<Button
|
||||
tabIndex={0}
|
||||
type='primary'
|
||||
onClick={handleEmailPasswordLogin}
|
||||
disabled={isLoading}
|
||||
|
||||
@@ -2,29 +2,29 @@
|
||||
|
||||
import { createContext, useContext } from 'use-context-selector'
|
||||
import useSWR from 'swr'
|
||||
import { fetchModelList, fetchTenantInfo } from '@/service/common'
|
||||
import { fetchDefaultModal, fetchModelList } from '@/service/common'
|
||||
import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
const ProviderContext = createContext<{
|
||||
currentProvider: {
|
||||
provider: string
|
||||
provider_name: string
|
||||
token_is_set: boolean
|
||||
is_valid: boolean
|
||||
token_is_valid: boolean
|
||||
} | null | undefined
|
||||
textGenerationModelList: BackendModel[]
|
||||
embeddingsModelList: BackendModel[]
|
||||
speech2textModelList: BackendModel[]
|
||||
agentThoughtModelList: BackendModel[]
|
||||
updateModelList: (type: ModelType) => void
|
||||
embeddingsDefaultModel?: BackendModel
|
||||
mutateEmbeddingsDefaultModel: () => void
|
||||
speech2textDefaultModel?: BackendModel
|
||||
mutateSpeech2textDefaultModel: () => void
|
||||
}>({
|
||||
currentProvider: null,
|
||||
textGenerationModelList: [],
|
||||
embeddingsModelList: [],
|
||||
speech2textModelList: [],
|
||||
agentThoughtModelList: [],
|
||||
updateModelList: () => {},
|
||||
speech2textDefaultModel: undefined,
|
||||
mutateSpeech2textDefaultModel: () => {},
|
||||
embeddingsDefaultModel: undefined,
|
||||
mutateEmbeddingsDefaultModel: () => {},
|
||||
})
|
||||
|
||||
export const useProviderContext = () => useContext(ProviderContext)
|
||||
@@ -35,8 +35,8 @@ type ProviderContextProviderProps = {
|
||||
export const ProviderContextProvider = ({
|
||||
children,
|
||||
}: ProviderContextProviderProps) => {
|
||||
const { data: userInfo } = useSWR({ url: '/info' }, fetchTenantInfo)
|
||||
const currentProvider = userInfo?.providers?.find(({ token_is_set, is_valid, provider_name }) => token_is_set && is_valid && (provider_name === 'openai' || provider_name === 'azure_openai'))
|
||||
const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
|
||||
const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
|
||||
const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
|
||||
const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
|
||||
const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
|
||||
@@ -54,12 +54,15 @@ export const ProviderContextProvider = ({
|
||||
|
||||
return (
|
||||
<ProviderContext.Provider value={{
|
||||
currentProvider,
|
||||
textGenerationModelList: textGenerationModelList || [],
|
||||
embeddingsModelList: embeddingsModelList || [],
|
||||
speech2textModelList: speech2textModelList || [],
|
||||
agentThoughtModelList: agentThoughtModelList || [],
|
||||
updateModelList,
|
||||
embeddingsDefaultModel,
|
||||
mutateEmbeddingsDefaultModel,
|
||||
speech2textDefaultModel,
|
||||
mutateSpeech2textDefaultModel,
|
||||
}}>
|
||||
{children}
|
||||
</ProviderContext.Provider>
|
||||
|
||||
@@ -242,7 +242,7 @@ export type FullDocumentDetail = SimpleDocumentDetail & {
|
||||
archived_reason: 'rule_modified' | 're_upload'
|
||||
archived_by: string
|
||||
archived_at: number
|
||||
doc_type?: DocType | null
|
||||
doc_type?: DocType | null | 'others'
|
||||
doc_metadata?: DocMetadata | null
|
||||
segment_count: number
|
||||
[key: string]: any
|
||||
|
||||
@@ -6,7 +6,7 @@ import type {
|
||||
ICurrentWorkspace,
|
||||
IWorkspace, LangGeniusVersionResponse, Member,
|
||||
OauthResponse, PluginProvider, Provider, ProviderAnthropicToken, ProviderAzureToken,
|
||||
SetupStatusResponse, TenantInfoResponse, UserProfileOriginResponse,
|
||||
SetupStatusResponse, UserProfileOriginResponse,
|
||||
} from '@/models/common'
|
||||
import type {
|
||||
UpdateOpenAIKeyResponse,
|
||||
@@ -34,10 +34,6 @@ export const updateUserProfile: Fetcher<CommonResponse, { url: string; body: Rec
|
||||
return post(url, { body }) as Promise<CommonResponse>
|
||||
}
|
||||
|
||||
export const fetchTenantInfo: Fetcher<TenantInfoResponse, { url: string }> = ({ url }) => {
|
||||
return get(url) as Promise<TenantInfoResponse>
|
||||
}
|
||||
|
||||
export const logout: Fetcher<CommonResponse, { url: string; params: Record<string, any> }> = ({ url, params }) => {
|
||||
return get(url, params) as Promise<CommonResponse>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user