|
@@ -6,6 +6,7 @@ from utils.text_splitter import TextSplitter
|
|
|
from utils.vector_distance import VectorDistance
|
|
|
from model.response import StandardResponse
|
|
|
from utils.vectorizer import Vectorizer
|
|
|
+DISTANCE_THRESHOLD = 0.65
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -16,6 +17,10 @@ class TextSearchRequest(BaseModel):
|
|
|
conversation_id: Optional[str] = None
|
|
|
need_convert: Optional[bool] = False
|
|
|
|
|
|
+class TextCompareRequest(BaseModel):
|
|
|
+ sentence: str
|
|
|
+ text: str
|
|
|
+
|
|
|
@router.post("/search", response_model=StandardResponse)
|
|
|
async def search_text(request: TextSearchRequest):
|
|
|
try:
|
|
@@ -34,6 +39,8 @@ async def search_text(request: TextSearchRequest):
|
|
|
cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
|
|
|
|
|
|
for sentence in sentences:
|
|
|
+ # if request.need_convert:
|
|
|
+ sentence = sentence.replace("\n", "<br>")
|
|
|
if len(sentence) < 10:
|
|
|
result_sentences.append(sentence)
|
|
|
continue
|
|
@@ -51,7 +58,7 @@ async def search_text(request: TextSearchRequest):
|
|
|
best_result = {**cached_result, 'distance': distance}
|
|
|
|
|
|
|
|
|
- if best_result and best_result['distance'] < 0.75:
|
|
|
+ if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
|
|
|
search_results = [best_result]
|
|
|
else:
|
|
|
search_results = []
|
|
@@ -64,32 +71,43 @@ async def search_text(request: TextSearchRequest):
|
|
|
|
|
|
# 处理搜索结果
|
|
|
for result in search_results:
|
|
|
- # 获取distance值,如果大于等于1则跳过
|
|
|
- distance = result.get("distance", 1.0)
|
|
|
- if distance >= 1:
|
|
|
+ distance = result.get("distance", DISTANCE_THRESHOLD)
|
|
|
+ if distance >= DISTANCE_THRESHOLD:
|
|
|
+ result_sentences.append(sentence)
|
|
|
continue
|
|
|
|
|
|
+ # 检查是否已存在相同引用
|
|
|
+ existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
|
|
|
+ current_index = reference_index
|
|
|
+ if existing_ref:
|
|
|
+ current_index = int(existing_ref["index"])
|
|
|
+ else:
|
|
|
+ # 添加到引用列表
|
|
|
+ reference = {
|
|
|
+ "index": str(reference_index),
|
|
|
+ "id": result["id"],
|
|
|
+ "content": result["content"],
|
|
|
+ "file_path": result.get("file_path", ""),
|
|
|
+ "title": result.get("title", ""),
|
|
|
+ "distance": distance,
|
|
|
+ "referrence": result.get("referrence", "")
|
|
|
+ }
|
|
|
+ all_references.append(reference)
|
|
|
+ reference_index += 1
|
|
|
+
|
|
|
# 添加引用标记
|
|
|
- result_sentence = sentence + f"^[{reference_index}]^"
|
|
|
- result_sentences.append(result_sentence)
|
|
|
+ if sentence.endswith('<br>'):
|
|
|
+ # 如果有多个<br>,在所有<br>前添加^[current_index]^
|
|
|
+ result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
|
|
|
+ else:
|
|
|
+ # 直接在句子末尾添加^[current_index]^
|
|
|
+ result_sentence = f'{sentence}^[{current_index}]^'
|
|
|
|
|
|
- # 添加到引用列表
|
|
|
- reference = {
|
|
|
- "index": str(reference_index),
|
|
|
- "content": result["content"],
|
|
|
- "file_path": result.get("file_path", ""),
|
|
|
- "title": result.get("title", ""),
|
|
|
- "distance": distance
|
|
|
- }
|
|
|
- all_references.append(reference)
|
|
|
- reference_index += 1
|
|
|
-
|
|
|
- answer = "\n".join(result_sentences)
|
|
|
- if request.need_convert:
|
|
|
- answer = answer.replace("\n", "</br>")
|
|
|
+ result_sentences.append(result_sentence)
|
|
|
+
|
|
|
# 组装返回数据
|
|
|
response_data = {
|
|
|
- "answer": answer,
|
|
|
+ "answer": result_sentences,
|
|
|
"references": all_references
|
|
|
}
|
|
|
|
|
@@ -99,4 +117,32 @@ async def search_text(request: TextSearchRequest):
|
|
|
logger.error(f"Text search failed: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
+@router.post("/match", response_model=StandardResponse)
|
|
|
+async def match_text(request: TextCompareRequest):
|
|
|
+ try:
|
|
|
+ sentences = TextSplitter.split_text(request.text)
|
|
|
+ sentence_vector = Vectorizer.get_embedding(request.sentence)
|
|
|
+ min_distance = float('inf')
|
|
|
+ best_sentence = ""
|
|
|
+ result_sentences = []
|
|
|
+ for temp in sentences:
|
|
|
+ result_sentences.append(temp)
|
|
|
+ if len(temp) < 10:
|
|
|
+ continue
|
|
|
+ temp_vector = Vectorizer.get_embedding(temp)
|
|
|
+ distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
|
|
|
+ if distance < min_distance and distance < DISTANCE_THRESHOLD:
|
|
|
+ min_distance = distance
|
|
|
+ best_sentence = temp
|
|
|
+
|
|
|
+ for i in range(len(result_sentences)):
|
|
|
+ result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
|
|
|
+ if result_sentences[i]["sentence"] == best_sentence:
|
|
|
+ result_sentences[i]["matched"] = True
|
|
|
+
|
|
|
+ return StandardResponse(success=True, records=result_sentences)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Text comparison failed: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
+
|
|
|
text_search_router = router
|