mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 11:29:30 +08:00
Compare commits
1 Commits
d7755dbbda
...
feat/add-j
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b716d8fab7 |
25
api/core/model_providers/models/embedding/jina_embedding.py
Normal file
25
api/core/model_providers/models/embedding/jina_embedding.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
|
||||
|
||||
|
||||
class JinaEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = JinaEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Jina: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
141
api/core/model_providers/providers/jina_provider.py
Normal file
141
api/core/model_providers/providers/jina_provider.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class JinaProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'jina'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'jina-embeddings-v2-base-en',
|
||||
'name': 'jina-embeddings-v2-base-en',
|
||||
},
|
||||
{
|
||||
'id': 'jina-embeddings-v2-small-en',
|
||||
'name': 'jina-embeddings-v2-small-en',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.EMBEDDINGS:
|
||||
model_class = JinaEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Jina API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key'],
|
||||
}
|
||||
|
||||
embedding = JinaEmbeddings(
|
||||
model='jina-embeddings-v2-small-en',
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
raise NotImplementedError
|
||||
10
api/core/model_providers/rules/jina.json
Normal file
10
api/core/model_providers/rules/jina.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed",
|
||||
"supported_model_types": [
|
||||
"embeddings"
|
||||
]
|
||||
}
|
||||
69
api/core/third_party/langchain/embeddings/jina_embedding.py
vendored
Normal file
69
api/core/third_party/langchain/embeddings/jina_embedding.py
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Wrapper around Jina embedding models."""
|
||||
from typing import Any, List
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around Jina embedding models.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
api_key: str
|
||||
model: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = self.invoke_embedding(text=text)
|
||||
embeddings.append(result)
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def invoke_embedding(self, text):
|
||||
params = {
|
||||
"model": self.model,
|
||||
"input": [
|
||||
text
|
||||
]
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(
|
||||
'https://api.jina.ai/v1/embeddings',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Jina HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return json_response["data"][0]["embedding"]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@@ -53,4 +53,7 @@ OPENLLM_SERVER_URL=
|
||||
LOCALAI_SERVER_URL=
|
||||
|
||||
# Cohere Credentials
|
||||
COHERE_API_KEY=
|
||||
COHERE_API_KEY=
|
||||
|
||||
# Jina Credentials
|
||||
JINA_API_KEY=
|
||||
@@ -0,0 +1,42 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
|
||||
from core.model_providers.providers.jina_provider import JinaProvider
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
def get_mock_provider(valid_api_key):
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='jina',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({
|
||||
'api_key': valid_api_key
|
||||
}),
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_embedding_model():
|
||||
model_name = 'jina-embeddings-v2-small-en'
|
||||
valid_api_key = os.environ['JINA_API_KEY']
|
||||
provider = JinaProvider(provider=get_mock_provider(valid_api_key))
|
||||
return JinaEmbedding(
|
||||
model_provider=provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
return encrypted_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_embedding(mock_decrypt):
|
||||
embedding_model = get_mock_embedding_model()
|
||||
rst = embedding_model.client.embed_query('test')
|
||||
assert isinstance(rst, list)
|
||||
assert len(rst) == 512
|
||||
88
api/tests/unit_tests/model_providers/test_jina_provider.py
Normal file
88
api/tests/unit_tests/model_providers/test_jina_provider.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.jina_provider import JinaProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'jina'
|
||||
MODEL_PROVIDER_CLASS = JinaProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'api_key': 'valid_key'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('core.third_party.langchain.embeddings.jina_embedding.JinaEmbeddings.embed_query',
|
||||
return_value=[1, 2])
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
credential = VALIDATE_CREDENTIAL.copy()
|
||||
credential['api_key'] = 'invalid_key'
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result['api_key'] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result['api_key'] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result['api_key'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
Reference in New Issue
Block a user