Compare commits

...

23 Commits

Author SHA1 Message Date
takatost
b24d1216e7 fix: memory with files 2023-11-11 15:41:28 +08:00
takatost
1be8cbb1bc fix: bugs 2023-11-11 15:24:00 +08:00
takatost
be94656903 fix: created_by_role not fill in 2023-11-11 02:37:39 +08:00
takatost
e43f53c875 fix: bug 2023-11-11 02:31:29 +08:00
takatost
6b97396422 fix: message files 2023-11-10 22:13:18 +08:00
Garfield Dai
0140517aae update. 2023-11-10 22:12:37 +08:00
Garfield Dai
f525159330 update. 2023-11-10 21:58:31 +08:00
takatost
136c473526 fix: token calc bug with images 2023-11-10 21:47:57 +08:00
Garfield Dai
ad81e5e877 update. 2023-11-10 21:30:34 +08:00
takatost
b33561b444 feat: add file arg validate 2023-11-10 21:14:38 +08:00
takatost
b06319c5e3 feat: optimize file route uri 2023-11-10 21:09:14 +08:00
takatost
45b8bd347f feat: add image preview route 2023-11-10 20:46:59 +08:00
takatost
8778a82862 feat: add preview_url 2023-11-10 18:34:05 +08:00
takatost
c0115490c6 feat: optimize image parse 2023-11-10 18:25:27 +08:00
Garfield Dai
39d1de1227 update. 2023-11-10 18:09:48 +08:00
Garfield Dai
b46070088c update. 2023-11-10 17:35:26 +08:00
Garfield Dai
f3a6856f9c feat: update model. 2023-11-09 16:56:40 +08:00
Garfield Dai
d4948179db feat: add parameters. 2023-11-09 15:57:58 +08:00
Garfield Dai
5290b4d930 feat: update model config. 2023-11-09 15:37:48 +08:00
takatost
e1a3c1ed95 feat: add image parse 2023-11-09 15:31:40 +08:00
Garfield Dai
1575e15222 feat: update entities 2023-11-09 13:56:45 +08:00
Garfield Dai
d759fa952f feat: add entities. 2023-11-09 11:49:58 +08:00
takatost
28778b5aa7 [WIP] Vision support in OpenAI Chat Model 2023-11-08 21:25:45 +08:00
42 changed files with 1209 additions and 240 deletions

View File

@@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001
APP_API_URL=http://127.0.0.1:5001
APP_WEB_URL=http://127.0.0.1:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
@@ -70,6 +73,14 @@ MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
# Model Configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
# Mail configuration, support: resend
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>

View File

