Browse Source

增加来源书籍名称

yuchengwei 1 month ago
parent
commit
0fadeffb68
1 changed files with 174 additions and 184 deletions
  1. 174 184
      router/text_search.py

+ 174 - 184
router/text_search.py

@@ -145,6 +145,14 @@ async def search_text(request: TextSearchRequest):
                     current_index = int(existing_ref["index"])
                 else:
                     # 添加到引用列表
+                    # 从referrence中提取文件名
+                    file_name = ""
+                    referrence = search_result.get("referrence", "")
+                    if referrence and "/books/" in referrence:
+                        file_name = referrence.split("/books/")[-1]
+                        # 去除文件扩展名
+                        file_name = os.path.splitext(file_name)[0]
+
                     reference = {
                         "index": str(reference_index),
                         "id": search_result["id"],
@@ -152,7 +160,8 @@ async def search_text(request: TextSearchRequest):
                         "file_path": search_result.get("file_path", ""),
                         "title": search_result.get("title", ""),
                         "distance": distance,
-                        "referrence": search_result.get("referrence", "")
+                        "file_name": file_name,
+                        "referrence": referrence
                     }
                     
                     
@@ -329,43 +338,132 @@ async def compare_text(request: TextCompareMultiRequest):
         logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
         raise HTTPException(status_code=500, detail=str(e))
 
+def _check_cache(node_id: int) -> Optional[dict]:
+    """检查并返回缓存结果"""
+    cache_key = f"xunzheng_{node_id}"
+    cached_result = cache.get(cache_key)
+    if cached_result:
+        logger.info(f"从缓存获取结果,node_id: {node_id}")
+        return cached_result
+    return None
+
+def _get_node_info(node_service: KGNodeService, node_id: int) -> dict:
+    """获取并验证节点信息"""
+    node = node_service.get_node(node_id)
+    if not node:
+        raise ValueError(f"节点不存在: {node_id}")
+    return {
+        "id": node_id,
+        "name": node.get('name', ''),
+        "category": node.get('category', ''),
+        "props": [],
+        "files": [],
+        "distance": 0
+    }
+
+def _process_search_result(search_result: dict, reference_index: int) -> tuple[dict, str]:
+    """处理搜索结果,返回引用信息和文件名"""
+    file_name = ""
+    referrence = search_result.get("referrence", "")
+    if referrence and "/books/" in referrence:
+        file_name = referrence.split("/books/")[-1]
+        file_name = os.path.splitext(file_name)[0]
+
+    reference = {
+        "index": str(reference_index),
+        "id": search_result["id"],
+        "content": search_result["content"],
+        "file_path": search_result.get("file_path", ""),
+        "title": search_result.get("title", ""),
+        "distance": search_result.get("distance", DISTANCE_THRESHOLD),
+        "page_no": search_result.get("page_no", ""),
+        "file_name": file_name,
+        "referrence": referrence
+    }
+    return reference, file_name
+
+def _get_file_type(file_name: str) -> str:
+    """根据文件名确定文件类型"""
+    file_name_lower = file_name.lower()
+    if file_name_lower.endswith(".pdf"):
+        return "pdf"
+    elif file_name_lower.endswith((".doc", ".docx")):
+        return "doc"
+    elif file_name_lower.endswith((".xls", ".xlsx")):
+        return "excel"
+    elif file_name_lower.endswith((".ppt", ".pptx")):
+        return "ppt"
+    return "other"
+
+def _process_sentence_search(node_name: str, prop_title: str, sentences: list, trunks_service: TrunksService) -> tuple[list, list]:
+    """处理句子搜索,返回结果句子和引用列表"""
+    result_sentences = []
+    all_references = []
+    reference_index = 1
+    i = 0
+    
+    while i < len(sentences):
+        sentence = sentences[i]
+        
+        if len(sentence) < 10 and i + 1 < len(sentences):
+            next_sentence = sentences[i + 1]
+            result_sentences.append({"sentence": sentence, "flag": ""})
+            search_text = f"{node_name}:{prop_title}:{sentence} {next_sentence}"
+            i += 1
+        elif len(sentence) < 10:
+            result_sentences.append({"sentence": sentence, "flag": ""})
+            i += 1
+            continue
+        else:
+            search_text = f"{node_name}:{prop_title}:{sentence}"
+        
+        i += 1
+        
+        search_results = trunks_service.search_by_vector(text=search_text, limit=1, type='trunk')
+        
+        if not search_results:
+            result_sentences.append({"sentence": sentence, "flag": ""})
+            continue
+            
+        for search_result in search_results:
+            if search_result.get("distance", DISTANCE_THRESHOLD) >= DISTANCE_THRESHOLD:
+                result_sentences.append({"sentence": sentence, "flag": ""})
+                continue
+                
+            existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
+            current_index = int(existing_ref["index"]) if existing_ref else reference_index
+            
+            if not existing_ref:
+                reference, _ = _process_search_result(search_result, reference_index)
+                all_references.append(reference)
+                reference_index += 1
+                
+            result_sentences.append({"sentence": sentence, "flag": str(current_index)})
+    
+    return result_sentences, all_references
+
 @router.post("/eb_search", response_model=StandardResponse)
 async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
     try:
         start_time = time.time()
         
         # 检查缓存
