Compare commits

...

2 Commits

Author SHA1 Message Date
John Wang
dd76c9aba0 fix: xinference embedding not normalized 2023-08-22 17:42:55 +08:00
John Wang
dad8d0e7be fix: xinference embd init err 2023-08-22 16:34:42 +08:00
2 changed files with 24 additions and 2 deletions

View File

@@ -1,4 +1,4 @@
from langchain.embeddings import XinferenceEmbeddings
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
@@ -14,7 +14,8 @@ class XinferenceEmbedding(BaseEmbedding):
)
client = XinferenceEmbeddings(
**credentials,
server_url=credentials['server_url'],
model_uid=credentials['model_uid'],
)
super().__init__(model_provider, client, name)

View File

@@ -0,0 +1,21 @@
from typing import List
import numpy as np
from langchain.embeddings import XinferenceEmbeddings
class XinferenceEmbedding(XinferenceEmbeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
vectors = super().embed_documents(texts)
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
return normalized_vectors
def embed_query(self, text: str) -> List[float]:
vector = super().embed_query(text)
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
return normalized_vector