Quellcode durchsuchen

智能体查询接口2

yuchengwei vor 4 Tagen
Ursprung
Commit
284b937ae9
2 geänderte Dateien mit 115 neuen und 32 gelöschten Zeilen
  1. 10 8
      src/knowledge/router/medical_knowledge_api.py
  2. 105 24
      src/knowledge/service/search_service.py

+ 10 - 8
src/knowledge/router/medical_knowledge_api.py

@@ -23,14 +23,13 @@ class DiseaseInfoRequest(BaseModel):
 
 @router.post("/symptom_diseases", response_model=StandardResponse)
 async def get_symptom_diseases(
-    request: SymptomDiseasesRequest,
-    db: Session = Depends(get_db)
+    request: SymptomDiseasesRequest
 ):
     try:
         # 实现获取症状相关疾病的逻辑
         search = SearchBusiness()
         results = search.get_symptom_diseases(request.symptoms)    
-        return StandardResponse(success=True, data=[])
+        return StandardResponse(success=True, data=results)
     except Exception as e:
         logger.error(f"获取症状相关疾病失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
@@ -41,7 +40,9 @@ async def get_disease_symptoms(
 ):
     try:
         # 实现获取疾病症状的逻辑
-        return StandardResponse(success=True, data=[])
+        search = SearchBusiness()
+        results = search.get_disease_symptoms(request.disease_id)    
+        return StandardResponse(success=True, data=results)
     except Exception as e:
         logger.error(f"获取疾病症状失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
@@ -117,7 +118,7 @@ async def search_concept(
     try:
         # 实现搜索医学概念的逻辑
         search = SearchBusiness()
-        results = search.search_nodes(name=request.query, type=request.type,id=None)
+        results = search.search_concept(name=request.query, type=request.type)
         return StandardResponse(success=True, data=results)
     except Exception as e:
         logger.error(f"搜索医学概念失败: {str(e)}")
@@ -130,7 +131,7 @@ async def get_relations(
     try:
         # 实现获取概念关系的逻辑
         search = SearchBusiness()
-        results = search.search_edges(name=None,src_id=request.concept_id)
+        results = search.get_relations(src_id=request.concept_id)
         return StandardResponse(success=True, data=results)
     except Exception as e:
         logger.error(f"获取概念关系失败: {str(e)}")
@@ -141,9 +142,10 @@ async def get_similar_concepts(
     request: SimilarConceptsRequest
 ):
     try:
-        # 实现获取相似概念的逻辑
+        # 1. 先根据ID获取节点名称
         search = SearchBusiness()
-        results = search.search_nodes(name=None,type=None,id=request.concept_id, limit=request.top_k)
+        results = search.get_similar_concepts(id=request.concept_id, top_k=request.top_k)
+        
         return StandardResponse(success=True, data=results)
     except Exception as e:
         logger.error(f"获取相似概念失败: {str(e)}")

+ 105 - 24
src/knowledge/service/search_service.py

@@ -18,7 +18,7 @@ class SearchBusiness:
     def search_nodes_and_edges(self, index, query):
         try:
             response = self.es.search(index=index, body=query)
-           
+
             hits = response["hits"]["hits"]
             results = []
             for hit in hits:
@@ -31,8 +31,7 @@ class SearchBusiness:
             logger.error(f"Index '{index}' not found: {e}")
             return []
 
-    def search_nodes(self,name,type,id,limit=10,from_=0):
-
+    def search_nodes(self, name=None, type=None, id=None, ids=None, limit=10, from_=0):
         try:
             query = {
                 "explain": "true",
@@ -54,6 +53,8 @@ class SearchBusiness:
                 query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_name": name}})
             if id:
                 query["query"]["bool"]["must"].append({"term": {"public_kg_nodes_id": id}})
+            if ids:
+                query["query"]["bool"]["must"].append({"terms": {"public_kg_nodes_id": ids}})
             if type:
                 query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": type}})
 
@@ -96,9 +97,13 @@ class SearchBusiness:
         """
         根据症状名称列表获取相关疾病列表
         :param symptom_names: 症状名称列表
-        :return: 疾病节点列表(按命中症状次数降序排列)
+        :return: {
+            "diseases": [{"id": str, "name": str, "score": float}],
+            "reasoning": str
+        }
         """
         disease_dict = {}
+        matched_symptoms = set()
 
         # 对症状名称去重
         unique_symptoms = list(set(symptom_names))
@@ -147,42 +152,118 @@ class SearchBusiness:
                                     else:
                                         node["count"] = 1
                                         disease_dict[node_id] = node
+                                    matched_symptoms.add(symptom_name)
 
         # 按count降序排列并过滤掉count小于阈值的疾病
         sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
         filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
-        return filtered_diseases
+        
+        # 格式化返回结果
+        diseases = [
+            {
+                "id": disease["public_kg_nodes_id"],
+                "name": disease["public_kg_nodes_name"],
+                "score": min(1.0, disease["count"] / len(unique_symptoms))
+            }
+            for disease in filtered_diseases
+        ]
+        
+        reasoning = "、".join(matched_symptoms) + "是这些疾病的常见症状" if matched_symptoms else ""
+        
+        return {
+            "diseases": diseases,
+            "reasoning": reasoning
+        }
 
-    def search_related_nodes(self, src_id, relation_name):
+    def get_disease_symptoms(self, src_id):
         """
-        查询与指定节点有特定关系的所有节点
-        :param src_id: 源节点ID
-        :param relation_name: 关系名称
-        :return: 相关节点列表
+        查询疾病的所有症状及其严重程度权重
+        :param src_id: 疾病节点ID
+        :return: {"symptoms": [症状列表], "severity_weights": {症状:权重}}
         """
         try:
-            # 1. 查询关系表获取所有dest_id
-            edges = self.search_edges(name=relation_name, src_id=src_id)
+            # 1. 查询关系表获取所有症状关系
+            edges = self.search_edges(name="症状", src_id=src_id)
             if not edges:
-                return []
+                return {"symptoms": [], "severity_weights": {}}
 
-            # 2. 收集所有目标节点ID
+            # 2. 收集所有症状节点ID
             dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges if "public_kg_edges_dest_id" in edge]
             if not dest_ids:
-                return []
+                return {"symptoms": [], "severity_weights": {}}
 
-            # 3. 查询节点表获取所有目标节点信息
-            nodes = []
-            for dest_id in dest_ids:
-                node_results = self.search_nodes(id=dest_id)
-                if node_results:
-                    nodes.extend(node_results)
+            # 3. 批量查询症状节点信息
+            nodes = self.search_nodes(ids=dest_ids)
+            if not nodes:
+                return {"symptoms": [], "severity_weights": {}}
 
-            return nodes
+            # 4. 构建返回结果
+            symptoms = [node["public_kg_nodes_name"] for node in nodes if "public_kg_nodes_name" in node]
+            
+            # 假设权重存储在节点的properties字段中,如果没有则使用默认权重
+            severity_weights = {
+                node["public_kg_nodes_name"]: float(node.get("severity", 0.8)) 
+                for node in nodes 
+                if "public_kg_nodes_name" in node
+            }
+            
+            return {"symptoms": symptoms, "severity_weights": severity_weights}
 
         except Exception as e:
-            logger.error(f"Search related nodes error: {e}")
-            return None
+            logger.error(f"Get disease symptoms error: {e}")
+            return {"symptoms": [], "severity_weights": {}}
+
+    def search_concept(self, name, type):
+        nodes = self.search_nodes(name=name, type=type)
+        concepts = [
+            {"id": node["public_kg_nodes_id"], "name": node["public_kg_nodes_name"], "type": node["public_kg_nodes_category"]}
+            for node in nodes
+        ] if nodes else []
+        return {"concepts": concepts}
+
+    def get_relations(self, src_id):
+        edges = self.search_edges(name=None, src_id=src_id)
+        if not edges:
+            return {"relations": []}
+            
+        dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges]
+        nodes = self.search_nodes(ids=dest_ids)
+        node_map = {node["public_kg_nodes_id"]: node for node in nodes}
+        
+        relations = []
+        for edge in edges:
+            dest_id = edge["public_kg_edges_dest_id"]
+            if dest_id in node_map:
+                relations.append({
+                    "relation_type": edge["public_kg_edges_category"],
+                    "target_id": dest_id,
+                    "target_name": node_map[dest_id].get("public_kg_nodes_name", "")
+                })
+                
+        return {"relations": relations}
+
+    def get_similar_concepts(self, id, top_k=5):
+        # 1. 先根据ID获取节点名称
+        node = self.search_nodes(id=id, limit=1)
+        if not node:
+            return {"similar_concepts": []}
+            
+        # 2. 使用节点名称进行相似搜索
+        node_name = node[0]["public_kg_nodes_name"]
+        results = self.search_nodes(name=node_name, limit=top_k)
+        
+        # 3. 格式化返回结果
+        similar_concepts = [
+            {
+                "id": result["public_kg_nodes_id"],
+                "name": result["public_kg_nodes_name"],
+                "similarity": result.get("_score", 0.0)
+            }
+            for result in results
+        ]
+        
+        return {"similar_concepts": similar_concepts}
+    
 
 if __name__ == "__main__":
     search_biz = SearchBusiness()