SGTY 3 kuukautta sitten
vanhempi
commit
65e86f2820

+ 29 - 0
.env

@@ -0,0 +1,29 @@
+POSTGRESQL_HOST = localhost
+POSTGRESQL_DATABASE = kg
+POSTGRESQL_DATABASE_AGENT = agent
+POSTGRESQL_USER = postgres
+POSTGRESQL_PASSWORD = difyai123456
+
+#indexing
+ELASTICSEARCH_HOST=https://localhost:9200
+ELASTICSEARCH_USER=juan
+ELASTICSEARCH_PWD="p@ssw0rd"
+WORD_INDEX=word_index
+TITLE_INDEX=title_index
+CHUNC_INDEX=chunc_index
+# DeepSeek API
+DEEPSEEK_API_URL=https://api.siliconflow.cn
+DEEPSEEK_API_KEY=sk-vecnpjmtmelcefdbtbbpqvzcegopxrherbnbjhscugbpxuif
+CACHED_DATA_PATH=D:\work\03\cached_data\new
+UPDATE_DATA_PATH=D:\work\03\qz_data\update
+FACTOR_DATA_PATH=D:\work\03\qz_data\factor
+GRAPH_API_URL=http://localhost:8000
+#Embedding
+EMBEDDING_MODEL=C:\Users\jiyua\.cache\modelscope\hub\models\BAAI\bge-m3
+#EMBEDDING_MODEL=C:\Users\jiyua\.cache\modelscope\hub\models\deepseek-ai\DeepSeek-R1-Distill-Qwen-1___5b
+DOC_PATH=D:/work/03/regulations.json
+DOC_STORAGE_PATH=D:/work/03/output/docs
+TRUNC_OUTPUT_PATH=D:/work/03/output/chunc_data
+DOC_ABSTRACT_OUTPUT_PATH=D:/work/03/output/doc_abstract
+JIEBA_USER_DICT=D:/work/03/ins_expert/dict/legal_terms.txt
+JIEBA_STOP_DICT=D:/work/03/ins_expert/dict/stop_words.txt

+ 46 - 0
agent/cdss/capbility.py

@@ -0,0 +1,46 @@
+import time
+
+from agent.cdss.models.schemas import CDSSDict, CDSSInput,CDSSInt,CDSSOutput,CDSSText
+from agent.cdss.libs.cdss_helper import CDSSHelper
+import logging
+logger = logging.getLogger(__name__)
+
+class CDSSCapability:
+    cdss_helper: CDSSHelper = None
+    def __init__(self):
+        self.cdss_helper = CDSSHelper()
+        logger.debug("CDSSCapability initialized")
+    
+    def process(self, input: CDSSInput, embeding_search:bool = True) -> CDSSOutput:        
+        start_nodes = []
+        chief_complaint = input.get_value("chief_complaint")
+        logger.info(f"process input: {input}")
+        output = CDSSOutput()
+        if chief_complaint:
+            start_time = time.time()
+            for keyword in chief_complaint:
+                results = self.cdss_helper.node_search(
+                    keyword, limit=10, node_type="word"
+                )
+                for item in results:
+                    if item['score']>1.9:
+                        start_nodes.append(item['id'])
+            end_time = time.time()
+            print(f"node_search执行完成,耗时:{end_time - start_time:.2f}秒")
+            logger.info(f"cdss start from {start_nodes}")    
+            result = self.cdss_helper.cdss_travel(input, start_nodes,max_hops=2)
+
+            # for item in result["details"]:
+            #     name, data = item
+            #     output.departments.value[name] = data
+            for item in result["score_diags"][:5]:
+                output.diagnosis.value[item[0]] = item[1]
+
+            # for item in result["checks"][:5]:
+            #     item[1]['score'] = item[1]['count'] / result["total_checks"]
+            #     output.checks.value[item[0]] = item[1]
+            # for item in result["drugs"][:5]:
+            #     item[1]['score'] = item[1]['count'] / result["total_checks"]
+            #     output.drugs.value[item[0]] = item[1]
+                #print(f"\t药品:{item[0]} {item[1]['count'] / result["total_drugs"] * 100:.2f} %")
+        return output

+ 714 - 0
agent/cdss/libs/cdss_helper.py