@@ -126,6 +126,7 @@ def register_blueprints(app):
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
@@ -155,6 +156,12 @@ def register_blueprints(app):
app.register_blueprint(console_app_bp)
CORS(files_bp,
allow_headers=['Content-Type'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
app.register_blueprint(files_bp)
# create app
app = create_app()

View File

@@ -26,6 +26,7 @@ DEFAULTS = {
'SERVICE_API_URL': 'https://api.dify.ai',
'APP_WEB_URL': 'https://udify.app',
'APP_API_URL': 'https://udify.app',
'FILES_URL': 'https://files.dify.ai',
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
@@ -57,7 +58,9 @@ DEFAULTS = {
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
'UPLOAD_FILE_BATCH_LIMIT': 5,
'OUTPUT_MODERATION_BUFFER_SIZE': 300
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64'
}
@@ -84,15 +87,9 @@ class Config:
"""Application configuration class."""
def __init__(self):
# app settings
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.29"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
@@ -100,70 +97,55 @@ class Config:
self.TESTING = False
self.LOG_LEVEL = get_env('LOG_LEVEL')
# The backend URL prefix of the console API.
# used to concatenate the login authorization callback or notion integration callback.
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
# The front-end URL prefix of the console web.
# used to concatenate some front-end addresses and for CORS configuration use.
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
# WebApp API backend Url prefix.
# used to declare the back-end URL for the front-end API.
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
# WebApp Url prefix.
# used to display WebAPP API Base Url to the front-end.
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
# Service API Url prefix.
# used to display Service API Base Url to the front-end.
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
# File preview or download Url prefix.
# used to display File preview or download Url to the front-end or as Multi-model inputs;
# Url is signed and has expiration time.
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
# Fallback Url prefix.
# Will be deprecated in the future.
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
# You can generate a strong key using `openssl rand -base64 42`.
# Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY')
# redis settings
self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT')
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# storage settings
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION')
# vector store settings, only support weaviate, qdrant
self.VECTOR_STORE = get_env('VECTOR_STORE')
# weaviate settings
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'WEB_API_CORS_ALLOW_ORIGINS', '*')
# mail settings
self.MAIL_TYPE = get_env('MAIL_TYPE')
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
# sentry settings
self.SENTRY_DSN = get_env('SENTRY_DSN')
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
# check update url
self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
# database settings
# ------------------------
# Database Configurations.
# ------------------------
db_credentials = {
key: get_env(key) for key in
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
@@ -177,14 +159,102 @@ class Config:
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
# celery settings
# ------------------------
# Redis Configurations.
# ------------------------
self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT')
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# ------------------------
# Celery worker Configurations.
# ------------------------
self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
self.CELERY_BACKEND = get_env('CELERY_BACKEND')
self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
# hosted provider credentials
# ------------------------
# File Storage Configurations.
# ------------------------
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION')
# ------------------------
# Vector Store Configurations.
# Currently, only support: qdrant, milvus, zilliz, weaviate
# ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE')
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus / zilliz setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# weaviate settings
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
# ------------------------
# Mail Configurations.
# ------------------------
self.MAIL_TYPE = get_env('MAIL_TYPE')
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
# ------------------------
# Sentry Configurations.
# ------------------------
self.SENTRY_DSN = get_env('SENTRY_DSN')
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
# ------------------------
# Business Configurations.
# ------------------------
# multi model send image format, support base64, url, default is base64
self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
# Dataset Configurations.
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
# File upload Configurations.
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
# Moderation in app Configurations.
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
# Notion integration setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
# ------------------------
# Platform Configurations.
# ------------------------
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
@@ -212,26 +282,6 @@ class Config:
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
# uploading settings
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
# moderation settings
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
class CloudEditionConfig(Config):
@@ -246,18 +296,5 @@ class CloudEditionConfig(Config):
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
class TestConfig(Config):
def __init__(self):
super().__init__()
self.EDITION = "SELF_HOSTED"
self.TESTING = True
db_credentials = {
key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT']
}
# use a different database for testing: dify_test
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test"
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@@ -40,6 +40,7 @@ class CompletionMessageApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
@@ -113,6 +114,7 @@ class ChatMessageApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')

View File

@@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'completion')
@setup_required
@login_required
@account_initialization_required
@@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'chat')
@setup_required
@login_required
@account_initialization_required
@@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource):
return {'result': 'success'}, 204
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')

View File

@@ -1,7 +1,6 @@
import datetime
import json
from cachetools import TTLCache
from flask import request
from flask_login import current_user
from libs.login import login_required
@@ -20,8 +19,6 @@ from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task
cache = TTLCache(maxsize=None, ttl=30)
class DataSourceApi(Resource):

View File

@@ -1,4 +1,3 @@
from cachetools import TTLCache
from flask import request, current_app
import services
@@ -15,9 +14,6 @@ from fields.file_fields import upload_config_fields, file_fields
from services.file_service import FileService
cache = TTLCache(maxsize=None, ttl=30)
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
PREVIEW_WORDS_LIMIT = 3000
@@ -30,9 +26,11 @@ class FileApi(Resource):
def get(self):
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
return {
'file_size_limit': file_size_limit,
'batch_count_limit': batch_count_limit
'batch_count_limit': batch_count_limit,
'image_file_size_limit': image_file_size_limit
}, 200
@setup_required

View File

@@ -32,6 +32,7 @@ class CompletionApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args()
@@ -91,6 +92,7 @@ class ChatApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, fields
from flask import current_app
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
@@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource):
'options': fields.List(fields.String)
}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(InstalledAppResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
}
@marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(InstalledAppResource):
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
'file_upload': app_model_config.file_upload_dict,
'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from libs.helper import uuid_value, TimestampField
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
from fields.conversation_fields import message_file_fields
feedback_fields = {
'rating': fields.String
@@ -19,6 +20,7 @@ message_fields = {
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}

View File

@@ -0,0 +1,10 @@
# -*- coding:utf-8 -*-
from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint('files', __name__)
api = ExternalApi(bp)
from . import image_preview

View File

@@ -0,0 +1,40 @@
from flask import request, Response
from flask_restful import Resource
import services
from controllers.files import api
from libs.exception import BaseHTTPException
from services.file_service import FileService
class ImagePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
sign = request.args.get('sign')
if not timestamp or not nonce or not sign:
return {'content': 'Invalid request.'}, 400
try:
generator, mimetype = FileService.get_image_preview(
file_id,
timestamp,
nonce,
sign
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_restful import fields, marshal_with
from flask import current_app
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
@@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource):
'options': fields.List(fields.String)
}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
@@ -28,7 +33,9 @@ class AppParameterApi(AppApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
}
@marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(AppApiResource):
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
'file_upload': app_model_config.file_upload_dict,
'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -28,6 +28,7 @@ class CompletionApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
@@ -90,6 +91,7 @@ class ChatApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')

