mirror of
https://gitee.com/infiniflow/ragflow.git
synced 2025-12-06 15:29:03 +08:00
### What problem does this PR solve? Make RAGFlow more asynchronous 2. #11551, #11579, #11619. ### Type of change - [x] Refactoring - [x] Performance Improvement
394 lines
15 KiB
Python
394 lines
15 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from html import escape
|
|
from typing import Any
|
|
|
|
from quart import request, make_response
|
|
from google_auth_oauthlib.flow import Flow
|
|
|
|
from api.db import InputType
|
|
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
|
|
from common.constants import RetCode, TaskStatus
|
|
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
|
from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
|
from common.misc_utils import get_uuid
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
from api.apps import login_required, current_user
|
|
|
|
|
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
async def set_connector():
|
|
req = await get_request_json()
|
|
if req.get("id"):
|
|
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
|
ConnectorService.update_by_id(req["id"], conn)
|
|
else:
|
|
req["id"] = get_uuid()
|
|
conn = {
|
|
"id": req["id"],
|
|
"tenant_id": current_user.id,
|
|
"name": req["name"],
|
|
"source": req["source"],
|
|
"input_type": InputType.POLL,
|
|
"config": req["config"],
|
|
"refresh_freq": int(req.get("refresh_freq", 30)),
|
|
"prune_freq": int(req.get("prune_freq", 720)),
|
|
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
|
"status": TaskStatus.SCHEDULE,
|
|
}
|
|
ConnectorService.save(**conn)
|
|
|
|
await asyncio.sleep(1)
|
|
e, conn = ConnectorService.get_by_id(req["id"])
|
|
|
|
return get_json_result(data=conn.to_dict())
|
|
|
|
|
|
@manager.route("/list", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
def list_connector():
|
|
return get_json_result(data=ConnectorService.list(current_user.id))
|
|
|
|
|
|
@manager.route("/<connector_id>", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
def get_connector(connector_id):
|
|
e, conn = ConnectorService.get_by_id(connector_id)
|
|
if not e:
|
|
return get_data_error_result(message="Can't find this Connector!")
|
|
return get_json_result(data=conn.to_dict())
|
|
|
|
|
|
@manager.route("/<connector_id>/logs", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
def list_logs(connector_id):
|
|
req = request.args.to_dict(flat=True)
|
|
arr, total = SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15)))
|
|
return get_json_result(data={"total": total, "logs": arr})
|
|
|
|
|
|
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
|
@login_required
|
|
async def resume(connector_id):
|
|
req = await get_request_json()
|
|
if req.get("resume"):
|
|
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
|
else:
|
|
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
|
return get_json_result(data=True)
|
|
|
|
|
|
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("kb_id")
|
|
async def rebuild(connector_id):
|
|
req = await get_request_json()
|
|
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
|
if err:
|
|
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
|
return get_json_result(data=True)
|
|
|
|
|
|
@manager.route("/<connector_id>/rm", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
def rm_connector(connector_id):
|
|
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
|
ConnectorService.delete_by_id(connector_id)
|
|
return get_json_result(data=True)
|
|
|
|
|
|
GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state"
|
|
GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
|
|
WEB_FLOW_TTL_SECS = 15 * 60
|
|
|
|
|
|
def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
|
"""Return Redis key for web OAuth state.
|
|
|
|
The default prefix keeps backward compatibility for Google Drive.
|
|
When source_type == "gmail", a different prefix is used so that
|
|
Drive/Gmail flows don't clash in Redis.
|
|
"""
|
|
if source_type == "gmail":
|
|
prefix = "gmail_web_flow_state"
|
|
else:
|
|
prefix = GOOGLE_WEB_FLOW_STATE_PREFIX
|
|
return f"{prefix}:{flow_id}"
|
|
|
|
|
|
def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
|
"""Return Redis key for web OAuth result.
|
|
|
|
Mirrors _web_state_cache_key logic for result storage.
|
|
"""
|
|
if source_type == "gmail":
|
|
prefix = "gmail_web_flow_result"
|
|
else:
|
|
prefix = GOOGLE_WEB_FLOW_RESULT_PREFIX
|
|
return f"{prefix}:{flow_id}"
|
|
|
|
|
|
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
|
if isinstance(payload, dict):
|
|
return payload
|
|
try:
|
|
return json.loads(payload)
|
|
except json.JSONDecodeError as exc: # pragma: no cover - defensive
|
|
raise ValueError("Invalid Google credentials JSON.") from exc
|
|
|
|
|
|
def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
|
web_section = credentials.get("web")
|
|
if not isinstance(web_section, dict):
|
|
raise ValueError("Google OAuth JSON must include a 'web' client configuration to use browser-based authorization.")
|
|
return {"web": web_section}
|
|
|
|
|
|
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
|
|
status = "success" if success else "error"
|
|
auto_close = "window.close();" if success else ""
|
|
escaped_message = escape(message)
|
|
# Drive: ragflow-google-drive-oauth
|
|
# Gmail: ragflow-gmail-oauth
|
|
payload_type = f"ragflow-{source}-oauth"
|
|
payload_json = json.dumps(
|
|
{
|
|
"type": payload_type,
|
|
"status": status,
|
|
"flowId": flow_id or "",
|
|
"message": message,
|
|
}
|
|
)
|
|
# TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type
|
|
html = GOOGLE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
|
title=f"Google {source.capitalize()} Authorization",
|
|
heading="Authorization complete" if success else "Authorization failed",
|
|
message=escaped_message,
|
|
payload_json=payload_json,
|
|
auto_close=auto_close,
|
|
)
|
|
response = await make_response(html, 200)
|
|
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
|
return response
|
|
|
|
|
|
@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("credentials")
|
|
async def start_google_web_oauth():
|
|
source = request.args.get("type", "google-drive")
|
|
if source not in ("google-drive", "gmail"):
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
|
|
|
if source == "gmail":
|
|
redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
|
scopes = GOOGLE_SCOPES[DocumentSource.GMAIL]
|
|
else:
|
|
redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI if source == "google-drive" else GMAIL_WEB_OAUTH_REDIRECT_URI
|
|
scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE if source == "google-drive" else DocumentSource.GMAIL]
|
|
|
|
if not redirect_uri:
|
|
return get_json_result(
|
|
code=RetCode.SERVER_ERROR,
|
|
message="Google OAuth redirect URI is not configured on the server.",
|
|
)
|
|
|
|
req = await get_request_json()
|
|
raw_credentials = req.get("credentials", "")
|
|
|
|
try:
|
|
credentials = _load_credentials(raw_credentials)
|
|
print(credentials)
|
|
except ValueError as exc:
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
|
|
|
if credentials.get("refresh_token"):
|
|
return get_json_result(
|
|
code=RetCode.ARGUMENT_ERROR,
|
|
message="Uploaded credentials already include a refresh token.",
|
|
)
|
|
|
|
try:
|
|
client_config = _get_web_client_config(credentials)
|
|
except ValueError as exc:
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
|
|
|
flow_id = str(uuid.uuid4())
|
|
try:
|
|
flow = Flow.from_client_config(client_config, scopes=scopes)
|
|
flow.redirect_uri = redirect_uri
|
|
authorization_url, _ = flow.authorization_url(
|
|
access_type="offline",
|
|
include_granted_scopes="true",
|
|
prompt="consent",
|
|
state=flow_id,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logging.exception("Failed to create Google OAuth flow: %s", exc)
|
|
return get_json_result(
|
|
code=RetCode.SERVER_ERROR,
|
|
message="Failed to initialize Google OAuth flow. Please verify the uploaded client configuration.",
|
|
)
|
|
|
|
cache_payload = {
|
|
"user_id": current_user.id,
|
|
"client_config": client_config,
|
|
"created_at": int(time.time()),
|
|
}
|
|
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
|
|
|
|
return get_json_result(
|
|
data={
|
|
"flow_id": flow_id,
|
|
"authorization_url": authorization_url,
|
|
"expires_in": WEB_FLOW_TTL_SECS,
|
|
}
|
|
)
|
|
|
|
|
|
@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
|
|
async def google_gmail_web_oauth_callback():
|
|
state_id = request.args.get("state")
|
|
error = request.args.get("error")
|
|
source = "gmail"
|
|
if source != 'gmail':
|
|
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
|
|
|
|
error_description = request.args.get("error_description") or error
|
|
|
|
if not state_id:
|
|
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
|
|
|
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
|
if not state_cache:
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
|
|
|
state_obj = json.loads(state_cache)
|
|
client_config = state_obj.get("client_config")
|
|
if not client_config:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
|
|
|
if error:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
|
|
|
code = request.args.get("code")
|
|
if not code:
|
|
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
|
|
|
try:
|
|
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
|
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL])
|
|
flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
|
flow.fetch_token(code=code)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
|
|
|
creds_json = flow.credentials.to_json()
|
|
result_payload = {
|
|
"user_id": state_obj.get("user_id"),
|
|
"credentials": creds_json,
|
|
}
|
|
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
|
|
|
print("\n\n", _web_result_cache_key(state_id, source), "\n\n")
|
|
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
|
|
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
|
|
|
|
|
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
|
async def google_drive_web_oauth_callback():
|
|
state_id = request.args.get("state")
|
|
error = request.args.get("error")
|
|
source = "google-drive"
|
|
if source not in ("google-drive", "gmail"):
|
|
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
|
|
|
|
error_description = request.args.get("error_description") or error
|
|
|
|
if not state_id:
|
|
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
|
|
|
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
|
if not state_cache:
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
|
|
|
state_obj = json.loads(state_cache)
|
|
client_config = state_obj.get("client_config")
|
|
if not client_config:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
|
|
|
if error:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
|
|
|
code = request.args.get("code")
|
|
if not code:
|
|
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
|
|
|
try:
|
|
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
|
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
|
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
|
flow.fetch_token(code=code)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
|
|
|
creds_json = flow.credentials.to_json()
|
|
result_payload = {
|
|
"user_id": state_obj.get("user_id"),
|
|
"credentials": creds_json,
|
|
}
|
|
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
|
|
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
|
|
|
@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("flow_id")
|
|
async def poll_google_web_result():
|
|
req = await request.json or {}
|
|
source = request.args.get("type")
|
|
if source not in ("google-drive", "gmail"):
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
|
flow_id = req.get("flow_id")
|
|
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source))
|
|
if not cache_raw:
|
|
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
|
|
|
result = json.loads(cache_raw)
|
|
if result.get("user_id") != current_user.id:
|
|
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
|
|
|
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
|
|
return get_json_result(data={"credentials": result.get("credentials")})
|