@@ -0,0 +1,714 @@
+from hmac import new
+import os
+import sys
+import logging
+import json
+import time
+
+from service.kg_edge_service import KGEdgeService
+from db.session import get_db
+from service.kg_node_service import KGNodeService
+from service.kg_prop_service import KGPropService
+from utils.cache import Cache
+
+current_path = os.getcwd()
+sys.path.append(current_path)
+
+from community.graph_helper import GraphHelper
+from typing import List
+from agent.cdss.models.schemas import CDSSInput
+from config.site import SiteConfig
+import networkx as nx
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+
+current_path = os.getcwd()
+sys.path.append(current_path)
+
+# 图谱数据缓存路径(由dump_graph_data.py生成)
+CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
+
+
+class CDSSHelper(GraphHelper):
+
+    def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
+
+        kg_node_service = KGNodeService(next(get_db()))
+        es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
+        results = []
+        for item in es_result:
+            n = self.graph.nodes.get(item["id"])
+            score = item["score"]
+            if n:
+                results.append({
+                    'id': item["title"],
+                    'score': score,
+                    "name": item["title"],
+                })
+        return results
+
+    def _load_entity_data(self):
+        config = SiteConfig()
+        # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
+        print("load entity data")
+        # 这里设置了读取的属性
+        data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": []}
+        with open(f"{CACHED_DATA_PATH}\\entities_med.json", "r", encoding="utf-8") as f:
+            entities = json.load(f)
+            for item in entities:
+                #如果id已经存在,则跳过
+                # if item[0] in data["id"]:
+                #     print(f"skip {item[0]}")
+                #     continue
+                data["id"].append(int(item[0]))
+                data["name"].append(item[1]["name"])
+                data["type"].append(item[1]["type"])
+                self._append_entity_attribute(data, item, "sex")
+                self._append_entity_attribute(data, item, "age")
+                self._append_entity_attribute(data, item, "is_symptom")
+                # item[1]["id"] = item[0]
+                # item[1]["name"] = item[0]
+                # attrs = item[1]
+                # self.graph.add_node(item[0], **attrs)
+        self.entity_data = pd.DataFrame(data)
+        self.entity_data.set_index("id", inplace=True)
+        print("load entity data finished")
+
+    def _append_entity_attribute(self, data, item, attr_name):
+        if attr_name in item[1]:
+            value = item[1][attr_name].split(":")
+            if len(value) < 2:
+                data[attr_name].append(value[0])
+            else:
+                data[attr_name].append(value[1])
+        else:
+            data[attr_name].append("")
+
+    def _load_relation_data(self):
+        config = SiteConfig()
+        # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
+        print("load relationship data")
+
+        for i in range(46):
+            if os.path.exists(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json"):
+                print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
+                with open(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json", "r", encoding="utf-8") as f:
+                    data = {"src": [], "dest": [], "type": [], "weight": []}
+                    relations = json.load(f)
+                    for item in relations:
+                        data["src"].append(int(item[0]))
+                        data["dest"].append(int(item[2]))
+                        data["type"].append(item[4]["type"])
+                        if "order" in item[4]:
+                            order = item[4]["order"].split(":")
+                            if len(order) < 2:
+                                data["weight"].append(order[0])
+                            else:
+                                data["weight"].append(order[1])
+                        else:
+                            data["weight"].append(1)
+
+                    self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
+
+    def build_graph(self):
+        self.entity_data = pd.DataFrame(
+            {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []})
+        self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []})
+        self._load_entity_data()
+        self._load_relation_data()
+        self._load_local_data()
+
+        self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True,
+                                             create_using=nx.DiGraph())
+
+        nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
+        # print(self.graph.in_edges('1257357',data=True))
+
+    def _load_local_data(self):
+        # 这里加载update数据和权重数据
+        config = SiteConfig()
+        self.update_data_path = config.get_config('UPDATE_DATA_PATH')
+        self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
+        print(f"load update data from {self.update_data_path}")
+        for root, dirs, files in os.walk(self.update_data_path):
+            for file in files:
+                file_path = os.path.join(root, file)
+                if file_path.endswith(".json") and file.startswith("ent"):
+                    self._load_update_entity_json(file_path)
+                if file_path.endswith(".json") and file.startswith("rel"):
+                    self._load_update_relationship_json(file_path)
+
+    def _load_update_entity_json(self, file):
+        '''load json data from file'''
+        print(f"load entity update data from {file}")
+
+        # 这里加载update数据,update数据是一个json文件,格式同cached data如下:
+
+        with open(file, "r", encoding="utf-8") as f:
+            entities = json.load(f)
+            for item in entities:
+                original_data = self.entity_data[self.entity_data.index == item[0]]
+                if original_data.empty:
+                    continue
+                original_data = original_data.iloc[0]
+                id = int(item[0])
+                name = item[1]["name"] if "name" in item[1] else original_data['name']
+                type = item[1]["type"] if "type" in item[1] else original_data['type']
+                allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[
+                    'allowed_sex_list']
+                allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[
+                    'allowed_age_range']
+
+                self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type,
+                                                                                                       allowed_sex_list,
+                                                                                                       allowed_age_range]
+
+    def _load_update_relationship_json(self, file):
+        '''load json data from file'''
+        print(f"load relationship update data from {file}")
+
+        with open(file, "r", encoding="utf-8") as f:
+            relations = json.load(f)
+            for item in relations:
+                data = {}
+                original_data = self.relation_data[(self.relation_data['src'] == data['src']) &
+                                                   (self.relation_data['dest'] == data['dest']) &
+                                                   (self.relation_data['type'] == data['type'])]
+                if original_data.empty:
+                    continue
+                original_data = original_data.iloc[0]
+                data["src"] = int(item[0])
+                data["dest"] = int(item[2])
+                data["type"] = item[4]["type"]
+                data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight']
+
+                self.relation_data.loc[(self.relation_data['src'] == data['src']) &
+                                       (self.relation_data['dest'] == data['dest']) &
+                                       (self.relation_data['type'] == data['type']), 'weight'] = data["weight"]
+
+    def check_sex_allowed(self, node, sex):
+        # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
+
+        sex_allowed = self.graph.nodes[node].get('sex', None)
+
+        #sexProps = self.propService.get_props_by_ref_id(node, 'sex')
+        #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][
+            #'prop_value'] != input.pat_sex.value:
+            #continue
+        if sex_allowed:
+            if len(sex_allowed) == 0:
+                # 如果性别列表为空,那么默认允许所有性别
+                return True
+            sex_allowed_list = sex_allowed.split(',')
+            if sex not in sex_allowed_list:
+                # 如果性别不匹配,跳过
+                return False
+        return True
+
+    def check_age_allowed(self, node, age):
+        # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
+        # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
+        age_allowed = self.graph.nodes[node].get('age', None)
+        if age_allowed:
+            if len(age_allowed) == 0:
+                # 如果年龄范围为空,那么默认允许所有年龄
+                return True
+            age_allowed_list = age_allowed.split('-')
+            age_min = int(age_allowed_list[0])
+            age_max = int(age_allowed_list[-1])
+            if age_max ==0:
+                return True
+            if age >= age_min and age < age_max:
+                # 如果年龄范围正常,那么返回True
+                return True
+        else:
+            # 如果没有设置年龄范围,那么默认返回True
+            return True
+        return False
+
+    def check_diease_allowed(self, node):
+        is_symptom = self.graph.nodes[node].get('is_symptom', None)
+        if is_symptom == "是":
+            return False
+        return True
+
+    propService = KGPropService(next(get_db()))
+    cache = Cache()
+    def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3):
+        """
+        基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品
+
+        参数:
+            input: CDSSInput对象,包含患者的基本信息(年龄、性别等)
+            start_nodes: 症状节点名称列表,作为遍历的起点
+            max_hops: 最大遍历深度,默认为3
+
+        返回值:
+            返回一个包含以下信息的字典:
+                - details: 按科室汇总的结果
+                - diags: 按相关性排序的疾病列表
+                - checks: 按出现频率排序的检查列表
+                - drugs: 按出现频率排序的药品列表
+                - total_diags: 疾病总数
+                - total_checks: 检查总数
+                - total_drugs: 药品总数
+
+        主要步骤:
+            1. 初始化允许的节点类型和关系类型
+            2. 将症状名称转换为节点ID
+            3. 遍历图谱查找相关疾病(STEP 1)
+            4. 查找疾病对应的科室、检查和药品(STEP 2)
+            5. 按科室汇总结果(STEP 3)
+            6. 对结果进行排序和统计(STEP 4-6)
+        """
+        start_time = time.time()
+        # 定义允许的节点类型,包括科室、疾病、药品、检查和症状
+        # 这些类型用于后续的节点过滤和路径查找
+        DEPARTMENT = ['科室', 'Department']
+        DIESEASE = ['疾病', 'Disease']
+        DRUG = ['药品', 'Drug']
+        CHECK = ['检查', 'Check']
+        SYMPTOM = ['症状', 'Symptom']
+        #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM
+        allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
+        # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
+        # 这些关系类型用于后续的路径查找和过滤
+        allowed_links = ['has_symptom', 'need_check', 'recommend_drug', 'belongs_to']
+        # 将输入的症状名称转换为节点ID
+        # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
+        node_ids = []
+        node_id_names = {}
+
+        # start_nodes里面重复的症状,去重同样的症状
+        start_nodes = list(set(start_nodes))
+        for node in start_nodes:
+            #print(f"searching for node {node}")
+            result = self.entity_data[self.entity_data['name'] == node]
+            # print(f"searching for node {result}")
+            for index, data in result.iterrows():
+                node_id_names[index] = data["name"]
+                node_ids = node_ids + [index]
+        #print(f"start travel from {node_id_names}")
+
+        # 这里是一个队列,用于存储待遍历的症状:
+        node_ids_filtered = []
+        for node in node_ids:
+            if self.graph.has_node(node):
+                node_ids_filtered.append(node)
+            else:
+                logger.debug(f"node {node} not found")
+        node_ids = node_ids_filtered
+
+        end_time = time.time()
+        print(f"init执行完成,耗时:{end_time - start_time:.2f}秒")
+        start_time = time.time()
+        results = {}
+        for node in node_ids:
+            visited = set()
+            temp_results = {}
+            cache_key = f"symptom_{node}"
+            cache_data = self.cache.get(cache_key)
+            if cache_data:
+                logger.debug(f"cache hit for {cache_key}")
+                temp_results = cache_data
+           
+                if results=={}:
+                    results = temp_results      
+                else:
+                    for disease_id in temp_results:
+                        path = temp_results[disease_id]["path"][0]
+                        if disease_id in results.keys():
+                            results[disease_id]["count"] = temp_results[disease_id]["count"] + 1
+                            results[disease_id]["path"].append(path)
+                        else:
+                            results[disease_id] = temp_results[disease_id]
+
+              
+                continue
+            queue = [(node, 0, node_id_names[node], {'allowed_types': allowed_types, 'allowed_links': allowed_links})] 
+
+            # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
+            if input.pat_age.value > 0 and input.pat_age.type == 'year':
+                # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
+                input.pat_age.value = input.pat_age.value * 12
+                input.pat_age.type = 'month'
+
+            # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
+            # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
+            while queue:
+                node, depth, path, data = queue.pop(0)
+                node = int(node)
+                # 这里是通过id去获取节点的name和type
+                entity_data = self.entity_data[self.entity_data.index == node]
+                # 如果节点不存在,那么跳过
+                if entity_data.empty:
+                    continue
+                if self.graph.nodes.get(node) is None:
+                    continue
+                node_type = self.entity_data[self.entity_data.index == node]['type'].tolist()[0]
+                node_name = self.entity_data[self.entity_data.index == node]['name'].tolist()[0]
+                # print(f"node {node} type {node_type}")
+                if node_type in DIESEASE:
+                    # print(f"node {node} type {node_type} is a disease")
+                    if self.check_diease_allowed(node) == False:
+                        continue
+                    if node in temp_results.keys():
+                        temp_results[node]["count"] = temp_results[node]["count"] + 1
+                        temp_results[node]["path"].append(path)
+                    else:
+                        temp_results[node] = {"type": node_type, "count": 1, "name": node_name, 'path': [path]}
+
+                    continue
+
+                if node in visited or depth > max_hops:
+                    # print(f"{node} already visited or reach max hops")
+                    continue
+
+                visited.add(node)
+                # print(f"check edges from {node}")
+                if node not in self.graph:
+                    # print(f"node {node} not found in graph")
+                    continue
+                # todo 目前是取入边,出边是不是也有用?
+                for edge in self.graph.in_edges(node, data=True):
+                    src, dest, edge_data = edge
+                    if src not in visited and depth + 1 < max_hops:
+                        # print(f"put into queue travel from {src} to {dest}")
+                        queue.append((src, depth + 1, path, data))
+                    # else:
+                    # print(f"skip travel from {src} to {dest}")
+            self.cache.set(cache_key, temp_results)
+            if results == {}:
+                results = temp_results
+            else:
+                for disease_id in temp_results:
+                    path = temp_results[disease_id]["path"][0]
+                    if disease_id in results.keys():
+                        results[disease_id]["count"] = temp_results[disease_id]["count"] + 1
+                        results[disease_id]["path"].append(path)
+                    else:
+                        results[disease_id] = temp_results[disease_id]
+
+        end_time = time.time()
+
+        # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病
+        new_results = {}
+        print(len(results))
+        for item in results:
+            if self.check_sex_allowed(node, input.pat_sex.value) == False:
+                continue
+            if self.check_age_allowed(node, input.pat_age.value) == False:
+                continue
+            new_results[item] = results[item]
+        results = new_results
+        print(len(results))
+        print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒")
+        print(f"STEP 1 遍历图谱查找相关疾病 finished")
+        # 这里输出markdonw格式日志
+        log_data = ["|疾病|症状|出现次数|是否相关"]
+        log_data.append("|--|--|--|--|")
+        for item in results:
+            data = results[item]
+            data['relevant'] = False
+            if data["count"] / len(start_nodes) > 0.5:
+                # 疾病有50%以上的症状出现,才认为是相关的
+                data['relevant'] = True
+                disease_name = data["name"]
+                key = 'disease_name_parent_' +disease_name
+                cached_value = self.cache.get(key)
+                if cached_value is None:
+                    out_edges = self.graph.out_edges(item, data=True)
+
+                    for edge in out_edges:
+                        src, dest, edge_data = edge
+                        if edge_data["type"] != '疾病相关父类':
+                            continue
+                        dest_data = self.entity_data[self.entity_data.index == dest]
+                        if dest_data.empty:
+                            continue
+                        dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                        self.cache.set(key, dest_name)
+                        break
+
+            # 如果data['relevant']为False,那么我们就不输出
+            if data['relevant'] == False:
+                continue
+
+
+            log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
+
+        content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
+        #print(f"\n{content}")
+        # STEP 2: 找到这些疾病对应的科室,检查和药品
+        # 由于这部分是按照疾病逐一去寻找,所以实际应用中可以缓存这些结果
+        start_time = time.time()
+        print("STEP 2 查找疾病对应的科室、检查和药品 start")
+
+        for disease in results.keys():
+            # TODO 这里需要对疾病对应的科室检查药品进行加载缓存,性能可以得到很大的提升
+            if results[disease]["relevant"] == False:
+                continue
+            print(f"search data for {disease}:{results[disease]['name']}")
+            queue = []
+            queue.append((disease, 0, disease, {'allowed_types': DEPARTMENT, 'allowed_links': ['belongs_to']}))
+
+            # 这里尝试过将visited放倒for disease循环外面,但是会造成一些问题,性能提升也不明显,所以这里还是放在for disease循环里面
+            visited = set()
+
+            while queue:
+                node, depth, disease, data = queue.pop(0)
+
+                if node in visited or depth > max_hops:
+                    continue
+                visited.add(node)
+
+                entity_data = self.entity_data[self.entity_data.index == node]
+
+                # 如果节点不存在,那么跳过
+                if entity_data.empty:
+                    continue
+                node_type = self.entity_data[self.entity_data.index == node]['type'].tolist()[0]
+                node_name = self.entity_data[self.entity_data.index == node]['name'].tolist()[0]
+
+                # print(f"node {results[disease].get("name", disease)} {node_name} type {node_type}")
+                # node_type = self.graph.nodes[node].get('type')
+                if node_type in DEPARTMENT:
+                    # 展开科室,重复次数为疾病出现的次数,为了方便后续统计
+                    department_data = [node_name] * results[disease]["count"]
+                    if 'department' in results[disease].keys():
+                        results[disease]["department"] = results[disease]["department"] + department_data
+                    else:
+                        results[disease]["department"] = department_data
+                    continue
+                # if node_type in CHECK:
+                #     if 'check' in results[disease].keys():
+                #         results[disease]["check"] = list(set(results[disease]["check"]+[node_name]))
+                #     else:
+                #         results[disease]["check"] = [node_name]
+                #     continue
+                # if node_type in DRUG:
+                #     if 'drug' in results[disease].keys():
+                #         results[disease]["drug"] = list(set(results[disease]["drug"]+[node_name]))
+                #     else:
+                #         results[disease]["drug"] = [node_name]
+                #     continue
+                out_edges = self.graph.out_edges(node, data=True)
+
+                for edge in out_edges:
+                    src, dest, edge_data = edge
+                    src_data = self.entity_data[self.entity_data.index == src]
+                    if src_data.empty:
+                        continue
+                    dest_data = self.entity_data[self.entity_data.index == dest]
+                    if dest_data.empty:
+                        continue
+                    src_name = self.entity_data[self.entity_data.index == src]['name'].tolist()[0]
+                    dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                    dest_type = self.entity_data[self.entity_data.index == dest]['type'].tolist()[0]
+
+                    if dest_type in allowed_types:
+                        if dest not in visited and depth + 1 < max_hops:
+                            # print(f"put travel request in queue from {src}:{src_name} to {dest}:{dest_name}")
+                            queue.append((edge[1], depth + 1, disease, data))
+
+                            # TODO 可以在这里将results里面的每个疾病对应的科室,检查和药品进行缓存,方便后续使用
+        # for item in results.keys():
+        #     department_data = results[item].get("department", [])
+        #     count_data = results[item].get("count")
+        #     check_data = results[item].get("check", [])
+        #     drug_data = results[item].get("drug", [])
+        #     #缓存代码放在这里
+
+        print(f"STEP 2 finished")
+        end_time = time.time()
+        print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒")
+        # 这里输出日志
+        log_data = ["|disease|count|department|check|drug|"]
+        log_data.append("|--|--|--|--|--|")
+        for item in results.keys():
+            department_data = results[item].get("department", [])
+            count_data = results[item].get("count")
+            check_data = results[item].get("check", [])
+            drug_data = results[item].get("drug", [])
+            log_data.append(
+                f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
+
+        #print("疾病科室检查药品相关统计\n" + "\n".join(log_data))
+        # 日志输出完毕
+
+        # STEP 3: 对于结果按照科室维度进行汇总
+        start_time = time.time()
+        print(f"STEP 3 对于结果按照科室维度进行汇总 start")
+        final_results = {}
+        total = 0
+        for disease in results.keys():
+            disease = int(disease)
+            # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理
+            departments = ['DEFAULT']
+            if 'department' in results[disease].keys():
+                departments = results[disease]["department"]
+            # else:
+            #     edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='belongs_to')
+            #     #edges有可能为空,这里需要做一下处理
+            #     if len(edges) == 0:
+            #         continue
+            #     departments = [edge['dest_node']['name'] for edge in edges]
+            # 处理查询结果
+            for department in departments:
+                total += 1
+                if not department in final_results.keys():
+                    final_results[department] = {
+                        "diseases": [results[disease].get("name", disease)],
+                        "checks": results[disease].get("check", []),
+                        "drugs": results[disease].get("drug", []),
+                        "count": 1
+                    }
+                else:
+                    final_results[department]["diseases"] = final_results[department]["diseases"] + [
+                        results[disease].get("name", disease)]
+                    final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get(
+                        "check", [])
+                    final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get(
+                        "drug", [])
+                    final_results[department]["count"] += 1
+
+        # 这里是统计科室出现的分布
+        for department in final_results.keys():
+            final_results[department]["score"] = final_results[department]["count"] / total
+
+        print(f"STEP 3 finished")
+        end_time = time.time()
+        print(f"STEP 3 执行完成,耗时:{end_time - start_time:.2f}秒")
+        # 这里输出日志
+        log_data = ["|department|disease|check|drug|count|score"]
+        log_data.append("|--|--|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases", [])
+            check_data = final_results[department].get("checks", [])
+            drug_data = final_results[department].get("drugs", [])
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(
+                f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
+
+        #print("\n" + "\n".join(log_data))
+
+        # STEP 4: 对于final_results里面的disease,checks和durgs统计出现的次数并且按照次数降序排序
+        print(f"STEP 4 start")
+        start_time = time.time()
+        def sort_data(data, count=5):
+            tmp = {}
+            for item in data:
+                if item in tmp.keys():
+                    tmp[item]["count"] += 1
+                else:
+                    tmp[item] = {"count": 1}
+            sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True)
+            return sorted_data[:count]
+
+        for department in final_results.keys():
+            final_results[department]['name'] = department
+            final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
+            final_results[department]["checks"] = sort_data(final_results[department]["checks"])
+            final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
+
+        # 这里把科室做一个排序,按照出现的次数降序排序
+        sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True)
+
+        print(f"STEP 4 finished")
+        end_time = time.time()
+        print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒")
+        # 这里输出markdown日志
+        log_data = ["|department|disease|check|drug|count|score"]
+        log_data.append("|--|--|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases")
+            check_data = final_results[department].get("checks")
+            drug_data = final_results[department].get("drugs")
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
+
+        #print("\n" + "\n".join(log_data))
+        # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
+        print(f"STEP 5 start")
+        start_time = time.time()
+        checks = {}
+        drugs = {}
+        diags = {}
+        total_check = 0
+        total_drug = 0
+        total_diags = 0
+        for department in final_results.keys():
+            # 这里是提取了科室出现的概率,对于缺省的科室设置了0.1
+            # 对于疾病来说用疾病在科室中出现的次数乘以科室出现的概率作为分数
+            if department == 'DEFAULT':
+                department_factor = 0.1 
+            else:
+                department_factor = final_results[department]["score"]
+            #input.department=department时,增加权重
+            #todo 科室同义词都需要增加权重
+            if input.department.value == department:
+                department_factor = department_factor*1.1
+            for disease, data in final_results[department]["diseases"]:
+                total_diags += 1
+                key = 'disease_name_parent_' + disease
+                cached_value = self.cache.get(key)
+                if cached_value is not None:
+                    disease = cached_value
+                if disease in diags.keys():
+                    diags[disease]["count"] += data["count"]
+                    diags[disease]["score"] += data["count"] * department_factor
+                else:
+                    diags[disease] = {"count": data["count"], "score": data["count"] * department_factor}
+            # 对于检查和药品直接累加出现的次数
+            # for check, data in final_results[department]["checks"]:
+            #     total_check += 1
+            #     if check in checks.keys():
+            #         checks[check]["count"] += data["count"]
+            #     else:
+            #         checks[check] = {"count":data["count"]}
+            # for drug, data in final_results[department]["drugs"]:
+            #     total_drug += 1
+            #     if drug in drugs.keys():
+            #         drugs[drug]["count"] += data["count"]
+            #     else:
+            #         drugs[drug] = {"count":data["count"]}
+
+        sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
+
+        # sorted_checks = sorted(checks.items(), key=lambda x:x[1]["count"],reverse=True)
+        # sorted_drugs = sorted(drugs.items(), key=lambda x:x[1]["count"],reverse=True)
+        print(f"STEP 5 finished")
+        end_time = time.time()
+        print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
+        # 这里输出markdown日志
+        log_data = ["|department|disease|check|drug|count|score"]
+        log_data.append("|--|--|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases")
+            # check_data = final_results[department].get("checks")
+            # drug_data = final_results[department].get("drugs")
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
+
+        #print("这里是经过排序的数据\n" + "\n".join(log_data))
+        # STEP 6: 整合数据并返回
+        # if "department" in item.keys():
+        #     final_results["department"] = list(set(final_results["department"]+item["department"]))
+        # if "diseases" in item.keys():
+        #     final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
+        # if "checks" in item.keys():
+        #     final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
+        # if "drugs" in item.keys():
+        #     final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
+        # if "symptoms" in item.keys():
+        #     final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
+
+        return {"details": sorted_final_results,
+                "score_diags": sorted_score_diags,"total_diags": total_diags,
+                # "checks":sorted_checks, "drugs":sorted_drugs,
+                # "total_checks":total_check, "total_drugs":total_drug
+                }

