Refactor function name (#11210)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-12 19:00:15 +08:00
committed by GitHub
parent a36a0fe71c
commit 296476ab89
20 changed files with 105 additions and 103 deletions

View File

@@ -807,7 +807,7 @@ def check_embedding():
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
total = docStoreConn.getTotal(res0)
total = docStoreConn.get_total(res0)
if total <= 0:
return []
@@ -824,7 +824,7 @@ def check_embedding():
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
ids = docStoreConn.getChunkIds(res1)
ids = docStoreConn.get_chunk_ids(res1)
if not ids:
continue

View File

@@ -309,7 +309,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@@ -322,7 +322,7 @@ class DocumentService(CommonService):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields(
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:

View File

@@ -69,7 +69,7 @@ class KGSearch(Dealer):
def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {}
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
es_res = self.dataStore.getFields(es_res, flds)
es_res = self.dataStore.get_fields(es_res, flds)
for _, ent in es_res.items():
for f in flds:
if f in ent and ent[f] is None:
@@ -88,7 +88,7 @@ class KGSearch(Dealer):
def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"])
for _, ent in es_res.items():
if get_float(ent["_score"]) < sim_thr:
@@ -300,7 +300,7 @@ class KGSearch(Dealer):
fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [],
OrderByExpr(), 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, fields)
comm_res_fields = self.dataStore.get_fields(comm_res, fields)
txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["content_with_weight"])

View File

@@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
@@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
)
# tot = settings.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds)
# tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds)
if len(es_res) == 0:
break

View File

@@ -38,11 +38,11 @@ class FulltextQueryer:
]
@staticmethod
def subSpecialChar(line):
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
@@ -92,7 +92,7 @@ class FulltextQueryer:
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt):
if not self.is_chinese(txt):
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
@@ -163,7 +163,7 @@ class FulltextQueryer:
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
@@ -171,7 +171,7 @@ class FulltextQueryer:
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
@@ -180,7 +180,7 @@ class FulltextQueryer:
if len(keywords) >= 32:
break
tk = FulltextQueryer.subSpecialChar(tk)
tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
@@ -198,7 +198,7 @@ class FulltextQueryer:
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
for s in syns
]
)
@@ -217,17 +217,17 @@ class FulltextQueryer:
return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
sims = cosine_similarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
def to_dict(tks):
if isinstance(tks, str):
tks = tks.split()
d = defaultdict(int)
@@ -236,8 +236,8 @@ class FulltextQueryer:
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
atks = to_dict(atks)
btkss = [to_dict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
@@ -262,10 +262,10 @@ class FulltextQueryer:
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:

View File

@@ -35,7 +35,7 @@ class RagTokenizer:
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
def _load_dict(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding='utf-8')
@@ -85,18 +85,18 @@ class RagTokenizer:
self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file
self.loadDict_(self.DIR_ + ".txt")
self._load_dict(self.DIR_ + ".txt")
def loadUserDict(self, fnm):
def load_user_dict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
self._load_dict(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def add_user_dict(self, fnm):
self._load_dict(fnm)
def _strQ2B(self, ustring):
"""Convert full-width characters to half-width characters"""
@@ -221,7 +221,7 @@ class RagTokenizer:
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
def _sort_tokens(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
@@ -246,7 +246,7 @@ class RagTokenizer:
return " ".join(res)
def maxForward_(self, line):
def _max_forward(self, line):
res = []
s = 0
while s < len(line):
@@ -270,7 +270,7 @@ class RagTokenizer:
return self.score_(res)
def maxBackward_(self, line):
def _max_backward(self, line):
res = []
s = len(line) - 1
while s >= 0:
@@ -336,8 +336,8 @@ class RagTokenizer:
continue
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
tks, s = self._max_forward(L)
tks1, s1 = self._max_backward(L)
if self.DEBUG:
logging.debug("[FW] {} {}".format(tks, s))
logging.debug("[BW] {} {}".format(tks1, s1))
@@ -369,7 +369,7 @@ class RagTokenizer:
# backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
@@ -385,7 +385,7 @@ class RagTokenizer:
assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
res = " ".join(res)
logging.debug("[TKS] {}".format(self.merge_(res)))
@@ -413,7 +413,7 @@ class RagTokenizer:
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
stk = self._sort_tokens(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
@@ -447,14 +447,13 @@ def is_number(s):
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
def naive_qie(txt):
tks = []
for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
@@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize
tag = tokenizer.tag
freq = tokenizer.freq
loadUserDict = tokenizer.loadUserDict
addUserDict = tokenizer.addUserDict
load_user_dict = tokenizer.load_user_dict
add_user_dict = tokenizer.add_user_dict
tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B
if __name__ == '__main__':
tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks))
@@ -506,7 +505,7 @@ if __name__ == '__main__':
if len(sys.argv) < 2:
sys.exit()
tknzr.DEBUG = False
tknzr.loadUserDict(sys.argv[1])
tknzr.load_user_dict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()

View File

@@ -102,7 +102,7 @@ class Dealer:
orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"]
@@ -115,7 +115,7 @@ class Dealer:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
@@ -127,20 +127,20 @@ class Dealer:
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
@@ -153,17 +153,17 @@ class Dealer:
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res)
ids = self.dataStore.get_chunk_ids(res)
keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
return self.SearchResult(
total=total,
ids=ids,
query_vector=q_vec,
aggregation=aggs,
highlight=highlight,
field=self.dataStore.getFields(res, src + ["_score"]),
field=self.dataStore.get_fields(res, src + ["_score"]),
keywords=keywords
)
@@ -488,7 +488,7 @@ class Dealer:
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
dict_chunks = self.dataStore.get_fields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
@@ -501,11 +501,11 @@ class Dealer:
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
return self.dataStore.get_aggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
res = self.dataStore.get_aggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
@@ -513,7 +513,7 @@ class Dealer:
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
@@ -529,7 +529,7 @@ class Dealer:
idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
@@ -552,7 +552,7 @@ class Dealer:
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
for _, doc in dict_chunks.items():
try:
toc.extend(json.loads(doc["content_with_weight"]))

View File

@@ -113,20 +113,20 @@ class Dealer:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
def token_merge(self, tks):
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
if i == 0 and one_term(tks[i]) and len(
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
@@ -232,7 +232,7 @@ class Dealer:
tw = list(zip(tks, wts))
else:
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
tt = self.token_merge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \

View File

@@ -28,7 +28,7 @@ def collect():
logging.debug(doc_locations)
if len(doc_locations) == 0:
time.sleep(1)
return
return None
return doc_locations

View File

@@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback):
task_canceled = has_canceled(task["id"])
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
@@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback):
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
def init_kb(row, vector_size: int):
@@ -719,7 +720,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
return False
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@@ -737,7 +738,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return False
return True

