|
@@ -1,12 +1,12 @@
|
|
|
from fastapi import APIRouter, HTTPException
|
|
|
-from pydantic import BaseModel
|
|
|
+from pydantic import BaseModel, Field, validator
|
|
|
from typing import List, Optional
|
|
|
from service.trunks_service import TrunksService
|
|
|
from utils.text_splitter import TextSplitter
|
|
|
from utils.vector_distance import VectorDistance
|
|
|
from model.response import StandardResponse
|
|
|
from utils.vectorizer import Vectorizer
|
|
|
-DISTANCE_THRESHOLD = 0.65
|
|
|
+DISTANCE_THRESHOLD = 0.8
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -21,6 +21,35 @@ class TextCompareRequest(BaseModel):
|
|
|
sentence: str
|
|
|
text: str
|
|
|
|
|
|
+class TextMatchRequest(BaseModel):
|
|
|
+ text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
|
|
|
+
|
|
|
+ @validator('text')
|
|
|
+ def validate_text(cls, v):
|
|
|
+ # 保留所有可打印字符、换行符和中文字符
|
|
|
+ v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
|
|
|
+
|
|
|
+ # 转义JSON特殊字符
|
|
|
+ # 先处理反斜杠,避免后续转义时出现问题
|
|
|
+ v = v.replace('\\', '\\\\')
|
|
|
+ # 处理引号和其他特殊字符
|
|
|
+ v = v.replace('"', '\\"')
|
|
|
+ v = v.replace('/', '\\/')
|
|
|
+ # 处理控制字符
|
|
|
+ v = v.replace('\n', '\\n')
|
|
|
+ v = v.replace('\r', '\\r')
|
|
|
+ v = v.replace('\t', '\\t')
|
|
|
+ v = v.replace('\b', '\\b')
|
|
|
+ v = v.replace('\f', '\\f')
|
|
|
+ # 处理Unicode转义
|
|
|
+ # v = v.replace('\u', '\\u')
|
|
|
+
|
|
|
+ return v
|
|
|
+
|
|
|
+class TextCompareMultiRequest(BaseModel):
|
|
|
+ origin: str
|
|
|
+ similar: str
|
|
|
+
|
|
|
@router.post("/search", response_model=StandardResponse)
|
|
|
async def search_text(request: TextSearchRequest):
|
|
|
try:
|
|
@@ -145,4 +174,119 @@ async def match_text(request: TextCompareRequest):
|
|
|
logger.error(f"Text comparison failed: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
+@router.post("/mr_search", response_model=StandardResponse)
|
|
|
+async def mr_search_text_content(request: TextMatchRequest):
|
|
|
+ try:
|
|
|
+ # 初始化服务
|
|
|
+ trunks_service = TrunksService()
|
|
|
+
|
|
|
+ # 获取文本向量并搜索相似内容
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
+ text=request.text,
|
|
|
+ limit=10,
|
|
|
+ type="mr"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 处理搜索结果
|
|
|
+ records = []
|
|
|
+ for result in search_results:
|
|
|
+ distance = result.get("distance", DISTANCE_THRESHOLD)
|
|
|
+ if distance >= DISTANCE_THRESHOLD:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 添加到引用列表
|
|
|
+ record = {
|
|
|
+ "content": result["content"],
|
|
|
+ "file_path": result.get("file_path", ""),
|
|
|
+ "title": result.get("title", ""),
|
|
|
+ "distance": distance,
|
|
|
+ }
|
|
|
+ records.append(record)
|
|
|
+
|
|
|
+ # 组装返回数据
|
|
|
+ response_data = {
|
|
|
+ "records": records
|
|
|
+ }
|
|
|
+
|
|
|
+ return StandardResponse(success=True, data=response_data)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Mr search failed: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
+
|
|
|
+@router.post("/mr_match", response_model=StandardResponse)
|
|
|
+async def compare_text(request: TextCompareMultiRequest):
|
|
|
+ try:
|
|
|
+ # 拆分两段文本
|
|
|
+ origin_sentences = TextSplitter.split_text(request.origin)
|
|
|
+ similar_sentences = TextSplitter.split_text(request.similar)
|
|
|
+
|
|
|
+ # 初始化结果列表
|
|
|
+ origin_results = []
|
|
|
+ similar_results = []
|
|
|
+
|
|
|
+ # 获取origin文本的向量
|
|
|
+ for origin_sent in origin_sentences:
|
|
|
+ if len(origin_sent) < 10:
|
|
|
+ origin_results.append({"sentence": origin_sent, "matched": False})
|
|
|
+ continue
|
|
|
+
|
|
|
+ origin_vector = Vectorizer.get_embedding(origin_sent)
|
|
|
+ min_distance = float('inf')
|
|
|
+ matched_sent = ""
|
|
|
+
|
|
|
+ # 与similar文本的每个句子计算相似度
|
|
|
+ for similar_sent in similar_sentences:
|
|
|
+ if len(similar_sent) < 10:
|
|
|
+ continue
|
|
|
+
|
|
|
+ similar_vector = Vectorizer.get_embedding(similar_sent)
|
|
|
+ distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
|
|
|
+
|
|
|
+ if distance < min_distance and distance < DISTANCE_THRESHOLD:
|
|
|
+ min_distance = distance
|
|
|
+ matched_sent = similar_sent
|
|
|
+
|
|
|
+ origin_results.append({
|
|
|
+ "sentence": origin_sent,
|
|
|
+ "matched": bool(matched_sent)
|
|
|
+ })
|
|
|
+
|
|
|
+ # 获取similar文本的向量
|
|
|
+ for similar_sent in similar_sentences:
|
|
|
+ if len(similar_sent) < 10:
|
|
|
+ similar_results.append({"sentence": similar_sent, "matched": False})
|
|
|
+ continue
|
|
|
+
|
|
|
+ similar_vector = Vectorizer.get_embedding(similar_sent)
|
|
|
+ min_distance = float('inf')
|
|
|
+ matched_sent = ""
|
|
|
+
|
|
|
+ # 与origin文本的每个句子计算相似度
|
|
|
+ for origin_sent in origin_sentences:
|
|
|
+ if len(origin_sent) < 10:
|
|
|
+ continue
|
|
|
+
|
|
|
+ origin_vector = Vectorizer.get_embedding(origin_sent)
|
|
|
+ distance = VectorDistance.calculate_distance(similar_vector, origin_vector)
|
|
|
+
|
|
|
+ if distance < min_distance and distance < DISTANCE_THRESHOLD:
|
|
|
+ min_distance = distance
|
|
|
+ matched_sent = origin_sent
|
|
|
+
|
|
|
+ similar_results.append({
|
|
|
+ "sentence": similar_sent,
|
|
|
+ "matched": bool(matched_sent)
|
|
|
+ })
|
|
|
+
|
|
|
+ response_data = {
|
|
|
+ "origin": origin_results,
|
|
|
+ "similar": similar_results
|
|
|
+ }
|
|
|
+
|
|
|
+ return StandardResponse(success=True, data=response_data)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Text comparison failed: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
+
|
|
|
text_search_router = router
|