from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field, validator
from typing import List, Optional
from service.trunks_service import TrunksService
from utils.sentence_util import SentenceUtil
from utils.vector_distance import VectorDistance
from model.response import StandardResponse
from utils.vectorizer import Vectorizer
# from utils.find_text_in_pdf import find_text_in_pdf
import os
DISTANCE_THRESHOLD = 0.73
import logging
import time
from db.session import get_db
from sqlalchemy.orm import Session
from service.kg_node_service import KGNodeService
from service.kg_prop_service import KGPropService
from cachetools import TTLCache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/text", tags=["Text Search"])
# 创建全局缓存实例
cache = TTLCache(maxsize=1000, ttl=3600)
class TextSearchRequest(BaseModel):
text: str
conversation_id: Optional[str] = None
need_convert: Optional[bool] = False
class TextCompareRequest(BaseModel):
sentence: str
text: str
class TextMatchRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
@validator('text')
def validate_text(cls, v):
# 保留所有可打印字符、换行符和中文字符
v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
# 转义JSON特殊字符
# 先处理反斜杠,避免后续转义时出现问题
v = v.replace('\\', '\\\\')
# 处理引号和其他特殊字符
v = v.replace('"', '\\"')
v = v.replace('/', '\\/')
# 处理控制字符
v = v.replace('\n', '\\n')
v = v.replace('\r', '\\r')
v = v.replace('\t', '\\t')
v = v.replace('\b', '\\b')
v = v.replace('\f', '\\f')
# 处理Unicode转义
# v = v.replace('\u', '\\u')
return v
class TextCompareMultiRequest(BaseModel):
origin: str
similar: str
class NodePropsSearchRequest(BaseModel):
node_id: int
props_ids: List[int]
@router.post("/clear_cache", response_model=StandardResponse)
async def clear_cache():
try:
# 清除全局缓存
cache.clear()
return StandardResponse(success=True, data={"message": "缓存已清除"})
except Exception as e:
logger.error(f"清除缓存失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/search", response_model=StandardResponse)
async def search_text(request: TextSearchRequest):
try:
#判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
if request.text.startswith('{') and request.text.endswith('}'):
from utils.json_to_text import JsonToTextConverter
converter = JsonToTextConverter()
request.text = converter.convert(request.text)
# 使用TextSplitter拆分文本
sentences = SentenceUtil.split_text(request.text)
if not sentences:
return StandardResponse(success=True, data={"answer": "", "references": []})
# 初始化服务和结果列表
trunks_service = TrunksService()
result_sentences = []
all_references = []
reference_index = 1
# 根据conversation_id获取缓存结果
cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
for sentence in sentences:
# if request.need_convert:
sentence = sentence.replace("\n", "
")
if len(sentence) < 10:
result_sentences.append(sentence)
continue
if cached_results:
# 如果有缓存结果,计算向量距离
min_distance = float('inf')
best_result = None
sentence_vector = Vectorizer.get_embedding(sentence)
for cached_result in cached_results:
content_vector = cached_result['embedding']
distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
if distance < min_distance:
min_distance = distance
best_result = {**cached_result, 'distance': distance}
if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
search_results = [best_result]
else:
search_results = []
else:
# 如果没有缓存结果,进行向量搜索
search_results = trunks_service.search_by_vector(
text=sentence,
limit=1,
type='trunk'
)
# 处理搜索结果
for search_result in search_results:
distance = search_result.get("distance", DISTANCE_THRESHOLD)
if distance >= DISTANCE_THRESHOLD:
result_sentences.append(sentence)
continue
# 检查是否已存在相同引用
existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
current_index = reference_index
if existing_ref:
current_index = int(existing_ref["index"])
else:
# 添加到引用列表
# 从referrence中提取文件名
file_name = ""
referrence = search_result.get("referrence", "")
if referrence and "/books/" in referrence:
file_name = referrence.split("/books/")[-1]
# 去除文件扩展名
file_name = os.path.splitext(file_name)[0]
reference = {
"index": str(reference_index),
"id": search_result["id"],
"content": search_result["content"],
"file_path": search_result.get("file_path", ""),
"title": search_result.get("title", ""),
"distance": distance,
"file_name": file_name,
"referrence": referrence
}
all_references.append(reference)
reference_index += 1
# 添加引用标记
if sentence.endswith('
'):
# 如果有多个
,在所有
前添加^[current_index]^
result_sentence = sentence.replace('
', f'^[{current_index}]^
')
else:
# 直接在句子末尾添加^[current_index]^
result_sentence = f'{sentence}^[{current_index}]^'
result_sentences.append(result_sentence)
# 组装返回数据
response_data = {
"answer": result_sentences,
"references": all_references
}
return StandardResponse(success=True, data=response_data)
except Exception as e:
logger.error(f"Text search failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/match", response_model=StandardResponse)
async def match_text(request: TextCompareRequest):
try:
sentences = SentenceUtil.split_text(request.text)
sentence_vector = Vectorizer.get_embedding(request.sentence)
min_distance = float('inf')
best_sentence = ""
result_sentences = []
for temp in sentences:
result_sentences.append(temp)
if len(temp) < 10:
continue
temp_vector = Vectorizer.get_embedding(temp)
distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
if distance < min_distance and distance < DISTANCE_THRESHOLD:
min_distance = distance
best_sentence = temp
for i in range(len(result_sentences)):
result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
if result_sentences[i]["sentence"] == best_sentence:
result_sentences[i]["matched"] = True
return StandardResponse(success=True, records=result_sentences)
except Exception as e:
logger.error(f"Text comparison failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/mr_search", response_model=StandardResponse)
async def mr_search_text_content(request: TextMatchRequest):
try:
# 初始化服务
trunks_service = TrunksService()
# 获取文本向量并搜索相似内容
search_results = trunks_service.search_by_vector(
text=request.text,
limit=10,
type="mr"
)
# 处理搜索结果
records = []
for result in search_results:
distance = result.get("distance", DISTANCE_THRESHOLD)
if distance >= DISTANCE_THRESHOLD:
continue
# 添加到引用列表
record = {
"content": result["content"],
"file_path": result.get("file_path", ""),
"title": result.get("title", ""),
"distance": distance,
}
records.append(record)
# 组装返回数据
response_data = {
"records": records
}
return StandardResponse(success=True, data=response_data)
except Exception as e:
logger.error(f"Mr search failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/mr_match", response_model=StandardResponse)
async def compare_text(request: TextCompareMultiRequest):
start_time = time.time()
try:
# 拆分两段文本
origin_sentences = SentenceUtil.split_text(request.origin)
similar_sentences = SentenceUtil.split_text(request.similar)
end_time = time.time()
logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
# 初始化结果列表
origin_results = []
# 过滤短句并预计算向量
valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
# 初始化similar_results,所有matched设为False
similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
# 批量获取向量
origin_vectors = {}
similar_vectors = {}
origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
if origin_batch:
origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
origin_vectors = dict(zip(origin_batch, origin_embeddings))
if similar_batch:
similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
similar_vectors = dict(zip(similar_batch, similar_embeddings))
end_time = time.time()
logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
# 处理origin文本
for origin_sent, is_valid in valid_origin_sentences:
if not is_valid:
origin_results.append({"sentence": origin_sent, "matched": False})
continue
origin_vector = origin_vectors[origin_sent]
matched = False
# 优化的相似度计算
for i, similar_result in enumerate(similar_results):
if similar_result["matched"]:
continue
similar_sent = similar_result["sentence"]
if len(similar_sent) < 10:
continue
similar_vector = similar_vectors.get(similar_sent)
if not similar_vector:
continue
distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
if distance < DISTANCE_THRESHOLD:
matched = True
similar_results[i]["matched"] = True
break
origin_results.append({"sentence": origin_sent, "matched": matched})
response_data = {
"origin": origin_results,
"similar": similar_results
}
end_time = time.time()
logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
return StandardResponse(success=True, data=response_data)
except Exception as e:
end_time = time.time()
logger.error(f"Text comparison failed: {str(e)}")
logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
raise HTTPException(status_code=500, detail=str(e))
def _check_cache(node_id: int) -> Optional[dict]:
"""检查并返回缓存结果"""
cache_key = f"xunzheng_{node_id}"
cached_result = cache.get(cache_key)
if cached_result:
logger.info(f"从缓存获取结果,node_id: {node_id}")
return cached_result
return None
def _get_node_info(node_service: KGNodeService, node_id: int) -> dict:
"""获取并验证节点信息"""
node = node_service.get_node(node_id)
if not node:
raise ValueError(f"节点不存在: {node_id}")
return {
"id": node_id,
"name": node.get('name', ''),
"category": node.get('category', ''),
"props": [],
"files": [],
"distance": 0
}
def _process_search_result(search_result: dict, reference_index: int) -> tuple[dict, str]:
"""处理搜索结果,返回引用信息和文件名"""
file_name = ""
referrence = search_result.get("referrence", "")
if referrence and "/books/" in referrence:
file_name = referrence.split("/books/")[-1]
file_name = os.path.splitext(file_name)[0]
reference = {
"index": str(reference_index),
"id": search_result["id"],
"content": search_result["content"],
"file_path": search_result.get("file_path", ""),
"title": search_result.get("title", ""),
"distance": search_result.get("distance", DISTANCE_THRESHOLD),
"page_no": search_result.get("page_no", ""),
"file_name": file_name,
"referrence": referrence
}
return reference, file_name
def _get_file_type(file_name: str) -> str:
"""根据文件名确定文件类型"""
file_name_lower = file_name.lower()
if file_name_lower.endswith(".pdf"):
return "pdf"
elif file_name_lower.endswith((".doc", ".docx")):
return "doc"
elif file_name_lower.endswith((".xls", ".xlsx")):
return "excel"
elif file_name_lower.endswith((".ppt", ".pptx")):
return "ppt"
return "other"
def _process_sentence_search(node_name: str, prop_title: str, sentences: list, trunks_service: TrunksService) -> tuple[list, list]:
"""处理句子搜索,返回结果句子和引用列表"""
result_sentences = []
all_references = []
reference_index = 1
i = 0
while i < len(sentences):
sentence = sentences[i]
if len(sentence) < 10 and i + 1 < len(sentences):
next_sentence = sentences[i + 1]
result_sentences.append({"sentence": sentence, "flag": ""})
search_text = f"{node_name}:{prop_title}:{sentence} {next_sentence}"
i += 1
elif len(sentence) < 10:
result_sentences.append({"sentence": sentence, "flag": ""})
i += 1
continue
else:
search_text = f"{node_name}:{prop_title}:{sentence}"
i += 1
search_results = trunks_service.search_by_vector(text=search_text, limit=1, type='trunk')
if not search_results:
result_sentences.append({"sentence": sentence, "flag": ""})
continue
for search_result in search_results:
if search_result.get("distance", DISTANCE_THRESHOLD) >= DISTANCE_THRESHOLD:
result_sentences.append({"sentence": sentence, "flag": ""})
continue
existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
current_index = int(existing_ref["index"]) if existing_ref else reference_index
if not existing_ref:
reference, _ = _process_search_result(search_result, reference_index)
all_references.append(reference)
reference_index += 1
result_sentences.append({"sentence": sentence, "flag": str(current_index)})
return result_sentences, all_references
@router.post("/eb_search", response_model=StandardResponse)
async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
try:
start_time = time.time()
# 检查缓存
cached_result = _check_cache(request.node_id)
if cached_result:
return StandardResponse(success=True, data=cached_result)
# 初始化服务
trunks_service = TrunksService()
node_service = KGNodeService(db)
prop_service = KGPropService(db)
# 获取节点信息
result = _get_node_info(node_service, request.node_id)
node_name = result["name"]
# 遍历props_ids查询属性信息
for prop_id in request.props_ids:
prop = prop_service.get_props_by_id(prop_id)
if not prop:
logger.warning(f"属性不存在: {prop_id}")
continue
prop_title = prop.get('prop_title', '')
prop_value = prop.get('prop_value', '')
# 创建属性结果对象
prop_result = {
"id": prop_id,
"category": prop.get('category', 0),
"prop_name": prop.get('prop_name', ''),
"prop_value": prop_value,
"prop_title": prop_title,
"type": prop.get('type', 1)
}
result["props"].append(prop_result)
# 先用完整的prop_value进行搜索
search_text = f"{node_name}:{prop_title}:{prop_value}"
full_search_results = trunks_service.search_by_vector(
text=search_text,
limit=1,
type='trunk'
)
# 处理搜索结果
if full_search_results and full_search_results[0].get("distance", DISTANCE_THRESHOLD) < DISTANCE_THRESHOLD:
search_result = full_search_results[0]
reference, _ = _process_search_result(search_result, 1)
prop_result["references"] = [reference]
prop_result["answer"] = [{
"sentence": prop_value,
"flag": "1"
}]
else:
# 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
sentences = SentenceUtil.split_text(prop_value)
result_sentences, references = _process_sentence_search(
node_name, prop_title, sentences, trunks_service
)
if references:
prop_result["references"] = references
if result_sentences:
prop_result["answer"] = result_sentences
# 处理文件信息
all_files = set()
file_index_map = {}
file_index = 1
# 收集文件信息
for prop_result in result["props"]:
if "references" not in prop_result:
continue
for ref in prop_result["references"]:
referrence = ref.get("referrence", "")
if not (referrence and "/books/" in referrence):
continue
file_name = referrence.split("/books/")[-1]
if not file_name:
continue
file_type = _get_file_type(file_name)
if file_name not in file_index_map:
file_index_map[file_name] = file_index
file_index += 1
all_files.add((file_name, file_type))
# 更新引用索引
for prop_result in result["props"]:
if "references" not in prop_result:
continue
for ref in prop_result["references"]:
referrence = ref.get("referrence", "")
if referrence and "/books/" in referrence:
file_name = referrence.split("/books/")[-1]
if file_name in file_index_map:
ref["index"] = f"{file_index_map[file_name]}-{ref['index']}"
# 更新answer中的flag
if "answer" in prop_result:
for sentence in prop_result["answer"]:
if sentence["flag"]:
for ref in prop_result["references"]:
if ref["index"].endswith(f"-{sentence['flag']}"):
sentence["flag"] = ref["index"]
break
# 添加文件信息到结果
result["files"] = sorted([{
"file_name": file_name,
"file_type": file_type,
"index": str(file_index_map[file_name])
} for file_name, file_type in all_files], key=lambda x: int(x["index"]))
end_time = time.time()
logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
# 缓存结果
cache_key = f"xunzheng_{request.node_id}"
cache[cache_key] = result
return StandardResponse(success=True, data=result)
except Exception as e:
logger.error(f"Node props search failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
text_search_router = router