浏览代码

代码提交

SGTY 4 天之前
父节点
当前提交
aaa0fffc4d
共有 2 个文件被更改,包括 92 次插入81 次删除
  1. 1 1
      src/knowledge/router/medical_knowledge_api.py
  2. 91 80
      src/knowledge/service/search_service.py

+ 1 - 1
src/knowledge/router/medical_knowledge_api.py

@@ -31,7 +31,7 @@ async def get_symptom_diseases(
         results = search.get_symptom_diseases(request.symptoms)    
         return StandardResponse(success=True, data=results)
     except Exception as e:
-        logger.error(f"获取症状相关疾病失败: {str(e)}")
+        logger.exception(f"获取症状相关疾病失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 @router.post("/disease_symptoms", response_model=StandardResponse)

+ 91 - 80
src/knowledge/service/search_service.py

@@ -3,7 +3,7 @@ current_path = os.getcwd()
 sys.path.append(current_path)
 import logging
 logger = logging.getLogger(__name__)
-
+import json
 from ..config.site import SiteConfig
 from elasticsearch import Elasticsearch, helpers, exceptions
 config = SiteConfig()
@@ -12,6 +12,8 @@ ELASTICSEARCH_PWD = config.get_config("ELASTICSEARCH_PWD", "/tmp")
 ELASTICSEARCH_HOST = config.get_config("ELASTICSEARCH_HOST", "/tmp")
 
 class SearchBusiness:
+    MIN_SCORE = 5
+
     def __init__(self):
         self.es = Elasticsearch(hosts=[ELASTICSEARCH_HOST], verify_certs=False, http_auth=(ELASTICSEARCH_USER, ELASTICSEARCH_PWD))
         pass
@@ -42,7 +44,8 @@ class SearchBusiness:
                         ]
                     }
                 },
-                "sort": [{"_score": {"order": "desc"}}]
+                "sort": [{"_score": {"order": "desc"}}],
+                "min_score": self.MIN_SCORE
             }
 
             if name:
@@ -80,7 +83,8 @@ class SearchBusiness:
                         ]
                     }
                 },
-                "sort": [{"_score": {"order": "desc"}}]
+                "sort": [{"_score": {"order": "desc"}}],
+                "min_score": self.MIN_SCORE
             }
             if name:
                 query["query"]["bool"]["must"].append({"match": {"public_kg_edges_category": name}})
@@ -98,87 +102,94 @@ class SearchBusiness:
             logger.error(f"Search error: {e}")
             return None
 
+    def _process_symptom_to_diseases(self, symptom_name):
+        """
+        处理单个症状到疾病的转换逻辑
+        :param symptom_name: 症状名称
+        :return: 疾病节点列表
+        """
+        # 存储已处理的疾病ID,避免同义词重复计算
+        processed_disease_ids = set()
+        simplified_diseases = []
+        # 获取症状节点ID
+        symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
+        if not symptom_nodes:
+            return simplified_diseases
+
+        for symptom_node in symptom_nodes:
+            symptom_id = symptom_node.get("public_kg_nodes_id")
+            if not symptom_id:
+                continue
+
+            # 查询该症状的'症状同义词'关系获取同义词节点
+            # synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
+            synonym_ids = {symptom_id}  # 包含原始症状ID
+            # if synonym_edges:
+            #     for edge in synonym_edges:
+            #         synonym_id = edge.get("public_kg_edges_dest_id")
+            #         if synonym_id:
+            #             synonym_ids.add(synonym_id)
+
+            # 对每个症状ID(包括同义词)查询'常见疾病'关系
+            for symptom_id in synonym_ids:
+                disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
+                if not disease_edges:
+                    continue
+
+                # 收集所有疾病节点
+                for edge in disease_edges:
+                    disease_id = edge.get("public_kg_edges_dest_id")
+                    if disease_id and disease_id not in processed_disease_ids:
+                        processed_disease_ids.add(disease_id)
+                        disease_nodes = self.search_nodes(id=disease_id, type="疾病")
+                        if disease_nodes:
+                            # 只返回需要的字段并重命名
+                            for disease in disease_nodes:
+                                simplified_diseases.append({
+                                    "disease_id": disease.get("public_kg_nodes_id"),
+                                    "disease_name": disease.get("public_kg_nodes_name"),
+                                    "score": disease.get("score", 0)
+                                })
+
+        return simplified_diseases
+
     def get_symptom_diseases(self, symptom_names):
         """
         根据症状名称列表获取相关疾病列表
         :param symptom_names: 症状名称列表
-        :return: {
-            "diseases": [{"id": str, "name": str, "score": float}],
-            "reasoning": str
-        }
+        :return: 字典,键为症状名称,值为对应的疾病列表
         """
+        symptom_disease_map = {}
         disease_dict = {}
-        matched_symptoms = set()
 
         # 对症状名称去重
         unique_symptoms = list(set(symptom_names))
         threshold = len(unique_symptoms) / 2
 
-        # 存储已处理的疾病ID,避免同义词重复计算
-        processed_disease_ids = set()
-
         for symptom_name in unique_symptoms:
