mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 03:45:27 +08:00
Compare commits
12 Commits
fix/trigge
...
feat-integ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a81c3f52c4 | ||
|
|
d9a8cd72d2 | ||
|
|
7fb50d59fc | ||
|
|
08b6630d63 | ||
|
|
f123d87429 | ||
|
|
57c989f74a | ||
|
|
cc9f6f82c7 | ||
|
|
1a2294d6d0 | ||
|
|
e51e2038b8 | ||
|
|
e05b20d5a4 | ||
|
|
0b2a42603b | ||
|
|
d7e1dd0a22 |
@@ -1,6 +1,9 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
import base64
|
||||
|
||||
from controllers.console import console_ns
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -41,3 +44,42 @@ class Invoices(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
|
||||
class PartnerTenants(Resource):
|
||||
@api.doc("sync_partner_tenants_bindings")
|
||||
@api.doc(description="Sync partner tenants bindings")
|
||||
@api.doc(params={"partner_key": "Partner key"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{
|
||||
"click_id": fields.String(required=True, description="Click Id from partner referral link")
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Tenants synced to partner successfully")
|
||||
@api.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("click_id", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
click_id = args["click_id"]
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
||||
if not click_id or not decoded_partner_key or not current_user.id:
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
@@ -62,13 +63,20 @@ class BillingService:
|
||||
retry=retry_if_exception_type(httpx.RequestError),
|
||||
reraise=True,
|
||||
)
|
||||
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
|
||||
def _send_request(cls, method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, json=None, params=None):
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||
if method == "GET" and response.status_code != httpx.codes.OK:
|
||||
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
|
||||
if method == "PUT":
|
||||
if response.status_code == httpx.codes.INTERNAL_SERVER_ERROR:
|
||||
raise InternalServerError(
|
||||
"Unable to process billing request. Please try again later or contact support."
|
||||
)
|
||||
if response.status_code != httpx.codes.OK:
|
||||
raise ValueError("Invalid arguments.")
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
@@ -179,3 +187,8 @@ class BillingService:
|
||||
@classmethod
|
||||
def clean_billing_info_cache(cls, tenant_id: str):
|
||||
redis_client.delete(f"tenant:{tenant_id}:billing_info")
|
||||
|
||||
@classmethod
|
||||
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
|
||||
json = {"account_id": account_id, "click_id": click_id}
|
||||
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=json)
|
||||
|
||||
254
api/tests/unit_tests/controllers/console/billing/test_billing.py
Normal file
254
api/tests/unit_tests/controllers/console/billing/test_billing.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console.billing.billing import PartnerTenants
|
||||
from models.account import Account
|
||||
|
||||
|
||||
class TestPartnerTenants:
|
||||
"""Unit tests for PartnerTenants controller."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask app for testing."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret-key"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock account."""
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "account-123"
|
||||
account.email = "test@example.com"
|
||||
account.current_tenant_id = "tenant-456"
|
||||
account.is_authenticated = True
|
||||
return account
|
||||
|
||||
@pytest.fixture
|
||||
def mock_billing_service(self):
|
||||
"""Mock BillingService."""
|
||||
with patch("controllers.console.billing.billing.BillingService") as mock_service:
|
||||
yield mock_service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_decorators(self):
|
||||
"""Mock decorators to avoid database access."""
|
||||
with (
|
||||
patch("controllers.console.wraps.db") as mock_db,
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("libs.login.dify_config.LOGIN_DISABLED", False),
|
||||
patch("libs.login.check_csrf_token") as mock_csrf,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_csrf.return_value = None
|
||||
yield {"db": mock_db, "csrf": mock_csrf}
|
||||
|
||||
def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators):
|
||||
"""Test successful partner tenants bindings sync."""
|
||||
# Arrange
|
||||
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
|
||||
click_id = "click-id-789"
|
||||
expected_response = {"result": "success", "data": {"synced": True}}
|
||||
|
||||
mock_billing_service.sync_partner_tenants_bindings.return_value = expected_response
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={"click_id": click_id},
|
||||
path=f"/billing/partners/{partner_key_encoded}/tenants",
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456")
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
):
|
||||
resource = PartnerTenants()
|
||||
result = resource.put(partner_key_encoded)
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_billing_service.sync_partner_tenants_bindings.assert_called_once_with(
|
||||
mock_account.id, "partner-key-123", click_id
|
||||
)
|
||||
|
||||
def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators):
|
||||
"""Test that invalid base64 partner_key raises BadRequest."""
|
||||
# Arrange
|
||||
invalid_partner_key = "invalid-base64-!@#$"
|
||||
click_id = "click-id-789"
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={"click_id": click_id},
|
||||
path=f"/billing/partners/{invalid_partner_key}/tenants",
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456")
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
):
|
||||
resource = PartnerTenants()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(BadRequest) as exc_info:
|
||||
resource.put(invalid_partner_key)
|
||||
assert "Invalid partner_key" in str(exc_info.value)
|
||||
|
||||
def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators):
|
||||
"""Test that missing click_id raises BadRequest."""
|
||||
# Arrange
|
||||
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={},
|
||||
path=f"/billing/partners/{partner_key_encoded}/tenants",
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456")
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
):
|
||||
resource = PartnerTenants()
|
||||
|
||||
# Act & Assert
|
||||
# reqparse will raise BadRequest for missing required field
|
||||
with pytest.raises(BadRequest):
|
||||
resource.put(partner_key_encoded)
|
||||
|
||||
def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators):
|
||||
"""Test handling of billing service JSON decode error.
|
||||
|
||||
When billing service returns non-200 status code with invalid JSON response,
|
||||
response.json() raises JSONDecodeError. This exception propagates to the controller
|
||||
and should be handled by the global error handler (handle_general_exception),
|
||||
which returns a 500 status code with error details.
|
||||
|
||||
Note: In unit tests, when directly calling resource.put(), the exception is raised
|
||||
directly. In actual Flask application, the error handler would catch it and return
|
||||
a 500 response with JSON: {"code": "unknown", "message": "...", "status": 500}
|
||||
"""
|
||||
# Arrange
|
||||
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
|
||||
click_id = "click-id-789"
|
||||
|
||||
# Simulate JSON decode error when billing service returns invalid JSON
|
||||
# This happens when billing service returns non-200 with empty/invalid response body
|
||||
json_decode_error = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_billing_service.sync_partner_tenants_bindings.side_effect = json_decode_error
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={"click_id": click_id},
|
||||
path=f"/billing/partners/{partner_key_encoded}/tenants",
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456")
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
):
|
||||
resource = PartnerTenants()
|
||||
|
||||
# Act & Assert
|
||||
# JSONDecodeError will be raised from the controller
|
||||
# In actual Flask app, this would be caught by handle_general_exception
|
||||
# which returns: {"code": "unknown", "message": str(e), "status": 500}
|
||||
with pytest.raises(json.JSONDecodeError) as exc_info:
|
||||
resource.put(partner_key_encoded)
|
||||
|
||||
# Verify the exception is JSONDecodeError
|
||||
assert isinstance(exc_info.value, json.JSONDecodeError)
|
||||
assert "Expecting value" in str(exc_info.value)
|
||||
|
||||
def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators):
|
||||
"""Test that empty click_id raises BadRequest."""
|
||||
# Arrange
|
||||
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
|
||||
click_id = ""
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={"click_id": click_id},
|
||||
path=f"/billing/partners/{partner_key_encoded}/tenants",
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456")
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
):
|
||||
resource = PartnerTenants()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(BadRequest) as exc_info:
|
||||
resource.put(partner_key_encoded)
|
||||
assert "Invalid partner information" in str(exc_info.value)
|
||||
|
||||
def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators):
|
||||
"""Test that empty partner_key after decode raises BadRequest."""
|
||||
# Arrange
|
||||
# Base64 encode an empty string
|
||||
empty_partner_key_encoded = base64.b64encode(b"").decode("utf-8")
|
||||
click_id = "click-id-789"
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={"click_id": click_id},
|
||||
path=f"/billing/partners/{empty_partner_key_encoded}/tenants",
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456")
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
):
|
||||
resource = PartnerTenants()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(BadRequest) as exc_info:
|
||||
resource.put(empty_partner_key_encoded)
|
||||
assert "Invalid partner information" in str(exc_info.value)
|
||||
|
||||
def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators):
|
||||
"""Test that empty user id raises BadRequest."""
|
||||
# Arrange
|
||||
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
|
||||
click_id = "click-id-789"
|
||||
mock_account.id = None # Empty user id
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={"click_id": click_id},
|
||||
path=f"/billing/partners/{partner_key_encoded}/tenants",
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456")
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
):
|
||||
resource = PartnerTenants()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(BadRequest) as exc_info:
|
||||
resource.put(partner_key_encoded)
|
||||
assert "Invalid partner information" in str(exc_info.value)
|
||||
|
||||
206
api/tests/unit_tests/services/test_billing_service.py
Normal file
206
api/tests/unit_tests/services/test_billing_service.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class TestBillingServiceSendRequest:
|
||||
"""Unit tests for BillingService._send_request method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_request(self):
|
||||
"""Mock httpx.request for testing."""
|
||||
with patch("services.billing_service.httpx.request") as mock_request:
|
||||
yield mock_request
|
||||
|
||||
@pytest.fixture
|
||||
def mock_billing_config(self):
|
||||
"""Mock BillingService configuration."""
|
||||
with (
|
||||
patch.object(BillingService, "base_url", "https://billing-api.example.com"),
|
||||
patch.object(BillingService, "secret_key", "test-secret-key"),
|
||||
):
|
||||
yield
|
||||
|
||||
def test_get_request_success(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test successful GET request."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success", "data": {"info": "test"}}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request("GET", "/test", params={"key": "value"})
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_httpx_request.assert_called_once()
|
||||
call_args = mock_httpx_request.call_args
|
||||
assert call_args[0][0] == "GET"
|
||||
assert call_args[0][1] == "https://billing-api.example.com/test"
|
||||
assert call_args[1]["params"] == {"key": "value"}
|
||||
assert call_args[1]["headers"]["Billing-Api-Secret-Key"] == "test-secret-key"
|
||||
assert call_args[1]["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.NOT_FOUND, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.BAD_REQUEST]
|
||||
)
|
||||
def test_get_request_non_200_status_code(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test GET request with non-200 status code raises ValueError."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService._send_request("GET", "/test")
|
||||
assert "Unable to retrieve billing information" in str(exc_info.value)
|
||||
|
||||
def test_put_request_success(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test successful PUT request."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request("PUT", "/test", json={"key": "value"})
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
call_args = mock_httpx_request.call_args
|
||||
assert call_args[0][0] == "PUT"
|
||||
|
||||
def test_put_request_internal_server_error(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test PUT request with INTERNAL_SERVER_ERROR raises InternalServerError."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.INTERNAL_SERVER_ERROR
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(InternalServerError) as exc_info:
|
||||
BillingService._send_request("PUT", "/test", json={"key": "value"})
|
||||
assert exc_info.value.code == 500
|
||||
assert "Unable to process billing request" in str(exc_info.value.description)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.NOT_FOUND, httpx.codes.UNAUTHORIZED, httpx.codes.FORBIDDEN]
|
||||
)
|
||||
def test_put_request_non_200_non_500(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test PUT request with non-200 and non-500 status code raises ValueError."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService._send_request("PUT", "/test", json={"key": "value"})
|
||||
assert "Invalid arguments." in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize("method", ["POST", "DELETE"])
|
||||
def test_non_get_non_put_request_success(self, mock_httpx_request, mock_billing_config, method):
|
||||
"""Test successful POST/DELETE request."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request(method, "/test", json={"key": "value"})
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
call_args = mock_httpx_request.call_args
|
||||
assert call_args[0][0] == method
|
||||
|
||||
@pytest.mark.parametrize("method", ["POST", "DELETE"])
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
|
||||
)
|
||||
def test_non_get_non_put_request_non_200_with_valid_json(
|
||||
self, mock_httpx_request, mock_billing_config, method, status_code
|
||||
):
|
||||
"""Test POST/DELETE request with non-200 status code but valid JSON response."""
|
||||
# Arrange
|
||||
error_response = {"detail": "Error message"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = error_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request(method, "/test", json={"key": "value"})
|
||||
|
||||
# Assert
|
||||
# POST and DELETE don't check status code, so they return the error JSON
|
||||
assert result == error_response
|
||||
|
||||
@pytest.mark.parametrize("method", ["POST", "DELETE"])
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
|
||||
)
|
||||
def test_non_get_non_put_request_non_200_with_invalid_json(
|
||||
self, mock_httpx_request, mock_billing_config, method, status_code
|
||||
):
|
||||
"""Test POST/DELETE request with non-200 status code and invalid JSON response raises exception.
|
||||
|
||||
When billing service raises HTTPException(status_code=500, detail=str(e)), it typically returns
|
||||
JSON like {"detail": "error"} which can be parsed. However, if the response cannot be parsed
|
||||
as JSON (e.g., empty response), response.json() will raise JSONDecodeError.
|
||||
"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = ""
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
BillingService._send_request(method, "/test", json={"key": "value"})
|
||||
|
||||
def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test that _send_request retries on httpx.RequestError."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
|
||||
# First call raises RequestError, second succeeds
|
||||
mock_httpx_request.side_effect = [
|
||||
httpx.RequestError("Network error"),
|
||||
mock_response,
|
||||
]
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request("GET", "/test")
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
assert mock_httpx_request.call_count == 2
|
||||
|
||||
def test_retry_exhausted_raises_exception(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test that _send_request raises exception after retries are exhausted."""
|
||||
# Arrange
|
||||
mock_httpx_request.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(httpx.RequestError):
|
||||
BillingService._send_request("GET", "/test")
|
||||
|
||||
# Should retry multiple times (wait=2, stop_before_delay=10 means ~5 attempts)
|
||||
assert mock_httpx_request.call_count > 1
|
||||
@@ -10,6 +10,7 @@ import { ProviderContextProvider } from '@/context/provider-context'
|
||||
import { ModalContextProvider } from '@/context/modal-context'
|
||||
import GotoAnything from '@/app/components/goto-anything'
|
||||
import Zendesk from '@/app/components/base/zendesk'
|
||||
import PartnerStack from '../components/billing/partner-stack'
|
||||
|
||||
const Layout = ({ children }: { children: ReactNode }) => {
|
||||
return (
|
||||
@@ -24,6 +25,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
<Header />
|
||||
</HeaderWrapper>
|
||||
{children}
|
||||
<PartnerStack />
|
||||
<GotoAnything />
|
||||
</ModalContextProvider>
|
||||
</ProviderContextProvider>
|
||||
|
||||
23
web/app/components/billing/partner-stack/index.tsx
Normal file
23
web/app/components/billing/partner-stack/index.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
'use client'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import type { FC } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
const PartnerStack: FC = () => {
|
||||
const { saveOrUpdate, bind } = usePSInfo()
|
||||
useEffect(() => {
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
// Save PartnerStack info in cookie first. Because if user hasn't logged in, redirecting to login page would cause lose the partnerStack info in URL.
|
||||
saveOrUpdate()
|
||||
// bind PartnerStack info after user logged in
|
||||
bind()
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<>
|
||||
</>
|
||||
)
|
||||
}
|
||||
export default React.memo(PartnerStack)
|
||||
62
web/app/components/billing/partner-stack/use-ps-info.ts
Normal file
62
web/app/components/billing/partner-stack/use-ps-info.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import { PARTNER_STACK_CONFIG } from '@/config'
|
||||
import { useBindPartnerStackInfo } from '@/service/use-billing'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import Cookies from 'js-cookie'
|
||||
import { useSearchParams } from 'next/navigation'
|
||||
import { useCallback } from 'react'
|
||||
|
||||
const usePSInfo = () => {
|
||||
const searchParams = useSearchParams()
|
||||
const psInfoInCookie = JSON.parse(Cookies.get(PARTNER_STACK_CONFIG.cookieName) || '{}')
|
||||
const psPartnerKey = searchParams.get('ps_partner_key') || psInfoInCookie?.partnerKey
|
||||
const psClickId = searchParams.get('ps_xid') || psInfoInCookie?.clickId
|
||||
const isPSChanged = psInfoInCookie?.partnerKey !== psPartnerKey || psInfoInCookie?.clickId !== psClickId
|
||||
const [hasBind, {
|
||||
setTrue: setBind,
|
||||
}] = useBoolean(false)
|
||||
const { mutateAsync } = useBindPartnerStackInfo()
|
||||
// Save to top domain. cloud.dify.ai => .dify.ai
|
||||
const domain = globalThis.location.hostname.replace('cloud', '')
|
||||
|
||||
const saveOrUpdate = useCallback(() => {
|
||||
if(!psPartnerKey || !psClickId)
|
||||
return
|
||||
if(!isPSChanged)
|
||||
return
|
||||
Cookies.set(PARTNER_STACK_CONFIG.cookieName, JSON.stringify({
|
||||
partnerKey: psPartnerKey,
|
||||
clickId: psClickId,
|
||||
}), {
|
||||
expires: PARTNER_STACK_CONFIG.saveCookieDays,
|
||||
path: '/',
|
||||
domain,
|
||||
})
|
||||
}, [psPartnerKey, psClickId, isPSChanged])
|
||||
|
||||
const bind = useCallback(async () => {
|
||||
if (psPartnerKey && psClickId && !hasBind) {
|
||||
let shouldRemoveCookie = false
|
||||
try {
|
||||
await mutateAsync({
|
||||
partnerKey: psPartnerKey,
|
||||
clickId: psClickId,
|
||||
})
|
||||
shouldRemoveCookie = true
|
||||
}
|
||||
catch (error: unknown) {
|
||||
if((error as { status: number })?.status === 400)
|
||||
shouldRemoveCookie = true
|
||||
}
|
||||
if (shouldRemoveCookie)
|
||||
Cookies.remove(PARTNER_STACK_CONFIG.cookieName, { path: '/', domain })
|
||||
setBind()
|
||||
}
|
||||
}, [psPartnerKey, psClickId, mutateAsync, hasBind, setBind])
|
||||
return {
|
||||
psPartnerKey,
|
||||
psClickId,
|
||||
saveOrUpdate,
|
||||
bind,
|
||||
}
|
||||
}
|
||||
export default usePSInfo
|
||||
@@ -2,10 +2,17 @@
|
||||
import { useSearchParams } from 'next/navigation'
|
||||
import OneMoreStep from './one-more-step'
|
||||
import NormalForm from './normal-form'
|
||||
import { useEffect } from 'react'
|
||||
import usePSInfo from '../components/billing/partner-stack/use-ps-info'
|
||||
|
||||
const SignIn = () => {
|
||||
const searchParams = useSearchParams()
|
||||
const step = searchParams.get('step')
|
||||
const { saveOrUpdate } = usePSInfo()
|
||||
|
||||
useEffect(() => {
|
||||
saveOrUpdate()
|
||||
}, [])
|
||||
|
||||
if (step === 'next')
|
||||
return <OneMoreStep />
|
||||
|
||||
@@ -443,3 +443,8 @@ export const STOP_PARAMETER_RULE: ModelParameterRule = {
|
||||
zh_Hans: '输入序列并按 Tab 键',
|
||||
},
|
||||
}
|
||||
|
||||
export const PARTNER_STACK_CONFIG = {
|
||||
cookieName: 'partner_stack_info',
|
||||
saveCookieDays: 90,
|
||||
}
|
||||
|
||||
19
web/service/use-billing.ts
Normal file
19
web/service/use-billing.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { useMutation } from '@tanstack/react-query'
|
||||
import { put } from './base'
|
||||
|
||||
const NAME_SPACE = 'billing'
|
||||
|
||||
export const useBindPartnerStackInfo = () => {
|
||||
return useMutation({
|
||||
mutationKey: [NAME_SPACE, 'bind-partner-stack'],
|
||||
mutationFn: (data: { partnerKey: string; clickId: string }) => {
|
||||
return put(`/billing/partners/${data.partnerKey}/tenants`, {
|
||||
body: {
|
||||
click_id: data.clickId,
|
||||
},
|
||||
}, {
|
||||
silent: true,
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user