Selaa lähdekoodia

Merge branch 'feature/openapi' of http://173.18.12.196:3000/python/knowledge into feature/openapi

SGTY 2 viikkoa sitten
vanhempi
commit
58c1e0b6f2

+ 23 - 0
server.py

@@ -0,0 +1,23 @@
+# server.py
+from mcp.server.fastmcp import FastMCP
+
+# Create an MCP server
+mcp = FastMCP("Demo")
+
+
+# Add an addition tool
+@mcp.tool()
+def add(a: int, b: int) -> int:
+    """Add two numbers"""
+    return a + b
+
+
+# Add a dynamic greeting resource
+@mcp.resource("greeting://{name}")
+def get_greeting(name: str) -> str:
+    """Get a personalized greeting"""
+    return f"Hello, {name}!"
+
+if __name__ == "__main__":
+    print(111)
+    mcp.run()

+ 1 - 2
src/knowledge/.env

@@ -3,6 +3,5 @@ DB_NAME=medkg
 DB_PORT=5432
 DB_USER=knowledge
 DB_PASSWORD=qwer1234.
-
-license=E:\project\knowledge\src\knowledge\utils\license_issued
+LICENSE_PATH=E:\project\knowledge2\src\knowledge\utils\license_issued
 EMBEDDING_MODEL=E:\project\bge-m3

+ 7 - 28
src/knowledge/config/site.py

@@ -11,34 +11,13 @@ class SiteConfig:
     
     def load_config(self):        
         self.config = {
-            "SITE_NAME": os.getenv("SITE_NAME", "DEMO"),
-            "SITE_DESCRIPTION": os.getenv("SITE_DESCRIPTION", "ChatGPT"),
-            "SITE_URL": os.getenv("SITE_URL", ""),
-            "SITE_LOGO": os.getenv("SITE_LOGO", ""),
-            "SITE_FAVICON": os.getenv("SITE_FAVICON"),
-            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
-            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
-            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD"),
-            'WORD_INDEX': os.getenv("WORD_INDEX"),
-            'TITLE_INDEX': os.getenv("TITLE_INDEX"),
-            'CHUNC_INDEX': os.getenv("CHUNC_INDEX"),
-            'DEEPSEEK_API_URL': os.getenv("DEEPSEEK_API_URL"),
-            'DEEPSEEK_API_KEY': os.getenv("DEEPSEEK_API_KEY"),
-            'CACHED_DATA_PATH': os.getenv("CACHED_DATA_PATH"),
-            'UPDATE_DATA_PATH': os.getenv("UPDATE_DATA_PATH"),
-            'FACTOR_DATA_PATH': os.getenv("FACTOR_DATA_PATH"),
-            'GRAPH_API_URL': os.getenv("GRAPH_API_URL"),
-            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL"),
-            'DOC_PATH': os.getenv("DOC_PATH"),
-            'DOC_STORAGE_PATH': os.getenv("DOC_STORAGE_PATH"),
-            'TRUNC_OUTPUT_PATH': os.getenv("TRUNC_OUTPUT_PATH"),
-            'DOC_ABSTRACT_OUTPUT_PATH': os.getenv("DOC_ABSTRACT_OUTPUT_PATH"),
-            'JIEBA_USER_DICT': os.getenv("JIEBA_USER_DICT"),
-            'JIEBA_STOP_DICT': os.getenv("JIEBA_STOP_DICT"),
-            'POSTGRESQL_HOST':  os.getenv("POSTGRESQL_HOST","localhost"),
-            'POSTGRESQL_DATABASE':  os.getenv("POSTGRESQL_DATABASE","kg"),
-            'POSTGRESQL_USER':  os.getenv("POSTGRESQL_USER","dify"),
-            'POSTGRESQL_PASSWORD':  os.getenv("POSTGRESQL_PASSWORD",quote("difyai123456")),
+            "LICENSE_PATH": os.getenv("LICENSE_PATH", ""),        
+            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL",""),
+            'DB_HOST':  os.getenv("DB_HOST",""),
+            'DB_NAME':  os.getenv("DB_NAME",""),
+            'DB_PORT':  os.getenv("DB_PORT",""),
+            'DB_USER': os.getenv("DB_USER", ""),
+            'DB_PASSWORD': os.getenv("DB_PASSWORD", ""),
         }
     def get_config(self, config_name, default=None): 
         config_name = config_name.upper()     

+ 6 - 2
src/knowledge/model/response.py

