Files
ragflow/common/settings.py
Wiratama cbdacf21f6 feat(gcs): Add support for Google Cloud Storage (GCS) integration (#11718)
### What problem does this PR solve?

This Pull Request introduces native support for Google Cloud Storage
(GCS) as an optional object storage backend.

Currently, RAGFlow relies on a limited set of storage options. This
feature addresses the need for seamless integration with GCP
environments, allowing users to leverage a fully managed, highly
durable, and scalable storage service (GCS) instead of needing to deploy
and maintain third-party object storage solutions. This simplifies
deployment, especially for users running on GCP infrastructure like GKE
or Cloud Run.

The implementation uses a single GCS bucket defined via configuration,
mapping RAGFlow's internal logical storage units (or "buckets") to
folder prefixes within that GCS container to maintain data separation.
This architectural choice avoids the operational complexities associated
with dynamically creating and managing unique GCS buckets for every
logical unit.

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 10:44:05 +08:00

345 lines
12 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import secrets
from datetime import date
import logging
from common.constants import RAG_FLOW_SERVICE_NAME
from common.file_utils import get_project_base_directory
from common.config_utils import get_base_config, decrypt_database_config
from common.misc_utils import pip_install_torch
from common.constants import SVR_QUEUE_NAME, Storage
import rag.utils
import rag.utils.es_conn
import rag.utils.infinity_conn
import rag.utils.ob_conn
import rag.utils.opensearch_conn
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
from rag.utils.gcs_conn import RAGFlowGCS
from rag.utils.minio_conn import RAGFlowMinio
from rag.utils.opendal_conn import OpenDALStorage
from rag.utils.s3_conn import RAGFlowS3
from rag.utils.oss_conn import RAGFlowOSS
from rag.nlp import search
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
CHAT_MDL = ""
EMBEDDING_MDL = ""
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
CHAT_CFG = ""
EMBEDDING_CFG = ""
RERANK_CFG = ""
ASR_CFG = ""
IMAGE2TEXT_CFG = ""
API_KEY = None
PARSERS = None
HOST_IP = None
HOST_PORT = None
SECRET_KEY = None
FACTORY_LLM_INFOS = None
ALLOWED_LLM_FACTORIES = None
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# authentication
AUTHENTICATION_CONF = None
# client
CLIENT_AUTHENTICATION = None
HTTP_APP_KEY = None
GITHUB_OAUTH = None
FEISHU_OAUTH = None
OAUTH_CONFIG = None
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
docStoreConn = None
retriever = None
kg_retriever = None
# user registration switch
REGISTER_ENABLED = 1
# sandbox-executor-manager
SANDBOX_HOST = None
STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8"))
SMTP_CONF = None
MAIL_SERVER = ""
MAIL_PORT = 000
MAIL_USE_SSL = True
MAIL_USE_TLS = False
MAIL_USERNAME = ""
MAIL_PASSWORD = ""
MAIL_DEFAULT_SENDER = ()
MAIL_FRONTEND_URL = ""
# move from rag.settings
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OB = {}
OSS = {}
OS = {}
GCS = {}
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
DOC_BULK_SIZE: int = 4
EMBEDDING_BATCH_SIZE: int = 16
PARALLEL_DEVICES: int = 0
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
STORAGE_IMPL = None
def get_svr_queue_name(priority: int) -> str:
if priority == 0:
return SVR_QUEUE_NAME
return f"{SVR_QUEUE_NAME}_{priority}"
def get_svr_queue_names():
return [get_svr_queue_name(priority) for priority in [1, 0]]
def _get_or_create_secret_key():
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
if secret_key and len(secret_key) >= 32:
return secret_key
# Check if there's a configured secret key
configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key")
if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32:
return configured_key
# Generate a new secure key and warn about it
import logging
new_key = secrets.token_hex(32)
logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
return new_key
class StorageFactory:
storage_mapping = {
Storage.MINIO: RAGFlowMinio,
Storage.AZURE_SPN: RAGFlowAzureSpnBlob,
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
Storage.AWS_S3: RAGFlowS3,
Storage.OSS: RAGFlowOSS,
Storage.OPENDAL: OpenDALStorage,
Storage.GCS: RAGFlowGCS,
}
@classmethod
def create(cls, storage: Storage):
return cls.storage_mapping[storage]()
def init_settings():
global DATABASE_TYPE, DATABASE
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
global ALLOWED_LLM_FACTORIES, LLM_FACTORY, LLM_BASE_URL
llm_settings = get_base_config("user_default_llm", {}) or {}
llm_default_models = llm_settings.get("default_models", {}) or {}
LLM_FACTORY = llm_settings.get("factory", "") or ""
LLM_BASE_URL = llm_settings.get("base_url", "") or ""
ALLOWED_LLM_FACTORIES = llm_settings.get("allowed_factories", None)
global REGISTER_ENABLED
try:
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
except Exception:
pass
global FACTORY_LLM_INFOS
try:
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
except Exception:
FACTORY_LLM_INFOS = []
global API_KEY
API_KEY = llm_settings.get("api_key")
global PARSERS
PARSERS = llm_settings.get(
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
)
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
chat_entry = _parse_model_entry(llm_default_models.get("chat_model", CHAT_MDL))
embedding_entry = _parse_model_entry(llm_default_models.get("embedding_model", EMBEDDING_MDL))
rerank_entry = _parse_model_entry(llm_default_models.get("rerank_model", RERANK_MDL))
asr_entry = _parse_model_entry(llm_default_models.get("asr_model", ASR_MDL))
image2text_entry = _parse_model_entry(llm_default_models.get("image2text_model", IMAGE2TEXT_MDL))
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
CHAT_MDL = CHAT_CFG.get("model", "") or ""
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
RERANK_MDL = RERANK_CFG.get("model", "") or ""
ASR_MDL = ASR_CFG.get("model", "") or ""
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
global HOST_IP, HOST_PORT
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
global SECRET_KEY
SECRET_KEY = _get_or_create_secret_key()
# authentication
authentication_conf = get_base_config("authentication", {})
global CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
# client
CLIENT_AUTHENTICATION = authentication_conf.get("client", {}).get("switch", False)
HTTP_APP_KEY = authentication_conf.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
OAUTH_CONFIG = get_base_config("oauth", {})
global DOC_ENGINE, DOC_ENGINE_INFINITY, docStoreConn, ES, OB, OS, INFINITY
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
lower_case_doc_engine = DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch":
ES = get_base_config("es", {})
docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch":
OS = get_base_config("os", {})
docStoreConn = rag.utils.opensearch_conn.OSConnection()
elif lower_case_doc_engine == "oceanbase":
OB = get_base_config("oceanbase", {})
docStoreConn = rag.utils.ob_conn.OBConnection()
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
elif STORAGE_IMPL_TYPE == 'GCS':
GCS = get_base_config("gcs", {})
global STORAGE_IMPL
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])
global retriever, kg_retriever
retriever = search.Dealer(docStoreConn)
from graphrag import search as kg_search
kg_retriever = kg_search.KGSearch(docStoreConn)
global SANDBOX_HOST
if int(os.environ.get("SANDBOX_ENABLED", "0")):
SANDBOX_HOST = os.environ.get("SANDBOX_HOST", "sandbox-executor-manager")
global SMTP_CONF
SMTP_CONF = get_base_config("smtp", {})
global MAIL_SERVER, MAIL_PORT, MAIL_USE_SSL, MAIL_USE_TLS, MAIL_USERNAME, MAIL_PASSWORD, MAIL_DEFAULT_SENDER, MAIL_FRONTEND_URL
MAIL_SERVER = SMTP_CONF.get("mail_server", "")
MAIL_PORT = SMTP_CONF.get("mail_port", 000)
MAIL_USE_SSL = SMTP_CONF.get("mail_use_ssl", True)
MAIL_USE_TLS = SMTP_CONF.get("mail_use_tls", False)
MAIL_USERNAME = SMTP_CONF.get("mail_username", "")
MAIL_PASSWORD = SMTP_CONF.get("mail_password", "")
mail_default_sender = SMTP_CONF.get("mail_default_sender", [])
if mail_default_sender and len(mail_default_sender) >= 2:
MAIL_DEFAULT_SENDER = (mail_default_sender[0], mail_default_sender[1])
MAIL_FRONTEND_URL = SMTP_CONF.get("mail_frontend_url", "")
global DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
def check_and_install_torch():
global PARALLEL_DEVICES
try:
pip_install_torch()
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
def _parse_model_entry(entry):
if isinstance(entry, str):
return {"name": entry, "factory": None, "api_key": None, "base_url": None}
if isinstance(entry, dict):
name = entry.get("name") or entry.get("model") or ""
return {
"name": name,
"factory": entry.get("factory"),
"api_key": entry.get("api_key"),
"base_url": entry.get("base_url"),
}
return {"name": "", "factory": None, "api_key": None, "base_url": None}
def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url):
name = (entry_dict.get("name") or "").strip()
m_factory = entry_dict.get("factory") or backup_factory or ""
m_api_key = entry_dict.get("api_key") or backup_api_key or ""
m_base_url = entry_dict.get("base_url") or backup_base_url or ""
if name and "@" not in name and m_factory:
name = f"{name}@{m_factory}"
return {
"model": name,
"factory": m_factory,
"api_key": m_api_key,
"base_url": m_base_url,
}
def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")