+ 4 - 1
cdss/models/schemas.py

@@ -48,7 +48,8 @@ class CDSSDict:
 class CDSSInput:
     pat_age: CDSSInt = CDSSInt(type='year', value=0)
     pat_sex: CDSSInt= CDSSInt(type='sex', value=0)
-    values: List[CDSSText] 
+    department: CDSSText = CDSSText(type='department', value="")
+    values: List[CDSSText]
     
     def __init__(self, **kwargs):
         #提取kwargs中的所有字段,并将它们添加到类的属性中。这样,在创建子类时,就可以直接使用这些字段了。
@@ -66,6 +67,8 @@ class CDSSInput:
                 return value.value
         
 class CDSSOutput:
+    departments: CDSSDict = CDSSDict(type='departments', value={})
     diagnosis: CDSSDict = CDSSDict(type='diagnosis', value={})
+    diagnosis2: CDSSDict = CDSSDict(type='diagnosis', value={})
     checks: CDSSDict = CDSSDict(type='checks', value={})
     drugs: CDSSDict = CDSSDict(type='drugs', value={})

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 0 - 277
app.log


+ 0 - 64
cdss/capbility.py

@@ -1,64 +0,0 @@
-from cdss.models.schemas import CDSSDict, CDSSInput,CDSSInt,CDSSOutput,CDSSText
-from cdss.libs.cdss_helper import CDSSHelper
-
-# CDSS能力核心类,负责处理临床决策支持系统的核心逻辑
-class CDSSCapability:
-    # CDSS帮助类实例
-    cdss_helper: CDSSHelper = None
-    def __init__(self):
-        # 初始化CDSS帮助类
-        self.cdss_helper = CDSSHelper()
-    
-    # 核心处理方法,接收输入数据并返回决策支持结果
-    # @param input: CDSS输入数据
-    # @param embeding_search: 是否使用embedding搜索,默认为True
-    # @return: CDSS输出结果
-    def process(self, input: CDSSInput, embeding_search:bool = True) -> CDSSOutput:
-        start_nodes = []
-        # 获取主诉信息
-        chief_complaint = input.get_value("chief_complaint")
-        # 初始化输出对象
-        output = CDSSOutput()
-        # 如果存在主诉信息,则进行处理
-        if chief_complaint:
-            for keyword in chief_complaint:
-                # 使用帮助类进行节点搜索,查找与关键词相关的节点
-                results = self.cdss_helper.node_search(
-                    keyword, limit=10, node_type="word"
-                )
-                for item in results:
-                    if item['score']>1.9:
-                        start_nodes.append(item['id'])
-            print("cdss start from :", start_nodes)    
-            # 使用帮助类进行CDSS路径遍历,最大跳数为2
-            result = self.cdss_helper.cdss_travel(input, start_nodes,max_hops=2)
-                
-            for item in result["details"]:
-                name, data = item
-                output.diagnosis.value[name] = data
-                print(f"{name}  {data['score'] *100:.2f} %")
-
-                for disease in data["diseases"]:
-                     print(f"\t疾病:{disease[0]} ")
-
-
-                for check in data["checks"]:
-                     print(f"\t检查:{check[0]} ")
-
-
-                for drug in data["drugs"]:
-                     print(f"\t药品:{drug[0]} ")
-
-            # 输出最终推荐的检查和药品信息
-            print("最终推荐的检查和药品如下:")
-            for item in result["checks"][:5]:        
-                item[1]['score'] = item[1]['count'] / result["total_checks"]
-                output.checks.value[item[0]] = item[1]                
-                # print(f"\t检查:{item[0]}  {item[1]['score'] * 100:.2f} %")    
-                # print(f"\t检查:{item[0]}  {item[1]['count'] / result["total_checks"] * 100:.2f} %")
-            
-            for item in result["drugs"][:5]:          
-                item[1]['score'] = item[1]['count'] / result["total_checks"]
-                output.drugs.value[item[0]] = item[1]  
-                #print(f"\t药品:{item[0]} {item[1]['count'] / result["total_drugs"] * 100:.2f} %")
-        return output