-            # 获取症状节点ID
-            symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
-            if not symptom_nodes:
-                continue
-
-            for symptom_node in symptom_nodes:
-                symptom_id = symptom_node.get("public_kg_nodes_id")
-                if not symptom_id:
-                    continue
-
-                # 查询该症状的'症状同义词'关系获取同义词节点
-                synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
-                synonym_ids = {symptom_id}  # 包含原始症状ID
-                if synonym_edges:
-                    for edge in synonym_edges:
-                        synonym_id = edge.get("public_kg_edges_dest_id")
-                        if synonym_id:
-                            synonym_ids.add(synonym_id)
-
-                # 对每个症状ID(包括同义词)查询'常见疾病'关系
-                for symptom_id in synonym_ids:
-                    disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
-                    if not disease_edges:
-                        continue
-
-                    # 收集所有疾病节点
-                    for edge in disease_edges:
-                        disease_id = edge.get("public_kg_edges_dest_id")
-                        if disease_id and disease_id not in processed_disease_ids:
-                            processed_disease_ids.add(disease_id)
-                            disease_nodes = self.search_nodes(id=disease_id, type="疾病")
-                            if disease_nodes:
-                                for node in disease_nodes:
-                                    node_id = node.get("public_kg_nodes_id")
-                                    if node_id in disease_dict:
-                                        disease_dict[node_id]["count"] += 1
-                                    else:
-                                        node["count"] = 1
-                                        disease_dict[node_id] = node
-                                    matched_symptoms.add(symptom_name)
-
+            diseases = self._process_symptom_to_diseases(symptom_name)
+            if diseases:
+                symptom_disease_map[symptom_name] = diseases
+        #json打印symptom_disease_map
+        print("symptom_disease_map:",json.dumps(symptom_disease_map, ensure_ascii=False, indent=4))
+        #统计每个疾病出现的次数
+        for symptom, diseases in symptom_disease_map.items():
+            for disease in diseases:
+                disease_id = disease.get("disease_id")
+                if disease_id in disease_dict:
+                    disease_dict[disease_id]["count"] += 1
+                else:
+                    disease_dict[disease_id] = {
+                        "disease_id": disease_id,
+                        "disease_name": disease.get("disease_name"),
+                        "score": disease.get("score", 0),
+                        "count": 1
+                    }
         # 按count降序排列并过滤掉count小于阈值的疾病
         sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
         filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
-        
-        # 格式化返回结果
-        diseases = [
-            {
-                "id": disease["public_kg_nodes_id"],
-                "name": disease["public_kg_nodes_name"],
-                "score": min(1.0, disease["count"] / len(unique_symptoms))
-            }
-            for disease in filtered_diseases
-        ]
-        
-        reasoning = "、".join(matched_symptoms) + "是这些疾病的常见症状" if matched_symptoms else ""
-        
-        return {
-            "diseases": diseases,
-            "reasoning": reasoning
-        }
+
+        return filtered_diseases
 
     def get_disease_symptoms(self, src_id):
         """
@@ -204,14 +215,14 @@ class SearchBusiness:
 
             # 4. 构建返回结果
             symptoms = [node["public_kg_nodes_name"] for node in nodes if "public_kg_nodes_name" in node]
-            
+
             # 假设权重存储在节点的properties字段中,如果没有则使用默认权重
             severity_weights = {
-                node["public_kg_nodes_name"]: float(node.get("severity", 0.8)) 
-                for node in nodes 
+                node["public_kg_nodes_name"]: float(node.get("severity", 0.8))
+                for node in nodes
                 if "public_kg_nodes_name" in node
             }
-            
+
             return {"symptoms": symptoms, "severity_weights": severity_weights}
 
         except Exception as e:
@@ -230,11 +241,11 @@ class SearchBusiness:
         edges = self.search_edges(src_id=src_id, size=2000)
         if not edges:
             return {"relations": []}
-            
+
         dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges]
         nodes = self.search_nodes(ids=dest_ids)
         node_map = {node["public_kg_nodes_id"]: node for node in nodes}
-        
+
         relations = []
         for edge in edges:
             dest_id = edge["public_kg_edges_dest_id"]
@@ -244,7 +255,7 @@ class SearchBusiness:
                     "target_id": dest_id,
                     "target_name": node_map[dest_id].get("public_kg_nodes_name", "")
                 })
-                
+
         return {"relations": relations}
 
     def get_similar_concepts(self, id, top_k=5):
@@ -252,11 +263,11 @@ class SearchBusiness:
         node = self.search_nodes(id=id, limit=1)
         if not node:
             return {"similar_concepts": []}
-            
+
         # 2. 使用节点名称进行相似搜索
         node_name = node[0]["public_kg_nodes_name"]
         results = self.search_nodes(name=node_name, limit=top_k)
-        
+
         # 3. 格式化返回结果
         similar_concepts = [
             {
@@ -266,9 +277,9 @@ class SearchBusiness:
             }
             for result in results
         ]
-        
+
         return {"similar_concepts": similar_concepts}
-    
+
 
 if __name__ == "__main__":
     search_biz = SearchBusiness()