|
@@ -4,8 +4,7 @@ sys.path.append(current_path)
|
|
import logging
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
-from agent.libs.schema import SchemaContent,SchemaData,SchemaDataItem
|
|
|
|
-from config.site import SiteConfig
|
|
|
|
|
|
+from ..config.site import SiteConfig
|
|
from elasticsearch import Elasticsearch, helpers, exceptions
|
|
from elasticsearch import Elasticsearch, helpers, exceptions
|
|
config = SiteConfig()
|
|
config = SiteConfig()
|
|
ELASTICSEARCH_USER = config.get_config("ELASTICSEARCH_USER", "/tmp")
|
|
ELASTICSEARCH_USER = config.get_config("ELASTICSEARCH_USER", "/tmp")
|
|
@@ -92,6 +91,67 @@ class SearchBusiness:
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logger.error(f"Search error: {e}")
|
|
logger.error(f"Search error: {e}")
|
|
return None
|
|
return None
|
|
|
|
+
|
|
|
|
+ def get_symptom_diseases(self, symptom_names):
|
|
|
|
+ """
|
|
|
|
+ 根据症状名称列表获取相关疾病列表
|
|
|
|
+ :param symptom_names: 症状名称列表
|
|
|
|
+ :return: 疾病节点列表(按命中症状次数降序排列)
|
|
|
|
+ """
|
|
|
|
+ disease_dict = {}
|
|
|
|
+
|
|
|
|
+ # 对症状名称去重
|
|
|
|
+ unique_symptoms = list(set(symptom_names))
|
|
|
|
+ threshold = len(unique_symptoms) / 2
|
|
|
|
+
|
|
|
|
+ # 存储已处理的疾病ID,避免同义词重复计算
|
|
|
|
+ processed_disease_ids = set()
|
|
|
|
+
|
|
|
|
+ for symptom_name in unique_symptoms:
|
|
|
|
+ # 获取症状节点ID
|
|
|
|
+ symptom_nodes = self.search_nodes(name=symptom_name, type="症状")
|
|
|
|
+ if not symptom_nodes:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ for symptom_node in symptom_nodes:
|
|
|
|
+ symptom_id = symptom_node.get("public_kg_nodes_id")
|
|
|
|
+ if not symptom_id:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 查询该症状的'症状同义词'关系获取同义词节点
|
|
|
|
+ synonym_edges = self.search_edges(name="症状同义词", src_id=symptom_id)
|
|
|
|
+ synonym_ids = {symptom_id} # 包含原始症状ID
|
|
|
|
+ if synonym_edges:
|
|
|
|
+ for edge in synonym_edges:
|
|
|
|
+ synonym_id = edge.get("public_kg_edges_dest_id")
|
|
|
|
+ if synonym_id:
|
|
|
|
+ synonym_ids.add(synonym_id)
|
|
|
|
+
|
|
|
|
+ # 对每个症状ID(包括同义词)查询'常见疾病'关系
|
|
|
|
+ for symptom_id in synonym_ids:
|
|
|
|
+ disease_edges = self.search_edges(name="常见疾病", src_id=symptom_id)
|
|
|
|
+ if not disease_edges:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 收集所有疾病节点
|
|
|
|
+ for edge in disease_edges:
|
|
|
|
+ disease_id = edge.get("public_kg_edges_dest_id")
|
|
|
|
+ if disease_id and disease_id not in processed_disease_ids:
|
|
|
|
+ processed_disease_ids.add(disease_id)
|
|
|
|
+ disease_nodes = self.search_nodes(id=disease_id, type="疾病")
|
|
|
|
+ if disease_nodes:
|
|
|
|
+ for node in disease_nodes:
|
|
|
|
+ node_id = node.get("public_kg_nodes_id")
|
|
|
|
+ if node_id in disease_dict:
|
|
|
|
+ disease_dict[node_id]["count"] += 1
|
|
|
|
+ else:
|
|
|
|
+ node["count"] = 1
|
|
|
|
+ disease_dict[node_id] = node
|
|
|
|
+
|
|
|
|
+ # 按count降序排列并过滤掉count小于阈值的疾病
|
|
|
|
+ sorted_diseases = sorted(disease_dict.values(), key=lambda x: x["count"], reverse=True)
|
|
|
|
+ filtered_diseases = [disease for disease in sorted_diseases if disease["count"] >= threshold]
|
|
|
|
+ return filtered_diseases
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
search_biz = SearchBusiness()
|
|
search_biz = SearchBusiness()
|