123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- from fastapi import APIRouter, HTTPException
- from pydantic import BaseModel
- from typing import List, Optional
- from service.trunks_service import TrunksService
- from utils.text_splitter import TextSplitter
- from utils.vector_distance import VectorDistance
- from model.response import StandardResponse
- from utils.vectorizer import Vectorizer
- DISTANCE_THRESHOLD = 0.65
- import logging
- logger = logging.getLogger(__name__)
- router = APIRouter(prefix="/text", tags=["Text Search"])
- class TextSearchRequest(BaseModel):
- text: str
- conversation_id: Optional[str] = None
- need_convert: Optional[bool] = False
- class TextCompareRequest(BaseModel):
- sentence: str
- text: str
- @router.post("/search", response_model=StandardResponse)
- async def search_text(request: TextSearchRequest):
- try:
- # 使用TextSplitter拆分文本
- sentences = TextSplitter.split_text(request.text)
- if not sentences:
- return StandardResponse(success=True, data={"answer": "", "references": []})
-
- # 初始化服务和结果列表
- trunks_service = TrunksService()
- result_sentences = []
- all_references = []
- reference_index = 1
-
- # 根据conversation_id获取缓存结果
- cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
-
- for sentence in sentences:
- # if request.need_convert:
- sentence = sentence.replace("\n", "<br>")
- if len(sentence) < 10:
- result_sentences.append(sentence)
- continue
- if cached_results:
- # 如果有缓存结果,计算向量距离
- min_distance = float('inf')
- best_result = None
- sentence_vector = Vectorizer.get_embedding(sentence)
-
- for cached_result in cached_results:
- content_vector = cached_result['embedding']
- distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
- if distance < min_distance:
- min_distance = distance
- best_result = {**cached_result, 'distance': distance}
-
-
- if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
- search_results = [best_result]
- else:
- search_results = []
- else:
- # 如果没有缓存结果,进行向量搜索
- search_results = trunks_service.search_by_vector(
- text=sentence,
- limit=1
- )
-
- # 处理搜索结果
- for result in search_results:
- distance = result.get("distance", DISTANCE_THRESHOLD)
- if distance >= DISTANCE_THRESHOLD:
- result_sentences.append(sentence)
- continue
-
- # 检查是否已存在相同引用
- existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
- current_index = reference_index
- if existing_ref:
- current_index = int(existing_ref["index"])
- else:
- # 添加到引用列表
- reference = {
- "index": str(reference_index),
- "id": result["id"],
- "content": result["content"],
- "file_path": result.get("file_path", ""),
- "title": result.get("title", ""),
- "distance": distance,
- "referrence": result.get("referrence", "")
- }
- all_references.append(reference)
- reference_index += 1
-
- # 添加引用标记
- if sentence.endswith('<br>'):
- # 如果有多个<br>,在所有<br>前添加^[current_index]^
- result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
- else:
- # 直接在句子末尾添加^[current_index]^
- result_sentence = f'{sentence}^[{current_index}]^'
-
- result_sentences.append(result_sentence)
-
- # 组装返回数据
- response_data = {
- "answer": result_sentences,
- "references": all_references
- }
-
- return StandardResponse(success=True, data=response_data)
-
- except Exception as e:
- logger.error(f"Text search failed: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- @router.post("/match", response_model=StandardResponse)
- async def match_text(request: TextCompareRequest):
- try:
- sentences = TextSplitter.split_text(request.text)
- sentence_vector = Vectorizer.get_embedding(request.sentence)
- min_distance = float('inf')
- best_sentence = ""
- result_sentences = []
- for temp in sentences:
- result_sentences.append(temp)
- if len(temp) < 10:
- continue
- temp_vector = Vectorizer.get_embedding(temp)
- distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
- if distance < min_distance and distance < DISTANCE_THRESHOLD:
- min_distance = distance
- best_sentence = temp
- for i in range(len(result_sentences)):
- result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
- if result_sentences[i]["sentence"] == best_sentence:
- result_sentences[i]["matched"] = True
-
- return StandardResponse(success=True, records=result_sentences)
- except Exception as e:
- logger.error(f"Text comparison failed: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- text_search_router = router
|