from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from service.trunks_service import TrunksService
from utils.text_splitter import TextSplitter
from utils.vector_distance import VectorDistance
from model.response import StandardResponse
from utils.vectorizer import Vectorizer
DISTANCE_THRESHOLD = 0.65
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/text", tags=["Text Search"])
class TextSearchRequest(BaseModel):
text: str
conversation_id: Optional[str] = None
need_convert: Optional[bool] = False
class TextCompareRequest(BaseModel):
sentence: str
text: str
@router.post("/search", response_model=StandardResponse)
async def search_text(request: TextSearchRequest):
try:
# 使用TextSplitter拆分文本
sentences = TextSplitter.split_text(request.text)
if not sentences:
return StandardResponse(success=True, data={"answer": "", "references": []})
# 初始化服务和结果列表
trunks_service = TrunksService()
result_sentences = []
all_references = []
reference_index = 1
# 根据conversation_id获取缓存结果
cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
for sentence in sentences:
# if request.need_convert:
sentence = sentence.replace("\n", "
")
if len(sentence) < 10:
result_sentences.append(sentence)
continue
if cached_results:
# 如果有缓存结果,计算向量距离
min_distance = float('inf')
best_result = None
sentence_vector = Vectorizer.get_embedding(sentence)
for cached_result in cached_results:
content_vector = cached_result['embedding']
distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
if distance < min_distance:
min_distance = distance
best_result = {**cached_result, 'distance': distance}
if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
search_results = [best_result]
else:
search_results = []
else:
# 如果没有缓存结果,进行向量搜索
search_results = trunks_service.search_by_vector(
text=sentence,
limit=1
)
# 处理搜索结果
for result in search_results:
distance = result.get("distance", DISTANCE_THRESHOLD)
if distance >= DISTANCE_THRESHOLD:
result_sentences.append(sentence)
continue
# 检查是否已存在相同引用
existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
current_index = reference_index
if existing_ref:
current_index = int(existing_ref["index"])
else:
# 添加到引用列表
reference = {
"index": str(reference_index),
"id": result["id"],
"content": result["content"],
"file_path": result.get("file_path", ""),
"title": result.get("title", ""),
"distance": distance,
"referrence": result.get("referrence", "")
}
all_references.append(reference)
reference_index += 1
# 添加引用标记
if sentence.endswith('
'):
# 如果有多个
,在所有
前添加^[current_index]^
result_sentence = sentence.replace('
', f'^[{current_index}]^
')
else:
# 直接在句子末尾添加^[current_index]^
result_sentence = f'{sentence}^[{current_index}]^'
result_sentences.append(result_sentence)
# 组装返回数据
response_data = {
"answer": result_sentences,
"references": all_references
}
return StandardResponse(success=True, data=response_data)
except Exception as e:
logger.error(f"Text search failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/match", response_model=StandardResponse)
async def match_text(request: TextCompareRequest):
try:
sentences = TextSplitter.split_text(request.text)
sentence_vector = Vectorizer.get_embedding(request.sentence)
min_distance = float('inf')
best_sentence = ""
result_sentences = []
for temp in sentences:
result_sentences.append(temp)
if len(temp) < 10:
continue
temp_vector = Vectorizer.get_embedding(temp)
distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
if distance < min_distance and distance < DISTANCE_THRESHOLD:
min_distance = distance
best_sentence = temp
for i in range(len(result_sentences)):
result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
if result_sentences[i]["sentence"] == best_sentence:
result_sentences[i]["matched"] = True
return StandardResponse(success=True, records=result_sentences)
except Exception as e:
logger.error(f"Text comparison failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
text_search_router = router