|
@@ -6,6 +6,8 @@ from utils.text_splitter import TextSplitter
|
|
|
from utils.vector_distance import VectorDistance
|
|
|
from model.response import StandardResponse
|
|
|
from utils.vectorizer import Vectorizer
|
|
|
+# from utils.find_text_in_pdf import find_text_in_pdf
|
|
|
+import os
|
|
|
DISTANCE_THRESHOLD = 0.8
|
|
|
import logging
|
|
|
import time
|
|
@@ -58,7 +60,6 @@ class TextCompareMultiRequest(BaseModel):
|
|
|
class NodePropsSearchRequest(BaseModel):
|
|
|
node_id: int
|
|
|
props_ids: List[int]
|
|
|
- conversation_id: Optional[str] = None
|
|
|
|
|
|
@router.post("/search", response_model=StandardResponse)
|
|
|
async def search_text(request: TextSearchRequest):
|
|
@@ -138,6 +139,8 @@ async def search_text(request: TextSearchRequest):
|
|
|
"distance": distance,
|
|
|
"referrence": search_result.get("referrence", "")
|
|
|
}
|
|
|
+
|
|
|
+
|
|
|
all_references.append(reference)
|
|
|
reference_index += 1
|
|
|
|
|
@@ -337,9 +340,6 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
|
|
|
"distance": 0
|
|
|
}
|
|
|
|
|
|
- # 缓存结果
|
|
|
- cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
|
|
|
-
|
|
|
# 遍历props_ids查询属性信息
|
|
|
for prop_id in request.props_ids:
|
|
|
prop = prop_service.get_props_by_id(prop_id)
|
|
@@ -374,11 +374,11 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
|
|
|
i = 0
|
|
|
while i < len(sentences):
|
|
|
original_sentence = sentences[i]
|
|
|
- sentence = original_sentence.replace("\n", "<br>")
|
|
|
+ sentence = original_sentence
|
|
|
|
|
|
# 如果当前句子长度小于10且不是最后一句,则与下一句合并
|
|
|
if len(sentence) < 10 and i + 1 < len(sentences):
|
|
|
- next_sentence = sentences[i + 1].replace("\n", "<br>")
|
|
|
+ next_sentence = sentences[i + 1]
|
|
|
combined_sentence = sentence + " " + next_sentence
|
|
|
# 添加原短句到结果,flag为空
|
|
|
result_sentences.append({
|
|
@@ -402,31 +402,12 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
|
|
|
|
|
|
i += 1
|
|
|
|
|
|
- # 检查缓存
|
|
|
- if cached_results:
|
|
|
- min_distance = float('inf')
|
|
|
- best_result = None
|
|
|
- sentence_vector = Vectorizer.get_embedding(search_text)
|
|
|
-
|
|
|
- for cached_result in cached_results:
|
|
|
- content_vector = cached_result['embedding']
|
|
|
- distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
|
|
|
- if distance < min_distance:
|
|
|
- min_distance = distance
|
|
|
- best_result = {**cached_result, 'distance': distance}
|
|
|
-
|
|
|
- if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
|
|
|
- search_results = [best_result]
|
|
|
- else:
|
|
|
- search_results = []
|
|
|
- else:
|
|
|
- # 进行向量搜索
|
|
|
- search_results = trunks_service.search_by_vector(
|
|
|
+ # 进行向量搜索
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
text=search_text,
|
|
|
limit=1,
|
|
|
- type='trunk',
|
|
|
- conversation_id=request.conversation_id
|
|
|
- )
|
|
|
+ type='trunk'
|
|
|
+ )
|
|
|
|
|
|
# 处理搜索结果
|
|
|
if not search_results:
|
|
@@ -461,8 +442,10 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
|
|
|
"file_path": search_result.get("file_path", ""),
|
|
|
"title": search_result.get("title", ""),
|
|
|
"distance": distance,
|
|
|
+ "page_no": search_result.get("page_no", ""),
|
|
|
"referrence": search_result.get("referrence", "")
|
|
|
}
|
|
|
+
|
|
|
all_references.append(reference)
|
|
|
reference_index += 1
|
|
|
|