View File

@@ -12,7 +12,7 @@ from libs.helper import TimestampField, uuid_value
from services.message_service import MessageService
from extensions.ext_database import db
from models.model import Message, EndUser
from fields.conversation_fields import message_file_fields
class MessageListApi(AppApiResource):
feedback_fields = {
@@ -43,6 +43,7 @@ class MessageListApi(AppApiResource):
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, fields
from flask import current_app
from controllers.web import api
from controllers.web.wraps import WebApiResource
@@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource):
'options': fields.List(fields.String)
}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(WebApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
}
@marshal_with(parameters_fields)
@@ -43,7 +50,11 @@ class AppParameterApi(WebApiResource):
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
'file_upload': app_model_config.file_upload_dict,
'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -30,6 +30,7 @@ class CompletionApi(WebApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
@@ -88,6 +89,7 @@ class ChatApi(WebApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')

View File

@@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
from fields.conversation_fields import message_file_fields
class MessageListApi(WebApiResource):
@@ -54,6 +55,7 @@ class MessageListApi(WebApiResource):
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField

View File

@@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource
from libs.helper import uuid_value, TimestampField
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
from fields.conversation_fields import message_file_fields
feedback_fields = {
'rating': fields.String
@@ -18,6 +20,7 @@ message_fields = {
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}

View File

@@ -13,11 +13,12 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
@@ -30,8 +31,8 @@ from core.moderation.factory import ModerationFactory
class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
is_override: bool = False, retriever_from: str = 'dev'):
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev'):
"""
errors: ProviderTokenNotInitError
"""
@@ -64,16 +65,20 @@ class Completion:
is_override=is_override,
inputs=inputs,
query=query,
files=files,
streaming=streaming,
model_instance=final_model_instance
)
prompt_message_files = [file.prompt_message_file for file in files]
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
model_instance=final_model_instance,
app_model_config=app_model_config,
query=query,
inputs=inputs
inputs=inputs,
files=prompt_message_files
)
# init orchestrator rule parser
@@ -95,6 +100,7 @@ class Completion:
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
@@ -146,6 +152,7 @@ class Completion:
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
@@ -257,6 +264,7 @@ class Completion:
@classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
files: List[PromptMessageFile],
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
@@ -270,6 +278,7 @@ class Completion:
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
@@ -280,6 +289,7 @@ class Completion:
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
@@ -337,7 +347,7 @@ class Completion:
@classmethod
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
@@ -348,7 +358,6 @@ class Completion:
max_tokens = 0
prompt_transform = PromptTransform()
prompt_messages = []
# get prompt without memory and context
if app_model_config.prompt_type == 'simple':
@@ -357,6 +366,7 @@ class Completion:
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance
@@ -367,6 +377,7 @@ class Completion:
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance

View File

@@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.file.file_obj import FileObj
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
@@ -16,13 +17,13 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource
MessageChain, DatasetRetrieverResource, MessageFile
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
inputs: dict, query: str, files: List[FileObj], streaming: bool,
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False):
self.start_at = time.perf_counter()
self.task_id = task_id
@@ -35,6 +36,7 @@ class ConversationMessageTask:
self.user = user
self.inputs = inputs
self.query = query
self.files = files
self.streaming = streaming
self.conversation = conversation
@@ -142,6 +144,19 @@ class ConversationMessageTask:
db.session.add(self.message)
db.session.commit()
for file in self.files:
message_file = MessageFile(
message_id=self.message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_file)
db.session.commit()
def append_message_text(self, text: str):
if text is not None:
self._pub_handler.pub_text(text)

View File

79
api/core/file/file_obj.py Normal file
View File

@@ -0,0 +1,79 @@
import enum
from typing import Optional
from pydantic import BaseModel
from core.file.upload_file_parser import UploadFileParser
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
from extensions.ext_database import db
from models.model import UploadFile
class FileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(enum.Enum):
REMOTE_URL = 'remote_url'
LOCAL_FILE = 'local_file'
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileObj(BaseModel):
id: Optional[str]
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
url: Optional[str]
upload_file_id: Optional[str]
file_config: dict
@property
def data(self) -> Optional[str]:
return self._get_data()
@property
def preview_url(self) -> Optional[str]:
return self._get_data(force_url=True)
@property
def prompt_message_file(self) -> PromptMessageFile:
if self.type == FileType.IMAGE:
image_config = self.file_config.get('image')
return ImagePromptMessageFile(
data=self.data,
detail=ImagePromptMessageFile.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == self.upload_file_id,
UploadFile.tenant_id == self.tenant_id
).first())
return UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=force_url
)
return None

