SGTY 4 dni temu
rodzic
commit
f6996ffba6

+ 4 - 4
src/knowledge/main.py

@@ -1,13 +1,13 @@
 # 导入FastAPI及相关模块
 import uvicorn
-#from fastapi_mcp import FastApiMCP
+from fastapi_mcp import FastApiMCP
 from py_tools.logging import logger
 
 from .settings import base_setting
 from .server import app
-# mcp = FastApiMCP(app)
-# mcp.mount()
-# mcp.setup_server()
+mcp = FastApiMCP(app)
+mcp.mount()
+mcp.setup_server()
 
 def main():
     logger.info(f"project run {base_setting.server_host}:{base_setting.server_port}")

+ 2 - 0
src/knowledge/router/medical_knowledge_api.py

@@ -28,6 +28,8 @@ async def get_symptom_diseases(
 ):
     try:
         # 实现获取症状相关疾病的逻辑
+        search = SearchBusiness()
+        results = search.get_symptom_diseases(request.symptoms)    
         return StandardResponse(success=True, data=[])
     except Exception as e:
         logger.error(f"获取症状相关疾病失败: {str(e)}")

+ 62 - 2
src/knowledge/service/search_service.py

@@ -4,8 +4,7 @@ sys.path.append(current_path)
 import logging
 logger = logging.getLogger(__name__)
 
-from agent.libs.schema import SchemaContent,SchemaData,SchemaDataItem
-from config.site import SiteConfig
+from ..config.site import SiteConfig
 from elasticsearch import Elasticsearch, helpers, exceptions
 config = SiteConfig()
 ELASTICSEARCH_USER = config.get_config("ELASTICSEARCH_USER", "/tmp")
@@ -92,6 +91,67 @@ class SearchBusiness:
         except Exception as e:
             logger.error(f"Search error: {e}")
             return None
+            
+    def get_symptom_diseases(self, symptom_names):
+        """
+        根据症状名称列表获取相关疾病列表
+        :param symptom_names: 症状名称列表
+        :return: 疾病节点列表(按命中症状次数降序排列)
+        """
+        disease_dict = {}
+        
+        # 对症状名称去重
+        unique_symptoms = list(set(symptom_names))
+        threshold = len(unique_symptoms) / 2
+        
+        # 存储已处理的疾病ID,避免同义词重复计算
+        processed_disease_ids = set()
+        
+        for symptom_name in unique_symptoms:
+            # 获取症状节点ID
+            symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
+            if not symptom_nodes:
+                continue
+                
+            for symptom_node in symptom_nodes:
+                symptom_id = symptom_node.get("public_kg_nodes_id")
+                if not symptom_id:
+                    continue
+                    
+                # 查询该症状的'症状同义词'关系获取同义词节点
+                synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
+                synonym_ids = {symptom_id}  # 包含原始症状ID
+                if synonym_edges:
+                    for edge in synonym_edges:
+                        synonym_id = edge.get("public_kg_edges_dest_id")
+                        if synonym_id:
+                            synonym_ids.add(synonym_id)
+                
+                # 对每个症状ID(包括同义词)查询'常见疾病'关系
+                for symptom_id in synonym_ids:
+                    disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
+                    if not disease_edges:
+                        continue
+                        
+                    # 收集所有疾病节点
+                    for edge in disease_edges:
+                        disease_id = edge.get("public_kg_edges_dest_id")
+                        if disease_id and disease_id not in processed_disease_ids:
+                            processed_disease_ids.add(disease_id)
+                            disease_nodes = self.search_nodes(id=disease_id, type="疾病")
+                            if disease_nodes:
+                                for node in disease_nodes:
+                                    node_id = node.get("public_kg_nodes_id")
+                                    if node_id in disease_dict:
+                                        disease_dict[node_id]["count"] += 1
+                                    else:
+                                        node["count"] = 1
+                                        disease_dict[node_id] = node
+                            
+        # 按count降序排列并过滤掉count小于阈值的疾病
+        sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
+        filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
+        return filtered_diseases
 
 if __name__ == "__main__":
     search_biz = SearchBusiness()