from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel, Field, validator from typing import List, Optional from service.trunks_service import TrunksService from utils.sentence_util import SentenceUtil from utils.vector_distance import VectorDistance from model.response import StandardResponse from utils.vectorizer import Vectorizer # from utils.find_text_in_pdf import find_text_in_pdf import os DISTANCE_THRESHOLD = 0.73 import logging import time from db.session import get_db from sqlalchemy.orm import Session from service.kg_node_service import KGNodeService from service.kg_prop_service import KGPropService from cachetools import TTLCache logger = logging.getLogger(__name__) router = APIRouter(prefix="/text", tags=["Text Search"]) # 创建全局缓存实例 cache = TTLCache(maxsize=1000, ttl=3600) class TextSearchRequest(BaseModel): text: str conversation_id: Optional[str] = None need_convert: Optional[bool] = False class TextCompareRequest(BaseModel): sentence: str text: str class TextMatchRequest(BaseModel): text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容") @validator('text') def validate_text(cls, v): # 保留所有可打印字符、换行符和中文字符 v = ''.join(char for char in v if char.isprintable() or char in '\n\r') # 转义JSON特殊字符 # 先处理反斜杠,避免后续转义时出现问题 v = v.replace('\\', '\\\\') # 处理引号和其他特殊字符 v = v.replace('"', '\\"') v = v.replace('/', '\\/') # 处理控制字符 v = v.replace('\n', '\\n') v = v.replace('\r', '\\r') v = v.replace('\t', '\\t') v = v.replace('\b', '\\b') v = v.replace('\f', '\\f') # 处理Unicode转义 # v = v.replace('\u', '\\u') return v class TextCompareMultiRequest(BaseModel): origin: str similar: str class NodePropsSearchRequest(BaseModel): node_id: int props_ids: List[int] @router.post("/clear_cache", response_model=StandardResponse) async def clear_cache(): try: # 清除全局缓存 cache.clear() return StandardResponse(success=True, data={"message": "缓存已清除"}) except Exception as e: logger.error(f"清除缓存失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/search", response_model=StandardResponse) async def search_text(request: TextSearchRequest): try: #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text if request.text.startswith('{') and request.text.endswith('}'): from utils.json_to_text import JsonToTextConverter converter = JsonToTextConverter() request.text = converter.convert(request.text) # 使用TextSplitter拆分文本 sentences = SentenceUtil.split_text(request.text) if not sentences: return StandardResponse(success=True, data={"answer": "", "references": []}) # 初始化服务和结果列表 trunks_service = TrunksService() result_sentences = [] all_references = [] reference_index = 1 # 根据conversation_id获取缓存结果 cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else [] for sentence in sentences: # if request.need_convert: sentence = sentence.replace("\n", "
") if len(sentence) < 10: result_sentences.append(sentence) continue if cached_results: # 如果有缓存结果,计算向量距离 min_distance = float('inf') best_result = None sentence_vector = Vectorizer.get_embedding(sentence) for cached_result in cached_results: content_vector = cached_result['embedding'] distance = VectorDistance.calculate_distance(sentence_vector, content_vector) if distance < min_distance: min_distance = distance best_result = {**cached_result, 'distance': distance} if best_result and best_result['distance'] < DISTANCE_THRESHOLD: search_results = [best_result] else: search_results = [] else: # 如果没有缓存结果,进行向量搜索 search_results = trunks_service.search_by_vector( text=sentence, limit=1, type='trunk' ) # 处理搜索结果 for search_result in search_results: distance = search_result.get("distance", DISTANCE_THRESHOLD) if distance >= DISTANCE_THRESHOLD: result_sentences.append(sentence) continue # 检查是否已存在相同引用 existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None) current_index = reference_index if existing_ref: current_index = int(existing_ref["index"]) else: # 添加到引用列表 # 从referrence中提取文件名 file_name = "" referrence = search_result.get("referrence", "") if referrence and "/books/" in referrence: file_name = referrence.split("/books/")[-1] # 去除文件扩展名 file_name = os.path.splitext(file_name)[0] reference = { "index": str(reference_index), "id": search_result["id"], "content": search_result["content"], "file_path": search_result.get("file_path", ""), "title": search_result.get("title", ""), "distance": distance, "file_name": file_name, "referrence": referrence } all_references.append(reference) reference_index += 1 # 添加引用标记 if sentence.endswith('
'): # 如果有多个
,在所有
前添加^[current_index]^ result_sentence = sentence.replace('
', f'^[{current_index}]^
') else: # 直接在句子末尾添加^[current_index]^ result_sentence = f'{sentence}^[{current_index}]^' result_sentences.append(result_sentence) # 组装返回数据 response_data = { "answer": result_sentences, "references": all_references } return StandardResponse(success=True, data=response_data) except Exception as e: logger.error(f"Text search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/match", response_model=StandardResponse) async def match_text(request: TextCompareRequest): try: sentences = SentenceUtil.split_text(request.text) sentence_vector = Vectorizer.get_embedding(request.sentence) min_distance = float('inf') best_sentence = "" result_sentences = [] for temp in sentences: result_sentences.append(temp) if len(temp) < 10: continue temp_vector = Vectorizer.get_embedding(temp) distance = VectorDistance.calculate_distance(sentence_vector, temp_vector) if distance < min_distance and distance < DISTANCE_THRESHOLD: min_distance = distance best_sentence = temp for i in range(len(result_sentences)): result_sentences[i] = {"sentence": result_sentences[i], "matched": False} if result_sentences[i]["sentence"] == best_sentence: result_sentences[i]["matched"] = True return StandardResponse(success=True, records=result_sentences) except Exception as e: logger.error(f"Text comparison failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/mr_search", response_model=StandardResponse) async def mr_search_text_content(request: TextMatchRequest): try: # 初始化服务 trunks_service = TrunksService() # 获取文本向量并搜索相似内容 search_results = trunks_service.search_by_vector( text=request.text, limit=10, type="mr" ) # 处理搜索结果 records = [] for result in search_results: distance = result.get("distance", DISTANCE_THRESHOLD) if distance >= DISTANCE_THRESHOLD: continue # 添加到引用列表 record = { "content": result["content"], "file_path": result.get("file_path", ""), "title": result.get("title", ""), "distance": distance, } records.append(record) # 组装返回数据 response_data = { "records": records } return StandardResponse(success=True, data=response_data) except Exception as e: logger.error(f"Mr search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/mr_match", response_model=StandardResponse) async def compare_text(request: TextCompareMultiRequest): start_time = time.time() try: # 拆分两段文本 origin_sentences = SentenceUtil.split_text(request.origin) similar_sentences = SentenceUtil.split_text(request.similar) end_time = time.time() logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms") # 初始化结果列表 origin_results = [] # 过滤短句并预计算向量 valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences] valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences] # 初始化similar_results,所有matched设为False similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences] # 批量获取向量 origin_vectors = {} similar_vectors = {} origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid] similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid] if origin_batch: origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch] origin_vectors = dict(zip(origin_batch, origin_embeddings)) if similar_batch: similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch] similar_vectors = dict(zip(similar_batch, similar_embeddings)) end_time = time.time() logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms") # 处理origin文本 for origin_sent, is_valid in valid_origin_sentences: if not is_valid: origin_results.append({"sentence": origin_sent, "matched": False}) continue origin_vector = origin_vectors[origin_sent] matched = False # 优化的相似度计算 for i, similar_result in enumerate(similar_results): if similar_result["matched"]: continue similar_sent = similar_result["sentence"] if len(similar_sent) < 10: continue similar_vector = similar_vectors.get(similar_sent) if not similar_vector: continue distance = VectorDistance.calculate_distance(origin_vector, similar_vector) if distance < DISTANCE_THRESHOLD: matched = True similar_results[i]["matched"] = True break origin_results.append({"sentence": origin_sent, "matched": matched}) response_data = { "origin": origin_results, "similar": similar_results } end_time = time.time() logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms") return StandardResponse(success=True, data=response_data) except Exception as e: end_time = time.time() logger.error(f"Text comparison failed: {str(e)}") logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms") raise HTTPException(status_code=500, detail=str(e)) def _check_cache(node_id: int) -> Optional[dict]: """检查并返回缓存结果""" cache_key = f"xunzheng_{node_id}" cached_result = cache.get(cache_key) if cached_result: logger.info(f"从缓存获取结果,node_id: {node_id}") return cached_result return None def _get_node_info(node_service: KGNodeService, node_id: int) -> dict: """获取并验证节点信息""" node = node_service.get_node(node_id) if not node: raise ValueError(f"节点不存在: {node_id}") return { "id": node_id, "name": node.get('name', ''), "category": node.get('category', ''), "props": [], "files": [], "distance": 0 } def _process_search_result(search_result: dict, reference_index: int) -> tuple[dict, str]: """处理搜索结果,返回引用信息和文件名""" file_name = "" referrence = search_result.get("referrence", "") if referrence and "/books/" in referrence: file_name = referrence.split("/books/")[-1] file_name = os.path.splitext(file_name)[0] reference = { "index": str(reference_index), "id": search_result["id"], "content": search_result["content"], "file_path": search_result.get("file_path", ""), "title": search_result.get("title", ""), "distance": search_result.get("distance", DISTANCE_THRESHOLD), "page_no": search_result.get("page_no", ""), "file_name": file_name, "referrence": referrence } return reference, file_name def _get_file_type(file_name: str) -> str: """根据文件名确定文件类型""" file_name_lower = file_name.lower() if file_name_lower.endswith(".pdf"): return "pdf" elif file_name_lower.endswith((".doc", ".docx")): return "doc" elif file_name_lower.endswith((".xls", ".xlsx")): return "excel" elif file_name_lower.endswith((".ppt", ".pptx")): return "ppt" return "other" def _process_sentence_search(node_name: str, prop_title: str, sentences: list, trunks_service: TrunksService) -> tuple[list, list]: """处理句子搜索,返回结果句子和引用列表""" result_sentences = [] all_references = [] reference_index = 1 i = 0 while i < len(sentences): sentence = sentences[i] if len(sentence) < 10 and i + 1 < len(sentences): next_sentence = sentences[i + 1] result_sentences.append({"sentence": sentence, "flag": ""}) search_text = f"{node_name}:{prop_title}:{sentence} {next_sentence}" i += 1 elif len(sentence) < 10: result_sentences.append({"sentence": sentence, "flag": ""}) i += 1 continue else: search_text = f"{node_name}:{prop_title}:{sentence}" i += 1 search_results = trunks_service.search_by_vector(text=search_text, limit=1, type='trunk') if not search_results: result_sentences.append({"sentence": sentence, "flag": ""}) continue for search_result in search_results: if search_result.get("distance", DISTANCE_THRESHOLD) >= DISTANCE_THRESHOLD: result_sentences.append({"sentence": sentence, "flag": ""}) continue existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None) current_index = int(existing_ref["index"]) if existing_ref else reference_index if not existing_ref: reference, _ = _process_search_result(search_result, reference_index) all_references.append(reference) reference_index += 1 result_sentences.append({"sentence": sentence, "flag": str(current_index)}) return result_sentences, all_references @router.post("/eb_search", response_model=StandardResponse) async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)): try: start_time = time.time() # 检查缓存 cached_result = _check_cache(request.node_id) if cached_result: return StandardResponse(success=True, data=cached_result) # 初始化服务 trunks_service = TrunksService() node_service = KGNodeService(db) prop_service = KGPropService(db) # 获取节点信息 result = _get_node_info(node_service, request.node_id) node_name = result["name"] # 遍历props_ids查询属性信息 for prop_id in request.props_ids: prop = prop_service.get_props_by_id(prop_id) if not prop: logger.warning(f"属性不存在: {prop_id}") continue prop_title = prop.get('prop_title', '') prop_value = prop.get('prop_value', '') # 创建属性结果对象 prop_result = { "id": prop_id, "category": prop.get('category', 0), "prop_name": prop.get('prop_name', ''), "prop_value": prop_value, "prop_title": prop_title, "type": prop.get('type', 1) } result["props"].append(prop_result) # 先用完整的prop_value进行搜索 search_text = f"{node_name}:{prop_title}:{prop_value}" full_search_results = trunks_service.search_by_vector( text=search_text, limit=1, type='trunk' ) # 处理搜索结果 if full_search_results and full_search_results[0].get("distance", DISTANCE_THRESHOLD) < DISTANCE_THRESHOLD: search_result = full_search_results[0] reference, _ = _process_search_result(search_result, 1) prop_result["references"] = [reference] prop_result["answer"] = [{ "sentence": prop_value, "flag": "1" }] else: # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索 sentences = SentenceUtil.split_text(prop_value) result_sentences, references = _process_sentence_search( node_name, prop_title, sentences, trunks_service ) if references: prop_result["references"] = references if result_sentences: prop_result["answer"] = result_sentences # 处理文件信息 all_files = set() file_index_map = {} file_index = 1 # 收集文件信息 for prop_result in result["props"]: if "references" not in prop_result: continue for ref in prop_result["references"]: referrence = ref.get("referrence", "") if not (referrence and "/books/" in referrence): continue file_name = referrence.split("/books/")[-1] if not file_name: continue file_type = _get_file_type(file_name) if file_name not in file_index_map: file_index_map[file_name] = file_index file_index += 1 all_files.add((file_name, file_type)) # 更新引用索引 for prop_result in result["props"]: if "references" not in prop_result: continue for ref in prop_result["references"]: referrence = ref.get("referrence", "") if referrence and "/books/" in referrence: file_name = referrence.split("/books/")[-1] if file_name in file_index_map: ref["index"] = f"{file_index_map[file_name]}-{ref['index']}" # 更新answer中的flag if "answer" in prop_result: for sentence in prop_result["answer"]: if sentence["flag"]: for ref in prop_result["references"]: if ref["index"].endswith(f"-{sentence['flag']}"): sentence["flag"] = ref["index"] break # 添加文件信息到结果 result["files"] = sorted([{ "file_name": file_name, "file_type": file_type, "index": str(file_index_map[file_name]) } for file_name, file_type in all_files], key=lambda x: int(x["index"])) end_time = time.time() logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms") # 缓存结果 cache_key = f"xunzheng_{request.node_id}" cache[cache_key] = result return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"Node props search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) text_search_router = router