+ 0 - 214
cdss/libs/cdss_helper.py

@@ -1,214 +0,0 @@
-import os
-import sys
-current_path = os.getcwd()
-sys.path.append(current_path)
-
-from community.graph_helper import GraphHelper
-from typing import List
-from cdss.models.schemas import CDSSInput
-class CDSSHelper(GraphHelper):
-    def check_sex_allowed(self, node, sex):        
-        #性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
-        sex_allowed = self.graph.nodes[node].get('allowed_sex_list', None)
-        if sex_allowed:
-            sex_allowed_list = sex_allowed.split(',')
-            if sex not in sex_allowed_list:
-                #如果性别不匹配,跳过
-                return False
-        return True
-    def check_age_allowed(self, node, age):
-        #年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
-        #如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
-        age_allowed = self.graph.nodes[node].get('allowed_age_range', None)
-        if age_allowed:
-            age_allowed_list = age_allowed.split('-')
-            age_min = int(age_allowed_list[0])
-            age_max = int(age_allowed_list[-1])
-            if age >= age_min and age < age_max:
-                #如果年龄范围正常,那么返回True
-                return True
-        else:
-            #如果没有设置年龄范围,那么默认返回True
-            return True
-        return False
-        
-    def cdss_travel(self, input:CDSSInput, start_nodes:List, max_hops=3):      
-        #这里设置了节点的type取值范围,可以根据实际情况进行修改,允许出现多个类型
-        DEPARTMENT=['科室']
-        DIESEASE=['疾病']
-        DRUG=['药品']
-        CHECK=['检查']
-        SYMPTOM=['症状']
-        #allowed_types = ['科室', '疾病', '药品', '检查', '症状']
-        allowed_types = DEPARTMENT + DIESEASE+ DRUG + CHECK + SYMPTOM
-        #这里设置了边的type取值范围,可以根据实际情况进行修改,允许出现多个类型
-        #不过后面的代码里面没有对边的type进行过滤,所以这里是留做以后扩展的
-        allowed_links = ['has_symptom', 'need_check', 'recommend_drug', 'belongs_to']
-        #详细解释下面一行代码
-        #queue是一个队列,里面存放了待遍历的节点,每个节点都有一个depth,表示当前节点的深度,
-        #一个path,表示当前节点的路径,一个data,表示当前节点的一些额外信息,比如allowed_types和allowed_links
-        #allowed_types表示当前节点的类型,allowed_links表示当前节点的边的类型
-        #这里的start_nodes是一个列表,里面存放了起始节点,每个起始节点都有一个depth为0,一个path为"/",一个data为{'allowed_types': allowed_types, 'allowed_links':allowed_links}
-        #这里的"/"表示根节点,因为根节点没有父节点,所以路径为"/"
-        #这里的data是一个字典,里面存放了allowed_types和allowed_links,这两个值都是列表,里面存放了允许的类型
-        queue = [(node, 0, "/", {'allowed_types': allowed_types, 'allowed_links':allowed_links}) for node in start_nodes]        
-        visited = set()      
-        results = {}
-        #整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
-        if input.pat_age.value > 0 and input.pat_age.type == 'year':
-            #这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
-            input.pat_age.value = input.pat_age.value * 12
-            input.pat_age.type = 'month'
-            
-        #STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
-        #由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
-        while queue:
-            node, depth, path, data = queue.pop(0)
-            #allowed_types = data['allowed_types']
-            #allowed_links = data['allowed_links']
-            indent = depth * 4
-            node_type = self.graph.nodes[node].get('type')
-            if node_type in DIESEASE:
-                if self.check_sex_allowed(node, input.pat_sex.value) == False:
-                    continue
-                if self.check_age_allowed(node, input.pat_age.value) == False:
-                    continue
-                if node in results.keys():                 
-                    results[node]["count"] = results[node]["count"] + 1   
-                    #print("疾病", node, "出现的次数", results[node]["count"])
-                else:
-                    results[node] = {"type": node_type, "count":1, 'path':path}            
-                continue
-            
-            if node in visited or depth > max_hops:
-                #print(">>> already visited or reach max hops")
-                continue                
-            
-            visited.add(node)
-            for edge in self.graph.edges(node, data=True):                
-                src, dest, edge_data = edge
-                #if edge_data.get('type') not in allowed_links:
-                #    continue
-                if edge[1] not in visited and depth + 1 < max_hops:                    
-                    queue.append((edge[1], depth + 1, path+"/"+src, data))
-                    #print("-" * (indent+4), f"start travel from {src} to {dest}")    
-        
-        #STEP 2: 找到这些疾病对应的科室,检查和药品       
-        #由于这部分是按照疾病逐一去寻找,所以实际应用中可以缓存这些结果
-        for disease in results.keys():
-            queue = [(disease, 0, {'allowed_types': DEPARTMENT, 'allowed_links':['belongs_to']})]
-            visited = set()
-            while queue:
-                node, depth, data = queue.pop(0)                
-                indent = depth * 4
-                if node in visited or depth > max_hops:
-                    #print(">>> already visited or reach max hops")  
-                    continue                 
-                
-                visited.add(node)
-                node_type = self.graph.nodes[node].get('type')
-                if node_type in DEPARTMENT:
-                    #展开科室,重复次数为疾病出现的次数,为了方便后续统计
-                    department_data = [node] * results[disease]["count"]
-                    # if results[disease]["count"] > 1:
-                    #     print("展开了科室", node, "次数", results[disease]["count"], "次")
-                    if 'department' in results[disease].keys():
-                        results[disease]["department"] = results[disease]["department"] + department_data
-                    else:
-                        results[disease]["department"] = department_data
-                    continue
-                if node_type in CHECK:
-                    if 'check' in results[disease].keys():
-                        results[disease]["check"] = list(set(results[disease]["check"]+[node]))
-                    else:
-                        results[disease]["check"] = [node]
-                    continue
-                if node_type in DRUG:
-                    if 'drug' in results[disease].keys():
-                        results[disease]["drug"] = list(set(results[disease]["drug"]+[node]))
-                    else:
-                        results[disease]["drug"] = [node]
-                    continue
-                for edge in self.graph.edges(node, data=True):                
-                    src, dest, edge_data = edge
-                    #if edge_data.get('type') not in allowed_links:
-                    #    continue
-                    if edge[1] not in visited and depth + 1 < max_hops:                    
-                        queue.append((edge[1], depth + 1, data))
-                        #print("-" * (indent+4), f"start travel from {src} to {dest}")    
-        #STEP 3: 对于结果按照科室维度进行汇总
-        final_results = {}
-        total = 0
-        for disease in results.keys():
-            if 'department' in results[disease].keys():
-                total += 1
-                for department in results[disease]["department"]:
-                    if not department in final_results.keys():
-                        final_results[department] = {
-                            "diseases": [disease],
-                            "checks": results[disease].get("check",[]), 
-                            "drugs": results[disease].get("drug",[]),
-                            "count": 1
-                        }
-                    else:
-                        final_results[department]["diseases"] = final_results[department]["diseases"]+[disease]
-                        final_results[department]["checks"] = final_results[department]["checks"]+results[disease].get("check",[])
-                        final_results[department]["drugs"] = final_results[department]["drugs"]+results[disease].get("drug",[])
-                        final_results[department]["count"] += 1
-        
-        for department in final_results.keys():
-            final_results[department]["score"] = final_results[department]["count"] / total
-        #STEP 4: 对于final_results里面的disease,checks和durgs统计出现的次数并且按照次数降序排序
-        def sort_data(data, count=5):
-            tmp = {}
-            for item in data:
-                if item in tmp.keys():
-                    tmp[item]["count"] +=1
-                else:
-                    tmp[item] = {"count":1}
-            sorted_data = sorted(tmp.items(), key=lambda x:x[1]["count"],reverse=True)
-            return sorted_data[:count]
-        
-        for department in final_results.keys():
-            final_results[department]['name'] = department
-            final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
-            final_results[department]["checks"] = sort_data(final_results[department]["checks"])
-            final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
-        
-        sorted_final_results = sorted(final_results.items(), key=lambda x:x[1]["count"],reverse=True)
-        
-        #STEP 5: 对于final_results里面的checks和durgs统计全局出现的次数并且按照次数降序排序
-        checks = {}
-        drugs ={}
-        total_check = 0
-        total_drug = 0
-        for department in final_results.keys():
-            for check, data in final_results[department]["checks"]:
-                total_check += 1
-                if check in checks.keys():
-                    checks[check]["count"] += data["count"]
-                else:
-                    checks[check] = {"count":data["count"]}
-            for drug, data in final_results[department]["drugs"]:
-                total_drug += 1
-                if drug in drugs.keys():
-                    drugs[drug]["count"] += data["count"]
-                else:
-                    drugs[drug] = {"count":data["count"]}
-        
-        sorted_checks = sorted(checks.items(), key=lambda x:x[1]["count"],reverse=True)
-        sorted_drugs = sorted(drugs.items(), key=lambda x:x[1]["count"],reverse=True)
-        #STEP 6: 整合数据并返回
-            # if "department" in item.keys():
-            #     final_results["department"] = list(set(final_results["department"]+item["department"]))
-            # if "diseases" in item.keys():
-            #     final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
-            # if "checks" in item.keys():
-            #     final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
-            # if "drugs" in item.keys():
-            #     final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
-            # if "symptoms" in item.keys():
-            #     final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
-        return {"details":sorted_final_results, 
-                "checks":sorted_checks, "drugs":sorted_drugs, 
-                "total_checks":total_check, "total_drugs":total_drug}