View File

@@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"Fail put {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return None
return None
def rm(self, bucket, fnm):
try:
@@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
def obj_exist(self, bucket, fnm):
try:
@@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None

View File

@@ -241,23 +241,23 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def getTotal(self, res):
def get_total(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
def get_chunk_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""

View File

@@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res):
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getChunkIds(self, res):
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
@@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection):
rr.append(d["_source"])
return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
@@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection):
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
@@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection):
return ans
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()

View File

@@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection):
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res_fields = self.getFields(res, res.columns.tolist())
res_fields = self.get_fields(res, res.columns.tolist())
return res_fields.get(chunkId, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
@@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection):
col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.getFields(row_to_opt, col_to_remove)
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items():
if remove_v in old_v[k]:
@@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
@@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection):
return res2.set_index("id").to_dict(orient="index")
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
@@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection):
ans[id] = txt
return ans
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""

View File

@@ -92,7 +92,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get {bucket}/{filename}")
self.__open__()
time.sleep(1)
return
return None
def obj_exist(self, bucket, filename, tenant_id=None):
try:
@@ -130,7 +130,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return
return None
def remove_bucket(self, bucket):
try:

View File

@@ -62,8 +62,7 @@ class OpenDALStorage:
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
r = self._operator.write(f"{bucket}/{fnm}", binary)
return r
return self._operator.write(f"{bucket}/{fnm}", binary)
def put(self, bucket, fnm, binary, tenant_id=None):
self._operator.write(f"{bucket}/{fnm}", binary)

View File

@@ -455,12 +455,12 @@ class OSConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res):
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getChunkIds(self, res):
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
@@ -471,7 +471,7 @@ class OSConnection(DocStoreConnection):
rr.append(d["_source"])
return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
@@ -490,7 +490,7 @@ class OSConnection(DocStoreConnection):
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
@@ -515,7 +515,7 @@ class OSConnection(DocStoreConnection):
return ans
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()

View File

@@ -141,7 +141,7 @@ class RAGFlowOSS:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
@use_prefix_path
@use_default_bucket
@@ -170,5 +170,5 @@ class RAGFlowOSS:
logging.exception(f"fail get url {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None

View File

@@ -104,6 +104,7 @@ class RedisDB:
if self.REDIS.get(a) == b:
return True
return False
def info(self):
info = self.REDIS.info()
@@ -124,7 +125,7 @@ class RedisDB:
def exist(self, k):
if not self.REDIS:
return
return None
try:
return self.REDIS.exists(k)
except Exception as e:
@@ -133,7 +134,7 @@ class RedisDB:
def get(self, k):
if not self.REDIS:
return
return None
try:
return self.REDIS.get(k)
except Exception as e:

View File

@@ -164,7 +164,7 @@ class RAGFlowS3:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
@use_prefix_path
@use_default_bucket
@@ -193,7 +193,7 @@ class RAGFlowS3:
logging.exception(f"fail get url {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
@use_default_bucket
def rm_bucket(self, bucket, *args, **kwargs):