-        cache_key = f"xunzheng_{request.node_id}"
-        cached_result = cache.get(cache_key)
+        cached_result = _check_cache(request.node_id)
         if cached_result:
-            logger.info(f"从缓存获取结果,node_id: {request.node_id}")
             return StandardResponse(success=True, data=cached_result)
+
         # 初始化服务
         trunks_service = TrunksService()
         node_service = KGNodeService(db)
         prop_service = KGPropService(db)
 
-        # 根据node_id查询节点信息
-        node = node_service.get_node(request.node_id)
-        if not node:
-            raise ValueError(f"节点不存在: {request.node_id}")
-
-        node_name = node.get('name', '')
-
-        # 初始化结果
-        result = {
-            "id": request.node_id,
-            "name": node_name,
-            "category": node.get('category', ''),
-            "props": [],
-            "files": [],
-            "distance": 0
-        }
+        # 获取节点信息
+        result = _get_node_info(node_service, request.node_id)
+        node_name = result["name"]
 
         # 遍历props_ids查询属性信息
         for prop_id in request.props_ids:
             prop = prop_service.get_props_by_id(prop_id)
-
             if not prop:
                 logger.warning(f"属性不存在: {prop_id}")
                 continue
@@ -373,14 +471,7 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
             prop_title = prop.get('prop_title', '')
             prop_value = prop.get('prop_value', '')
 
-            # 先用完整的prop_value进行搜索
-            search_text = f"{node_name}:{prop_title}:{prop_value}"
-            full_search_results = trunks_service.search_by_vector(
-                text=search_text,
-                limit=1,
-                type='trunk'
-            )
-
+            # 创建属性结果对象
             prop_result = {
                 "id": prop_id,
                 "category": prop.get('category', 0),
@@ -389,171 +480,69 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                 "prop_title": prop_title,
                 "type": prop.get('type', 1)
             }
-
-            # 添加到结果中
             result["props"].append(prop_result)
 
-            # 处理属性值中的句子
-            result_sentences = []
-            all_references = []
-            reference_index = 1
+            # 先用完整的prop_value进行搜索
+            search_text = f"{node_name}:{prop_title}:{prop_value}"
+            full_search_results = trunks_service.search_by_vector(
+                text=search_text,
+                limit=1,
+                type='trunk'
+            )
 
-            # 如果整体搜索找到匹配结果且距离小于阈值,直接使用整体结果
+            # 处理搜索结果
             if full_search_results and full_search_results[0].get("distance", DISTANCE_THRESHOLD) < DISTANCE_THRESHOLD:
                 search_result = full_search_results[0]
