from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel, Field, validator from typing import List, Optional from service.trunks_service import TrunksService from utils.text_splitter import TextSplitter from utils.vector_distance import VectorDistance from model.response import StandardResponse from utils.vectorizer import Vectorizer # from utils.find_text_in_pdf import find_text_in_pdf import os DISTANCE_THRESHOLD = 0.8 import logging import time from db.session import get_db from sqlalchemy.orm import Session from service.kg_node_service import KGNodeService from service.kg_prop_service import KGPropService logger = logging.getLogger(__name__) router = APIRouter(prefix="/text", tags=["Text Search"]) class TextSearchRequest(BaseModel): text: str conversation_id: Optional[str] = None need_convert: Optional[bool] = False class TextCompareRequest(BaseModel): sentence: str text: str class TextMatchRequest(BaseModel): text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容") @validator('text') def validate_text(cls, v): # 保留所有可打印字符、换行符和中文字符 v = ''.join(char for char in v if char.isprintable() or char in '\n\r') # 转义JSON特殊字符 # 先处理反斜杠,避免后续转义时出现问题 v = v.replace('\\', '\\\\') # 处理引号和其他特殊字符 v = v.replace('"', '\\"') v = v.replace('/', '\\/') # 处理控制字符 v = v.replace('\n', '\\n') v = v.replace('\r', '\\r') v = v.replace('\t', '\\t') v = v.replace('\b', '\\b') v = v.replace('\f', '\\f') # 处理Unicode转义 # v = v.replace('\u', '\\u') return v class TextCompareMultiRequest(BaseModel): origin: str similar: str class NodePropsSearchRequest(BaseModel): node_id: int props_ids: List[int] @router.post("/search", response_model=StandardResponse) async def search_text(request: TextSearchRequest): try: #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text if request.text.startswith('{') and request.text.endswith('}'): from utils.json_to_text import JsonToTextConverter converter = JsonToTextConverter() request.text = converter.convert(request.text) # 使用TextSplitter拆分文本 sentences = TextSplitter.split_text(request.text) if not sentences: return StandardResponse(success=True, data={"answer": "", "references": []}) # 初始化服务和结果列表 trunks_service = TrunksService() result_sentences = [] all_references = [] reference_index = 1 # 根据conversation_id获取缓存结果 cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else [] for sentence in sentences: # if request.need_convert: sentence = sentence.replace("\n", "
") if len(sentence) < 10: result_sentences.append(sentence) continue if cached_results: # 如果有缓存结果,计算向量距离 min_distance = float('inf') best_result = None sentence_vector = Vectorizer.get_embedding(sentence) for cached_result in cached_results: content_vector = cached_result['embedding'] distance = VectorDistance.calculate_distance(sentence_vector, content_vector) if distance < min_distance: min_distance = distance best_result = {**cached_result, 'distance': distance} if best_result and best_result['distance'] < DISTANCE_THRESHOLD: search_results = [best_result] else: search_results = [] else: # 如果没有缓存结果,进行向量搜索 search_results = trunks_service.search_by_vector( text=sentence, limit=1, type='trunk' ) # 处理搜索结果 for search_result in search_results: distance = search_result.get("distance", DISTANCE_THRESHOLD) if distance >= DISTANCE_THRESHOLD: result_sentences.append(sentence) continue # 检查是否已存在相同引用 existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None) current_index = reference_index if existing_ref: current_index = int(existing_ref["index"]) else: # 添加到引用列表 reference = { "index": str(reference_index), "id": search_result["id"], "content": search_result["content"], "file_path": search_result.get("file_path", ""), "title": search_result.get("title", ""), "distance": distance, "referrence": search_result.get("referrence", "") } all_references.append(reference) reference_index += 1 # 添加引用标记 if sentence.endswith('
'): # 如果有多个
,在所有
前添加^[current_index]^ result_sentence = sentence.replace('
', f'^[{current_index}]^
') else: # 直接在句子末尾添加^[current_index]^ result_sentence = f'{sentence}^[{current_index}]^' result_sentences.append(result_sentence) # 组装返回数据 response_data = { "answer": result_sentences, "references": all_references } return StandardResponse(success=True, data=response_data) except Exception as e: logger.error(f"Text search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/match", response_model=StandardResponse) async def match_text(request: TextCompareRequest): try: sentences = TextSplitter.split_text(request.text) sentence_vector = Vectorizer.get_embedding(request.sentence) min_distance = float('inf') best_sentence = "" result_sentences = [] for temp in sentences: result_sentences.append(temp) if len(temp) < 10: continue temp_vector = Vectorizer.get_embedding(temp) distance = VectorDistance.calculate_distance(sentence_vector, temp_vector) if distance < min_distance and distance < DISTANCE_THRESHOLD: min_distance = distance best_sentence = temp for i in range(len(result_sentences)): result_sentences[i] = {"sentence": result_sentences[i], "matched": False} if result_sentences[i]["sentence"] == best_sentence: result_sentences[i]["matched"] = True return StandardResponse(success=True, records=result_sentences) except Exception as e: logger.error(f"Text comparison failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/mr_search", response_model=StandardResponse) async def mr_search_text_content(request: TextMatchRequest): try: # 初始化服务 trunks_service = TrunksService() # 获取文本向量并搜索相似内容 search_results = trunks_service.search_by_vector( text=request.text, limit=10, type="mr" ) # 处理搜索结果 records = [] for result in search_results: distance = result.get("distance", DISTANCE_THRESHOLD) if distance >= DISTANCE_THRESHOLD: continue # 添加到引用列表 record = { "content": result["content"], "file_path": result.get("file_path", ""), "title": result.get("title", ""), "distance": distance, } records.append(record) # 组装返回数据 response_data = { "records": records } return StandardResponse(success=True, data=response_data) except Exception as e: logger.error(f"Mr search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/mr_match", response_model=StandardResponse) async def compare_text(request: TextCompareMultiRequest): start_time = time.time() try: # 拆分两段文本 origin_sentences = TextSplitter.split_text(request.origin) similar_sentences = TextSplitter.split_text(request.similar) end_time = time.time() logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms") # 初始化结果列表 origin_results = [] # 过滤短句并预计算向量 valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences] valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences] # 初始化similar_results,所有matched设为False similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences] # 批量获取向量 origin_vectors = {} similar_vectors = {} origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid] similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid] if origin_batch: origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch] origin_vectors = dict(zip(origin_batch, origin_embeddings)) if similar_batch: similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch] similar_vectors = dict(zip(similar_batch, similar_embeddings)) end_time = time.time() logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms") # 处理origin文本 for origin_sent, is_valid in valid_origin_sentences: if not is_valid: origin_results.append({"sentence": origin_sent, "matched": False}) continue origin_vector = origin_vectors[origin_sent] matched = False # 优化的相似度计算 for i, similar_result in enumerate(similar_results): if similar_result["matched"]: continue similar_sent = similar_result["sentence"] if len(similar_sent) < 10: continue similar_vector = similar_vectors.get(similar_sent) if not similar_vector: continue distance = VectorDistance.calculate_distance(origin_vector, similar_vector) if distance < DISTANCE_THRESHOLD: matched = True similar_results[i]["matched"] = True break origin_results.append({"sentence": origin_sent, "matched": matched}) response_data = { "origin": origin_results, "similar": similar_results } end_time = time.time() logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms") return StandardResponse(success=True, data=response_data) except Exception as e: end_time = time.time() logger.error(f"Text comparison failed: {str(e)}") logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms") raise HTTPException(status_code=500, detail=str(e)) @router.post("/eb_search", response_model=StandardResponse) async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)): try: start_time = time.time() # 初始化服务 trunks_service = TrunksService() node_service = KGNodeService(db) prop_service = KGPropService(db) # 根据node_id查询节点信息 node = node_service.get_node(request.node_id) if not node: raise ValueError(f"节点不存在: {request.node_id}") node_name = node.get('name', '') # 初始化结果 result = { "id": request.node_id, "name": node_name, "category": node.get('category', ''), "props": [], "files": [], "distance": 0 } # 遍历props_ids查询属性信息 for prop_id in request.props_ids: prop = prop_service.get_props_by_id(prop_id) if not prop: logger.warning(f"属性不存在: {prop_id}") continue prop_title = prop.get('prop_title', '') prop_value = prop.get('prop_value', '') # 拆分属性值为句子 sentences = TextSplitter.split_text(prop_value) prop_result = { "id": prop_id, "category": prop.get('category', 0), "prop_name": prop.get('prop_name', ''), "prop_value": prop_value, "prop_title": prop_title, "type": prop.get('type', 1) } # 添加到结果中 result["props"].append(prop_result) # 处理属性值中的句子 result_sentences = [] all_references = [] reference_index = 1 # 对每个句子进行向量搜索 i = 0 while i < len(sentences): original_sentence = sentences[i] sentence = original_sentence # 如果当前句子长度小于10且不是最后一句,则与下一句合并 if len(sentence) < 10 and i + 1 < len(sentences): next_sentence = sentences[i + 1] combined_sentence = sentence + " " + next_sentence # 添加原短句到结果,flag为空 result_sentences.append({ "sentence": sentence, "flag": "" }) # 使用合并后的句子进行搜索 search_text = f"{node_name}:{prop_title}:{combined_sentence}" i += 1 # 跳过下一句,因为已经合并使用 elif len(sentence) < 10: # 如果是最后一句且长度小于10,直接添加到结果,flag为空 result_sentences.append({ "sentence": sentence, "flag": "" }) i += 1 continue else: # 句子长度足够,直接使用 search_text = f"{node_name}:{prop_title}:{sentence}" i += 1 # 进行向量搜索 search_results = trunks_service.search_by_vector( text=search_text, limit=1, type='trunk' ) # 处理搜索结果 if not search_results: # 没有搜索结果,添加原句子,flag为空 result_sentences.append({ "sentence": sentence, "flag": "" }) continue for search_result in search_results: distance = search_result.get("distance", DISTANCE_THRESHOLD) if distance >= DISTANCE_THRESHOLD: # 距离过大,添加原句子,flag为空 result_sentences.append({ "sentence": sentence, "flag": "" }) continue # 检查是否已存在相同引用 existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None) current_index = reference_index if existing_ref: current_index = int(existing_ref["index"]) else: # 添加到引用列表 reference = { "index": str(reference_index), "id": search_result["id"], "content": search_result["content"], "file_path": search_result.get("file_path", ""), "title": search_result.get("title", ""), "distance": distance, "page_no": search_result.get("page_no", ""), "referrence": search_result.get("referrence", "") } all_references.append(reference) reference_index += 1 # 添加句子和引用标记(作为单独的flag字段) result_sentences.append({ "sentence": sentence, "flag": str(current_index) }) # 更新属性值,添加引用信息 if all_references: prop_result["references"] = all_references # 将处理后的句子添加到结果中 if result_sentences: prop_result["answer"] = result_sentences # 处理所有引用中的文件信息 all_files = set() for prop_result in result["props"]: if "references" in prop_result: for ref in prop_result["references"]: referrence = ref.get("referrence", "") if referrence and "/books/" in referrence: # 提取/books/后面的文件名 file_name = referrence.split("/books/")[-1] if file_name: # 根据文件名后缀确定文件类型 file_type = "" if file_name.lower().endswith(".pdf"): file_type = "pdf" elif file_name.lower().endswith(".doc") or file_name.lower().endswith(".docx"): file_type = "doc" elif file_name.lower().endswith(".xls") or file_name.lower().endswith(".xlsx"): file_type = "excel" elif file_name.lower().endswith(".ppt") or file_name.lower().endswith(".pptx"): file_type = "ppt" else: file_type = "other" all_files.add((file_name, file_type)) # 将文件信息添加到结果中 result["files"] = [{ "file_name": file_name, "file_type": file_type } for file_name, file_type in all_files] end_time = time.time() logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms") return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"Node props search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) text_search_router = router