|
@@ -0,0 +1,877 @@
|
|
|
|
+from fastapi import APIRouter, HTTPException, Depends
|
|
|
|
+from pydantic import BaseModel, Field, validator
|
|
|
|
+from typing import List, Optional
|
|
|
|
+from ..service.trunks_service import TrunksService
|
|
|
|
+from ..utils.sentence_util import SentenceUtil
|
|
|
|
+
|
|
|
|
+from ..utils.vector_distance import VectorDistance
|
|
|
|
+from ..model.response import StandardResponse
|
|
|
|
+from utils.vectorizer import Vectorizer
|
|
|
|
+# from utils.find_text_in_pdf import find_text_in_pdf
|
|
|
|
+import os
|
|
|
|
+DISTANCE_THRESHOLD = 0.73
|
|
|
|
+import logging
|
|
|
|
+import time
|
|
|
|
+from ..db.session import get_db
|
|
|
|
+from sqlalchemy.orm import Session
|
|
|
|
+from ..service.kg_node_service import KGNodeService
|
|
|
|
+from ..service.kg_prop_service import KGPropService
|
|
|
|
+from ..service.kg_edge_service import KGEdgeService
|
|
|
|
+
|
|
|
|
+from cachetools import TTLCache
|
|
|
|
+
|
|
|
|
+# 使用TextSimilarityFinder进行文本相似度匹配
|
|
|
|
+from utils.text_similarity import TextSimilarityFinder
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
+router = APIRouter(tags=["Text Search"])
|
|
|
|
+
|
|
|
|
+# 创建全局缓存实例
|
|
|
|
+cache = TTLCache(maxsize=1000, ttl=3600)
|
|
|
|
+
|
|
|
|
+class TextSearchRequest(BaseModel):
|
|
|
|
+ text: str
|
|
|
|
+ conversation_id: Optional[str] = None
|
|
|
|
+ need_convert: Optional[bool] = False
|
|
|
|
+
|
|
|
|
+class TextCompareRequest(BaseModel):
|
|
|
|
+ sentence: str
|
|
|
|
+ text: str
|
|
|
|
+
|
|
|
|
+class TextMatchRequest(BaseModel):
|
|
|
|
+ text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
|
|
|
|
+
|
|
|
|
+ @validator('text')
|
|
|
|
+ def validate_text(cls, v):
|
|
|
|
+ # 保留所有可打印字符、换行符和中文字符
|
|
|
|
+ v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
|
|
|
|
+
|
|
|
|
+ # 转义JSON特殊字符
|
|
|
|
+ # 先处理反斜杠,避免后续转义时出现问题
|
|
|
|
+ v = v.replace('\\', '\\\\')
|
|
|
|
+ # 处理引号和其他特殊字符
|
|
|
|
+ v = v.replace('"', '\\"')
|
|
|
|
+ v = v.replace('/', '\\/')
|
|
|
|
+ # 处理控制字符
|
|
|
|
+ v = v.replace('\n', '\\n')
|
|
|
|
+ v = v.replace('\r', '\\r')
|
|
|
|
+ v = v.replace('\t', '\\t')
|
|
|
|
+ v = v.replace('\b', '\\b')
|
|
|
|
+ v = v.replace('\f', '\\f')
|
|
|
|
+ # 处理Unicode转义
|
|
|
|
+ # v = v.replace('\u', '\\u')
|
|
|
|
+
|
|
|
|
+ return v
|
|
|
|
+
|
|
|
|
+class TextCompareMultiRequest(BaseModel):
|
|
|
|
+ origin: str
|
|
|
|
+ similar: str
|
|
|
|
+
|
|
|
|
+class NodePropsSearchRequest(BaseModel):
|
|
|
|
+ node_id: int
|
|
|
|
+ props_ids: List[int]
|
|
|
|
+ symptoms: Optional[List[str]] = None
|
|
|
|
+
|
|
|
|
+@router.post("/kgrt_api/text/clear_cache", response_model=StandardResponse)
|
|
|
|
+async def clear_cache():
|
|
|
|
+ try:
|
|
|
|
+ # 清除全局缓存
|
|
|
|
+ cache.clear()
|
|
|
|
+ return StandardResponse(success=True, data={"message": "缓存已清除"})
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"清除缓存失败: {str(e)}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
+
|
|
|
|
+@router.post("/kgrt_api/text/search", response_model=StandardResponse)
|
|
|
|
+@router.post("/knowledge/text/search", response_model=StandardResponse)
|
|
|
|
+async def search_text(request: TextSearchRequest):
|
|
|
|
+ try:
|
|
|
|
+ #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
|
|
|
|
+ if request.text.startswith('{') and request.text.endswith('}'):
|
|
|
|
+ from utils.json_to_text import JsonToTextConverter
|
|
|
|
+ converter = JsonToTextConverter()
|
|
|
|
+ request.text = converter.convert(request.text)
|
|
|
|
+
|
|
|
|
+ # 使用TextSplitter拆分文本
|
|
|
|
+ sentences = SentenceUtil.split_text(request.text)
|
|
|
|
+ if not sentences:
|
|
|
|
+ return StandardResponse(success=True, data={"answer": "", "references": []})
|
|
|
|
+
|
|
|
|
+ # 初始化服务和结果列表
|
|
|
|
+ trunks_service = TrunksService()
|
|
|
|
+ result_sentences = []
|
|
|
|
+ all_references = []
|
|
|
|
+ reference_index = 1
|
|
|
|
+
|
|
|
|
+ # 根据conversation_id获取缓存结果
|
|
|
|
+ cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
|
|
|
|
+
|
|
|
|
+ for sentence in sentences:
|
|
|
|
+ # if request.need_convert:
|
|
|
|
+ sentence = sentence.replace("\n", "<br>")
|
|
|
|
+ if len(sentence) < 10:
|
|
|
|
+ result_sentences.append(sentence)
|
|
|
|
+ continue
|
|
|
|
+ if cached_results:
|
|
|
|
+ # 如果有缓存结果,计算向量距离
|
|
|
|
+ min_distance = float('inf')
|
|
|
|
+ best_result = None
|
|
|
|
+ sentence_vector = Vectorizer.get_embedding(sentence)
|
|
|
|
+
|
|
|
|
+ for cached_result in cached_results:
|
|
|
|
+ content_vector = cached_result['embedding']
|
|
|
|
+ distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
|
|
|
|
+ if distance < min_distance:
|
|
|
|
+ min_distance = distance
|
|
|
|
+ best_result = {**cached_result, 'distance': distance}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
|
|
|
|
+ search_results = [best_result]
|
|
|
|
+ else:
|
|
|
|
+ search_results = []
|
|
|
|
+ else:
|
|
|
|
+ # 如果没有缓存结果,进行向量搜索
|
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
|
+ text=sentence,
|
|
|
|
+ limit=1,
|
|
|
|
+ type='trunk'
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 处理搜索结果
|
|
|
|
+ for search_result in search_results:
|
|
|
|
+ distance = search_result.get("distance", DISTANCE_THRESHOLD)
|
|
|
|
+ if distance >= DISTANCE_THRESHOLD:
|
|
|
|
+ result_sentences.append(sentence)
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 检查是否已存在相同引用
|
|
|
|
+ existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
|
|
|
|
+ current_index = reference_index
|
|
|
|
+ if existing_ref:
|
|
|
|
+ current_index = int(existing_ref["index"])
|
|
|
|
+ else:
|
|
|
|
+ # 添加到引用列表
|
|
|
|
+ # 从referrence中提取文件名
|
|
|
|
+ file_name = ""
|
|
|
|
+ referrence = search_result.get("referrence", "")
|
|
|
|
+ if referrence and "/books/" in referrence:
|
|
|
|
+ file_name = referrence.split("/books/")[-1]
|
|
|
|
+ # 去除文件扩展名
|
|
|
|
+ file_name = os.path.splitext(file_name)[0]
|
|
|
|
+
|
|
|
|
+ reference = {
|
|
|
|
+ "index": str(reference_index),
|
|
|
|
+ "id": search_result["id"],
|
|
|
|
+ "content": search_result["content"],
|
|
|
|
+ "file_path": search_result.get("file_path", ""),
|
|
|
|
+ "title": search_result.get("title", ""),
|
|
|
|
+ "distance": distance,
|
|
|
|
+ "file_name": file_name,
|
|
|
|
+ "referrence": referrence
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ all_references.append(reference)
|
|
|
|
+ reference_index += 1
|
|
|
|
+
|
|
|
|
+ # 添加引用标记
|
|
|
|
+ if sentence.endswith('<br>'):
|
|
|
|
+ # 如果有多个<br>,在所有<br>前添加^[current_index]^
|
|
|
|
+ result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
|
|
|
|
+ else:
|
|
|
|
+ # 直接在句子末尾添加^[current_index]^
|
|
|
|
+ result_sentence = f'{sentence}^[{current_index}]^'
|
|
|
|
+
|
|
|
|
+ result_sentences.append(result_sentence)
|
|
|
|
+
|
|
|
|
+ # 组装返回数据
|
|
|
|
+ response_data = {
|
|
|
|
+ "answer": result_sentences,
|
|
|
|
+ "references": all_references
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return StandardResponse(success=True, data=response_data)
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Text search failed: {str(e)}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
+
|
|
|
|
+@router.post("/kgrt_api/text/match", response_model=StandardResponse)
|
|
|
|
+@router.post("/knowledge/text/match", response_model=StandardResponse)
|
|
|
|
+async def match_text(request: TextCompareRequest):
|
|
|
|
+ try:
|
|
|
|
+ sentences = SentenceUtil.split_text(request.text)
|
|
|
|
+ sentence_vector = Vectorizer.get_embedding(request.sentence)
|
|
|
|
+ min_distance = float('inf')
|
|
|
|
+ best_sentence = ""
|
|
|
|
+ result_sentences = []
|
|
|
|
+ for temp in sentences:
|
|
|
|
+ result_sentences.append(temp)
|
|
|
|
+ if len(temp) < 10:
|
|
|
|
+ continue
|
|
|
|
+ temp_vector = Vectorizer.get_embedding(temp)
|
|
|
|
+ distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
|
|
|
|
+ if distance < min_distance and distance < DISTANCE_THRESHOLD:
|
|
|
|
+ min_distance = distance
|
|
|
|
+ best_sentence = temp
|
|
|
|
+
|
|
|
|
+ for i in range(len(result_sentences)):
|
|
|
|
+ result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
|
|
|
|
+ if result_sentences[i]["sentence"] == best_sentence:
|
|
|
|
+ result_sentences[i]["matched"] = True
|
|
|
|
+
|
|
|
|
+ return StandardResponse(success=True, records=result_sentences)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Text comparison failed: {str(e)}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
+
|
|
|
|
+@router.post("/kgrt_api/text/mr_search", response_model=StandardResponse)
|
|
|
|
+@router.post("/knowledge/text/mr_search", response_model=StandardResponse)
|
|
|
|
+async def mr_search_text_content(request: TextMatchRequest):
|
|
|
|
+ try:
|
|
|
|
+ # 初始化服务
|
|
|
|
+ trunks_service = TrunksService()
|
|
|
|
+
|
|
|
|
+ # 获取文本向量并搜索相似内容
|
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
|
+ text=request.text,
|
|
|
|
+ limit=10,
|
|
|
|
+ type="mr"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 处理搜索结果
|
|
|
|
+ records = []
|
|
|
|
+ for result in search_results:
|
|
|
|
+ distance = result.get("distance", DISTANCE_THRESHOLD)
|
|
|
|
+ if distance >= DISTANCE_THRESHOLD:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 添加到引用列表
|
|
|
|
+ record = {
|
|
|
|
+ "content": result["content"],
|
|
|
|
+ "file_path": result.get("file_path", ""),
|
|
|
|
+ "title": result.get("title", ""),
|
|
|
|
+ "distance": distance,
|
|
|
|
+ }
|
|
|
|
+ records.append(record)
|
|
|
|
+
|
|
|
|
+ # 组装返回数据
|
|
|
|
+ response_data = {
|
|
|
|
+ "records": records
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return StandardResponse(success=True, data=response_data)
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Mr search failed: {str(e)}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
+
|
|
|
|
+@router.post("/kgrt_api/text/mr_match", response_model=StandardResponse)
|
|
|
|
+@router.post("/knowledge/text/mr_match", response_model=StandardResponse)
|
|
|
|
+async def compare_text(request: TextCompareMultiRequest):
|
|
|
|
+ start_time = time.time()
|
|
|
|
+ try:
|
|
|
|
+ # 拆分两段文本
|
|
|
|
+ origin_sentences = SentenceUtil.split_text(request.origin)
|
|
|
|
+ similar_sentences = SentenceUtil.split_text(request.similar)
|
|
|
|
+ end_time = time.time()
|
|
|
|
+ logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
|
+
|
|
|
|
+ # 初始化结果列表
|
|
|
|
+ origin_results = []
|
|
|
|
+
|
|
|
|
+ # 过滤短句并预计算向量
|
|
|
|
+ valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
|
|
|
|
+ valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
|
|
|
|
+
|
|
|
|
+ # 初始化similar_results,所有matched设为False
|
|
|
|
+ similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
|
|
|
|
+
|
|
|
|
+ # 批量获取向量
|
|
|
|
+ origin_vectors = {}
|
|
|
|
+ similar_vectors = {}
|
|
|
|
+ origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
|
|
|
|
+ similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
|
|
|
|
+
|
|
|
|
+ if origin_batch:
|
|
|
|
+ origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
|
|
|
|
+ origin_vectors = dict(zip(origin_batch, origin_embeddings))
|
|
|
|
+
|
|
|
|
+ if similar_batch:
|
|
|
|
+ similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
|
|
|
|
+ similar_vectors = dict(zip(similar_batch, similar_embeddings))
|
|
|
|
+
|
|
|
|
+ end_time = time.time()
|
|
|
|
+ logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
|
+ # 处理origin文本
|
|
|
|
+ for origin_sent, is_valid in valid_origin_sentences:
|
|
|
|
+ if not is_valid:
|
|
|
|
+ origin_results.append({"sentence": origin_sent, "matched": False})
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ origin_vector = origin_vectors[origin_sent]
|
|
|
|
+ matched = False
|
|
|
|
+
|
|
|
|
+ # 优化的相似度计算
|
|
|
|
+ for i, similar_result in enumerate(similar_results):
|
|
|
|
+ if similar_result["matched"]:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ similar_sent = similar_result["sentence"]
|
|
|
|
+ if len(similar_sent) < 10:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ similar_vector = similar_vectors.get(similar_sent)
|
|
|
|
+ if not similar_vector:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
|
|
|
|
+ if distance < DISTANCE_THRESHOLD:
|
|
|
|
+ matched = True
|
|
|
|
+ similar_results[i]["matched"] = True
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ origin_results.append({"sentence": origin_sent, "matched": matched})
|
|
|
|
+
|
|
|
|
+ response_data = {
|
|
|
|
+ "origin": origin_results,
|
|
|
|
+ "similar": similar_results
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ end_time = time.time()
|
|
|
|
+ logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
|
+ return StandardResponse(success=True, data=response_data)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ end_time = time.time()
|
|
|
|
+ logger.error(f"Text comparison failed: {str(e)}")
|
|
|
|
+ logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
+
|
|
|
|
+def _check_cache(node_id: int) -> Optional[dict]:
|
|
|
|
+ """检查并返回缓存结果"""
|
|
|
|
+ cache_key = f"xunzheng_{node_id}"
|
|
|
|
+ cached_result = cache.get(cache_key)
|
|
|
|
+ if cached_result:
|
|
|
|
+ logger.info(f"从缓存获取结果,node_id: {node_id}")
|
|
|
|
+ return cached_result
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+def _get_node_info(node_service: KGNodeService, node_id: int) -> dict:
|
|
|
|
+ """获取并验证节点信息"""
|
|
|
|
+ node = node_service.get_node(node_id)
|
|
|
|
+ if not node:
|
|
|
|
+ raise ValueError(f"节点不存在: {node_id}")
|
|
|
|
+ return {
|
|
|
|
+ "id": node_id,
|
|
|
|
+ "name": node.get('name', ''),
|
|
|
|
+ "category": node.get('category', ''),
|
|
|
|
+ "props": [],
|
|
|
|
+ "files": [],
|
|
|
|
+ "distance": 0
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+def _process_search_result(search_result: dict, reference_index: int) -> tuple[dict, str]:
|
|
|
|
+ """处理搜索结果,返回引用信息和文件名"""
|
|
|
|
+ file_name = ""
|
|
|
|
+ referrence = search_result.get("referrence", "")
|
|
|
|
+ if referrence and "/books/" in referrence:
|
|
|
|
+ file_name = referrence.split("/books/")[-1]
|
|
|
|
+ file_name = os.path.splitext(file_name)[0]
|
|
|
|
+
|
|
|
|
+ reference = {
|
|
|
|
+ "index": str(reference_index),
|
|
|
|
+ "id": search_result["id"],
|
|
|
|
+ "content": search_result["content"],
|
|
|
|
+ "file_path": search_result.get("file_path", ""),
|
|
|
|
+ "title": search_result.get("title", ""),
|
|
|
|
+ "distance": search_result.get("distance", DISTANCE_THRESHOLD),
|
|
|
|
+ "page_no": search_result.get("page_no", ""),
|
|
|
|
+ "file_name": file_name,
|
|
|
|
+ "referrence": referrence
|
|
|
|
+ }
|
|
|
|
+ return reference, file_name
|
|
|
|
+
|
|
|
|
+def _get_file_type(file_name: str) -> str:
|
|
|
|
+ """根据文件名确定文件类型"""
|
|
|
|
+ file_name_lower = file_name.lower()
|
|
|
|
+ if file_name_lower.endswith(".pdf"):
|
|
|
|
+ return "pdf"
|
|
|
|
+ elif file_name_lower.endswith((".doc", ".docx")):
|
|
|
|
+ return "doc"
|
|
|
|
+ elif file_name_lower.endswith((".xls", ".xlsx")):
|
|
|
|
+ return "excel"
|
|
|
|
+ elif file_name_lower.endswith((".ppt", ".pptx")):
|
|
|
|
+ return "ppt"
|
|
|
|
+ return "other"
|
|
|
|
+
|
|
|
|
+def _process_sentence_search(node_name: str, prop_title: str, sentences: list, trunks_service: TrunksService) -> tuple[list, list]:
|
|
|
|
+ keywords = [node_name, prop_title] if node_name and prop_title else None
|
|
|
|
+ return _process_sentence_search_keywords(sentences, trunks_service,keywords=keywords)
|
|
|
|
+
|
|
|
|
+def _process_sentence_search_keywords(sentences: list, trunks_service: TrunksService,keywords: Optional[List[str]] = None) -> tuple[list, list]:
|
|
|
|
+ """处理句子搜索,返回结果句子和引用列表"""
|
|
|
|
+ result_sentences = []
|
|
|
|
+ all_references = []
|
|
|
|
+ reference_index = 1
|
|
|
|
+ i = 0
|
|
|
|
+
|
|
|
|
+ while i < len(sentences):
|
|
|
|
+ sentence = sentences[i]
|
|
|
|
+ search_text = sentence
|
|
|
|
+ if keywords:
|
|
|
|
+ search_text = f"{keywords}:{sentence}"
|
|
|
|
+ # if len(sentence) < 10 and i + 1 < len(sentences):
|
|
|
|
+ # next_sentence = sentences[i + 1]
|
|
|
|
+ # # result_sentences.append({"sentence": sentence, "flag": ""})
|
|
|
|
+ # search_text = f"{node_name}:{prop_title}:{sentence} {next_sentence}"
|
|
|
|
+ # i += 1
|
|
|
|
+ # elif len(sentence) < 10:
|
|
|
|
+ # result_sentences.append({"sentence": sentence, "flag": ""})
|
|
|
|
+ # i += 1
|
|
|
|
+ # continue
|
|
|
|
+ # else:
|
|
|
|
+ i += 1
|
|
|
|
+
|
|
|
|
+ # 使用向量搜索获取相似内容
|
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
|
+ text=search_text,
|
|
|
|
+ limit=500,
|
|
|
|
+ type='trunk',
|
|
|
|
+ distance=0.7
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 准备语料库数据
|
|
|
|
+ trunk_texts = []
|
|
|
|
+ trunk_ids = []
|
|
|
|
+ # 创建一个字典来存储trunk的详细信息
|
|
|
|
+ trunk_details = {}
|
|
|
|
+
|
|
|
|
+ for trunk in search_results:
|
|
|
|
+ trunk_texts.append(trunk.get('content'))
|
|
|
|
+ trunk_ids.append(trunk.get('id'))
|
|
|
|
+ # 缓存trunk的详细信息
|
|
|
|
+ trunk_details[trunk.get('id')] = {
|
|
|
|
+ 'id': trunk.get('id'),
|
|
|
|
+ 'content': trunk.get('content'),
|
|
|
|
+ 'file_path': trunk.get('file_path'),
|
|
|
|
+ 'title': trunk.get('title'),
|
|
|
|
+ 'referrence': trunk.get('referrence'),
|
|
|
|
+ 'page_no': trunk.get('page_no')
|
|
|
|
+ }
|
|
|
|
+ if len(trunk_texts) == 0:
|
|
|
|
+ continue
|
|
|
|
+ # 初始化TextSimilarityFinder并加载语料库
|
|
|
|
+ similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
|
|
|
|
+ similarity_finder.load_corpus(trunk_texts, trunk_ids)
|
|
|
|
+
|
|
|
|
+ # 使用TextSimilarityFinder进行相似度匹配
|
|
|
|
+ similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
|
|
|
|
+
|
|
|
|
+ if not similar_results: # 设置相似度阈值
|
|
|
|
+ result_sentences.append({"sentence": sentence, "flag": ""})
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 获取最相似的文本对应的trunk_id
|
|
|
|
+ trunk_id = similar_results[0]['path']
|
|
|
|
+
|
|
|
|
+ # 从缓存中获取trunk详细信息
|
|
|
|
+ trunk_info = trunk_details.get(trunk_id)
|
|
|
|
+
|
|
|
|
+ if trunk_info:
|
|
|
|
+ search_result = {
|
|
|
|
+ **trunk_info,
|
|
|
|
+ 'distance': similar_results[0]['similarity'] # 转换相似度为距离
|
|
|
|
+ }
|
|
|
|
+ # 检查相似度是否达到阈值
|
|
|
|
+ if search_result['distance'] >= DISTANCE_THRESHOLD:
|
|
|
|
+ result_sentences.append({"sentence": sentence, "flag": ""})
|
|
|
|
+ continue
|
|
|
|
+ # 检查是否已存在相同引用
|
|
|
|
+ existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
|
|
|
|
+ current_index = int(existing_ref["index"]) if existing_ref else reference_index
|
|
|
|
+
|
|
|
|
+ if not existing_ref:
|
|
|
|
+ reference, _ = _process_search_result(search_result, reference_index)
|
|
|
|
+ all_references.append(reference)
|
|
|
|
+ reference_index += 1
|
|
|
|
+
|
|
|
|
+ result_sentences.append({"sentence": sentence, "flag": str(current_index)})
|
|
|
|
+
|
|
|
|
+ return result_sentences, all_references
|
|
|
|
+
|
|
|
|
+def _mark_symptoms(text: str, symptom_list: List[str]) -> str:
|
|
|
|
+ """处理症状标记"""
|
|
|
|
+ if not symptom_list:
|
|
|
|
+ return text
|
|
|
|
+
|
|
|
|
+ marked_sentence = text
|
|
|
|
+ # 创建一个标记位置的列表,记录每个位置是否已被标记
|
|
|
|
+ marked_positions = [False] * len(marked_sentence)
|
|
|
|
+
|
|
|
|
+ # 创建一个列表来存储已处理的症状
|
|
|
|
+ processed_symptoms = []
|
|
|
|
+
|
|
|
|
+ for symptom in symptom_list:
|
|
|
|
+ # 检查是否已处理过该症状或其子集
|
|
|
|
+ if any(symptom in processed_sym or processed_sym in symptom for processed_sym in processed_symptoms):
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 查找所有匹配位置
|
|
|
|
+ start_pos = 0
|
|
|
|
+ while True:
|
|
|
|
+ pos = marked_sentence.find(symptom, start_pos)
|
|
|
|
+ if pos == -1:
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ # 检查这个位置是否已被标记
|
|
|
|
+ if not any(marked_positions[pos:pos + len(symptom)]):
|
|
|
|
+ # 标记这个范围的所有位置
|
|
|
|
+ for i in range(pos, pos + len(symptom)):
|
|
|
|
+ marked_positions[i] = True
|
|
|
|
+ # 替换文本
|
|
|
|
+ marked_sentence = marked_sentence[:pos] + f'<i style="color:red;">{symptom}</i>' + marked_sentence[pos + len(symptom):]
|
|
|
|
+ # 将成功标记的症状添加到已处理列表中
|
|
|
|
+ if symptom not in processed_symptoms:
|
|
|
|
+ processed_symptoms.append(symptom)
|
|
|
|
+ # 更新标记位置数组以适应新插入的标签
|
|
|
|
+ new_positions = [False] * (len('<i style="color:red;">') + len('</i>'))
|
|
|
|
+ marked_positions = marked_positions[:pos] + new_positions + marked_positions[pos:]
|
|
|
|
+
|
|
|
|
+ start_pos = pos + len('<i style="color:red;">') + len(symptom) + len('</i>')
|
|
|
|
+
|
|
|
|
+ return marked_sentence
|
|
|
|
+
|
|
|
|
+@router.post("/kgrt_api/text/eb_search", response_model=StandardResponse)
|
|
|
|
+@router.post("/knowledge/text/eb_search", response_model=StandardResponse)
|
|
|
|
+async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
|
|
|
|
+ try:
|
|
|
|
+ start_time = time.time()
|
|
|
|
+
|
|
|
|
+ # 检查缓存
|
|
|
|
+ cached_result = _check_cache(request.node_id)
|
|
|
|
+ if cached_result:
|
|
|
|
+ # 如果有症状列表,处理症状标记
|
|
|
|
+ if request.symptoms:
|
|
|
|
+ symptom_list = []
|
|
|
|
+ try:
|
|
|
|
+ # 初始化服务
|
|
|
|
+ node_service = KGNodeService(db)
|
|
|
|
+ edge_service = KGEdgeService(db)
|
|
|
|
+
|
|
|
|
+ for symptom in request.symptoms:
|
|
|
|
+ # 添加原始症状
|
|
|
|
+ symptom_list.append(symptom)
|
|
|
|
+ try:
|
|
|
|
+ # 获取症状节点
|
|
|
|
+ symptom_node = node_service.get_node_by_name_category(symptom, '症状')
|
|
|
|
+ # 获取症状相关同义词(包括1.0和2.0版本)
|
|
|
|
+ for category in ['症状同义词', '症状同义词2.0']:
|
|
|
|
+ edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], category=category)
|
|
|
|
+ if edges:
|
|
|
|
+ # 添加同义词
|
|
|
|
+ for edge in edges:
|
|
|
|
+ if edge['dest_node'] and edge['dest_node'].get('name'):
|
|
|
|
+ symptom_list.append(edge['dest_node']['name'])
|
|
|
|
+ except ValueError:
|
|
|
|
+ # 如果找不到节点,只添加原始症状
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 按照字符长度进行倒序排序
|
|
|
|
+ symptom_list.sort(key=len, reverse=True)
|
|
|
|
+
|
|
|
|
+ # 处理缓存结果中的症状标记
|
|
|
|
+ for prop in cached_result.get('props', []):
|
|
|
|
+ if prop.get('prop_title') == '临床表现' and 'answer' in prop:
|
|
|
|
+ for answer in prop['answer']:
|
|
|
|
+ answer['sentence'] = _mark_symptoms(answer['sentence'], symptom_list)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"处理症状标记失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+ return StandardResponse(success=True, data=cached_result)
|
|
|
|
+
|
|
|
|
+ # 初始化服务
|
|
|
|
+ trunks_service = TrunksService()
|
|
|
|
+ node_service = KGNodeService(db)
|
|
|
|
+ prop_service = KGPropService(db)
|
|
|
|
+ edge_service = KGEdgeService(db)
|
|
|
|
+
|
|
|
|
+ # 获取节点信息
|
|
|
|
+ result = _get_node_info(node_service, request.node_id)
|
|
|
|
+ node_name = result["name"]
|
|
|
|
+
|
|
|
|
+ # 处理症状列表
|
|
|
|
+ symptom_list = []
|
|
|
|
+ if request.symptoms:
|
|
|
|
+ for symptom in request.symptoms:
|
|
|
|
+ try:
|
|
|
|
+ # 添加原始症状
|
|
|
|
+ symptom_list.append(symptom)
|
|
|
|
+ # 获取症状节点
|
|
|
|
+ symptom_node = node_service.get_node_by_name_category(symptom, '症状')
|
|
|
|
+ # 获取症状相关同义词(包括1.0和2.0版本)
|
|
|
|
+ for category in ['症状同义词', '症状同义词2.0']:
|
|
|
|
+ edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], category=category)
|
|
|
|
+ if edges:
|
|
|
|
+ # 添加同义词
|
|
|
|
+ for edge in edges:
|
|
|
|
+ if edge['dest_node'] and edge['dest_node'].get('name'):
|
|
|
|
+ symptom_list.append(edge['dest_node']['name'])
|
|
|
|
+ except ValueError:
|
|
|
|
+ # 如果找不到节点,只添加原始症状
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 按照字符长度进行倒序排序
|
|
|
|
+ symptom_list.sort(key=len, reverse=True)
|
|
|
|
+
|
|
|
|
+ # 遍历props_ids查询属性信息
|
|
|
|
+ for prop_id in request.props_ids:
|
|
|
|
+ prop = prop_service.get_prop_by_id(prop_id)
|
|
|
|
+ if not prop:
|
|
|
|
+ logger.warning(f"属性不存在: {prop_id}")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ prop_title = prop.get('prop_title', '')
|
|
|
|
+ prop_value = prop.get('prop_value', '')
|
|
|
|
+
|
|
|
|
+ # 创建属性结果对象
|
|
|
|
+ prop_result = {
|
|
|
|
+ "id": prop_id,
|
|
|
|
+ "category": prop.get('category', 0),
|
|
|
|
+ "prop_name": prop.get('prop_name', ''),
|
|
|
|
+ "prop_value": prop_value,
|
|
|
|
+ "prop_title": prop_title,
|
|
|
|
+ "type": prop.get('type', 1)
|
|
|
|
+ }
|
|
|
|
+ result["props"].append(prop_result)
|
|
|
|
+
|
|
|
|
+ # 如果prop_value为'无',则跳过搜索
|
|
|
|
+ if prop_value == '无':
|
|
|
|
+ prop_result["answer"] = [{
|
|
|
|
+ "sentence": prop_value,
|
|
|
|
+ "flag": ""
|
|
|
|
+ }]
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 先用完整的prop_value进行搜索
|
|
|
|
+ search_text = f"{node_name}:{prop_title}:{prop_value}"
|
|
|
|
+ # 使用向量搜索获取相似内容
|
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
|
+ text=search_text,
|
|
|
|
+ limit=500,
|
|
|
|
+ type='trunk',
|
|
|
|
+ distance=0.7
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 准备语料库数据
|
|
|
|
+ trunk_texts = []
|
|
|
|
+ trunk_ids = []
|
|
|
|
+
|
|
|
|
+ # 创建一个字典来存储trunk的详细信息
|
|
|
|
+ trunk_details = {}
|
|
|
|
+
|
|
|
|
+ for trunk in search_results:
|
|
|
|
+ trunk_texts.append(trunk.get('content'))
|
|
|
|
+ trunk_ids.append(trunk.get('id'))
|
|
|
|
+ # 缓存trunk的详细信息
|
|
|
|
+ trunk_details[trunk.get('id')] = {
|
|
|
|
+ 'id': trunk.get('id'),
|
|
|
|
+ 'content': trunk.get('content'),
|
|
|
|
+ 'file_path': trunk.get('file_path'),
|
|
|
|
+ 'title': trunk.get('title'),
|
|
|
|
+ 'referrence': trunk.get('referrence'),
|
|
|
|
+ 'page_no': trunk.get('page_no')
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(trunk_texts)==0:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 初始化TextSimilarityFinder并加载语料库
|
|
|
|
+ similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
|
|
|
|
+ similarity_finder.load_corpus(trunk_texts, trunk_ids)
|
|
|
|
+
|
|
|
|
+ similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
|
|
|
|
+
|
|
|
|
+ # 处理搜索结果
|
|
|
|
+ if similar_results and similar_results[0]['similarity']>=0.3: # 设置相似度阈值
|
|
|
|
+ # 获取最相似的文本对应的trunk_id
|
|
|
|
+ trunk_id = similar_results[0]['path']
|
|
|
|
+
|
|
|
|
+ # 从缓存中获取trunk详细信息
|
|
|
|
+ trunk_info = trunk_details.get(trunk_id)
|
|
|
|
+
|
|
|
|
+ if trunk_info:
|
|
|
|
+ search_result = {
|
|
|
|
+ **trunk_info,
|
|
|
|
+ 'distance': similar_results[0]['similarity'] # 转换相似度为距离
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ reference, _ = _process_search_result(search_result, 1)
|
|
|
|
+ prop_result["references"] = [reference]
|
|
|
|
+ prop_result["answer"] = [{
|
|
|
|
+ "sentence": prop_value,
|
|
|
|
+ "flag": "1"
|
|
|
|
+ }]
|
|
|
|
+ else:
|
|
|
|
+ # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
|
|
|
|
+ sentences = SentenceUtil.split_text(prop_value,10)
|
|
|
|
+ else:
|
|
|
|
+ # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
|
|
|
|
+ sentences = SentenceUtil.split_text(prop_value,10)
|
|
|
|
+ result_sentences, references = _process_sentence_search(
|
|
|
|
+ node_name, prop_title, sentences, trunks_service
|
|
|
|
+ )
|
|
|
|
+ if references:
|
|
|
|
+ prop_result["references"] = references
|
|
|
|
+ if result_sentences:
|
|
|
|
+ prop_result["answer"] = result_sentences
|
|
|
|
+
|
|
|
|
+ # 处理文件信息
|
|
|
|
+ all_files = set()
|
|
|
|
+ file_index_map = {}
|
|
|
|
+ file_index = 1
|
|
|
|
+
|
|
|
|
+ # 收集文件信息
|
|
|
|
+ for prop_result in result["props"]:
|
|
|
|
+ if "references" not in prop_result:
|
|
|
|
+ continue
|
|
|
|
+ for ref in prop_result["references"]:
|
|
|
|
+ referrence = ref.get("referrence", "")
|
|
|
|
+ if not (referrence and "/books/" in referrence):
|
|
|
|
+ continue
|
|
|
|
+ file_name = referrence.split("/books/")[-1]
|
|
|
|
+ if not file_name:
|
|
|
|
+ continue
|
|
|
|
+ file_type = _get_file_type(file_name)
|
|
|
|
+ if file_name not in file_index_map:
|
|
|
|
+ file_index_map[file_name] = file_index
|
|
|
|
+ file_index += 1
|
|
|
|
+ all_files.add((file_name, file_type))
|
|
|
|
+
|
|
|
|
+ # 更新引用索引
|
|
|
|
+ for prop_result in result["props"]:
|
|
|
|
+ if "references" not in prop_result:
|
|
|
|
+ continue
|
|
|
|
+ for ref in prop_result["references"]:
|
|
|
|
+ referrence = ref.get("referrence", "")
|
|
|
|
+ if referrence and "/books/" in referrence:
|
|
|
|
+ file_name = referrence.split("/books/")[-1]
|
|
|
|
+ if file_name in file_index_map:
|
|
|
|
+ ref["index"] = f"{file_index_map[file_name]}-{ref['index']}"
|
|
|
|
+
|
|
|
|
+ # 更新answer中的index
|
|
|
|
+ if "answer" in prop_result:
|
|
|
|
+ for sentence in prop_result["answer"]:
|
|
|
|
+ if sentence["index"]:
|
|
|
|
+ for ref in prop_result["references"]:
|
|
|
|
+ if ref["index"].endswith(f"-{sentence['index']}"):
|
|
|
|
+ sentence["flag"] = ref["index"]
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ # 添加文件信息到结果
|
|
|
|
+ result["files"] = sorted([{
|
|
|
|
+ "file_name": file_name,
|
|
|
|
+ "file_type": file_type,
|
|
|
|
+ "index": str(file_index_map[file_name])
|
|
|
|
+ } for file_name, file_type in all_files], key=lambda x: int(x["index"]))
|
|
|
|
+
|
|
|
|
+ # 缓存结果
|
|
|
|
+ cache_key = f"xunzheng_{request.node_id}"
|
|
|
|
+ cache[cache_key] = result
|
|
|
|
+
|
|
|
|
+ # 处理症状标记
|
|
|
|
+ if request.symptoms:
|
|
|
|
+ for prop in result.get('props', []):
|
|
|
|
+ if prop.get('prop_title') == '临床表现' and 'answer' in prop:
|
|
|
|
+ for answer in prop['answer']:
|
|
|
|
+ answer['sentence'] = _mark_symptoms(answer['sentence'], symptom_list)
|
|
|
|
+
|
|
|
|
+ end_time = time.time()
|
|
|
|
+ logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
|
|
|
|
+
|
|
|
|
+ return StandardResponse(success=True, data=result)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Node props search failed: {str(e)}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
+
|
|
|
|
+class FindSimilarTexts(BaseModel):
|
|
|
|
+ keywords:Optional[List[str]] = None
|
|
|
|
+ search_text: str
|
|
|
|
+
|
|
|
|
+@router.post("/knowledge/text/find_similar_texts", response_model=StandardResponse)
|
|
|
|
+async def find_similar_texts(request: FindSimilarTexts, db: Session = Depends(get_db)):
|
|
|
|
+ trunks_service = TrunksService()
|
|
|
|
+ search_text = request.search_text
|
|
|
|
+ if request.keywords:
|
|
|
|
+ search_text = f"{request.keywords}:{search_text}"
|
|
|
|
+ # 使用向量搜索获取相似内容
|
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
|
+ text=search_text,
|
|
|
|
+ limit=500,
|
|
|
|
+ type='trunk',
|
|
|
|
+ distance=0.7
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 准备语料库数据
|
|
|
|
+ trunk_texts = []
|
|
|
|
+ trunk_ids = []
|
|
|
|
+
|
|
|
|
+ # 创建一个字典来存储trunk的详细信息
|
|
|
|
+ trunk_details = {}
|
|
|
|
+
|
|
|
|
+ for trunk in search_results:
|
|
|
|
+ trunk_texts.append(trunk.get('content'))
|
|
|
|
+ trunk_ids.append(trunk.get('id'))
|
|
|
|
+ # 缓存trunk的详细信息
|
|
|
|
+ trunk_details[trunk.get('id')] = {
|
|
|
|
+ 'id': trunk.get('id'),
|
|
|
|
+ 'content': trunk.get('content'),
|
|
|
|
+ 'file_path': trunk.get('file_path'),
|
|
|
|
+ 'title': trunk.get('title'),
|
|
|
|
+ 'referrence': trunk.get('referrence'),
|
|
|
|
+ 'page_no': trunk.get('page_no')
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(trunk_texts) == 0:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ # 初始化TextSimilarityFinder并加载语料库
|
|
|
|
+ similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
|
|
|
|
+ similarity_finder.load_corpus(trunk_texts, trunk_ids)
|
|
|
|
+
|
|
|
|
+ similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
|
|
|
|
+ prop_result = {}
|
|
|
|
+ # 处理搜索结果
|
|
|
|
+ if similar_results and similar_results[0]['similarity'] >= 0.3: # 设置相似度阈值
|
|
|
|
+ # 获取最相似的文本对应的trunk_id
|
|
|
|
+ trunk_id = similar_results[0]['path']
|
|
|
|
+
|
|
|
|
+ # 从缓存中获取trunk详细信息
|
|
|
|
+ trunk_info = trunk_details.get(trunk_id)
|
|
|
|
+
|
|
|
|
+ if trunk_info:
|
|
|
|
+ search_result = {
|
|
|
|
+ **trunk_info,
|
|
|
|
+ 'distance': similar_results[0]['similarity'] # 转换相似度为距离
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ reference, _ = _process_search_result(search_result, 1)
|
|
|
|
+ prop_result["references"] = [reference]
|
|
|
|
+ prop_result["answer"] = [{
|
|
|
|
+ "sentence": request.search_text,
|
|
|
|
+ "flag": "1"
|
|
|
|
+ }]
|
|
|
|
+ else:
|
|
|
|
+ # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
|
|
|
|
+ sentences = SentenceUtil.split_text(request.search_text, 10)
|
|
|
|
+ result_sentences, references = _process_sentence_search_keywords(
|
|
|
|
+ sentences, trunks_service,keywords=request.keywords
|
|
|
|
+ )
|
|
|
|
+ if references:
|
|
|
|
+ prop_result["references"] = references
|
|
|
|
+ if result_sentences:
|
|
|
|
+ prop_result["answer"] = result_sentences
|
|
|
|
+ return StandardResponse(success=True,data=prop_result)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+text_search_router = router
|