-                reference = {
-                    "index": str(reference_index),
-                    "id": search_result["id"],
-                    "content": search_result["content"],
-                    "file_path": search_result.get("file_path", ""),
-                    "title": search_result.get("title", ""),
-                    "distance": search_result.get("distance", DISTANCE_THRESHOLD),
-                    "page_no": search_result.get("page_no", ""),
-                    "referrence": search_result.get("referrence", "")
-                }
-                all_references.append(reference)
-                result_sentences.append({
+                reference, _ = _process_search_result(search_result, 1)
+                prop_result["references"] = [reference]
+                prop_result["answer"] = [{
                     "sentence": prop_value,
-                    "flag": str(reference_index)
-                })
+                    "flag": "1"
+                }]
             else:
                 # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
                 sentences = TextSplitter.split_text(prop_value)
-                i = 0
-                while i < len(sentences):
-                    original_sentence = sentences[i]
-                    sentence = original_sentence
-                    
-                    # 如果当前句子长度小于10且不是最后一句,则与下一句合并
-                    if len(sentence) < 10 and i + 1 < len(sentences):
-                        next_sentence = sentences[i + 1]
-                        combined_sentence = sentence + " " + next_sentence
-                        # 添加原短句到结果,flag为空
-                        result_sentences.append({
-                            "sentence": sentence,
-                            "flag": ""
-                        })
-                        # 使用合并后的句子进行搜索
-                        search_text = f"{node_name}:{prop_title}:{combined_sentence}"
-                        i += 1  # 跳过下一句,因为已经合并使用
-                    elif len(sentence) < 10:
-                        # 如果是最后一句且长度小于10,直接添加到结果,flag为空
-                        result_sentences.append({
-                            "sentence": sentence,
-                            "flag": ""
-                        })
-                        i += 1
-                        continue
-                    else:
-                        # 句子长度足够,直接使用
-                        search_text = f"{node_name}:{prop_title}:{sentence}"
-                    
-                    i += 1
-
-                # 进行向量搜索
-                search_results = trunks_service.search_by_vector(
-                        text=search_text,
-                        limit=1,
-                        type='trunk'
+                result_sentences, references = _process_sentence_search(
+                    node_name, prop_title, sentences, trunks_service
                 )
+                if references:
+                    prop_result["references"] = references
+                if result_sentences:
+                    prop_result["answer"] = result_sentences
 
-                # 处理搜索结果
-                if not search_results:
-                    # 没有搜索结果,添加原句子,flag为空
-                    result_sentences.append({
-                        "sentence": sentence,
-                        "flag": ""
-                    })
-                    continue
-                    
-                for search_result in search_results:
-                    distance = search_result.get("distance", DISTANCE_THRESHOLD)
-                    if distance >= DISTANCE_THRESHOLD:
-                        # 距离过大,添加原句子,flag为空
-                        result_sentences.append({
-                            "sentence": sentence,
-                            "flag": ""
-                        })
-                        continue
-
-                    # 检查是否已存在相同引用
-                    existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
-                    current_index = reference_index
-                    if existing_ref:
-                        current_index = int(existing_ref["index"])
-                    else:
-                        # 添加到引用列表
-                        reference = {
-                            "index": str(reference_index),
-                            "id": search_result["id"],
-                            "content": search_result["content"],
-                            "file_path": search_result.get("file_path", ""),
-                            "title": search_result.get("title", ""),
-                            "distance": distance,
-                            "page_no": search_result.get("page_no", ""),
-                            "referrence": search_result.get("referrence", "")
-                        }
-                        
-                        all_references.append(reference)
-                        reference_index += 1
-
-                    # 添加句子和引用标记(作为单独的flag字段)
-                    result_sentences.append({
-                        "sentence": sentence,
-                        "flag": str(current_index)
-                    })
-
-            # 更新属性值,添加引用信息
-            if all_references:
-                prop_result["references"] = all_references
-
-            # 将处理后的句子添加到结果中
-            if result_sentences:
-                prop_result["answer"] = result_sentences
-
-        # 处理所有引用中的文件信息
+        # 处理文件信息
         all_files = set()
         file_index_map = {}
         file_index = 1
-        
-        # 第一次遍历收集所有文件信息
+
+        # 收集文件信息
         for prop_result in result["props"]:
-            if "references" in prop_result:
-                for ref in prop_result["references"]:
-                    referrence = ref.get("referrence", "")
-                    if referrence and "/books/" in referrence:
-                        # 提取/books/后面的文件名
-                        file_name = referrence.split("/books/")[-1]
-                        if file_name:
-                            # 根据文件名后缀确定文件类型
-                            file_type = ""
-                            if file_name.lower().endswith(".pdf"):
-                                file_type = "pdf"
-                            elif file_name.lower().endswith(".doc") or file_name.lower().endswith(".docx"):
-                                file_type = "doc"
-                            elif file_name.lower().endswith(".xls") or file_name.lower().endswith(".xlsx"):
-                                file_type = "excel"
-                            elif file_name.lower().endswith(".ppt") or file_name.lower().endswith(".pptx"):
-                                file_type = "ppt"
-                            else:
-                                file_type = "other"
-                            
-                            if file_name not in file_index_map:
-                                file_index_map[file_name] = file_index
-                                file_index += 1
-                            all_files.add((file_name, file_type))
-        
-        # 第二次遍历更新引用的index
+            if "references" not in prop_result:
+                continue
+            for ref in prop_result["references"]:
+                referrence = ref.get("referrence", "")
+                if not (referrence and "/books/" in referrence):
+                    continue
+                file_name = referrence.split("/books/")[-1]
+                if not file_name:
+                    continue
+                file_type = _get_file_type(file_name)
+                if file_name not in file_index_map:
+                    file_index_map[file_name] = file_index
+                    file_index += 1
+                all_files.add((file_name, file_type))
+
+        # 更新引用索引
         for prop_result in result["props"]:
-            if "references" in prop_result:
-                for ref in prop_result["references"]:
-                    referrence = ref.get("referrence", "")
-                    if referrence and "/books/" in referrence:
-                        file_name = referrence.split("/books/")[-1]
-                        if file_name in file_index_map:
-                            # 更新reference的index为"文件索引-原索引"
-                            ref["index"] = f"{file_index_map[file_name]}-{ref['index']}"
-            
+            if "references" not in prop_result:
+                continue
+            for ref in prop_result["references"]:
+                referrence = ref.get("referrence", "")
+                if referrence and "/books/" in referrence:
+                    file_name = referrence.split("/books/")[-1]
+                    if file_name in file_index_map:
+                        ref["index"] = f"{file_index_map[file_name]}-{ref['index']}"
+
             # 更新answer中的flag
             if "answer" in prop_result:
                 for sentence in prop_result["answer"]:
@@ -562,20 +551,21 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                             if ref["index"].endswith(f"-{sentence['flag']}"):
                                 sentence["flag"] = ref["index"]
                                 break
-        
-        # 将文件信息添加到结果中
+
+        # 添加文件信息到结果
         result["files"] = sorted([{
             "file_name": file_name,
             "file_type": file_type,
             "index": str(file_index_map[file_name])
         } for file_name, file_type in all_files], key=lambda x: int(x["index"]))
-        
+
         end_time = time.time()
         logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
-        
-        # 将结果存入缓存
+
+        # 缓存结果
+        cache_key = f"xunzheng_{request.node_id}"
         cache[cache_key] = result
-        
+
         return StandardResponse(success=True, data=result)
     except Exception as e:
         logger.error(f"Node props search failed: {str(e)}")