View File

@@ -0,0 +1,179 @@
from typing import List, Union, Optional, Dict
import requests
from core.file.file_obj import FileObj, FileType, FileTransferMethod
from core.file.upload_file_parser import SUPPORT_EXTENSIONS
from extensions.ext_database import db
from models.account import Account
from models.model import MessageFile, EndUser, AppModelConfig, UploadFile
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> List[FileObj]:
"""
validate and transform files arg
:param files:
:param app_model_config:
:param user:
:return:
"""
file_upload_config = app_model_config.file_upload_dict
for file in files:
if not isinstance(file, dict):
raise ValueError('Invalid file format')
if not file.get('type'):
raise ValueError('Missing file type')
FileType.value_of(file.get('type'))
if not file.get('transfer_method'):
raise ValueError('Missing file transfer method')
FileTransferMethod.value_of(file.get('transfer_method'))
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
if not file.get('url'):
raise ValueError('Missing file url')
if not file.get('url').startswith('http'):
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_upload_config)
# validate files
new_files = []
for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE:
# parse and validate files
image_config = file_upload_config.get('image')
# check if image file feature is enabled
if not image_config['enabled']:
continue
# Validate number of files
if len(files) > image_config['number_limits']:
raise ValueError('Number of image files exceeds the maximum limit')
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config['transfer_methods']:
raise ValueError('Invalid transfer method')
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError('Invalid file type')
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
result, error = self._check_image_remote_url(file_obj.url)
if result is False:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.upload_file_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
UploadFile.extension.in_(SUPPORT_EXTENSIONS)
).first())
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError('Invalid upload file')
new_files.append(file_obj)
# return all file objs
return new_files
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
"""
transform message files
:param files:
:param app_model_config:
:return:
"""
# transform files to file objs
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
"""
transform files to file objs
:param files:
:param file_upload_config:
:return:
"""
type_file_objs: Dict[FileType, List[FileObj]] = {
# Currently only support image
FileType.IMAGE: []
}
if not files:
return type_file_objs
# group by file type and convert file args or message files to FileObj
for file in files:
file_obj = self._to_file_obj(file, file_upload_config)
if file_obj.type not in type_file_objs:
continue
type_file_objs[file_obj.type].append(file_obj)
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
"""
transform file to file obj
:param file:
:return:
"""
if isinstance(file, dict):
return FileObj(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=FileTransferMethod.value_of(file.get('transfer_method')),
url=file.get('url'),
upload_file_id=file.get('upload_file_id') or None,
file_config=file_upload_config
)
else:
return FileObj(
id=file.id,
tenant_id=self.tenant_id,
type=FileType.value_of(file.type),
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
upload_file_id=file.upload_file_id or None,
file_config=file_upload_config
)
def _check_image_remote_url(self, url):
try:
response = requests.head(url, allow_redirects=True)
if response.status_code == 200:
content_type = response.headers.get('Content-Type', '')
if content_type.startswith('image/'):
return True, "URL exists and is an image."
else:
return False, "URL exists but is not an image."
else:
return False, "URL does not exist."
except requests.RequestException as e:
return False, f"Error checking URL: {e}"

View File

@@ -0,0 +1,79 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from typing import Optional
from flask import current_app
from extensions.ext_storage import storage
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in SUPPORT_EXTENSIONS:
return None
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
return cls.get_signed_temp_image_url(upload_file)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f'File not found: {upload_file.key}')
return None
encoded_string = base64.b64encode(data).decode('utf-8')
return f'data:{upload_file.mime_type};base64,{encoded_string}'
@classmethod
def get_signed_temp_image_url(cls, upload_file) -> str:
"""
get signed url from upload file
:param upload_file: UploadFile object
:return:
"""
base_url = current_app.config.get('FILES_URL')
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= 300 # expired after 5 minutes

View File

