text_search.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. from fastapi import APIRouter, HTTPException, Depends
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. DISTANCE_THRESHOLD = 0.8
  10. import logging
  11. import time
  12. from db.session import get_db
  13. from sqlalchemy.orm import Session
  14. from service.kg_node_service import KGNodeService
  15. from service.kg_prop_service import KGPropService
  16. logger = logging.getLogger(__name__)
  17. router = APIRouter(prefix="/text", tags=["Text Search"])
  18. class TextSearchRequest(BaseModel):
  19. text: str
  20. conversation_id: Optional[str] = None
  21. need_convert: Optional[bool] = False
  22. class TextCompareRequest(BaseModel):
  23. sentence: str
  24. text: str
  25. class TextMatchRequest(BaseModel):
  26. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  27. @validator('text')
  28. def validate_text(cls, v):
  29. # 保留所有可打印字符、换行符和中文字符
  30. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  31. # 转义JSON特殊字符
  32. # 先处理反斜杠,避免后续转义时出现问题
  33. v = v.replace('\\', '\\\\')
  34. # 处理引号和其他特殊字符
  35. v = v.replace('"', '\\"')
  36. v = v.replace('/', '\\/')
  37. # 处理控制字符
  38. v = v.replace('\n', '\\n')
  39. v = v.replace('\r', '\\r')
  40. v = v.replace('\t', '\\t')
  41. v = v.replace('\b', '\\b')
  42. v = v.replace('\f', '\\f')
  43. # 处理Unicode转义
  44. # v = v.replace('\u', '\\u')
  45. return v
  46. class TextCompareMultiRequest(BaseModel):
  47. origin: str
  48. similar: str
  49. class NodePropsSearchRequest(BaseModel):
  50. node_id: int
  51. props_ids: List[int]
  52. conversation_id: Optional[str] = None
  53. @router.post("/search", response_model=StandardResponse)
  54. async def search_text(request: TextSearchRequest):
  55. try:
  56. #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
  57. if request.text.startswith('{') and request.text.endswith('}'):
  58. from utils.json_to_text import JsonToTextConverter
  59. converter = JsonToTextConverter()
  60. request.text = converter.convert(request.text)
  61. # 使用TextSplitter拆分文本
  62. sentences = TextSplitter.split_text(request.text)
  63. if not sentences:
  64. return StandardResponse(success=True, data={"answer": "", "references": []})
  65. # 初始化服务和结果列表
  66. trunks_service = TrunksService()
  67. result_sentences = []
  68. all_references = []
  69. reference_index = 1
  70. # 根据conversation_id获取缓存结果
  71. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  72. for sentence in sentences:
  73. # if request.need_convert:
  74. sentence = sentence.replace("\n", "<br>")
  75. if len(sentence) < 10:
  76. result_sentences.append(sentence)
  77. continue
  78. if cached_results:
  79. # 如果有缓存结果,计算向量距离
  80. min_distance = float('inf')
  81. best_result = None
  82. sentence_vector = Vectorizer.get_embedding(sentence)
  83. for cached_result in cached_results:
  84. content_vector = cached_result['embedding']
  85. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  86. if distance < min_distance:
  87. min_distance = distance
  88. best_result = {**cached_result, 'distance': distance}
  89. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  90. search_results = [best_result]
  91. else:
  92. search_results = []
  93. else:
  94. # 如果没有缓存结果,进行向量搜索
  95. search_results = trunks_service.search_by_vector(
  96. text=sentence,
  97. limit=1,
  98. type='trunk'
  99. )
  100. # 处理搜索结果
  101. for search_result in search_results:
  102. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  103. if distance >= DISTANCE_THRESHOLD:
  104. result_sentences.append(sentence)
  105. continue
  106. # 检查是否已存在相同引用
  107. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  108. current_index = reference_index
  109. if existing_ref:
  110. current_index = int(existing_ref["index"])
  111. else:
  112. # 添加到引用列表
  113. reference = {
  114. "index": str(reference_index),
  115. "id": search_result["id"],
  116. "content": search_result["content"],
  117. "file_path": search_result.get("file_path", ""),
  118. "title": search_result.get("title", ""),
  119. "distance": distance,
  120. "referrence": search_result.get("referrence", "")
  121. }
  122. all_references.append(reference)
  123. reference_index += 1
  124. # 添加引用标记
  125. if sentence.endswith('<br>'):
  126. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  127. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  128. else:
  129. # 直接在句子末尾添加^[current_index]^
  130. result_sentence = f'{sentence}^[{current_index}]^'
  131. result_sentences.append(result_sentence)
  132. # 组装返回数据
  133. response_data = {
  134. "answer": result_sentences,
  135. "references": all_references
  136. }
  137. return StandardResponse(success=True, data=response_data)
  138. except Exception as e:
  139. logger.error(f"Text search failed: {str(e)}")
  140. raise HTTPException(status_code=500, detail=str(e))
  141. @router.post("/match", response_model=StandardResponse)
  142. async def match_text(request: TextCompareRequest):
  143. try:
  144. sentences = TextSplitter.split_text(request.text)
  145. sentence_vector = Vectorizer.get_embedding(request.sentence)
  146. min_distance = float('inf')
  147. best_sentence = ""
  148. result_sentences = []
  149. for temp in sentences:
  150. result_sentences.append(temp)
  151. if len(temp) < 10:
  152. continue
  153. temp_vector = Vectorizer.get_embedding(temp)
  154. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  155. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  156. min_distance = distance
  157. best_sentence = temp
  158. for i in range(len(result_sentences)):
  159. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  160. if result_sentences[i]["sentence"] == best_sentence:
  161. result_sentences[i]["matched"] = True
  162. return StandardResponse(success=True, records=result_sentences)
  163. except Exception as e:
  164. logger.error(f"Text comparison failed: {str(e)}")
  165. raise HTTPException(status_code=500, detail=str(e))
  166. @router.post("/mr_search", response_model=StandardResponse)
  167. async def mr_search_text_content(request: TextMatchRequest):
  168. try:
  169. # 初始化服务
  170. trunks_service = TrunksService()
  171. # 获取文本向量并搜索相似内容
  172. search_results = trunks_service.search_by_vector(
  173. text=request.text,
  174. limit=10,
  175. type="mr"
  176. )
  177. # 处理搜索结果
  178. records = []
  179. for result in search_results:
  180. distance = result.get("distance", DISTANCE_THRESHOLD)
  181. if distance >= DISTANCE_THRESHOLD:
  182. continue
  183. # 添加到引用列表
  184. record = {
  185. "content": result["content"],
  186. "file_path": result.get("file_path", ""),
  187. "title": result.get("title", ""),
  188. "distance": distance,
  189. }
  190. records.append(record)
  191. # 组装返回数据
  192. response_data = {
  193. "records": records
  194. }
  195. return StandardResponse(success=True, data=response_data)
  196. except Exception as e:
  197. logger.error(f"Mr search failed: {str(e)}")
  198. raise HTTPException(status_code=500, detail=str(e))
  199. @router.post("/mr_match", response_model=StandardResponse)
  200. async def compare_text(request: TextCompareMultiRequest):
  201. start_time = time.time()
  202. try:
  203. # 拆分两段文本
  204. origin_sentences = TextSplitter.split_text(request.origin)
  205. similar_sentences = TextSplitter.split_text(request.similar)
  206. end_time = time.time()
  207. logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
  208. # 初始化结果列表
  209. origin_results = []
  210. # 过滤短句并预计算向量
  211. valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
  212. valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
  213. # 初始化similar_results,所有matched设为False
  214. similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
  215. # 批量获取向量
  216. origin_vectors = {}
  217. similar_vectors = {}
  218. origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
  219. similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
  220. if origin_batch:
  221. origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
  222. origin_vectors = dict(zip(origin_batch, origin_embeddings))
  223. if similar_batch:
  224. similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
  225. similar_vectors = dict(zip(similar_batch, similar_embeddings))
  226. end_time = time.time()
  227. logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
  228. # 处理origin文本
  229. for origin_sent, is_valid in valid_origin_sentences:
  230. if not is_valid:
  231. origin_results.append({"sentence": origin_sent, "matched": False})
  232. continue
  233. origin_vector = origin_vectors[origin_sent]
  234. matched = False
  235. # 优化的相似度计算
  236. for i, similar_result in enumerate(similar_results):
  237. if similar_result["matched"]:
  238. continue
  239. similar_sent = similar_result["sentence"]
  240. if len(similar_sent) < 10:
  241. continue
  242. similar_vector = similar_vectors.get(similar_sent)
  243. if not similar_vector:
  244. continue
  245. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  246. if distance < DISTANCE_THRESHOLD:
  247. matched = True
  248. similar_results[i]["matched"] = True
  249. break
  250. origin_results.append({"sentence": origin_sent, "matched": matched})
  251. response_data = {
  252. "origin": origin_results,
  253. "similar": similar_results
  254. }
  255. end_time = time.time()
  256. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  257. return StandardResponse(success=True, data=response_data)
  258. except Exception as e:
  259. end_time = time.time()
  260. logger.error(f"Text comparison failed: {str(e)}")
  261. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  262. raise HTTPException(status_code=500, detail=str(e))
  263. @router.post("/eb_search", response_model=StandardResponse)
  264. async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
  265. try:
  266. start_time = time.time()
  267. # 初始化服务
  268. trunks_service = TrunksService()
  269. node_service = KGNodeService(db)
  270. prop_service = KGPropService(db)
  271. # 根据node_id查询节点信息
  272. node = node_service.get_node(request.node_id)
  273. if not node:
  274. raise ValueError(f"节点不存在: {request.node_id}")
  275. node_name = node.get('name', '')
  276. # 初始化结果
  277. result = {
  278. "id": request.node_id,
  279. "name": node_name,
  280. "category": node.get('category', ''),
  281. "props": [],
  282. "distance": 0
  283. }
  284. # 缓存结果
  285. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  286. # 遍历props_ids查询属性信息
  287. for prop_id in request.props_ids:
  288. prop = prop_service.get_props_by_id(prop_id)
  289. if not prop:
  290. logger.warning(f"属性不存在: {prop_id}")
  291. continue
  292. prop_title = prop.get('prop_title', '')
  293. prop_value = prop.get('prop_value', '')
  294. # 拆分属性值为句子
  295. sentences = TextSplitter.split_text(prop_value)
  296. prop_result = {
  297. "id": prop_id,
  298. "category": prop.get('category', 0),
  299. "prop_name": prop.get('prop_name', ''),
  300. "prop_value": prop_value,
  301. "prop_title": prop_title,
  302. "type": prop.get('type', 1)
  303. }
  304. # 添加到结果中
  305. result["props"].append(prop_result)
  306. # 处理属性值中的句子
  307. result_sentences = []
  308. all_references = []
  309. reference_index = 1
  310. # 对每个句子进行向量搜索
  311. for sentence in sentences:
  312. original_sentence = sentence
  313. sentence = sentence.replace("\n", "<br>")
  314. if len(sentence) < 10:
  315. result_sentences.append(sentence)
  316. continue
  317. # 构建搜索文本
  318. search_text = f"{node_name}:{prop_title}:{sentence}"
  319. # 检查缓存
  320. if cached_results:
  321. min_distance = float('inf')
  322. best_result = None
  323. sentence_vector = Vectorizer.get_embedding(search_text)
  324. for cached_result in cached_results:
  325. content_vector = cached_result['embedding']
  326. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  327. if distance < min_distance:
  328. min_distance = distance
  329. best_result = {**cached_result, 'distance': distance}
  330. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  331. search_results = [best_result]
  332. else:
  333. search_results = []
  334. else:
  335. # 进行向量搜索
  336. search_results = trunks_service.search_by_vector(
  337. text=search_text,
  338. limit=1,
  339. type='trunk',
  340. conversation_id=request.conversation_id
  341. )
  342. # 处理搜索结果
  343. for search_result in search_results:
  344. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  345. if distance >= DISTANCE_THRESHOLD:
  346. result_sentences.append(sentence)
  347. continue
  348. # 检查是否已存在相同引用
  349. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  350. current_index = reference_index
  351. if existing_ref:
  352. current_index = int(existing_ref["index"])
  353. else:
  354. # 添加到引用列表
  355. reference = {
  356. "index": str(reference_index),
  357. "id": search_result["id"],
  358. "content": search_result["content"],
  359. "file_path": search_result.get("file_path", ""),
  360. "title": search_result.get("title", ""),
  361. "distance": distance,
  362. "referrence": search_result.get("referrence", "")
  363. }
  364. all_references.append(reference)
  365. reference_index += 1
  366. # 添加引用标记
  367. if sentence.endswith('<br>'):
  368. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  369. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  370. else:
  371. # 直接在句子末尾添加^[current_index]^
  372. result_sentence = f'{sentence}^[{current_index}]^'
  373. result_sentences.append(result_sentence)
  374. # 更新属性值,添加引用信息
  375. if all_references:
  376. prop_result["references"] = all_references
  377. # 将处理后的句子添加到结果中
  378. if result_sentences:
  379. prop_result["answer"] = result_sentences
  380. end_time = time.time()
  381. logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  382. return StandardResponse(success=True, data=result)
  383. except Exception as e:
  384. logger.error(f"Node props search failed: {str(e)}")
  385. raise HTTPException(status_code=500, detail=str(e))
  386. text_search_router = router