mirror of
https://gitee.com/infiniflow/ragflow.git
synced 2025-12-06 07:19:03 +08:00
Refactor function name (#11210)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -807,7 +807,7 @@ def check_embedding():
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
total = docStoreConn.getTotal(res0)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
@@ -824,7 +824,7 @@ def check_embedding():
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.getChunkIds(res1)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
|
||||
@@ -309,7 +309,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@@ -322,7 +322,7 @@ class DocumentService(CommonService):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
|
||||
@@ -69,7 +69,7 @@ class KGSearch(Dealer):
|
||||
def _ent_info_from_(self, es_res, sim_thr=0.3):
|
||||
res = {}
|
||||
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
|
||||
es_res = self.dataStore.getFields(es_res, flds)
|
||||
es_res = self.dataStore.get_fields(es_res, flds)
|
||||
for _, ent in es_res.items():
|
||||
for f in flds:
|
||||
if f in ent and ent[f] is None:
|
||||
@@ -88,7 +88,7 @@ class KGSearch(Dealer):
|
||||
|
||||
def _relation_info_from_(self, es_res, sim_thr=0.3):
|
||||
res = {}
|
||||
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
|
||||
es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
|
||||
"weight_int"])
|
||||
for _, ent in es_res.items():
|
||||
if get_float(ent["_score"]) < sim_thr:
|
||||
@@ -300,7 +300,7 @@ class KGSearch(Dealer):
|
||||
fltr["entities_kwd"] = entities
|
||||
comm_res = self.dataStore.search(fields, [], fltr, [],
|
||||
OrderByExpr(), 0, topn, idxnms, kb_ids)
|
||||
comm_res_fields = self.dataStore.getFields(comm_res, fields)
|
||||
comm_res_fields = self.dataStore.get_fields(comm_res, fields)
|
||||
txts = []
|
||||
for ii, (_, row) in enumerate(comm_res_fields.items()):
|
||||
obj = json.loads(row["content_with_weight"])
|
||||
|
||||
@@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = settings.docStoreConn.getFields(res, fields)
|
||||
fields2 = settings.docStoreConn.get_fields(res, fields)
|
||||
graph_doc_ids = set()
|
||||
for chunk_id in fields2.keys():
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
@@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
es_res = await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
)
|
||||
# tot = settings.docStoreConn.getTotal(es_res)
|
||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
||||
# tot = settings.docStoreConn.get_total(es_res)
|
||||
es_res = settings.docStoreConn.get_fields(es_res, flds)
|
||||
|
||||
if len(es_res) == 0:
|
||||
break
|
||||
|
||||
@@ -38,11 +38,11 @@ class FulltextQueryer:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
def sub_special_char(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
def is_chinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
@@ -92,7 +92,7 @@ class FulltextQueryer:
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
if not self.is_chinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
@@ -163,7 +163,7 @@ class FulltextQueryer:
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
@@ -171,7 +171,7 @@ class FulltextQueryer:
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
@@ -180,7 +180,7 @@ class FulltextQueryer:
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
@@ -198,7 +198,7 @@ class FulltextQueryer:
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
@@ -217,17 +217,17 @@ class FulltextQueryer:
|
||||
return None, keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
sims = cosine_similarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
if np.sum(sims[0]) == 0:
|
||||
return np.array(tksim), tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
def to_dict(tks):
|
||||
if isinstance(tks, str):
|
||||
tks = tks.split()
|
||||
d = defaultdict(int)
|
||||
@@ -236,8 +236,8 @@ class FulltextQueryer:
|
||||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
atks = to_dict(atks)
|
||||
btkss = [to_dict(tks) for tks in btkss]
|
||||
return [self.similarity(atks, btks) for btks in btkss]
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
@@ -262,10 +262,10 @@ class FulltextQueryer:
|
||||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
|
||||
@@ -35,7 +35,7 @@ class RagTokenizer:
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
def _load_dict(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding='utf-8')
|
||||
@@ -85,18 +85,18 @@ class RagTokenizer:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self.loadDict_(self.DIR_ + ".txt")
|
||||
self._load_dict(self.DIR_ + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
def load_user_dict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
self._load_dict(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
def add_user_dict(self, fnm):
|
||||
self._load_dict(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""Convert full-width characters to half-width characters"""
|
||||
@@ -221,7 +221,7 @@ class RagTokenizer:
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
def _sort_tokens(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
@@ -246,7 +246,7 @@ class RagTokenizer:
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
def _max_forward(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
@@ -270,7 +270,7 @@ class RagTokenizer:
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
def _max_backward(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
@@ -336,8 +336,8 @@ class RagTokenizer:
|
||||
continue
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
tks, s = self._max_forward(L)
|
||||
tks1, s1 = self._max_backward(L)
|
||||
if self.DEBUG:
|
||||
logging.debug("[FW] {} {}".format(tks, s))
|
||||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||||
@@ -369,7 +369,7 @@ class RagTokenizer:
|
||||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
@@ -385,7 +385,7 @@ class RagTokenizer:
|
||||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
res = " ".join(res)
|
||||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||||
@@ -413,7 +413,7 @@ class RagTokenizer:
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
stk = self._sort_tokens(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
@@ -447,14 +447,13 @@ def is_number(s):
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
def naive_qie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
@@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize
|
||||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||
tag = tokenizer.tag
|
||||
freq = tokenizer.freq
|
||||
loadUserDict = tokenizer.loadUserDict
|
||||
addUserDict = tokenizer.addUserDict
|
||||
load_user_dict = tokenizer.load_user_dict
|
||||
add_user_dict = tokenizer.add_user_dict
|
||||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
|
||||
tks = tknzr.tokenize(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
@@ -506,7 +505,7 @@ if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.DEBUG = False
|
||||
tknzr.loadUserDict(sys.argv[1])
|
||||
tknzr.load_user_dict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
|
||||
@@ -102,7 +102,7 @@ class Dealer:
|
||||
orderBy.asc("top_int")
|
||||
orderBy.desc("create_timestamp_flt")
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
highlightFields = ["content_ltks", "title_tks"]
|
||||
@@ -115,7 +115,7 @@ class Dealer:
|
||||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
@@ -127,20 +127,20 @@ class Dealer:
|
||||
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
else:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
for k in keywords:
|
||||
@@ -153,17 +153,17 @@ class Dealer:
|
||||
kwds.add(kk)
|
||||
|
||||
logging.debug(f"TOTAL: {total}")
|
||||
ids = self.dataStore.getChunkIds(res)
|
||||
ids = self.dataStore.get_chunk_ids(res)
|
||||
keywords = list(kwds)
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
|
||||
return self.SearchResult(
|
||||
total=total,
|
||||
ids=ids,
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src + ["_score"]),
|
||||
field=self.dataStore.get_fields(res, src + ["_score"]),
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
@@ -488,7 +488,7 @@ class Dealer:
|
||||
for p in range(offset, max_count, bs):
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
|
||||
kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||
dict_chunks = self.dataStore.get_fields(es_res, fields)
|
||||
for id, doc in dict_chunks.items():
|
||||
doc["id"] = id
|
||||
if dict_chunks:
|
||||
@@ -501,11 +501,11 @@ class Dealer:
|
||||
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
|
||||
return []
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
return self.dataStore.getAggregation(res, "tag_kwd")
|
||||
return self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
|
||||
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
res = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
res = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
total = np.sum([c for _, c in res])
|
||||
return {t: (c + 1) / (total + S) for t, c in res}
|
||||
|
||||
@@ -513,7 +513,7 @@ class Dealer:
|
||||
idx_nm = index_name(tenant_id)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return False
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
@@ -529,7 +529,7 @@ class Dealer:
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
match_txt, _ = self.qryr.question(question, min_match=0.0)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return {}
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
@@ -552,7 +552,7 @@ class Dealer:
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
||||
kb_ids)
|
||||
toc = []
|
||||
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
||||
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
|
||||
for _, doc in dict_chunks.items():
|
||||
try:
|
||||
toc.extend(json.loads(doc["content_with_weight"]))
|
||||
|
||||
@@ -113,20 +113,20 @@ class Dealer:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
def token_merge(self, tks):
|
||||
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
if i == 0 and one_term(tks[i]) and len(
|
||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
@@ -232,7 +232,7 @@ class Dealer:
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
tt = self.token_merge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
|
||||
@@ -28,7 +28,7 @@ def collect():
|
||||
logging.debug(doc_locations)
|
||||
if len(doc_locations) == 0:
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
return doc_locations
|
||||
|
||||
|
||||
|
||||
@@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback):
|
||||
task_canceled = has_canceled(task["id"])
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
return None
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
@@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
d["page_num_int"] = [100000000]
|
||||
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
return d
|
||||
return None
|
||||
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
@@ -719,7 +720,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
return False
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
@@ -737,7 +738,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob:
|
||||
logging.exception(f"Fail put {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return None
|
||||
return None
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
@@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def obj_exist(self, bucket, fnm):
|
||||
try:
|
||||
@@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
@@ -241,23 +241,23 @@ class DocStoreConnection(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
|
||||
@@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
@@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection):
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
@@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection):
|
||||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
@@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection):
|
||||
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
|
||||
@@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, ["id"])
|
||||
res_fields = self.getFields(res, res.columns.tolist())
|
||||
res_fields = self.get_fields(res, res.columns.tolist())
|
||||
return res_fields.get(chunkId, None)
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
@@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
col_to_remove = list(removeValue.keys())
|
||||
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
|
||||
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
row_to_opt = self.getFields(row_to_opt, col_to_remove)
|
||||
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
|
||||
for id, old_v in row_to_opt.items():
|
||||
for k, remove_v in removeValue.items():
|
||||
if remove_v in old_v[k]:
|
||||
@@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
@@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
|
||||
return res2.set_index("id").to_dict(orient="index")
|
||||
|
||||
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
@@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
|
||||
@@ -92,7 +92,7 @@ class RAGFlowMinio:
|
||||
logging.exception(f"Fail to get {bucket}/{filename}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def obj_exist(self, bucket, filename, tenant_id=None):
|
||||
try:
|
||||
@@ -130,7 +130,7 @@ class RAGFlowMinio:
|
||||
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def remove_bucket(self, bucket):
|
||||
try:
|
||||
|
||||
@@ -62,8 +62,7 @@ class OpenDALStorage:
|
||||
|
||||
def health(self):
|
||||
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
r = self._operator.write(f"{bucket}/{fnm}", binary)
|
||||
return r
|
||||
return self._operator.write(f"{bucket}/{fnm}", binary)
|
||||
|
||||
def put(self, bucket, fnm, binary, tenant_id=None):
|
||||
self._operator.write(f"{bucket}/{fnm}", binary)
|
||||
|
||||
@@ -455,12 +455,12 @@ class OSConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
@@ -471,7 +471,7 @@ class OSConnection(DocStoreConnection):
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
@@ -490,7 +490,7 @@ class OSConnection(DocStoreConnection):
|
||||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
@@ -515,7 +515,7 @@ class OSConnection(DocStoreConnection):
|
||||
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
|
||||
@@ -141,7 +141,7 @@ class RAGFlowOSS:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
@@ -170,5 +170,5 @@ class RAGFlowOSS:
|
||||
logging.exception(f"fail get url {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
|
||||
@@ -104,6 +104,7 @@ class RedisDB:
|
||||
|
||||
if self.REDIS.get(a) == b:
|
||||
return True
|
||||
return False
|
||||
|
||||
def info(self):
|
||||
info = self.REDIS.info()
|
||||
@@ -124,7 +125,7 @@ class RedisDB:
|
||||
|
||||
def exist(self, k):
|
||||
if not self.REDIS:
|
||||
return
|
||||
return None
|
||||
try:
|
||||
return self.REDIS.exists(k)
|
||||
except Exception as e:
|
||||
@@ -133,7 +134,7 @@ class RedisDB:
|
||||
|
||||
def get(self, k):
|
||||
if not self.REDIS:
|
||||
return
|
||||
return None
|
||||
try:
|
||||
return self.REDIS.get(k)
|
||||
except Exception as e:
|
||||
|
||||
@@ -164,7 +164,7 @@ class RAGFlowS3:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
@@ -193,7 +193,7 @@ class RAGFlowS3:
|
||||
logging.exception(f"fail get url {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
@use_default_bucket
|
||||
def rm_bucket(self, bucket, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user