Compare commits

...

4 Commits

Author SHA1 Message Date
Garfield Dai
b8db580833 feat: hugging face supports embeddings. 2023-09-19 21:08:37 +08:00
StyleZhang
9d3ba98d07 fix: frontend huggingface embedding hide task type 2023-09-19 11:02:54 +08:00
StyleZhang
8b2573efab feat: frontend add huggingface embedding support 2023-09-19 10:46:32 +08:00
Garfield Dai
757ee4d39f feat: hugging face supports embeddings. 2023-09-19 10:34:19 +08:00
6 changed files with 271 additions and 22 deletions

View File

@@ -0,0 +1,24 @@
from replicate.exceptions import ModelError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class HuggingfaceEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = HuggingfaceHubEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface embedding: {str(ex)}")

View File

@@ -10,6 +10,8 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from models.provider import ProviderType
@@ -33,6 +35,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = HuggingfaceHubModel
elif model_type == ModelType.EMBEDDINGS:
model_class = HuggingfaceEmbedding
else:
raise NotImplementedError
@@ -63,7 +67,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
:param model_type:
:param credentials:
"""
if model_type != ModelType.TEXT_GENERATION:
if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
raise NotImplementedError
if 'huggingfacehub_api_type' not in credentials \
@@ -88,19 +92,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'feature-extraction'):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization.')
'text-generation, summarization, feature-extraction.')
try:
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
if credentials['task_type'] == 'feature-extraction':
cls.check_embedding_valid(credentials, model_name)
else:
cls.check_llm_valid(credentials)
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
else:
@@ -112,13 +112,33 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "feature-extraction")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
@classmethod
def check_llm_valid(cls, credentials: dict):
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
@classmethod
def check_embedding_valid(cls, credentials: dict, model_name: str):
embedding_model = HuggingfaceHubEmbeddings(
model=model_name,
**credentials
)
embedding_model.embed_query("ping")
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:

View File

@@ -0,0 +1,68 @@
from typing import Any, Dict, List, Optional, Union
import json
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from huggingface_hub import InferenceClient
HOSTED_INFERENCE_API = 'hosted_inference_api'
INFERENCE_ENDPOINTS = 'inference_endpoints'
class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
client: Any
model: str
task_type: Optional[str] = None
huggingfacehub_api_type: Optional[str] = None
huggingfacehub_api_token: Optional[str] = None
huggingfacehub_endpoint_url: Optional[str] = None
class Config:
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values['huggingfacehub_api_token'] = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values['client'] = InferenceClient(values['huggingfacehub_api_token'])
return values
def embeddings(self, inputs: Union[str, List[str]]) -> str:
model = ''
if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
model = self.model
else:
model = self.huggingfacehub_endpoint_url
output = self.client.post(
json={
"inputs": inputs,
"options": {
"wait_for_model": False
}
}, model=model)
return json.loads(output.decode())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
output = self.embeddings(texts)
if isinstance(output, list):
return output
return [list(map(float, e)) for e in output]
def embed_query(self, text: str) -> List[float]:
output = self.embeddings(text)
if isinstance(output, list):
return output
return list(map(float, output))

View File

@@ -51,4 +51,4 @@ stripe~=5.5.0
pandas==1.5.3
xinference==0.4.2
safetensors==0.3.2
zhipuai==1.0.7
zhipuai==1.0.7

View File

@@ -0,0 +1,101 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import Provider, ProviderType, ProviderModel
DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2'
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='huggingface_hub',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker):
valid_api_key = os.environ['HUGGINGFACE_API_KEY']
endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL']
model_provider = HuggingfaceHubProvider(provider=get_mock_provider())
credentials = {
'huggingfacehub_api_type': huggingfacehub_api_type,
'huggingfacehub_api_token': valid_api_key,
'task_type': 'feature-extraction'
}
if huggingfacehub_api_type == 'inference_endpoints':
credentials['huggingfacehub_endpoint_url'] = endpoint_url
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='huggingface_hub',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query',
return_value=mock_query)
return HuggingfaceEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 384
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 384
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 384
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 384

View File

@@ -68,14 +68,25 @@ const config: ProviderConfig = {
]
}
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
if (v?.model_type === 'embeddings') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'model_type',
]
}
else {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
}
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
@@ -83,6 +94,31 @@ const config: ProviderConfig = {
}, {})
},
fields: [
{
type: 'radio',
key: 'model_type',
required: true,
label: {
'en': 'Model Type',
'zh-Hans': '模型类型',
},
options: [
{
key: 'text-generation',
label: {
'en': 'Text Generation',
'zh-Hans': '文本生成',
},
},
{
key: 'embeddings',
label: {
'en': 'Embeddings',
'zh-Hans': 'Embeddings',
},
},
],
},
{
type: 'radio',
key: 'huggingfacehub_api_type',
@@ -148,7 +184,7 @@ const config: ProviderConfig = {
},
},
{
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api' || value?.model_type === 'embeddings',
type: 'radio',
key: 'task_type',
required: true,