|
@@ -18,7 +18,7 @@ class SearchBusiness:
|
|
|
def search_nodes_and_edges(self, index, query):
|
|
|
try:
|
|
|
response = self.es.search(index=index, body=query)
|
|
|
-
|
|
|
+
|
|
|
hits = response["hits"]["hits"]
|
|
|
results = []
|
|
|
for hit in hits:
|
|
@@ -31,8 +31,7 @@ class SearchBusiness:
|
|
|
logger.error(f"Index '{index}' not found: {e}")
|
|
|
return []
|
|
|
|
|
|
- def search_nodes(self,name,type,id,limit=10,from_=0):
|
|
|
-
|
|
|
+ def search_nodes(self, name=None, type=None, id=None, ids=None, limit=10, from_=0):
|
|
|
try:
|
|
|
query = {
|
|
|
"explain": "true",
|
|
@@ -54,6 +53,8 @@ class SearchBusiness:
|
|
|
query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_name": name}})
|
|
|
if id:
|
|
|
query["query"]["bool"]["must"].append({"term": {"public_kg_nodes_id": id}})
|
|
|
+ if ids:
|
|
|
+ query["query"]["bool"]["must"].append({"terms": {"public_kg_nodes_id": ids}})
|
|
|
if type:
|
|
|
query["query"]["bool"]["must"].append({"match": {"public_kg_nodes_category": type}})
|
|
|
|
|
@@ -96,9 +97,13 @@ class SearchBusiness:
|
|
|
"""
|
|
|
根据症状名称列表获取相关疾病列表
|
|
|
:param symptom_names: 症状名称列表
|
|
|
- :return: 疾病节点列表(按命中症状次数降序排列)
|
|
|
+ :return: {
|
|
|
+ "diseases": [{"id": str, "name": str, "score": float}],
|
|
|
+ "reasoning": str
|
|
|
+ }
|
|
|
"""
|
|
|
disease_dict = {}
|
|
|
+ matched_symptoms = set()
|
|
|
|
|
|
# 对症状名称去重
|
|
|
unique_symptoms = list(set(symptom_names))
|
|
@@ -147,42 +152,118 @@ class SearchBusiness:
|
|
|
else:
|
|
|
node["count"] = 1
|
|
|
disease_dict[node_id] = node
|
|
|
+ matched_symptoms.add(symptom_name)
|
|
|
|
|
|
# 按count降序排列并过滤掉count小于阈值的疾病
|
|
|
sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
|
|
|
filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
|
|
|
- return filtered_diseases
|
|
|
+
|
|
|
+ # 格式化返回结果
|
|
|
+ diseases = [
|
|
|
+ {
|
|
|
+ "id": disease["public_kg_nodes_id"],
|
|
|
+ "name": disease["public_kg_nodes_name"],
|
|
|
+ "score": min(1.0, disease["count"] / len(unique_symptoms))
|
|
|
+ }
|
|
|
+ for disease in filtered_diseases
|
|
|
+ ]
|
|
|
+
|
|
|
+ reasoning = "、".join(matched_symptoms) + "是这些疾病的常见症状" if matched_symptoms else ""
|
|
|
+
|
|
|
+ return {
|
|
|
+ "diseases": diseases,
|
|
|
+ "reasoning": reasoning
|
|
|
+ }
|
|
|
|
|
|
- def search_related_nodes(self, src_id, relation_name):
|
|
|
+ def get_disease_symptoms(self, src_id):
|
|
|
"""
|
|
|
- 查询与指定节点有特定关系的所有节点
|
|
|
- :param src_id: 源节点ID
|
|
|
- :param relation_name: 关系名称
|
|
|
- :return: 相关节点列表
|
|
|
+ 查询疾病的所有症状及其严重程度权重
|
|
|
+ :param src_id: 疾病节点ID
|
|
|
+ :return: {"symptoms": [症状列表], "severity_weights": {症状:权重}}
|
|
|
"""
|
|
|
try:
|
|
|
- # 1. 查询关系表获取所有dest_id
|
|
|
- edges = self.search_edges(name=relation_name, src_id=src_id)
|
|
|
+ # 1. 查询关系表获取所有症状关系
|
|
|
+ edges = self.search_edges(name="症状", src_id=src_id)
|
|
|
if not edges:
|
|
|
- return []
|
|
|
+ return {"symptoms": [], "severity_weights": {}}
|
|
|
|
|
|
- # 2. 收集所有目标节点ID
|
|
|
+ # 2. 收集所有症状节点ID
|
|
|
dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges if "public_kg_edges_dest_id" in edge]
|
|
|
if not dest_ids:
|
|
|
- return []
|
|
|
+ return {"symptoms": [], "severity_weights": {}}
|
|
|
|
|
|
- # 3. 查询节点表获取所有目标节点信息
|
|
|
- nodes = []
|
|
|
- for dest_id in dest_ids:
|
|
|
- node_results = self.search_nodes(id=dest_id)
|
|
|
- if node_results:
|
|
|
- nodes.extend(node_results)
|
|
|
+ # 3. 批量查询症状节点信息
|
|
|
+ nodes = self.search_nodes(ids=dest_ids)
|
|
|
+ if not nodes:
|
|
|
+ return {"symptoms": [], "severity_weights": {}}
|
|
|
|
|
|
- return nodes
|
|
|
+ # 4. 构建返回结果
|
|
|
+ symptoms = [node["public_kg_nodes_name"] for node in nodes if "public_kg_nodes_name" in node]
|
|
|
+
|
|
|
+ # 假设权重存储在节点的properties字段中,如果没有则使用默认权重
|
|
|
+ severity_weights = {
|
|
|
+ node["public_kg_nodes_name"]: float(node.get("severity", 0.8))
|
|
|
+ for node in nodes
|
|
|
+ if "public_kg_nodes_name" in node
|
|
|
+ }
|
|
|
+
|
|
|
+ return {"symptoms": symptoms, "severity_weights": severity_weights}
|
|
|
|
|
|
except Exception as e:
|
|
|
- logger.error(f"Search related nodes error: {e}")
|
|
|
- return None
|
|
|
+ logger.error(f"Get disease symptoms error: {e}")
|
|
|
+ return {"symptoms": [], "severity_weights": {}}
|
|
|
+
|
|
|
+ def search_concept(self, name, type):
|
|
|
+ nodes = self.search_nodes(name=name, type=type)
|
|
|
+ concepts = [
|
|
|
+ {"id": node["public_kg_nodes_id"], "name": node["public_kg_nodes_name"], "type": node["public_kg_nodes_category"]}
|
|
|
+ for node in nodes
|
|
|
+ ] if nodes else []
|
|
|
+ return {"concepts": concepts}
|
|
|
+
|
|
|
+ def get_relations(self, src_id):
|
|
|
+ edges = self.search_edges(name=None, src_id=src_id)
|
|
|
+ if not edges:
|
|
|
+ return {"relations": []}
|
|
|
+
|
|
|
+ dest_ids = [edge["public_kg_edges_dest_id"] for edge in edges]
|
|
|
+ nodes = self.search_nodes(ids=dest_ids)
|
|
|
+ node_map = {node["public_kg_nodes_id"]: node for node in nodes}
|
|
|
+
|
|
|
+ relations = []
|
|
|
+ for edge in edges:
|
|
|
+ dest_id = edge["public_kg_edges_dest_id"]
|
|
|
+ if dest_id in node_map:
|
|
|
+ relations.append({
|
|
|
+ "relation_type": edge["public_kg_edges_category"],
|
|
|
+ "target_id": dest_id,
|
|
|
+ "target_name": node_map[dest_id].get("public_kg_nodes_name", "")
|
|
|
+ })
|
|
|
+
|
|
|
+ return {"relations": relations}
|
|
|
+
|
|
|
+ def get_similar_concepts(self, id, top_k=5):
|
|
|
+ # 1. 先根据ID获取节点名称
|
|
|
+ node = self.search_nodes(id=id, limit=1)
|
|
|
+ if not node:
|
|
|
+ return {"similar_concepts": []}
|
|
|
+
|
|
|
+ # 2. 使用节点名称进行相似搜索
|
|
|
+ node_name = node[0]["public_kg_nodes_name"]
|
|
|
+ results = self.search_nodes(name=node_name, limit=top_k)
|
|
|
+
|
|
|
+ # 3. 格式化返回结果
|
|
|
+ similar_concepts = [
|
|
|
+ {
|
|
|
+ "id": result["public_kg_nodes_id"],
|
|
|
+ "name": result["public_kg_nodes_name"],
|
|
|
+ "similarity": result.get("_score", 0.0)
|
|
|
+ }
|
|
|
+ for result in results
|
|
|
+ ]
|
|
|
+
|
|
|
+ return {"similar_concepts": similar_concepts}
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
search_biz = SearchBusiness()
|