|
@@ -51,13 +51,13 @@ class SearchBusiness:
|
|
query["from"] = from_
|
|
query["from"] = from_
|
|
|
|
|
|
if name:
|
|
if name:
|
|
- query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": name}})
|
|
|
|
|
|
+ query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_name": name}})
|
|
if id:
|
|
if id:
|
|
- query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_id": id}})
|
|
|
|
|
|
+ query["query"]["bool"]["must"].append({"term": {"public_kg_nodes_id": id}})
|
|
if type:
|
|
if type:
|
|
query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": type}})
|
|
query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": type}})
|
|
|
|
|
|
- results = self.search_nodes_and_edges(index="connector-postgres-test", query=query)
|
|
|
|
|
|
+ results = self.search_nodes_and_edges(index="connector-postgresql-all", query=query)
|
|
return results
|
|
return results
|
|
except exceptions.NotFoundError as e:
|
|
except exceptions.NotFoundError as e:
|
|
logger.error(f"Index not found: {e}")
|
|
logger.error(f"Index not found: {e}")
|
|
@@ -81,9 +81,9 @@ class SearchBusiness:
|
|
if name:
|
|
if name:
|
|
query["query"]["bool"]["must"].append({"match": {"public_kg_edges_category": name}})
|
|
query["query"]["bool"]["must"].append({"match": {"public_kg_edges_category": name}})
|
|
if src_id:
|
|
if src_id:
|
|
- query["query"]["bool"]["must"].append({"match": {"public_kg_edges_src_id": src_id}})
|
|
|
|
|
|
+ query["query"]["bool"]["must"].append({"term": {"public_kg_edges_src_id": src_id}})
|
|
|
|
|
|
- results = self.search_nodes_and_edges(index="connector-postgres-test", query=query)
|
|
|
|
|
|
+ results = self.search_nodes_and_edges(index="connector-postgresql-all", query=query)
|
|
return results
|
|
return results
|
|
except exceptions.NotFoundError as e:
|
|
except exceptions.NotFoundError as e:
|
|
logger.error(f"Index not found: {e}")
|
|
logger.error(f"Index not found: {e}")
|
|
@@ -91,7 +91,7 @@ class SearchBusiness:
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logger.error(f"Search error: {e}")
|
|
logger.error(f"Search error: {e}")
|
|
return None
|
|
return None
|
|
-
|
|
|
|
|
|
+
|
|
def get_symptom_diseases(self, symptom_names):
|
|
def get_symptom_diseases(self, symptom_names):
|
|
"""
|
|
"""
|
|
根据症状名称列表获取相关疾病列表
|
|
根据症状名称列表获取相关疾病列表
|
|
@@ -99,25 +99,25 @@ class SearchBusiness:
|
|
:return: 疾病节点列表(按命中症状次数降序排列)
|
|
:return: 疾病节点列表(按命中症状次数降序排列)
|
|
"""
|
|
"""
|
|
disease_dict = {}
|
|
disease_dict = {}
|
|
-
|
|
|
|
|
|
+
|
|
# 对症状名称去重
|
|
# 对症状名称去重
|
|
unique_symptoms = list(set(symptom_names))
|
|
unique_symptoms = list(set(symptom_names))
|
|
threshold = len(unique_symptoms) / 2
|
|
threshold = len(unique_symptoms) / 2
|
|
-
|
|
|
|
|
|
+
|
|
# 存储已处理的疾病ID,避免同义词重复计算
|
|
# 存储已处理的疾病ID,避免同义词重复计算
|
|
processed_disease_ids = set()
|
|
processed_disease_ids = set()
|
|
-
|
|
|
|
|
|
+
|
|
for symptom_name in unique_symptoms:
|
|
for symptom_name in unique_symptoms:
|
|
# 获取症状节点ID
|
|
# 获取症状节点ID
|
|
symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
|
|
symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
|
|
if not symptom_nodes:
|
|
if not symptom_nodes:
|
|
continue
|
|
continue
|
|
-
|
|
|
|
|
|
+
|
|
for symptom_node in symptom_nodes:
|
|
for symptom_node in symptom_nodes:
|
|
symptom_id = symptom_node.get("public_kg_nodes_id")
|
|
symptom_id = symptom_node.get("public_kg_nodes_id")
|
|
if not symptom_id:
|
|
if not symptom_id:
|
|
continue
|
|
continue
|
|
-
|
|
|
|
|
|
+
|
|
# 查询该症状的'症状同义词'关系获取同义词节点
|
|
# 查询该症状的'症状同义词'关系获取同义词节点
|
|
synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
|
|
synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
|
|
synonym_ids = {symptom_id} # 包含原始症状ID
|
|
synonym_ids = {symptom_id} # 包含原始症状ID
|
|
@@ -126,13 +126,13 @@ class SearchBusiness:
|
|
synonym_id = edge.get("public_kg_edges_dest_id")
|
|
synonym_id = edge.get("public_kg_edges_dest_id")
|
|
if synonym_id:
|
|
if synonym_id:
|
|
synonym_ids.add(synonym_id)
|
|
synonym_ids.add(synonym_id)
|
|
-
|
|
|
|
|
|
+
|
|
# 对每个症状ID(包括同义词)查询'常见疾病'关系
|
|
# 对每个症状ID(包括同义词)查询'常见疾病'关系
|
|
for symptom_id in synonym_ids:
|
|
for symptom_id in synonym_ids:
|
|
disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
|
|
disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
|
|
if not disease_edges:
|
|
if not disease_edges:
|
|
continue
|
|
continue
|
|
-
|
|
|
|
|
|
+
|
|
# 收集所有疾病节点
|
|
# 收集所有疾病节点
|
|
for edge in disease_edges:
|
|
for edge in disease_edges:
|
|
disease_id = edge.get("public_kg_edges_dest_id")
|
|
disease_id = edge.get("public_kg_edges_dest_id")
|
|
@@ -147,12 +147,43 @@ class SearchBusiness:
|
|
else:
|
|
else:
|
|
node["count"] = 1
|
|
node["count"] = 1
|
|
disease_dict[node_id] = node
|
|
disease_dict[node_id] = node
|
|
-
|
|
|
|
|
|
+
|
|
# 按count降序排列并过滤掉count小于阈值的疾病
|
|
# 按count降序排列并过滤掉count小于阈值的疾病
|
|
sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
|
|
sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
|
|
filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
|
|
filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
|
|
return filtered_diseases
|
|
return filtered_diseases
|
|
|
|
|
|
|
|
+ def search_related_nodes(self, src_id, relation_name):
|
|
|
|
+ """
|
|
|
|
+ 查询与指定节点有特定关系的所有节点
|
|
|
|
+ :param src_id: 源节点ID
|
|
|
|
+ :param relation_name: 关系名称
|
|
|
|
+ :return: 相关节点列表
|
|
|
|
+ """
|
|
|
|
+ try:
|
|
|
|
+ # 1. 查询关系表获取所有dest_id
|
|
|
|
+ edges = self.search_edges(name=relation_name, src_id=src_id)
|
|
|
|
+ if not edges:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # 2. 收集所有目标节点ID
|
|
|
|
+ dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges if "public_kg_edges_dest_id" in edge]
|
|
|
|
+ if not dest_ids:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # 3. 查询节点表获取所有目标节点信息
|
|
|
|
+ nodes = []
|
|
|
|
+ for dest_id in dest_ids:
|
|
|
|
+ node_results = self.search_nodes(id=dest_id)
|
|
|
|
+ if node_results:
|
|
|
|
+ nodes.extend(node_results)
|
|
|
|
+
|
|
|
|
+ return nodes
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Search related nodes error: {e}")
|
|
|
|
+ return None
|
|
|
|
+
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
search_biz = SearchBusiness()
|
|
search_biz = SearchBusiness()
|
|
index=""
|
|
index=""
|