@@ -3,6 +3,7 @@ from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage
from core.file.message_file_parser import MessageFileParser
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
from core.model_providers.models.llm.base import BaseLLM
from extensions.ext_database import db
@@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
@property
def buffer(self) -> List[BaseMessage]:
"""String buffer of memory."""
app_model = self.conversation.app
# fetch limited messages desc, and return reversed
messages = db.session.query(Message).filter(
Message.conversation_id == self.conversation.id,
@@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
files = message.message_files
if files:
file_objs = message_file_parser.transform_message_files(
files, message.app_model_config
)
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
chat_messages.append(PromptMessage(
content=message.query,
type=MessageType.USER,
files=prompt_message_files
))
else:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:

View File

@@ -1,4 +1,5 @@
import enum
from typing import Any, cast, Union, List, Dict
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
@@ -18,17 +19,53 @@ class MessageType(enum.Enum):
SYSTEM = 'system'
class PromptMessageFileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in PromptMessageFileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class PromptMessageFile(BaseModel):
type: PromptMessageFileType
data: Any
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW
class PromptMessage(BaseModel):
type: MessageType = MessageType.USER
content: str = ''
files: list[PromptMessageFile] = []
function_call: dict = None
class LCHumanMessageWithFiles(HumanMessage):
# content: Union[str, List[Union[str, Dict]]]
content: str
files: list[PromptMessageFile]
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
if not message.files:
lc_messages.append(HumanMessage(content=message.content))
else:
lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
elif message.type == MessageType.ASSISTANT:
additional_kwargs = {}
if message.function_call:
@@ -44,7 +81,14 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
if isinstance(message, LCHumanMessageWithFiles):
prompt_messages.append(PromptMessage(
content=message.content,
type=MessageType.USER,
files=message.files
))
else:
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content,

View File

@@ -1,11 +1,9 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from openai import api_requestor
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI

View File

@@ -8,7 +8,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
@@ -17,31 +17,36 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
class AppMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
class PromptTransform:
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
return [PromptMessage(content=prompt)], stops
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory,
model_instance)
return [PromptMessage(content=prompt, files=files)], stops
def get_advanced_prompt(self,
app_mode: str,
app_model_config: str,
inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
def get_advanced_prompt(self,
app_mode: str,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
app_mode_enum = AppMode(app_mode)
@@ -51,15 +56,20 @@ class PromptTransform:
if app_mode_enum == AppMode.CHAT:
if model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
files, context, memory,
model_instance)
elif model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
context, memory, model_instance)
elif app_mode_enum == AppMode.COMPLETION:
if model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
files, context)
elif model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
files, context)
return prompt_messages
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
@@ -71,7 +81,7 @@ class PromptTransform:
return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
@@ -79,7 +89,7 @@ class PromptTransform:
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
# baichuan
if isinstance(model_instance, BaichuanModel):
@@ -94,13 +104,13 @@ class PromptTransform:
return 'common_completion'
else:
return 'common_chat'
def _prompt_file_name_for_baichuan(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
@@ -111,7 +121,7 @@ class PromptTransform:
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
@@ -180,11 +190,11 @@ class PromptTransform:
stops = None
return prompt, stops
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
@@ -195,17 +205,18 @@ class PromptTransform:
else:
prompt_inputs['#query#'] = ''
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
prompt_template: PromptTemplateParser, prompt_inputs: dict,
model_instance: BaseLLM) -> None:
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=raw_prompt,
inputs={ '#histories#': '', **prompt_inputs }
inputs={'#histories#': '', **prompt_inputs}
)
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, rest_tokens)
@@ -213,7 +224,8 @@ class PromptTransform:
else:
prompt_inputs['#histories#'] = ''
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
model_instance: BaseLLM) -> None:
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
@@ -242,19 +254,20 @@ class PromptTransform:
return prompt
def _get_chat_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
app_model_config: str,
inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
prompt_messages = []
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -262,21 +275,23 @@ class PromptTransform:
self._set_query_variable(query, prompt_template, prompt_inputs)
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
model_instance)
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
return prompt_messages
def _get_chat_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
app_model_config: str,
inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
prompt_messages = []
@@ -292,23 +307,24 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
self._append_chat_histories(memory, prompt_messages, model_instance)
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
return prompt_messages
def _get_completion_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
app_model_config: str,
inputs: dict,
files: List[PromptMessageFile],
context: Optional[str]) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
prompt_messages = []
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -316,14 +332,15 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
return prompt_messages
def _get_completion_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
app_model_config: str,
inputs: dict,
files: List[PromptMessageFile],
context: Optional[str]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
prompt_messages = []
@@ -339,6 +356,11 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
return prompt_messages
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
for prompt_message in prompt_messages[::-1]:
if prompt_message.type == MessageType.USER:
prompt_message.files = files
break
return prompt_messages

View File

