|
@@ -1,305 +0,0 @@
|
|
|
-from fastapi import APIRouter, HTTPException
|
|
|
-from pydantic import BaseModel, Field, validator
|
|
|
-from typing import List, Optional
|
|
|
-from service.trunks_service import TrunksService
|
|
|
-from utils.text_splitter import TextSplitter
|
|
|
-from utils.vector_distance import VectorDistance
|
|
|
-from model.response import StandardResponse
|
|
|
-from utils.vectorizer import Vectorizer
|
|
|
-DISTANCE_THRESHOLD = 0.8
|
|
|
-import logging
|
|
|
-import time
|
|
|
-
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
-router = APIRouter(prefix="/text", tags=["Text Search"])
|
|
|
-
|
|
|
-class TextSearchRequest(BaseModel):
|
|
|
- text: str
|
|
|
- conversation_id: Optional[str] = None
|
|
|
- need_convert: Optional[bool] = False
|
|
|
-
|
|
|
-class TextCompareRequest(BaseModel):
|
|
|
- sentence: str
|
|
|
- text: str
|
|
|
-
|
|
|
-class TextMatchRequest(BaseModel):
|
|
|
- text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
|
|
|
-
|
|
|
- @validator('text')
|
|
|
- def validate_text(cls, v):
|
|
|
- # 保留所有可打印字符、换行符和中文字符
|
|
|
- v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
|
|
|
-
|
|
|
- # 转义JSON特殊字符
|
|
|
- # 先处理反斜杠,避免后续转义时出现问题
|
|
|
- v = v.replace('\\', '\\\\')
|
|
|
- # 处理引号和其他特殊字符
|
|
|
- v = v.replace('"', '\\"')
|
|
|
- v = v.replace('/', '\\/')
|
|
|
- # 处理控制字符
|
|
|
- v = v.replace('\n', '\\n')
|
|
|
- v = v.replace('\r', '\\r')
|
|
|
- v = v.replace('\t', '\\t')
|
|
|
- v = v.replace('\b', '\\b')
|
|
|
- v = v.replace('\f', '\\f')
|
|
|
- # 处理Unicode转义
|
|
|
- # v = v.replace('\u', '\\u')
|
|
|
-
|
|
|
- return v
|
|
|
-
|
|
|
-class TextCompareMultiRequest(BaseModel):
|
|
|
- origin: str
|
|
|
- similar: str
|
|
|
-
|
|
|
-@router.post("/search", response_model=StandardResponse)
|
|
|
-async def search_text(request: TextSearchRequest):
|
|
|
- try:
|
|
|
- #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
|
|
|
- if request.text.startswith('{') and request.text.endswith('}'):
|
|
|
- from utils.json_to_text import JsonToTextConverter
|
|
|
- converter = JsonToTextConverter()
|
|
|
- request.text = converter.convert(request.text)
|
|
|
-
|
|
|
- # 使用TextSplitter拆分文本
|
|
|
- sentences = TextSplitter.split_text(request.text)
|
|
|
- if not sentences:
|
|
|
- return StandardResponse(success=True, data={"answer": "", "references": []})
|
|
|
-
|
|
|
- # 初始化服务和结果列表
|
|
|
- trunks_service = TrunksService()
|
|
|
- result_sentences = []
|
|
|
- all_references = []
|
|
|
- reference_index = 1
|
|
|
-
|
|
|
- # 根据conversation_id获取缓存结果
|
|
|
- cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
|
|
|
-
|
|
|
- for sentence in sentences:
|
|
|
- # if request.need_convert:
|
|
|
- sentence = sentence.replace("\n", "<br>")
|
|
|
- if len(sentence) < 10:
|
|
|
- result_sentences.append(sentence)
|
|
|
- continue
|
|
|
- if cached_results:
|
|
|
- # 如果有缓存结果,计算向量距离
|
|
|
- min_distance = float('inf')
|
|
|
- best_result = None
|
|
|
- sentence_vector = Vectorizer.get_embedding(sentence)
|
|
|
-
|
|
|
- for cached_result in cached_results:
|
|
|
- content_vector = cached_result['embedding']
|
|
|
- distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
|
|
|
- if distance < min_distance:
|
|
|
- min_distance = distance
|
|
|
- best_result = {**cached_result, 'distance': distance}
|
|
|
-
|
|
|
-
|
|
|
- if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
|
|
|
- search_results = [best_result]
|
|
|
- else:
|
|
|
- search_results = []
|
|
|
- else:
|
|
|
- # 如果没有缓存结果,进行向量搜索
|
|
|
- search_results = trunks_service.search_by_vector(
|
|
|
- text=sentence,
|
|
|
- limit=1,
|
|
|
- type='trunk'
|
|
|
- )
|
|
|
-
|
|
|
- # 处理搜索结果
|
|
|
- for result in search_results:
|
|
|
- distance = result.get("distance", DISTANCE_THRESHOLD)
|
|
|
- if distance >= DISTANCE_THRESHOLD:
|
|
|
- result_sentences.append(sentence)
|
|
|
- continue
|
|
|
-
|
|
|
- # 检查是否已存在相同引用
|
|
|
- existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
|
|
|
- current_index = reference_index
|
|
|
- if existing_ref:
|
|
|
- current_index = int(existing_ref["index"])
|
|
|
- else:
|
|
|
- # 添加到引用列表
|
|
|
- reference = {
|
|
|
- "index": str(reference_index),
|
|
|
- "id": result["id"],
|
|
|
- "content": result["content"],
|
|
|
- "file_path": result.get("file_path", ""),
|
|
|
- "title": result.get("title", ""),
|
|
|
- "distance": distance,
|
|
|
- "referrence": result.get("referrence", "")
|
|
|
- }
|
|
|
- all_references.append(reference)
|
|
|
- reference_index += 1
|
|
|
-
|
|
|
- # 添加引用标记
|
|
|
- if sentence.endswith('<br>'):
|
|
|
- # 如果有多个<br>,在所有<br>前添加^[current_index]^
|
|
|
- result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
|
|
|
- else:
|
|
|
- # 直接在句子末尾添加^[current_index]^
|
|
|
- result_sentence = f'{sentence}^[{current_index}]^'
|
|
|
-
|
|
|
- result_sentences.append(result_sentence)
|
|
|
-
|
|
|
- # 组装返回数据
|
|
|
- response_data = {
|
|
|
- "answer": result_sentences,
|
|
|
- "references": all_references
|
|
|
- }
|
|
|
-
|
|
|
- return StandardResponse(success=True, data=response_data)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Text search failed: {str(e)}")
|
|
|
- raise HTTPException(status_code=500, detail=str(e))
|
|
|
-
|
|
|
-@router.post("/match", response_model=StandardResponse)
|
|
|
-async def match_text(request: TextCompareRequest):
|
|
|
- try:
|
|
|
- sentences = TextSplitter.split_text(request.text)
|
|
|
- sentence_vector = Vectorizer.get_embedding(request.sentence)
|
|
|
- min_distance = float('inf')
|
|
|
- best_sentence = ""
|
|
|
- result_sentences = []
|
|
|
- for temp in sentences:
|
|
|
- result_sentences.append(temp)
|
|
|
- if len(temp) < 10:
|
|
|
- continue
|
|
|
- temp_vector = Vectorizer.get_embedding(temp)
|
|
|
- distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
|
|
|
- if distance < min_distance and distance < DISTANCE_THRESHOLD:
|
|
|
- min_distance = distance
|
|
|
- best_sentence = temp
|
|
|
-
|
|
|
- for i in range(len(result_sentences)):
|
|
|
- result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
|
|
|
- if result_sentences[i]["sentence"] == best_sentence:
|
|
|
- result_sentences[i]["matched"] = True
|
|
|
-
|
|
|
- return StandardResponse(success=True, records=result_sentences)
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Text comparison failed: {str(e)}")
|
|
|
- raise HTTPException(status_code=500, detail=str(e))
|
|
|
-
|
|
|
-@router.post("/mr_search", response_model=StandardResponse)
|
|
|
-async def mr_search_text_content(request: TextMatchRequest):
|
|
|
- try:
|
|
|
- # 初始化服务
|
|
|
- trunks_service = TrunksService()
|
|
|
-
|
|
|
- # 获取文本向量并搜索相似内容
|
|
|
- search_results = trunks_service.search_by_vector(
|
|
|
- text=request.text,
|
|
|
- limit=10,
|
|
|
- type="mr"
|
|
|
- )
|
|
|
-
|
|
|
- # 处理搜索结果
|
|
|
- records = []
|
|
|
- for result in search_results:
|
|
|
- distance = result.get("distance", DISTANCE_THRESHOLD)
|
|
|
- if distance >= DISTANCE_THRESHOLD:
|
|
|
- continue
|
|
|
-
|
|
|
- # 添加到引用列表
|
|
|
- record = {
|
|
|
- "content": result["content"],
|
|
|
- "file_path": result.get("file_path", ""),
|
|
|
- "title": result.get("title", ""),
|
|
|
- "distance": distance,
|
|
|
- }
|
|
|
- records.append(record)
|
|
|
-
|
|
|
- # 组装返回数据
|
|
|
- response_data = {
|
|
|
- "records": records
|
|
|
- }
|
|
|
-
|
|
|
- return StandardResponse(success=True, data=response_data)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Mr search failed: {str(e)}")
|
|
|
- raise HTTPException(status_code=500, detail=str(e))
|
|
|
-
|
|
|
-@router.post("/mr_match", response_model=StandardResponse)
|
|
|
-async def compare_text(request: TextCompareMultiRequest):
|
|
|
- start_time = time.time()
|
|
|
- try:
|
|
|
- # 拆分两段文本
|
|
|
- origin_sentences = TextSplitter.split_text(request.origin)
|
|
|
- similar_sentences = TextSplitter.split_text(request.similar)
|
|
|
- end_time = time.time()
|
|
|
- logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
-
|
|
|
- # 初始化结果列表
|
|
|
- origin_results = []
|
|
|
-
|
|
|
- # 过滤短句并预计算向量
|
|
|
- valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
|
|
|
- valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
|
|
|
-
|
|
|
- # 初始化similar_results,所有matched设为False
|
|
|
- similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
|
|
|
-
|
|
|
- # 批量获取向量
|
|
|
- origin_vectors = {}
|
|
|
- similar_vectors = {}
|
|
|
- origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
|
|
|
- similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
|
|
|
-
|
|
|
- if origin_batch:
|
|
|
- origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
|
|
|
- origin_vectors = dict(zip(origin_batch, origin_embeddings))
|
|
|
-
|
|
|
- if similar_batch:
|
|
|
- similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
|
|
|
- similar_vectors = dict(zip(similar_batch, similar_embeddings))
|
|
|
-
|
|
|
- end_time = time.time()
|
|
|
- logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
- # 处理origin文本
|
|
|
- for origin_sent, is_valid in valid_origin_sentences:
|
|
|
- if not is_valid:
|
|
|
- origin_results.append({"sentence": origin_sent, "matched": False})
|
|
|
- continue
|
|
|
-
|
|
|
- origin_vector = origin_vectors[origin_sent]
|
|
|
- matched = False
|
|
|
-
|
|
|
- # 优化的相似度计算
|
|
|
- for i, similar_result in enumerate(similar_results):
|
|
|
- if similar_result["matched"]:
|
|
|
- continue
|
|
|
-
|
|
|
- similar_sent = similar_result["sentence"]
|
|
|
- if len(similar_sent) < 10:
|
|
|
- continue
|
|
|
-
|
|
|
- similar_vector = similar_vectors.get(similar_sent)
|
|
|
- if not similar_vector:
|
|
|
- continue
|
|
|
-
|
|
|
- distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
|
|
|
- if distance < DISTANCE_THRESHOLD:
|
|
|
- matched = True
|
|
|
- similar_results[i]["matched"] = True
|
|
|
- break
|
|
|
-
|
|
|
- origin_results.append({"sentence": origin_sent, "matched": matched})
|
|
|
-
|
|
|
- response_data = {
|
|
|
- "origin": origin_results,
|
|
|
- "similar": similar_results
|
|
|
- }
|
|
|
-
|
|
|
- end_time = time.time()
|
|
|
- logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
- return StandardResponse(success=True, data=response_data)
|
|
|
- except Exception as e:
|
|
|
- end_time = time.time()
|
|
|
- logger.error(f"Text comparison failed: {str(e)}")
|
|
|
- logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
- raise HTTPException(status_code=500, detail=str(e))
|
|
|
-
|
|
|
-text_search_router = router
|