yuchengwei пре 4 дана
родитељ
комит
f3b464e7f8
2 измењених фајлова са 54 додато и 34 уклоњено
  1. 3 3
      src/knowledge/main.py
  2. 51 31
      src/knowledge/service/search_service.py

+ 3 - 3
src/knowledge/main.py

@@ -161,15 +161,15 @@ async def get_symptom_diseases(
     }"""
     }"""
           )
           )
 async def get_disease_symptoms(
 async def get_disease_symptoms(
-        disease_id: str = Query(...,
+        disease_name: str = Query(...,
                   description="疾病ID,应为标准医学术语ID。该参数用于查询该疾病的所有症状及其严重程度权重。",
                   description="疾病ID,应为标准医学术语ID。该参数用于查询该疾病的所有症状及其严重程度权重。",
-                  examples="D001",
+                  examples="艾滋病",
                   min_length=1)
                   min_length=1)
 ):
 ):
     try:
     try:
         # 实现获取疾病症状的逻辑
         # 实现获取疾病症状的逻辑
         search = SearchBusiness()
         search = SearchBusiness()
-        results = search.get_disease_symptoms(disease_id)
+        results = search.get_disease_symptoms(disease_name)
         return StandardResponse(success=True, data=results)
         return StandardResponse(success=True, data=results)
     except Exception as e:
     except Exception as e:
         logger.error(f"获取疾病症状失败: {str(e)}")
         logger.error(f"获取疾病症状失败: {str(e)}")

+ 51 - 31
src/knowledge/service/search_service.py

@@ -2,7 +2,7 @@ import sys,os
 current_path = os.getcwd()
 current_path = os.getcwd()
 sys.path.append(current_path)
 sys.path.append(current_path)
 import logging
 import logging
-logger = logging.getLogger(__name__)
+from py_tools.logging import logger
 import json
 import json
 from ..config.site import SiteConfig
 from ..config.site import SiteConfig
 from elasticsearch import Elasticsearch, helpers, exceptions
 from elasticsearch import Elasticsearch, helpers, exceptions
@@ -192,43 +192,63 @@ class SearchBusiness:
         print("filtered_diseases:",json.dumps(filtered_diseases, ensure_ascii=False, indent=4))
         print("filtered_diseases:",json.dumps(filtered_diseases, ensure_ascii=False, indent=4))
         return filtered_diseases
         return filtered_diseases
 
 
-    def get_disease_symptoms(self, src_id):
+    def get_disease_symptoms(self, disease_name):
         """
         """
-        查询疾病的所有症状及其严重程度权重
-        :param src_id: 疾病节点ID
-        :return: {"symptoms": [症状列表], "severity_weights": {症状:权重}}
+        查询疾病的所有症状
+        :param disease_name: 疾病名称
+        :return: [{"disease": 疾病名称, "symptoms": [症状列表]}]
+                  (同名疾病会被合并,症状列表会去重,空症状列表会被过滤)
         """
         """
         try:
         try:
-            # 1. 查询关系表获取所有症状关系
-            edges = self.search_edges(name="症状", src_id=src_id)
-            if not edges:
-                return {"symptoms": [], "severity_weights": {}}
-
-            # 2. 收集所有症状节点ID
-            dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges if "public_kg_edges_dest_id" in edge]
-            if not dest_ids:
-                return {"symptoms": [], "severity_weights": {}}
-
-            # 3. 批量查询症状节点信息
-            nodes = self.search_nodes(ids=dest_ids)
-            if not nodes:
-                return {"symptoms": [], "severity_weights": {}}
-
-            # 4. 构建返回结果
-            symptoms = [node["public_kg_nodes_name"] for node in nodes if "public_kg_nodes_name" in node]
-
-            # 假设权重存储在节点的properties字段中,如果没有则使用默认权重
-            severity_weights = {
-                node["public_kg_nodes_name"]: float(node.get("severity", 0.8))
-                for node in nodes
-                if "public_kg_nodes_name" in node
-            }
+            # 1. 根据疾病名称查询疾病列表
+            diseases = self.search_nodes(name=disease_name, type="疾病")
+            logger.info(f"Get diseases by name: {disease_name}, result: {diseases}")
+            if not diseases:
+                return []
+
+            disease_map = {}
+
+            for disease in diseases:
+                disease_name = disease.get("public_kg_nodes_name", "")
 
 
-            return {"symptoms": symptoms, "severity_weights": severity_weights}
+                # 如果疾病名称已存在,则获取已有记录,否则创建新记录
+                if disease_name not in disease_map:
+                    disease_map[disease_name] = {
+                        "disease": disease_name,
+                        "symptoms": []
+                    }
+
+                disease_info = disease_map[disease_name]
+
+                # 2. 根据疾病ID查找症状关系
+                edges = self.search_edges(name="相关症状", src_id=disease["public_kg_nodes_id"])
+                logger.info(f"Get symptoms by disease id: {disease['public_kg_nodes_id']}, result: {edges}")
+                if not edges:
+                    continue
+
+                # 3. 收集症状节点ID
+                symptom_ids = [edge["public_kg_edges_dest_id"] for edge in edges if "public_kg_edges_dest_id" in edge]
+                if not symptom_ids:
+                    continue
+
+                # 4. 查询症状节点信息
+                symptoms = self.search_nodes(ids=symptom_ids)
+                if not symptoms:
+                    continue
+
+                # 5. 收集症状名称
+                for symptom in symptoms:
+                    if "public_kg_nodes_name" in symptom:
+                        symptom_name = symptom["public_kg_nodes_name"]
+                        if symptom_name not in disease_info["symptoms"]:
+                            disease_info["symptoms"].append(symptom_name)
+
+            # 过滤掉症状列表为空的疾病
+            return [d for d in disease_map.values() if d["symptoms"]]
 
 
         except Exception as e:
         except Exception as e:
             logger.error(f"Get disease symptoms error: {e}")
             logger.error(f"Get disease symptoms error: {e}")
-            return {"symptoms": [], "severity_weights": {}}
+            return []
 
 
     def search_concept(self, name, type):
     def search_concept(self, name, type):
         nodes = self.search_nodes(name=name, type=type)
         nodes = self.search_nodes(name=name, type=type)