+ 1 - 1
community/community_report.py

@@ -186,7 +186,7 @@ def generate_report(G, partition):
 
 if __name__ == "__main__":
     try:
-        from graph_helper2 import GraphHelper
+        from graph_helper import GraphHelper
         graph_helper = GraphHelper()
         G = graph_helper.graph
         print("graph loaded")

+ 15 - 7
community/dump_graph_data.py

@@ -67,7 +67,7 @@ def get_names(src_id, dest_id):
 
 def get_relationships():
     #COUNT_SQL = "select count(*) from kg_edges where version=:version"
-    COUNT_SQL = "select count(*) from kg_edges"
+    COUNT_SQL = "select count(*) from kg_edges where status=0"
     result = db.execute(text(COUNT_SQL))
     count = result.scalar()
 
@@ -78,14 +78,20 @@ def get_relationships():
     file_index = 1
     while start < count:    
         #sql = """select id,name,category,src_id,dest_id from kg_edges where version=:version order by id limit :batch OFFSET :start"""
-        sql = """select id,name,category,src_id,dest_id from kg_edges order by id limit :batch OFFSET :start"""
+        sql = """select id,name,category,src_id,dest_id from kg_edges where status=0 order by id limit :batch OFFSET :start"""
         result = db.execute(text(sql), {'start':start, 'batch':batch})
         #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
         row_count = 0
         for row in result:
             id,name,category,src_id,dest_id = row
             props = get_props(id)
-            src_id, src_name, src_category, dest_id, dest_name, dest_category = get_names(src_id, dest_id)
+            #如果get_names异常,跳过
+            try:
+                src_id, src_name, src_category, dest_id, dest_name, dest_category = get_names(src_id, dest_id)
+            except Exception as e:
+                print(e)
+                print("src_id: ", src_id, "dest_id: ", dest_id)
+                continue
             #src_name或dest_name为空,说明节点不存在,跳过
             if src_name == "" or dest_name == "":
                 continue
@@ -104,7 +110,9 @@ def get_relationships():
     with open(current_path+"\\relationship_med_0.json", "w", encoding="utf-8") as f:
         f.write(json.dumps(edges, ensure_ascii=False,indent=4))
 
-#导出节点数据
-get_entities()
-#导出关系数据
-get_relationships()
+
+if __name__ == "__main__":
+    #导出节点数据
+    get_entities()
+    #导出关系数据
+    get_relationships()

+ 1 - 23
community/graph_helper.py

@@ -24,7 +24,7 @@ def load_entity_data():
         return entities
 
 def load_relation_data(g):
-    for i in range(0):
+    for i in range(89):
         if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
             continue
         if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
@@ -32,12 +32,8 @@ def load_relation_data(g):
             with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
                 relations = json.load(f)
                 for item in relations:
-
                     if item[0] is None or item[1] is None or item[2] is None:
                         continue
-                    #删除item[2]['weight']属性
-                    if 'weight' in item[2]:
-                        del item[2]['weight']
                     g.add_edge(item[0], item[1], weight=1, **item[2])
 
 
@@ -89,24 +85,6 @@ class GraphHelper:
                     **n
                 })
         return results
-        
-        # for n in self.graph.nodes(data=True):
-        #     match = True
-        #     if node_id and n[0] != node_id:
-        #         continue
-        #     if node_type and n[1].get('type') != node_type:
-        #         continue
-        #     if filters:
-        #         for k, v in filters.items():
-        #             if n[1].get(k) != v:
-        #                 match = False
-        #                 break
-        #     if match:
-        #         results.append({
-        #             'id': n[0],
-        #             **n[1]
-        #         })
-        return results
 
     def edge_search(self, source=None, target=None, edge_type=None, min_weight=0):
         """边检索功能"""

+ 0 - 273
community/graph_helper2.bak

