|
@@ -8,6 +8,7 @@ from model.response import StandardResponse
|
|
|
from utils.vectorizer import Vectorizer
|
|
|
DISTANCE_THRESHOLD = 0.8
|
|
|
import logging
|
|
|
+import time
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
router = APIRouter(prefix="/text", tags=["Text Search"])
|
|
@@ -101,7 +102,8 @@ async def search_text(request: TextSearchRequest):
|
|
|
# 如果没有缓存结果,进行向量搜索
|
|
|
search_results = trunks_service.search_by_vector(
|
|
|
text=sentence,
|
|
|
- limit=1
|
|
|
+ limit=1,
|
|
|
+ type='trunk'
|
|
|
)
|
|
|
|
|
|
# 处理搜索结果
|
|
@@ -222,77 +224,82 @@ async def mr_search_text_content(request: TextMatchRequest):
|
|
|
|
|
|
@router.post("/mr_match", response_model=StandardResponse)
|
|
|
async def compare_text(request: TextCompareMultiRequest):
|
|
|
+ start_time = time.time()
|
|
|
try:
|
|
|
# 拆分两段文本
|
|
|
origin_sentences = TextSplitter.split_text(request.origin)
|
|
|
similar_sentences = TextSplitter.split_text(request.similar)
|
|
|
+ end_time = time.time()
|
|
|
+ logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
|
|
|
# 初始化结果列表
|
|
|
origin_results = []
|
|
|
- similar_results = []
|
|
|
|
|
|
- # 获取origin文本的向量
|
|
|
- for origin_sent in origin_sentences:
|
|
|
- if len(origin_sent) < 10:
|
|
|
+ # 过滤短句并预计算向量
|
|
|
+ valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
|
|
|
+ valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
|
|
|
+
|
|
|
+ # 初始化similar_results,所有matched设为False
|
|
|
+ similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
|
|
|
+
|
|
|
+ # 批量获取向量
|
|
|
+ origin_vectors = {}
|
|
|
+ similar_vectors = {}
|
|
|
+ origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
|
|
|
+ similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
|
|
|
+
|
|
|
+ if origin_batch:
|
|
|
+ origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
|
|
|
+ origin_vectors = dict(zip(origin_batch, origin_embeddings))
|
|
|
+
|
|
|
+ if similar_batch:
|
|
|
+ similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
|
|
|
+ similar_vectors = dict(zip(similar_batch, similar_embeddings))
|
|
|
+
|
|
|
+ end_time = time.time()
|
|
|
+ logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
+ # 处理origin文本
|
|
|
+ for origin_sent, is_valid in valid_origin_sentences:
|
|
|
+ if not is_valid:
|
|
|
origin_results.append({"sentence": origin_sent, "matched": False})
|
|
|
continue
|
|
|
-
|
|
|
- origin_vector = Vectorizer.get_embedding(origin_sent)
|
|
|
- min_distance = float('inf')
|
|
|
- matched_sent = ""
|
|
|
|
|
|
- # 与similar文本的每个句子计算相似度
|
|
|
- for similar_sent in similar_sentences:
|
|
|
+ origin_vector = origin_vectors[origin_sent]
|
|
|
+ matched = False
|
|
|
+
|
|
|
+ # 优化的相似度计算
|
|
|
+ for i, similar_result in enumerate(similar_results):
|
|
|
+ if similar_result["matched"]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ similar_sent = similar_result["sentence"]
|
|
|
if len(similar_sent) < 10:
|
|
|
continue
|
|
|
|
|
|
- similar_vector = Vectorizer.get_embedding(similar_sent)
|
|
|
- distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
|
|
|
-
|
|
|
- if distance < min_distance and distance < DISTANCE_THRESHOLD:
|
|
|
- min_distance = distance
|
|
|
- matched_sent = similar_sent
|
|
|
-
|
|
|
- origin_results.append({
|
|
|
- "sentence": origin_sent,
|
|
|
- "matched": bool(matched_sent)
|
|
|
- })
|
|
|
-
|
|
|
- # 获取similar文本的向量
|
|
|
- for similar_sent in similar_sentences:
|
|
|
- if len(similar_sent) < 10:
|
|
|
- similar_results.append({"sentence": similar_sent, "matched": False})
|
|
|
- continue
|
|
|
-
|
|
|
- similar_vector = Vectorizer.get_embedding(similar_sent)
|
|
|
- min_distance = float('inf')
|
|
|
- matched_sent = ""
|
|
|
-
|
|
|
- # 与origin文本的每个句子计算相似度
|
|
|
- for origin_sent in origin_sentences:
|
|
|
- if len(origin_sent) < 10:
|
|
|
+ similar_vector = similar_vectors.get(similar_sent)
|
|
|
+ if not similar_vector:
|
|
|
continue
|
|
|
|
|
|
- origin_vector = Vectorizer.get_embedding(origin_sent)
|
|
|
- distance = VectorDistance.calculate_distance(similar_vector, origin_vector)
|
|
|
-
|
|
|
- if distance < min_distance and distance < DISTANCE_THRESHOLD:
|
|
|
- min_distance = distance
|
|
|
- matched_sent = origin_sent
|
|
|
+ distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
|
|
|
+ if distance < DISTANCE_THRESHOLD:
|
|
|
+ matched = True
|
|
|
+ similar_results[i]["matched"] = True
|
|
|
+ break
|
|
|
|
|
|
- similar_results.append({
|
|
|
- "sentence": similar_sent,
|
|
|
- "matched": bool(matched_sent)
|
|
|
- })
|
|
|
+ origin_results.append({"sentence": origin_sent, "matched": matched})
|
|
|
|
|
|
response_data = {
|
|
|
"origin": origin_results,
|
|
|
"similar": similar_results
|
|
|
}
|
|
|
|
|
|
+ end_time = time.time()
|
|
|
+ logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
return StandardResponse(success=True, data=response_data)
|
|
|
except Exception as e:
|
|
|
+ end_time = time.time()
|
|
|
logger.error(f"Text comparison failed: {str(e)}")
|
|
|
+ logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
text_search_router = router
|