Compare commits

...

1 Commits

Author SHA1 Message Date
takatost
b716d8fab7 feat: jina embedding support 2023-11-29 00:05:05 +08:00
7 changed files with 379 additions and 1 deletions

View File

@@ -0,0 +1,25 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
class JinaEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = JinaEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Jina: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,141 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
from models.provider import ProviderType
class JinaProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'jina'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'jina-embeddings-v2-base-en',
'name': 'jina-embeddings-v2-base-en',
},
{
'id': 'jina-embeddings-v2-small-en',
'name': 'jina-embeddings-v2-small-en',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.EMBEDDINGS:
model_class = JinaEmbedding
else:
raise NotImplementedError
return model_class
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Jina API Key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
}
embedding = JinaEmbeddings(
model='jina-embeddings-v2-small-en',
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
def _get_text_generation_model_mode(self, model_name) -> str:
raise NotImplementedError
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
raise NotImplementedError

View File

@@ -0,0 +1,10 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed",
"supported_model_types": [
"embeddings"
]
}

View File

@@ -0,0 +1,69 @@
"""Wrapper around Jina embedding models."""
from typing import Any, List
import requests
from pydantic import BaseModel, Extra
from langchain.embeddings.base import Embeddings
class JinaEmbeddings(BaseModel, Embeddings):
"""Wrapper around Jina embedding models.
"""
client: Any #: :meta private:
api_key: str
model: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Jina's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
result = self.invoke_embedding(text=text)
embeddings.append(result)
return [list(map(float, e)) for e in embeddings]
def invoke_embedding(self, text):
params = {
"model": self.model,
"input": [
text
]
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
response = requests.post(
'https://api.jina.ai/v1/embeddings',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"Jina HTTP {response.status_code} error: {response.text}")
json_response = response.json()
return json_response["data"][0]["embedding"]
def embed_query(self, text: str) -> List[float]:
"""Call out to Jina's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@@ -53,4 +53,7 @@ OPENLLM_SERVER_URL=
LOCALAI_SERVER_URL=
# Cohere Credentials
COHERE_API_KEY=
COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=

View File

@@ -0,0 +1,42 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
from core.model_providers.providers.jina_provider import JinaProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='jina',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'jina-embeddings-v2-small-en'
valid_api_key = os.environ['JINA_API_KEY']
provider = JinaProvider(provider=get_mock_provider(valid_api_key))
return JinaEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 512

View File

@@ -0,0 +1,88 @@
import pytest
from unittest.mock import patch
import json
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.jina_provider import JinaProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'jina'
MODEL_PROVIDER_CLASS = JinaProvider
VALIDATE_CREDENTIAL = {
'api_key': 'valid_key'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.embeddings.jina_embedding.JinaEmbeddings.embed_query',
return_value=[1, 2])
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['api_key'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)