@@ -1,273 +0,0 @@
-"""
-医疗知识图谱助手模块
-
-本模块提供构建医疗知识图谱、执行社区检测、路径查找等功能
-
-主要功能:
-1. 构建医疗知识图谱
-2. 支持节点/关系检索
-3. 社区检测
-4. 路径查找
-5. 邻居分析
-"""
-import networkx as nx
-import argparse
-import json
-from tabulate import tabulate
-import leidenalg
-import igraph as ig
-import sys,os
-
-from db.session import get_db
-from service.kg_node_service import KGNodeService
-
-# 当前工作路径
-current_path = os.getcwd()
-sys.path.append(current_path)
-
-# Leiden算法社区检测的分辨率参数,控制社区划分的粒度
-RESOLUTION = 0.07
-
-# 图谱数据缓存路径(由dump_graph_data.py生成)
-CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
-
-def load_entity_data():
-    """
-    加载实体数据
-    
-    返回:
-        list: 实体数据列表,每个元素格式为[node_id, attributes_dict]
-    """
-    print("load entity data")
-    with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f:
-        entities = json.load(f)
-        return entities
-
-def load_relation_data(g):
-    """
-    分块加载关系数据
-    
-    参数:
-        g (nx.Graph): 要添加边的NetworkX图对象
-    
-    说明:
-        1. 支持分块加载多个关系文件(relationship_med_0.json ~ relationship_med_29.json)
-        2. 每个关系项格式为[source, target, {relation_attrs}]
-    """
-    for i in range(89):
-        if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
-            print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"))
-            with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
-       
-                relations = json.load(f)
-                for item in relations:
-                    # 添加带权重的边,并存储关系属性
-                    weight = int(item[2].pop('weight', '8').replace('权重:', ''))
-                    #如果item[0]或者item[1]为空或null,则跳过
-                    if item[0] is None or item[1] is None:
-                        continue
-                    g.add_edge(item[0], item[1], weight=weight, **item[2])
-
-class GraphHelper:
-    """
-    医疗知识图谱助手类
-    
-    功能:
-        - 构建医疗知识图谱
-        - 支持节点/关系检索
-        - 社区检测
-        - 路径查找
-        - 邻居分析
-    
-    属性:
-        graph: NetworkX图对象,存储知识图谱
-    """
-    
-    def __init__(self):
-        """
-        初始化方法
-        
-        功能:
-            1. 初始化graph属性为None
-            2. 调用build_graph()方法构建知识图谱
-        """
-        self.graph = None
-        self.build_graph()
-
-    def build_graph(self):
-        """构建知识图谱
-        
-        步骤:
-            1. 初始化空图
-            2. 加载实体数据作为节点
-            3. 加载关系数据作为边
-        """
-        self.graph = nx.Graph()
-        
-        # 加载节点数据(疾病、症状等)
-        entities = load_entity_data()
-        for item in entities:
-            node_id = item[0]
-            attrs = item[1]
-            self.graph.add_node(node_id, **attrs)
-        
-        # 加载边数据(疾病-症状关系等)
-        load_relation_data(self.graph)
-
-    def node_search2(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
-        """节点检索功能"""
-
-        kg_node_service = KGNodeService(next(get_db()))
-        es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
-        results = []
-        for item in es_result:
-            n = self.graph.nodes.get(item["title"])
-            score = item["score"]
-            if n:
-                results.append({
-                    'id': item["title"],
-                    'score': score,
-                    **n
-                })
-        return results
-
-    def node_search(self, node_id=None, node_type=None, filters=None):
-        """节点检索
-        
-        参数:
-            node_id (str): 精确匹配节点ID
-            node_type (str): 按节点类型过滤
-            filters (dict): 自定义属性过滤,格式为{属性名: 期望值}
-            
-        返回:
-            list: 匹配的节点列表,每个节点包含id和所有属性
-        """
-        results = []
-        
-        # 遍历所有节点进行多条件过滤
-        for n in self.graph.nodes(data=True):
-            match = True
-            if node_id and n[0] != node_id:
-                continue
-            if node_type and n[1].get('type') != node_type:
-                continue
-            if filters:
-                for k, v in filters.items():
-                    if n[1].get(k) != v:
-                        match = False
-                        break
-            if match:
-                results.append({
-                    'id': n[0],
-                    **n[1]
-                })
-        return results
-
-    def neighbor_search(self, center_node, hops=2):
-        """邻居节点检索
-        
-        参数:
-            center_node (str): 中心节点ID
-            hops (int): 跳数(默认2跳)
-            
-        返回:
-            tuple: (邻居实体列表, 关联关系列表)
-            
-        算法说明:
-            使用BFS算法进行层级遍历,时间复杂度O(k^d),其中k为平均度数,d为跳数
-        """
-        # 执行BFS遍历
-        visited = {center_node: 0}
-        queue = [center_node]
-        relations = []
-        
-        while queue:
-            try:
-                current = queue.pop(0)
-                current_hop = visited[current]
-                
-                if current_hop >= hops:
-                    continue
-                
-                # 遍历相邻节点
-                for neighbor in self.graph.neighbors(current):
-                    if neighbor not in visited:
-                        visited[neighbor] = current_hop + 1
-                        queue.append(neighbor)
-                    
-                    # 记录边关系
-                    edge_data = self.graph.get_edge_data(current, neighbor)
-                    relations.append({
-                        'src_name': current,
-                        'dest_name': neighbor,
-                        **edge_data
-                    })
-            except Exception as e:
-                print(f"Error processing node {current}: {str(e)}")
-                continue
-        
-        # 提取邻居实体(排除中心节点)
-        entities = [
-            {'id': n, **self.graph.nodes[n]}
-            for n in visited if n != center_node
-        ]
-        
-        return entities, relations
-
-    def detect_communities(self):
-        """使用Leiden算法进行社区检测
-        
-        返回:
-            tuple: (添加社区属性的图对象, 社区划分结果)
-            
-        算法说明:
-            1. 将NetworkX图转换为igraph格式
-            2. 使用Leiden算法(分辨率参数RESOLUTION=0.07)
-            3. 将社区标签添加回原始图
-            4. 时间复杂度约为O(n log n)
-        """
-        # 转换图格式
-        ig_graph = ig.Graph.from_networkx(self.graph)
-        
-        # 执行Leiden算法
-        partition = leidenalg.find_partition(
-            ig_graph, 
-            leidenalg.CPMVertexPartition,
-            resolution_parameter=RESOLUTION,
-            n_iterations=2
-        )
-        
-        # 添加社区属性
-        for i, node in enumerate(self.graph.nodes()):
-            self.graph.nodes[node]['community'] = partition.membership[i]
-        
-        return self.graph, partition
-
-    def find_paths(self, source, target, max_paths=5):
-        """查找所有简单路径
-        
-        参数:
-            source (str): 起始节点
-            target (str): 目标节点
-            max_paths (int): 最大返回路径数
-            
-        返回:
-            dict: 包含最短路径和所有路径的结果字典
-            
-        注意:
-            使用Yen算法寻找top k最短路径,时间复杂度O(kn(m + n log n))
-        """
-        result = {'shortest_path': [], 'all_paths': []}
-        
-        try:
-            # 使用Dijkstra算法找最短路径
-            shortest_path = nx.shortest_path(self.graph, source, target, weight='weight')
-            result['shortest_path'] = shortest_path
-            
-            # 使用Yen算法找top k路径
-            all_paths = list(nx.shortest_simple_paths(self.graph, source, target, weight='weight'))[:max_paths]
-            result['all_paths'] = all_paths
-        except nx.NetworkXNoPath:
-            pass
-            
-        return result

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 0 - 2535071
community/web/cached_data/entities_med.json


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 0 - 40343
community/web/cached_data/relationship_med_0.json


+ 53 - 0
config/site.py

@@ -0,0 +1,53 @@
+import os
+from dotenv import load_dotenv
+from urllib.parse import quote
+
+load_dotenv()
+
+
+class SiteConfig:
+    def __init__(self):
+        self.load_config()
+    
+    def load_config(self):        
+        self.config = {
+            "SITE_NAME": os.getenv("SITE_NAME", "DEMO"),
+            "SITE_DESCRIPTION": os.getenv("SITE_DESCRIPTION", "ChatGPT"),
+            "SITE_URL": os.getenv("SITE_URL", ""),
+            "SITE_LOGO": os.getenv("SITE_LOGO", ""),
+            "SITE_FAVICON": os.getenv("SITE_FAVICON"),
+            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
+            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
+            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD"),
+            'WORD_INDEX': os.getenv("WORD_INDEX"),
+            'TITLE_INDEX': os.getenv("TITLE_INDEX"),
+            'CHUNC_INDEX': os.getenv("CHUNC_INDEX"),
+            'DEEPSEEK_API_URL': os.getenv("DEEPSEEK_API_URL"),
+            'DEEPSEEK_API_KEY': os.getenv("DEEPSEEK_API_KEY"),
+            'CACHED_DATA_PATH': os.getenv("CACHED_DATA_PATH"),
+            'UPDATE_DATA_PATH': os.getenv("UPDATE_DATA_PATH"),
+            'FACTOR_DATA_PATH': os.getenv("FACTOR_DATA_PATH"),
+            'GRAPH_API_URL': os.getenv("GRAPH_API_URL"),
+            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL"),
+            'DOC_PATH': os.getenv("DOC_PATH"),
+            'DOC_STORAGE_PATH': os.getenv("DOC_STORAGE_PATH"),
+            'TRUNC_OUTPUT_PATH': os.getenv("TRUNC_OUTPUT_PATH"),
+            'DOC_ABSTRACT_OUTPUT_PATH': os.getenv("DOC_ABSTRACT_OUTPUT_PATH"),
+            'JIEBA_USER_DICT': os.getenv("JIEBA_USER_DICT"),
+            'JIEBA_STOP_DICT': os.getenv("JIEBA_STOP_DICT"),
+            'POSTGRESQL_HOST':  os.getenv("POSTGRESQL_HOST","localhost"),
+            'POSTGRESQL_DATABASE':  os.getenv("POSTGRESQL_DATABASE","kg"),
+            'POSTGRESQL_USER':  os.getenv("POSTGRESQL_USER","dify"),
+            'POSTGRESQL_PASSWORD':  os.getenv("POSTGRESQL_PASSWORD",quote("difyai123456")),
+        }
+    def get_config(self, config_name): 
+        config_name = config_name.upper()       
+        if config_name in self.config:            
+            return self.config[config_name]
+        else:
+            return None
+    def check_config(self, config_list):
+        for item in config_list:
+            if not self.get_config(item):
+                raise ValueError(f"Configuration '{item}' is not set.")
+      

+ 2 - 2
db/session.py

@@ -11,13 +11,13 @@ DB_HOST = os.getenv("DB_HOST", "173.18.12.203")
 DB_PORT = os.getenv("DB_PORT", "5432")
 DB_USER = os.getenv("DB_USER", "knowledge")
 DB_PASS = os.getenv("DB_PASSWORD", "qwer1234.")
-DB_NAME = os.getenv("DB_NAME", "postgres")
+DB_NAME = os.getenv("DB_NAME", "medkg")
 
 DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
 
 engine = create_engine(
     DATABASE_URL,
-    pool_size=20,
+    pool_size=30,
     max_overflow=10,
     pool_pre_ping=True,
     connect_args={'options': '-c search_path=public'},

+ 5 - 1
main.py

@@ -6,6 +6,8 @@ from typing import Optional, Set
 # 导入FastAPI及相关模块
 import os
 import uvicorn
+
+from agent.cdss.capbility import CDSSCapability
 from router.knowledge_dify import dify_kb_router
 from router.knowledge_saas import saas_kb_router
 from router.text_search import text_search_router
@@ -14,7 +16,7 @@ from router.knowledge_nodes_api import knowledge_nodes_api_router
 
 # 配置日志
 logging.basicConfig(
-    level=logging.INFO,
+    level=logging.ERROR,
     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
     handlers=[
         logging.StreamHandler(),
@@ -118,6 +120,8 @@ async def interceptor_middleware(request: Request, call_next):
     return response
 
 
+#capability = CDSSCapability()
+
 if __name__ == "__main__":
     logger.info('Starting uvicorn server...2222')
     #uvicorn main:app --host 0.0.0.0 --port 8000 --reload

+ 23 - 128
router/graph_router.py

@@ -1,6 +1,7 @@
 import sys,os
 
-from community.graph_helper import GraphHelper
+from agent.cdss.capbility import CDSSCapability
+from agent.cdss.models.schemas import CDSSInput, CDSSInt, CDSSText
 from model.response import StandardResponse
 
 current_path = os.getcwd()
@@ -17,20 +18,21 @@ import json
 
 
 router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
-graph_helper = GraphHelper()
-
-
 
 @router.get("/nodes/recommend", response_model=StandardResponse)
 async def recommend(
     chief: str
 ):
+    start_time = time.time()
     app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
     conversation_id = get_conversation_id(app_id)
     result = call_chat_api(app_id, conversation_id, chief)
     json_data = json.loads(result)
-    keyword = " ".join(json_data["chief_complaint"])
-    return await neighbor_search(keyword=keyword, neighbor_type='Check',limit=10)
+    keyword = " ".join(json_data["symptoms"])
+    result = await neighbor_search(keyword=keyword, neighbor_type='Check', limit=10)
+    end_time = time.time()
+    print(f"recommend执行完成,耗时:{end_time - start_time:.2f}秒")
+    return result;
 
 
 @router.get("/nodes/neighbor_search", response_model=StandardResponse)
@@ -45,132 +47,22 @@ async def neighbor_search(
     根据关键词和属性过滤条件搜索图谱节点
     """
     try:
-        start_time = time.time()
-        print(f"开始执行neighbor_search,参数:keyword={keyword}, limit={limit}, node_type={node_type}, neighbor_type={neighbor_type}, min_degree={min_degree}")
-        scores_factor = 1.7
-        results = []
-        diseases = {}
-
-        has_good_result = False
-
-        if not has_good_result:
-            keywords = keyword.split(" ")
-            new_results = []
-            for item in keywords:
-                if len(item) > 1:
-                    results = graph_helper.node_search2(
-                        item,
-                        limit=limit,
-                        node_type=node_type,
-                        min_degree=min_degree
-                    )
-
-                    for result_item in results:
-                        if result_item["score"] > scores_factor:
-                            new_results.append(result_item)
-                            if result_item["type"] == "Disease":
-                                if result_item["id"] not in diseases:
-                                    diseases[result_item["id"]] =  {
-                                                        "id":result_item["id"],
-                                                        "type":1,
-                                                        "count":1
-                                                    }
-                                else:
-                                    diseases[result_item["id"]]["count"] = diseases[result_item["id"]]["count"] + 1
-                                has_good_result = True
-            results = new_results
-            print("扩展搜索的结果数量:",len(results))
-
-        neighbors_data = {}
 
-        for item in results:
-            entities, relations = graph_helper.neighbor_search(item["id"], 1)
-            max = 20 #因为类似发热这种疾病会有很多关联的疾病,所以需要防止检索范围过大,设置了上限
-            for neighbor in entities:
-
-                #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
-                if "type" not in neighbor.keys():
-                    continue
-                if neighbor["type"] == neighbor_type:
-                    #如果这里正好找到了要求检索的节点类型
-                    if neighbor["id"] not in neighbors_data:
-                        neighbors_data[neighbor["id"]] =  {
-                                            "id":neighbor["id"],
-                                            "type":neighbor["type"],
-                                            "count":1
-                                        }
-                    else:
-                         neighbors_data[neighbor["id"]]["count"] = neighbors_data[neighbor["id"]]["count"] + 1
-                else:
-                    #如果这里找到的节点是个疾病,那么就再检索一层,看看是否有符合要求的节点类型
-                    if neighbor["type"] == "Disease":
-                        if neighbor["id"] not in diseases:
-                            diseases[neighbor["id"]] =  {
-                                                "id":neighbor["id"],
-                                                "type":"Disease",
-                                                "count":1
-                                            }
-                        else:
-                            diseases[neighbor["id"]]["count"] = diseases[neighbor["id"]]["count"] + 1
-                        disease_entities, relations = graph_helper.neighbor_search(neighbor["id"], 1)
-                        for disease_neighbor in disease_entities:
-                            #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
-                            if "type" in disease_neighbor.keys():
-                                if disease_neighbor["type"] == neighbor_type:
-                                    if disease_neighbor["id"] not in neighbors_data:
-                                        neighbors_data[disease_neighbor["id"]] = {
-                                            "id":disease_neighbor["id"],
-                                            "type":disease_neighbor["type"],
-                                            "count":1
-                                        }
-                                    else:
-                                        neighbors_data[disease_neighbor["id"]]["count"] = neighbors_data[disease_neighbor["id"]]["count"] + 1
-                        #最多搜索的范围是max个疾病
-                        max = max - 1
-                        if max == 0:
-                            break
-        disease_data = [diseases[k] for k in diseases]
-        disease_data = sorted(disease_data, key=lambda x:x["count"],reverse=True)
-        data = [neighbors_data[k] for k in neighbors_data if neighbors_data[k]["type"] == "Check"]
-
-        data = sorted(data, key=lambda x:x["count"],reverse=True)
-
-        if len(data) > 10:
-            data = data[:10]
-            factor = 1.0
-            total = 0.0
-            for item in data:
-                total = item["count"] * factor + total
-            for item in data:
-                item["count"] = item["count"] / total
-            factor = factor * 0.9
-
-        if len(disease_data) > 10:
-            disease_data = disease_data[:10]
-            factor = 1.0
-            total = 0.0
-            for item in disease_data:
-                total = item["count"] * factor + total
-            for item in disease_data:
-                item["count"] = item["count"] / total
-            factor = factor * 0.9
+        print(f"开始执行neighbor_search,参数:keyword={keyword}, limit={limit}, node_type={node_type}, neighbor_type={neighbor_type}, min_degree={min_degree}")
+        keywords = keyword.split(" ")
 
-        for item in data:
-            item["type"] = 3
-            item["name"] = item["id"]
-            item["rate"] = round(item["count"] * 100, 2)
-        for item in disease_data:
-            item["type"] = 1
-            item["name"] = item["id"]
-            item["rate"] = round(item["count"] * 100, 2)
-        end_time = time.time()
-        print(f"neighbor_search执行完成,耗时:{end_time - start_time:.2f}秒")
+        record = CDSSInput(
+            #pat_age=CDSSInt(type="month", value=24),
+            #pat_sex=CDSSText(type="sex", value="女"),
+            chief_complaint=keywords,
+        )
+        # 使用从main.py导入的capability实例处理CDSS逻辑
+        output = capability.process(input=record)
 
+        print(output.diagnosis.value)
         return StandardResponse(
             success=True,
-            records={"可能诊断":disease_data,"推荐检验":data},
-            error_code = 0,
-            error_msg=f"Found {len(results)} nodes"
+            data={"可能诊断":output.diagnosis.value,"可能诊断2":output.diagnosis2.value,"推荐检验":output.checks.value,"症状":keywords}
         )
     except Exception as e:
         print(e)
@@ -180,5 +72,8 @@ async def neighbor_search(
             error_code=500,
             error_msg=str(e)
         )
-
+capability = CDSSCapability()
+#def get_capability():
+    #from main import capability
+    #return capability
 graph_router = router

+ 1 - 1
router/knowledge_saas.py

@@ -49,7 +49,7 @@ class VectorSearchRequest(BaseModel):
 class NodeRelationshipRequest(BaseModel):
     src_id: int
 
-@router.post("/nodes/paginated_search2", response_model=StandardResponse)
+@router.post("/nodes/paginated_search", response_model=StandardResponse)
 async def paginated_search(
     payload: PaginatedSearchRequest,
     db: Session = Depends(get_db)

+ 6 - 7
service/kg_edge_service.py

@@ -69,7 +69,7 @@ class KGEdgeService:
             logger.error(f"删除边失败: {str(e)}")
             raise ValueError("Delete failed")
 
-    def get_edges_by_nodes(self, src_id: Optional[int], dest_id: Optional[int], and_logic: bool = True):
+    def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None):
         if src_id is None and dest_id is None:
             raise ValueError("至少需要提供一个有效的查询条件")
         try:
@@ -78,11 +78,10 @@ class KGEdgeService:
                 filters.append(KGEdge.src_id == src_id)
             if dest_id is not None:
                 filters.append(KGEdge.dest_id == dest_id)
+            if category is not None:
+                filters.append(KGEdge.category == category)
+            edges = self.db.query(KGEdge).filter(*filters).all()
 
-            if and_logic:
-                edges = self.db.query(KGEdge).filter(*filters).all()
-            else:
-                edges = self.db.query(KGEdge).filter(or_(*filters)).all()
             from service.kg_node_service import KGNodeService
             node_service = KGNodeService(self.db)
             result = []
@@ -94,8 +93,8 @@ class KGEdgeService:
                         'dest_id': edge.dest_id,
                         'name': edge.name,
                         'version': edge.version,
-                        'src_node': node_service.get_node(edge.src_id),
-                        'dest_node': node_service.get_node(edge.dest_id)
+                        #'src_node': node_service.get_node(edge.src_id),
+                        #'dest_node': node_service.get_node(edge.dest_id)
                     }
                     result.append(edge_info)
                 except ValueError as e:

+ 4 - 4
service/kg_node_service.py

@@ -13,7 +13,7 @@ from service.kg_edge_service import KGEdgeService
 
 logger = logging.getLogger(__name__)
 DISTANCE_THRESHOLD = 0.65
-DISTANCE_THRESHOLD2 = 0.3
+DISTANCE_THRESHOLD2 = 0.4
 class KGNodeService:
     def __init__(self, db: Session):
         self.db = db
@@ -36,8 +36,7 @@ class KGNodeService:
                 KGNode.embedding.l2_distance(query_embedding).label('distance')
             )
             .filter(KGNode.status == 0)
-            #过滤掉version不等于'er'的节点
-            .filter(KGNode.version != 'ER')
+            #todo 是否能提高性能 改成余弦算法
             .filter(KGNode.embedding.l2_distance(query_embedding) <= DISTANCE_THRESHOLD2)
             .order_by('distance').limit(top_k).all()
         )
@@ -198,10 +197,11 @@ class KGNodeService:
 
         while True:
             try:
+                #下面的查询语句,增加根据id排序,防止并发问题
                 nodes = self.db.query(KGNode).filter(
                     #KGNode.version == 'ER',
                     KGNode.embedding == None
-                ).offset(offset).limit(batch_size).all()
+                ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
 
                 if not nodes:
                     break

+ 5 - 2
service/kg_prop_service.py

@@ -11,9 +11,12 @@ class KGPropService:
     def __init__(self, db: Session):
         self.db = db
 
-    def get_props_by_ref_id(self, ref_id: int) -> List[dict]:
+    def get_props_by_ref_id(self, ref_id: int, prop_name: str = None) -> List[dict]:
         try:
-            props = self.db.query(KGProp).filter(KGProp.ref_id == ref_id).all()
+            query = self.db.query(KGProp).filter(KGProp.ref_id == ref_id)
+            if prop_name:
+                query = query.filter(KGProp.prop_name == prop_name)
+            props = query.all()
             return [{
                 'id': p.id,
                 'category': p.category,

+ 0 - 44
tests/community/test_dump_graph_data.py

@@ -1,44 +0,0 @@
-import unittest
-import unittest
-import json
-import os
-from community.dump_graph_data import get_props, get_entities, get_relationships
-
-class TestDumpGraphData(unittest.TestCase):
-    @classmethod
-    def setUpClass(cls):
-        cls.test_data_path = os.path.join(os.getcwd(), 'web', 'cached_data')
-        os.makedirs(cls.test_data_path, exist_ok=True)
-
-    def test_get_props(self):
-        """测试属性获取方法,验证多属性合并逻辑"""
-        props = get_props('1')  # 使用真实节点ID
-        self.assertIsInstance(props, dict)
-        self.assertTrue(len(props) > 0)
-
-    def test_entity_export(self):
-        """验证实体导出的分页查询和JSON序列化"""
-        get_entities()
-        
-        # 验证生成的实体文件
-        with open(os.path.join(self.test_data_path, 'entities_med.json'), 'r', encoding='utf-8') as f:
-            entities = json.load(f)
-            self.assertTrue(len(entities) > 0)
-
-    def test_relationship_chunking(self):
-        """测试关系数据分块写入逻辑,验证每批处理1000条"""
-        get_relationships()
-        
-        # 验证生成的关系文件
-        file_count = len([f for f in os.listdir(self.test_data_path) 
-                         if f.startswith('relationship_med_')])
-        self.assertTrue(file_count > 0)
-
-    #def tearDown(self):
-        # 清理测试生成的文件
-        #for f in os.listdir(self.test_data_path):
-            #if f.startswith(('entities_med', 'relationship_med_')):
-                #os.remove(os.path.join(self.test_data_path, f))
-
-if __name__ == '__main__':
-    unittest.main()

+ 0 - 97
tests/community/test_graph_helper.py

@@ -1,97 +0,0 @@
-import unittest
-import json
-import os
-from community.graph_helper2 import GraphHelper
-
-class TestGraphHelper(unittest.TestCase):
-    """
-    图谱助手测试套件
-    测试数据路径:web/cached_data 下的实际医疗知识图谱数据
-    """
-
-    @classmethod
-    def setUpClass(cls):
-        # 初始化图谱助手并构建图谱
-        cls.helper = GraphHelper()
-        cls.test_node = "感染性发热"  # 使用实际存在的测试节点
-        cls.test_community_node = "糖尿病"  # 用于社区检测的测试节点
-
-    def test_graph_construction(self):
-        """
-        测试图谱构建完整性
-        验证点:节点和边数量应大于0
-        """
-        node_count = len(self.helper.graph.nodes)
-        edge_count = len(self.helper.graph.edges)
-        self.assertGreater(node_count, 0, "节点数量应大于0")
-        self.assertGreater(edge_count, 0, "边数量应大于0")
-
-    def test_node_search(self):
-        """
-        测试节点搜索功能
-        场景:1.按节点ID精确搜索 2.按类型过滤 3.自定义属性过滤
-        """
-        # 精确ID搜索
-        result = self.helper.node_search(node_id=self.test_node)
-        self.assertEqual(len(result), 1, "应找到唯一匹配节点")
-
-        # 类型过滤搜索
-        type_results = self.helper.node_search(node_type="症状")
-        self.assertTrue(all(item['type'] == "症状" for item in type_results))
-
-        # 自定义属性过滤
-        filter_results = self.helper.node_search(filters={"description": "发热病因"})
-        self.assertTrue(len(filter_results) >= 1)
-
-    def test_neighbor_search(self):
-        """
-        测试邻居检索功能
-        验证点:1.跳数限制 2.不包含中心节点 3.关系完整性
-        """
-        entities, relations = self.helper.neighbor_search(self.test_node, hops=2)
-        
-        self.assertFalse(any(e['name'] == self.test_node for e in entities),
-                        "结果不应包含中心节点")
-        
-        # 验证关系双向连接
-        for rel in relations:
-            self.assertTrue(
-                any(e['name'] == rel['src_name'] for e in entities) or
-                any(e['name'] == rel['dest_name'] for e in entities)
-            )
-
-    def test_path_finding(self):
-        """
-        测试路径查找功能
-        使用已知存在路径的节点对进行验证
-        """
-        target_node = "肺炎"
-        result = self.helper.find_paths(self.test_node, target_node)
-        
-        self.assertIn('shortest_path', result)
-        self.assertTrue(len(result['shortest_path']) >= 2)
-        
-        # 验证所有路径都包含起始节点
-        for path in result['all_paths']:
-            self.assertEqual(path[0], self.test_node)
-            self.assertEqual(path[-1], target_node)
-
-    def test_community_detection(self):
-        """
-        测试社区检测功能
-        验证点:1.社区标签存在 2.同类节点聚集 3.社区数量合理
-        """
-        graph, partition = self.helper.detect_communities()
-        
-        # 验证节点社区属性
-        test_node_community = graph.nodes[self.test_community_node]['community']
-        self.assertIsInstance(test_node_community, int)
-        
-        # 验证同类节点聚集(例如糖尿病相关节点)
-        diabetes_nodes = [n for n, attr in graph.nodes(data=True) 
-                         if attr.get('type') == "代谢疾病"]
-        communities = set(graph.nodes[n]['community'] for n in diabetes_nodes)
-        self.assertLessEqual(len(communities), 3, "同类节点应集中在少数社区")
-
-if __name__ == '__main__':
-    unittest.main()

+ 1 - 1
tests/service/test_kg_node_service.py

@@ -51,5 +51,5 @@ class TestPaginatedSearch:
         assert results['pagination']['pageNo'] == 1
 
 class TestBatchProcess:
-    def test_batch_process_er_nodes(self, kg_node_service, test_node_data):
+    def test_batch_process_nodes(self, kg_node_service, test_node_data):
         kg_node_service.batch_process_er_nodes()

+ 17 - 9
tests/test.py

@@ -1,19 +1,27 @@
-from cdss.capbility import CDSSCapability
-from cdss.models.schemas import CDSSInput, CDSSOutput, CDSSInt
+from agent.cdss.capbility import CDSSCapability
+from agent.cdss.models.schemas import CDSSInput, CDSSOutput, CDSSInt, CDSSText
 
 capability = CDSSCapability()
 
 record = CDSSInput(
-    pat_age=CDSSInt(type="month", value=21), 
-    pat_sex=CDSSInt(type="sex", value=1),
-    chief_complaint=["腹痛", "发热", "腹泻"],
+    pat_age=CDSSInt(type="month", value=24),
+    pat_sex=CDSSText(type="sex", value="男"),
+    #chief_complaint=["腹痛", "发热", "腹泻"],
+    #chief_complaint=["呕血", "黑便", "头晕", "心悸"],
+    #chief_complaint=["流鼻涕"],
+
+    department=CDSSText(type='department',value="急诊医学科")
     )
 
 if __name__ == "__main__":
     output = capability.process(input=record)
+    #output = capability.process(input=record)
+    print(output.departments.value)
     for item in output.diagnosis.value:
         print(f"DIAG {item}  {output.diagnosis.value[item]} ")
-    for item in output.checks.value:
-        print(f"CHECK {item}  {output.checks.value[item]} ")
-    for item in output.drugs.value:
-        print(f"DRUG {item}  {output.drugs.value[item]} ")
+    for item in output.diagnosis2.value:
+        print(f"count DIAG {item}  {output.diagnosis2.value[item]} ")
+    # for item in output.checks.value:
+    #     print(f"CHECK {item}  {output.checks.value[item]} ")
+    # for item in output.drugs.value:
+    #     print(f"DRUG {item}  {output.drugs.value[item]} ")

+ 5 - 19
utils/vectorizer.py

@@ -64,22 +64,8 @@ class Vectorizer:
             logger.error(f"API请求失败: {str(e)}")
             raise
 
-    if __name__ == '__main__':
-        text ='''姓名:李XX  
-            性别:女  
-            年龄:55岁  
-            住院号:NJZY20231102  
-            主诉:突发胸痛伴呼吸困难2小时  
-            现病史:患者于下午3时许突发胸痛,位于心前区,呈压榨性疼痛,伴呼吸困难,持续不缓解,无恶心呕吐及二便失禁。急诊测血压150/90mmHg,心率100次/分。心电图示II、III、aVF导联ST段弓背向上抬高0.3-0.5mV。发病前2周曾诉间断性胸闷,每次持续数分钟自行缓解。  
-            既往史:高血压8年(间断服用降压药),无糖尿病史,无手术史。长期吸烟(20包/年),饮酒史5年(红酒约150g/日)。  
-            体格检查:BP 155/92mmHg,心率102次/分。神志清楚,痛苦面容,双肺呼吸音清,未闻及干湿性啰音。心界不大,心率102次/分,律齐,心尖区可闻及2/6级收缩期杂音。  
-            辅助检查:心电图示II、III、aVF导联ST段弓背向上抬高0.3-0.5mV;心肌酶谱:肌酸激酶同工酶(CK-MB)100U/L,肌钙蛋白T(cTnT)0.4ng/ml;心脏超声示左室下壁运动减弱。  
-            诊断:急性下壁心肌梗死  
-            治疗计划:  
-            1. 抗血小板治疗:阿司匹林300mg嚼服,氯吡格雷300mg负荷剂量后75mg qd。  
-            2. 抗凝治疗:低分子量肝素4000U皮下注射,每12小时1次。  
-            3. 冠状动脉介入治疗(PCI):急诊行冠状动脉造影,必要时行支架植入术。  
-            4. 调脂治疗:阿托伐他汀20mg qn。  
-            5. 血管扩张剂:硝酸甘油静脉泵入。'''
-        embedding = get_embedding(text)
-        print(f'生成的embedding向量:\n{embedding}')
+if __name__ == "__main__":
+    text ='腹痛'
+    embedding = Vectorizer.get_embedding(text)
+    print(embedding)
+    print(len(embedding))