@@ -1,10 +1,13 @@
import os
from typing import Dict, Any, Optional, Union, Tuple
from typing import Dict, Any, Optional, Union, Tuple, List, cast
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
from pydantic import root_validator
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
class EnhanceChatOpenAI(ChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
@@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI):
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [self._convert_message_to_dict(m) for m in messages]
return message_dicts, params
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model, encoding = self._get_encoding_model()
if model.startswith("gpt-3.5-turbo-0301"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [self._convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
value = text
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, LCHumanMessageWithFiles):
content = [
{
"type": "text",
"text": message.content
}
]
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
content.append({
"type": "image_url",
"image_url": {
"url": file.data,
"detail": file.detail.value
}
})
message_dict = {"role": "user", "content": content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict

View File

@@ -1,6 +1,7 @@
import os
import shutil
from contextlib import closing
from typing import Union, Generator
import boto3
from botocore.exceptions import ClientError
@@ -45,7 +46,13 @@ class Storage:
with open(os.path.join(os.getcwd(), filename), "wb") as f:
f.write(data)
def load(self, filename):
def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
if stream:
return self.load_stream(filename)
else:
return self.load_once(filename)
def load_once(self, filename: str) -> bytes:
if self.storage_type == 's3':
try:
with closing(self.client) as client:
@@ -69,6 +76,34 @@ class Storage:
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
if self.storage_type == 's3':
try:
with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename)
for chunk in response['Body'].iter_chunks():
yield chunk
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError("File not found")
else:
raise
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
with open(filename, "rb") as f:
while chunk := f.read(4096): # Read in chunks of 4KB
yield chunk
return generate()
def download(self, filename, target_filepath):
if self.storage_type == 's3':
with closing(self.client) as client:

View File

@@ -32,7 +32,8 @@ model_config_fields = {
'prompt_type': fields.String,
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
'file_upload': fields.Raw(attribute='file_upload_dict'),
}
app_detail_fields = {
@@ -140,4 +141,4 @@ app_site_fields = {
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
}

View File

@@ -28,6 +28,12 @@ annotation_fields = {
'created_at': TimestampField
}
message_file_fields = {
'id': fields.String,
'type': fields.String,
'url': fields.String,
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
@@ -43,7 +49,8 @@ message_detail_fields = {
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
'created_at': TimestampField,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
}
feedback_stat_fields = {
@@ -111,11 +118,6 @@ conversation_message_detail_fields = {
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_with_summary_fields = {
'id': fields.String,
'status': fields.String,
@@ -180,4 +182,4 @@ conversation_with_model_config_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
}
}

View File

@@ -4,7 +4,8 @@ from libs.helper import TimestampField
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
'batch_count_limit': fields.Integer,
'image_file_size_limit': fields.Integer,
}
file_fields = {

View File

@@ -1,6 +1,7 @@
from flask_restful import fields
from libs.helper import TimestampField
from fields.conversation_fields import message_file_fields
feedback_fields = {
'rating': fields.String
@@ -31,6 +32,7 @@ message_fields = {
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField

View File

@@ -0,0 +1,59 @@
"""add gpt4v supports
Revision ID: 8fe468ba0ca5
Revises: a9836e3baeee
Create Date: 2023-11-09 11:39:00.006432
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8fe468ba0ca5'
down_revision = 'a9836e3baeee'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('message_files',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('message_id', postgresql.UUID(), nullable=False),
sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('transfer_method', sa.String(length=255), nullable=False),
sa.Column('url', sa.String(length=255), nullable=True),
sa.Column('upload_file_id', postgresql.UUID(), nullable=True),
sa.Column('created_by_role', sa.String(length=255), nullable=False),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='message_file_pkey')
)
with op.batch_alter_table('message_files', schema=None) as batch_op:
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
with op.batch_alter_table('upload_files', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('upload_files', schema=None) as batch_op:
batch_op.drop_column('created_by_role')
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('file_upload')
with op.batch_alter_table('message_files', schema=None) as batch_op:
batch_op.drop_index('message_file_message_idx')
batch_op.drop_index('message_file_created_by_idx')
op.drop_table('message_files')
# ### end Alembic commands ###

View File

@@ -1,10 +1,10 @@
import json
from json import JSONDecodeError
from flask import current_app, request
from flask_login import UserMixin
from sqlalchemy.dialects.postgresql import UUID
from core.file.upload_file_parser import UploadFileParser
from libs.helper import generate_string
from extensions.ext_database import db
from .account import Account, Tenant
@@ -98,6 +98,7 @@ class AppModelConfig(db.Model):
completion_prompt_config = db.Column(db.Text)
dataset_configs = db.Column(db.Text)
external_data_tools = db.Column(db.Text)
file_upload = db.Column(db.Text)
@property
def app(self):
@@ -161,6 +162,10 @@ class AppModelConfig(db.Model):
def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
@property
def file_upload_dict(self) -> dict:
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
def to_dict(self) -> dict:
return {
"provider": "",
@@ -182,7 +187,8 @@ class AppModelConfig(db.Model):
"prompt_type": self.prompt_type,
"chat_prompt_config": self.chat_prompt_config_dict,
"completion_prompt_config": self.completion_prompt_config_dict,
"dataset_configs": self.dataset_configs_dict
"dataset_configs": self.dataset_configs_dict,
"file_upload": self.file_upload_dict
}
def from_model_config_dict(self, model_config: dict):
@@ -213,6 +219,8 @@ class AppModelConfig(db.Model):
if model_config.get('completion_prompt_config') else None
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
if model_config.get('dataset_configs') else None
self.file_upload = json.dumps(model_config.get('file_upload')) \
if model_config.get('file_upload') else None
return self
def copy(self):
@@ -238,7 +246,8 @@ class AppModelConfig(db.Model):
prompt_type=self.prompt_type,
chat_prompt_config=self.chat_prompt_config,
completion_prompt_config=self.completion_prompt_config,
dataset_configs=self.dataset_configs
dataset_configs=self.dataset_configs,
file_upload=self.file_upload
)
return new_app_model_config
@@ -512,6 +521,37 @@ class Message(db.Model):
return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
.order_by(DatasetRetrieverResource.position.asc()).all()
@property
def message_files(self):
return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
@property
def files(self):
message_files = self.message_files
files = []
for message_file in message_files:
url = message_file.url
if message_file.type == 'image':
if message_file.transfer_method == 'local_file':
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == message_file.upload_file_id
).first())
url = UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=True
)
files.append({
'id': message_file.id,
'type': message_file.type,
'url': url
})
return files
class MessageFeedback(db.Model):
__tablename__ = 'message_feedbacks'
@@ -540,6 +580,25 @@ class MessageFeedback(db.Model):
return account
class MessageFile(db.Model):
__tablename__ = 'message_files'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='message_file_pkey'),
db.Index('message_file_message_idx', 'message_id'),
db.Index('message_file_created_by_idx', 'created_by')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
transfer_method = db.Column(db.String(255), nullable=False)
url = db.Column(db.String(255), nullable=True)
upload_file_id = db.Column(UUID, nullable=True)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class MessageAnnotation(db.Model):
__tablename__ = 'message_annotations'
__table_args__ = (
@@ -683,6 +742,7 @@ class UploadFile(db.Model):
size = db.Column(db.Integer, nullable=False)
extension = db.Column(db.String(255), nullable=False)
mime_type = db.Column(db.String(255), nullable=True)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@@ -783,4 +843,3 @@ class DatasetRetrieverResource(db.Model):
retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())

View File

@@ -315,6 +315,9 @@ class AppModelConfigService:
# moderation validation
cls.is_moderation_valid(tenant_id, config)
# file upload validation
cls.is_file_upload_valid(config)
# Filter out extra parameters
filtered_config = {
"opening_statement": config["opening_statement"],
@@ -338,7 +341,8 @@ class AppModelConfigService:
"prompt_type": config["prompt_type"],
"chat_prompt_config": config["chat_prompt_config"],
"completion_prompt_config": config["completion_prompt_config"],
"dataset_configs": config["dataset_configs"]
"dataset_configs": config["dataset_configs"],
"file_upload": config["file_upload"]
}
return filtered_config
@@ -371,6 +375,34 @@ class AppModelConfigService:
config=config
)
@classmethod
def is_file_upload_valid(cls, config: dict):
if 'file_upload' not in config or not config["file_upload"]:
config["file_upload"] = {}
if not isinstance(config["file_upload"], dict):
raise ValueError("file_upload must be of dict type")
# check image config
if 'image' not in config["file_upload"] or not config["file_upload"]["image"]:
config["file_upload"]["image"] = {"enabled": False}
if config['file_upload']['image']['enabled']:
number_limits = config['file_upload']['image']['number_limits']
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
detail = config['file_upload']['image']['detail']
if detail not in ['high', 'low']:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config['file_upload']['image']['transfer_methods']
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in ['remote_url', 'local_file']:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
@classmethod
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
if 'external_data_tools' not in config or not config["external_data_tools"]:

View File

@@ -3,7 +3,7 @@ import logging
import threading
import time
import uuid
from typing import Generator, Union, Any, Optional
from typing import Generator, Union, Any, Optional, List
from flask import current_app, Flask
from redis.client import PubSub
@@ -12,9 +12,11 @@ from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.file.message_file_parser import MessageFileParser
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.models.entity.message import PromptMessageFile
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
@@ -35,6 +37,7 @@ class CompletionService:
# is streaming mode
inputs = args['inputs']
query = args['query']
files = args['files'] if 'files' in args and args['files'] else []
if app_model.mode != 'completion' and not query:
raise ValueError('query is required')
@@ -132,6 +135,14 @@ class CompletionService:
# clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
app_model_config,
user
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
@@ -146,6 +157,7 @@ class CompletionService:
'app_model_config': app_model_config.copy(),
'query': query,
'inputs': inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': conversation,
'streaming': streaming,
@@ -156,7 +168,8 @@ class CompletionService:
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming)
@@ -172,8 +185,10 @@ class CompletionService:
return user
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, detached_user: Union[Account, EndUser],
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig,
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'):
with flask_app.app_context():
@@ -195,6 +210,7 @@ class CompletionService:
query=query,
inputs=inputs,
user=user,
files=files,
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
@@ -215,7 +231,8 @@ class CompletionService:
db.session.commit()
@classmethod
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600
@@ -274,6 +291,12 @@ class CompletionService:
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.transform_message_files(
message.files, app_model_config
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
@@ -288,6 +311,7 @@ class CompletionService:
'app_model_config': app_model_config.copy(),
'query': message.query,
'inputs': message.inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': None,
'streaming': streaming,
@@ -388,7 +412,8 @@ class CompletionService:
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'message_replace':
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
yield "data: " + json.dumps(
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':

View File

@@ -1,42 +1,51 @@
import datetime
import hashlib
import time
import uuid
from typing import Generator, Tuple
from cachetools import TTLCache
from flask import request, current_app
from flask import current_app
from flask_login import current_user
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from core.data_loader.file_extractor import FileExtractor
from core.file.upload_file_parser import UploadFileParser
from extensions.ext_storage import storage
from extensions.ext_database import db
from models.account import Account
from models.model import UploadFile
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
'jpg', 'jpeg', 'png', 'webp', 'gif']
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
PREVIEW_WORDS_LIMIT = 3000
cache = TTLCache(maxsize=None, ttl=30)
class FileService:
@staticmethod
def upload_file(file: FileStorage) -> UploadFile:
# read file content
file_content = file.read()
# get file size
file_size = len(file_content)
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if file_size > file_size_limit:
message = f'File size exceeded. {file_size} > {file_size_limit}'
raise FileTooLargeError(message)
def upload_file(file: FileStorage, only_image: bool = False) -> UploadFile:
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
# read file content
file_content = file.read()
# get file size
file_size = len(file_content)
if extension.lower() in IMAGE_EXTENSIONS:
file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024
else:
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if file_size > file_size_limit:
message = f'File size exceeded. {file_size} > {file_size_limit}'
raise FileTooLargeError(message)
# user uuid as file name
file_uuid = str(uuid.uuid4())
@@ -47,15 +56,17 @@ class FileService:
# save file to db
config = current_app.config
user = current_user
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
tenant_id=user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_by_role=('account' if isinstance(user, Account) else 'end_user'),
created_by=user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
@@ -99,12 +110,6 @@ class FileService:
@staticmethod
def get_file_preview(file_id: str) -> str:
# get file storage key
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
@@ -121,3 +126,25 @@ class FileService:
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return text
@staticmethod
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> Tuple[Generator, str]:
result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
# extract text from file
extension = upload_file.extension
if extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
generator = storage.load(upload_file.key, stream=True)
return generator, upload_file.mime_type

View File

@@ -5,7 +5,7 @@ from unittest.mock import patch
from langchain.schema import Generation, ChatGeneration, AIMessage
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from models.provider import Provider, ProviderType
@@ -57,6 +57,18 @@ def test_chat_get_num_tokens(mock_decrypt):
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_vision_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-4-vision-preview')
messages = [
PromptMessage(content='Whats in first image?', files=[
ImageMessageFile(
data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
])
]
rst = openai_model.get_num_tokens(messages)
assert rst == 77
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
@@ -80,4 +92,20 @@ def test_chat_run(mock_decrypt, mocker):
messages,
stop=['\nHuman:'],
)
assert (len(rst.content) > 0)
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_vision_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_openai_model('gpt-4-vision-preview')
messages = [
PromptMessage(content='Whats in first image?', files=[
ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
])
]
rst = openai_model.run(
messages,
)
assert len(rst.content) > 0

View File

@@ -19,18 +19,22 @@ services:
# different from api or web app domain.
# example: http://cloud.dify.ai
CONSOLE_API_URL: ''
# The URL for Service API endpoints, refers to the base URL of the current API service if api domain is
# The URL prefix for Service API endpoints, refers to the base URL of the current API service if api domain is
# different from console domain.
# example: http://api.dify.ai
SERVICE_API_URL: ''
# The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
# The URL prefix for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app
APP_API_URL: ''
# The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
# The URL prefix for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app
APP_WEB_URL: ''
# File preview or download Url prefix.
# used to display File preview or download Url to the front-end or as Multi-model inputs;
# Url is signed and has expiration time.
FILES_URL: ''
# When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
MIGRATION_ENABLED: 'true'
# The configurations of postgres database connection.