@@ -5,6 +5,10 @@ class StandardResponse(BaseModel):
     success: bool
     requestId: Optional[str] = None
     errorCode: Optional[int] = None
+    #验证结果编码:0:成功;1:时间过期;2:并发过高;3:次数耗尽;9:系统异常
+    vaildCode: Optional[int] = None
     errorMsg: Optional[str] = None
-    records: Optional[Any] = None
-    data: Optional[Any] = None
+    data: Optional[Any] = None
+    
+    def getSuccess(self):
+        return self.success

+ 3 - 3
src/knowledge/router/knowledge_nodes_api.py

@@ -59,7 +59,7 @@ async def paginated_search(
             data=ObjectToJsonArrayConverter.convert(result)
         )
     except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
+        logger.exception(f"分页查询失败: {str(e)}")
         raise HTTPException(
             status_code=500,
             detail=StandardResponse(
@@ -107,7 +107,7 @@ async def get_node_relationships_condition(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 
@@ -155,7 +155,7 @@ async def get_node_relationships(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 class GetNodeProperties(BaseModel):

+ 64 - 19
src/knowledge/server.py

@@ -2,7 +2,7 @@ from contextlib import asynccontextmanager
 from datetime import datetime
 from typing import Optional
 
-from fastapi import FastAPI, Depends, Security, HTTPException
+from fastapi import FastAPI, Depends, Security, HTTPException, Query
 from py_tools.connections.http import AsyncHttpClient
 from py_tools.logging import logger
 from pydantic import BaseModel
@@ -13,6 +13,7 @@ from .db.session import get_db
 from .middlewares.base import register_middlewares
 from .model.response import StandardResponse
 from .router.knowledge_nodes_api import knowledge_nodes_api_router, get_request_id, api_key_header
+from .service.kg_edge_service import KGEdgeService
 from .service.kg_node_service import KGNodeService
 from .utils import log_util
 from .utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
@@ -39,26 +40,69 @@ async def health_check():
         "service": "knowledge-graph"
     }
 
-class Agent(BaseModel):
-    node_name: str
-    relation: str
-@app.post("/nodes/paginated_search", response_model=StandardResponse,operation_id="实体查询")
-async def paginated_search(
-    payload: Agent,
-    db: Session = Depends(get_db),
-    request_id: str = Depends(get_request_id),
-    api_key: str = Security(api_key_header)
-):
+
+@app.post("/test", operation_id="医疗知识图谱目标节点查询", summary="根据医疗知识图谱获取医疗相关信息",
+         description="""根据三元组的起始节点名称和关系名称,查询目标节点列表。
+         该接口主要用于医疗知识图谱查询场景,例如:通过输入疾病名称和相关关系类型,
+         返回该疾病对应的相关症状、治疗方法等信息。
+         典型应用场景包括:
+         - 症状查询:输入疾病名称和"疾病相关症状"关系
+         - 诊断依据查询:输入疾病名称和"诊断依据"关系
+         - 鉴别诊断查询:输入疾病名称和"疾病相关鉴别诊断"关系""",
+         response_description="""返回目标节点名称的字符串列表,格式为:
+         ["节点名称1", "节点名称2", ...]""")
+async def test(node_name: str = Query(...,
+                  description="""知识图谱三元组的起始节点名称,通常是疾病名称。
+                 示例值:感冒、高血压、糖尿病等""",
+                  example="糖尿病"),
+                node_category: str = Query(...,
+                  description="""知识图谱三元组的起始节点类型,通常是疾病。
+                 示例值:疾病、症状等""",
+                  example="疾病"),
+                relation_name: str= Query(...,
+                  description="""知识图谱三元组的关系名称,描述节点间的关系类型。
+                 常见关系类型包括:
+                 - 疾病相关症状
+                 - 诊断依据
+                 - 疾病相关鉴别诊断""",
+                  example="疾病相关症状"), 
+                db: Session = Depends(get_db)) -> list[str]:    
+    """
+    根据起始节点名称和关系名称查询目标节点名称列表
+    
+    参数:
+        node_name: 起始节点名称(通常是疾病名称)
+        relation_name: 关系类型名称
+        
+    返回:
+        目标节点名称的字符串列表,如果查询不到结果则返回空列表
+    """
     try:
         service = KGNodeService(db)
-        service.search_title_index()
-        return StandardResponse(
-            success=True,
-            requestId=request_id,
-            data=ObjectToJsonArrayConverter.convert(result)
-        )
+        search_params = {
+            'keyword': node_name,
+            'category': node_category,
+            'pageNo': 1,
+            'limit': 1,
+            'load_props': False,
+            'distance': 0.45,
+        }
+        node_list = service.paginated_search(search_params)
+        edge_service = KGEdgeService(db)
+        
+        results = []
+        if node_list and node_list.get('records'):
+            first_node = node_list['records'][0]
+            src_id = first_node['id']
+            edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None,
+                                                  name=relation_name)
+           
+            for edge in edges:
+                dest_node = edge['dest_node']
+                results.append(dest_node['name'])
+        return results
     except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
+        logger.exception(f"分页查询失败: {str(e)}")
         raise HTTPException(
             status_code=500,
             detail=StandardResponse(
@@ -68,6 +112,7 @@ async def paginated_search(
             )
         )
 
+
 async def init_setup():
     """初始化项目配置"""
 
@@ -79,7 +124,7 @@ async def startup():
     await init_setup()
 
     # 加载路由
-    #app.include_router(knowledge_nodes_api_router)
+    app.include_router(knowledge_nodes_api_router)
 
     logger.info("fastapi startup success")
 

+ 0 - 2
src/knowledge/service/kg_node_service.py

@@ -35,8 +35,6 @@ class KGNodeService:
                 KGNode.embedding.l2_distance(query_embedding).label('distance')
             )
             .filter(KGNode.status == 0)
-            #过滤掉version不等于'er'的节点
-            .filter(KGNode.version != 'ER')
             .filter(KGNode.embedding.l2_distance(query_embedding) <= DISTANCE_THRESHOLD2)
             .order_by('distance').limit(top_k).all()
         )

+ 65 - 0
src/knowledge/utils/DeepseekUtil.py

@@ -0,0 +1,65 @@
+import requests
+import json
+
+def chat(question):
+    url = "https://api.lkeap.cloud.tencent.com/v1/chat/completions"
+
+    payload = json.dumps({
+        "model": "deepseek-v3",
+        "messages": [
+            {
+                "role": "user",
+                "content": question
+            }
+        ]
+    }, ensure_ascii=False)
+    headers = {
+        'Content-Type': 'application/json',
+        'appid': '',
+        'Authorization': 'Bearer sk-vxnoy9XGv0xG3qZ2XfuMlKzY6eKB9XST1nTSn5PQJxDLKDjY'
+    }
+
+    response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8"))
+    #response.text是个json,获取choices数组中第一个元素的message的content字段
+    answer = json.loads(response.text)
+    answer = answer["choices"][0]["message"]["content"]
+
+    print(answer)
+    #返回answer
+    return answer
+
+
+
+if __name__ == '__main__':
+    print(chat('''## 核心能力
+1. 医学语义理解:
+   - 精准解析临床表现描述
+   - 将口语化表达转换为ICD-11标准术语
+   - 完整提取症状特征(性质/部位/放射等),保持症状描述的临床完整性
+
+2. 专业能力:
+   - 精通ICD-11症状学术语体系
+   - 识别症状关联性
+   - 具备临床鉴别诊断思维
+
+## 输出要求
+- 仅输出JSON格式结果,禁止添加解释性文字
+- 非症状不要抽取
+ - 症状术语必须标准化
+- 尽量保留原始症状的临床特征(如部位、性质、放射等)
+- 抽取的症状应该简洁明了,尽量保持在5个字符以内,最多不宜超过7个字符。如果超过可以分多个症状词进行抽取。
+
+## 处理流程
+1. 接收患者主诉文本
+2. 识别并提取所有症状描述
+3. 转换为ICD-11标准术语
+4. 结构化输出症状列表
+
+示例1:
+
+输入:突然感觉胸口压榨样疼痛,持续不缓解,向左肩和下颌放射,伴大汗、恶心,已经30分钟了。
+
+输出: { "symptoms": ["胸痛", "左肩放射痛", "下颌放射痛","大汗","恶心"] }
+本次用户输入:
+主诉:右上腹痛绞痛2小时
+现病史:2小时前无诱因下出现持续性右上腹部绞痛,剧痛难忍,伴恶心,发热,黄疸;无呕吐,无大小便异常。'''))

+ 148 - 29
src/knowledge/utils/license.py

@@ -1,39 +1,158 @@
-from cryptography.hazmat.primitives.asymmetric import padding
-from cryptography.hazmat.primitives import hashes,serialization
+import os
 import json
 import time
-import traceback
+import base64
+import logging
+from threading import Thread
+from cachetools import TTLCache
+from cryptography.hazmat.primitives.asymmetric import padding
+from cryptography.hazmat.primitives import hashes,serialization
+
+from ..db.session import get_db
+from ..model.dict_system import DictSystem
+from src.knowledge.config.site import SiteConfig
+from src.knowledge.model.response import StandardResponse
+
+logger = logging.getLogger(__name__)
+
+#并发缓存
+concurrency_cache = TTLCache(maxsize=10, ttl=5)
+#证书验证缓存
+vaild_cache = {}
+#5分值之内触发一次保存
+need_save = False
+
+#验证入口
+def validate_license():
+    if vaild_cache == {}:
+        return StandardResponse(success=False,vaildCode=9, errorMsg="系统异常:系统未加载证书信息")
+    elif 'result' in vaild_cache:   #已经是失败状态就直接返回失败
+        return vaild_cache["result"]
+    else:
+        timestamp = int(time.time())
+        timestamp_str = str(timestamp)
+        print(concurrency_cache,'=concurrency_cache')
+        if timestamp_str in concurrency_cache:  # 检查并发是否过高
+            if concurrency_cache[timestamp_str] > vaild_cache["api_concurrent_limit"]:
+                logger.info("当前["+timestamp_str+"]请求的并发数为"+str(concurrency_cache[timestamp_str]))
+                return StandardResponse(success=False,vaildCode=2, errorMsg="并发数过高")
+            else:
+                concurrency_cache[timestamp_str] += 1
+        else:
+            concurrency_cache[timestamp_str] = 1
+        
+        if timestamp > vaild_cache["expiration_time"]: # 检查日期是否过期
+            logger.info("证书过期时间为"+str(vaild_cache["expiration_time"]))
+            vaild_cache["result"] = StandardResponse(success=False,vaildCode=1, errorMsg="许可证已过期")
+            return vaild_cache["result"]
+        
+        if vaild_cache["api_use_times"] > vaild_cache["api_invoke_max_count"]:  # 检查次数
+            logger.info("请求使用次数为"+str(vaild_cache["api_concurrent_limit"]))
+            vaild_cache["result"] = StandardResponse(success=False,vaildCode=3, errorMsg="使用次数已经用完")
+            return vaild_cache["result"]
+        vaild_cache["api_use_times"] += 1
 
-def validate_license(public_key_pem, license_json, signature):
-    public_key = serialization.load_pem_public_key(public_key_pem)
+        global need_save
+        if need_save == False:
+            Thread(target=save_vaild_cache).start()
+            need_save = True
 
+        return StandardResponse(success=True,vaildCode=0, data=vaild_cache)
+        
+#初始化验证参数
+def init_vaild_cache():
     try:
-        public_key.verify(
-        signature,
-        license_json,
-        padding.PKCS1v15(),
-        hashes.SHA256()
-        )
+        config = SiteConfig()
+        license_path = config.get_config('LICENSE_PATH')
+
+        public_key_pem = None
+        license_json = None
+        signature = None
+
+        with open(os.path.join(license_path, "public.key"), "rb") as f:
+            public_key_pem = f.read()
+        with open(os.path.join(license_path, "license_issued.lic"), "rb") as f:
+            data = json.loads(f.read())
+            license_json = json.dumps(data, sort_keys=True).encode()
+        with open(os.path.join(license_path, "license_issued.key"), "rb") as f:
+            signature = f.read()
+            
+        public_key = serialization.load_pem_public_key(public_key_pem)
+        public_key.verify(signature,license_json,padding.PKCS1v15(),hashes.SHA256())
+        license_data = json.loads(license_json.decode())
+
+        content = license_data['content']
+        for c in content:
+            if c['name'] == 'api_invoke_max_count':
+                vaild_cache['api_invoke_max_count'] = c['value']
+            elif c['name'] == 'api_concurrent_limit':
+                vaild_cache['api_concurrent_limit'] = c['value']
+        vaild_cache['expiration_time'] = license_data['expiration_time']
+        vaild_cache['api_use_times'] = 0
+        
+        #从数据库加载 
+        db = next(get_db())
+        result = (db.query(DictSystem).filter_by(dict_name = 'vaild_str',is_deleted = 0).first())
+        print(result)
+        if result is not None: #说明还没有存储,不清楚是不是被删除,从证书上同步
+            dict_value = str(base64.b64decode(bytes(result.dict_value,encoding="utf-8") ),encoding='utf-8')
+            dict_value_json=json.loads(dict_value)
+            vaild_cache['api_use_times'] = dict_value_json['api_use_times']
+            vaild_cache['id'] = result.id
     except:
-        #打印异常信息
-        traceback.print_exc()
-        return False
+        vaild_cache["result"] =  StandardResponse(success=False,vaildCode=9, errorMsg="许可证签名验证失败:是否上传了许可证文件")
+        logger.exception("许可证签名验证失败")
+
+#保存验证的参数,有请求时延迟5分钟触发保存
+def save_vaild_cache():
+    time.sleep(10) #5分钟
+    global need_save
+    if need_save == False:
+        return
+    # time.sleep(5*60) #5分钟
+    if vaild_cache is not None:
+        #保存数据到数据库
+        tmp_cache = {}
+        tmp_cache['api_use_times'] = vaild_cache['api_use_times']
+        new_dict = DictSystem(
+            dict_name='vaild_str',
+            dict_value=str(base64.b64encode(json.dumps(tmp_cache, ensure_ascii=False, indent=1).encode("utf-8")),encoding="utf-8"),
+            is_deleted=0
+        )
+        db = next(get_db())
+        print(vaild_cache)
+        if 'id' in vaild_cache:
+            db.query(DictSystem).filter_by(id=vaild_cache['id']).update({DictSystem.dict_value: new_dict.dict_value})
+            new_dict.id = vaild_cache['id']
+        else:
+            db.add(new_dict)
+        db.commit()
+        vaild_cache['id'] = new_dict.id
+    
+    need_save = False
 
-    license_data=json.loads(license_json.decode())
-    # 检查是否过期
-    if time.time()>license_data["expiration_time"]:
-        return False
-    return True
+#重置验证参数
+def reset_vaild_cache():
+    global vaild_cache
+    if vaild_cache != {} and 'id' in vaild_cache:
+        db = next(get_db())
+        db.delete( db.query(DictSystem).filter_by(id=vaild_cache['id']).first()) #删除信息
+        db.commit()
+    vaild_cache={}
+    global need_save
+    need_save = False
+    init_vaild_cache()  #再初始化
 
 if __name__ == '__main__':
-    with open("license_issued/public.key","rb") as f:
-        public_key_pem = f.read()
-    with open("license_issued/license_issued.lic","rb") as f:
-        data = json.loads(f.read())
-        license_json = json.dumps(data, sort_keys=True).encode()
-    with open("license_issued/license_issued.key","rb") as f:
-        signature = f.read()
-    if validate_license(public_key_pem,license_json, signature):
-        print("许可证有效!")
+    init_vaild_cache()
+    response = validate_license()
+    if response.success:
+        print("1许可证有效!")
+    else:
+        print(response.errorMsg)
+    reset_vaild_cache()
+    response = validate_license()
+    if response.success:
+        print("3许可证有效!")
     else:
-        print("许可证无效或已过期!")
+        print(response.errorMsg)

+ 134 - 0
src/knowledge/utils/mcp_client.py

@@ -0,0 +1,134 @@
+"""
+MCP客户端实现,使用fastmcp库连接SSE协议的MCP服务端
+
+功能包括:
+1. 建立与MCP服务端的连接
+2. 发送消息到服务端
+3. 接收服务端的SSE事件流
+4. 错误处理和重连机制
+"""
+import fastmcp
+import asyncio
+from typing import Optional, Callable
+
+class MCPClient:
+    """
+    MCP客户端类,用于连接SSE协议的MCP服务端
+    
+    参数:
+        server_url (str): MCP服务端的URL地址
+        on_message (Callable): 接收到消息时的回调函数
+        reconnect_interval (int): 重连间隔时间(秒)
+    """
+    def __init__(self, server_url: str, on_message: Callable, reconnect_interval: int = 5):
+        self.server_url = server_url
+        self.on_message = on_message
+        self.reconnect_interval = reconnect_interval
+        self.client: Optional[fastmcp.Client] = None
+        self.is_connected = False
+
+    async def connect(self):
+        """
+        连接到MCP服务端
+        
+        如果连接失败会自动重试,间隔时间为reconnect_interval
+        """
+        while True:
+            try:
+                # 创建fastmcp客户端实例
+                self.client = fastmcp.Client(self.server_url)
+                
+                # 建立连接
+                await self.client.connect()
+                self.is_connected = True
+                print(f"成功连接到MCP服务端: {self.server_url}")
+                
+                # 开始监听SSE事件流
+                await self._listen_events()
+                
+            except Exception as e:
+                print(f"连接MCP服务端失败: {e}")
+                self.is_connected = False
+                
+                # 等待重连间隔后重试
+                await asyncio.sleep(self.reconnect_interval)
+                print(f"尝试重新连接...")
+                continue
+
+    async def _listen_events(self):
+        """
+        监听SSE事件流
+        
+        这是一个内部方法,用于持续接收服务端推送的事件
+        """
+        try:
+            async for event in self.client.listen():
+                # 调用回调函数处理接收到的消息
+                self.on_message(event.data)
+                
+        except Exception as e:
+            print(f"监听事件流时出错: {e}")
+            self.is_connected = False
+            
+    async def send_message(self, message: str):
+        """
+        发送消息到MCP服务端
+        
+        参数:
+            message (str): 要发送的消息内容
+        """
+        if not self.is_connected:
+            raise ConnectionError("客户端未连接到MCP服务端")
+            
+        try:
+            await self.client.send(message)
+            print(f"消息已发送: {message}")
+            
+        except Exception as e:
+            print(f"发送消息失败: {e}")
+            self.is_connected = False
+            raise
+
+    async def close(self):
+        """
+        关闭MCP客户端连接
+        """
+        if self.client and self.is_connected:
+            await self.client.close()
+            self.is_connected = False
+            print("MCP客户端连接已关闭")
+
+
+async def example_usage():
+    """
+    示例用法: 展示如何使用MCPClient类
+    """
+    def on_message_callback(data):
+        """消息接收回调函数示例"""
+        print(f"收到服务端消息: {data}")
+    
+    # 创建客户端实例
+    client = MCPClient(
+        server_url="http://localhost:8081/mcp",
+        on_message=on_message_callback
+    )
+    
+    try:
+        # 启动客户端
+        await client.connect()
+        
+        # 发送测试消息
+        await client.send_message("Hello MCP Server!")
+        
+        # 保持运行
+        while True:
+            await asyncio.sleep(1)
+            
+    except KeyboardInterrupt:
+        # 捕获Ctrl+C关闭客户端
+        await client.close()
+
+
+if __name__ == "__main__":
+    # 运行示例
+    asyncio.run(example_usage())

+ 70 - 0
utils/excel_importer.py

@@ -0,0 +1,70 @@
+import pandas as pd
+from sqlalchemy.orm import Session
+from service.kg_node_service import KGNodeService
+from service.kg_prop_service import KGPropService
+import logging
+
+from utils.vectorizer import Vectorizer
+
+logger = logging.getLogger(__name__)
+
+class ExcelImporter:
+    def __init__(self, db: Session):
+        self.node_service = KGNodeService(db)
+        self.prop_service = KGPropService(db)
+    
+    def import_from_excel(self, file_path: str, category: str, prop_name: str):
+        try:
+            # 读取Excel文件
+            df = pd.read_excel(file_path, header=None)
+            
+            # 遍历每一行数据
+            for _, row in df.iterrows():
+                entity_name = str(row[0]).strip()
+                prop_value = str(row[1]).strip() if len(row) > 1 else ''
+                
+                if not entity_name:
+                    continue
+                
+                # 检查节点是否存在
+                node = self.node_service.get_node_by_name_category(entity_name, category)
+                
+                if not node:
+                    # 创建新节点
+                    node_data = {
+                        'name': entity_name,
+                        'category': category,
+                        'version': 'xysy',
+                        'embedding': Vectorizer.get_embedding(entity_name),
+                        'status': 0
+                    }
+                    node = self.node_service.create_node(node_data)
+                
+                # 创建属性
+                if prop_value:
+                    node_id = node['id'] if isinstance(node, dict) else node.id
+                    prop = self.prop_service.get_prop_by_ref_id(node_id, prop_name)
+                    if not prop:
+                        prop_data = {
+                            'ref_id': node_id,
+                            'category': 1,
+                            'prop_name': prop_name,
+                            'prop_value': prop_value,
+                            'type': 1
+                        }
+                        self.prop_service.create_prop(prop_data)
+            
+            return True
+        except Exception as e:
+            logger.error(f"导入Excel数据失败: {str(e)}")
+            raise ValueError(f"导入失败: {str(e)}")
+
+if __name__ == "__main__":
+    file_path = "C:\\Users\\17664\\Desktop\\入院主诊断-诊疗计划.xlsx"
+    category = "疾病"
+    prop_name = "intramural_treatment_plan"
+
+    from db.session import get_db
+    db = next(get_db())
+    importer = ExcelImporter(db)
+    importer.import_from_excel(file_path, category, prop_name)