import sys,os current_path = os.getcwd() sys.path.append(current_path) from config.site import SiteConfig from fastapi import APIRouter, Depends, Query from db.database import get_db from sqlalchemy.orm import Session from sqlalchemy import text from agent.models.web.response import StandardResponse,FAILED,SUCCESS from agent.models.web.request import BasicRequest from agent.libs.graph import GraphBusiness from agent.libs.auth import verify_session_id, SessionValues import logging from typing import Optional, List from agent.models.db.graph import DbKgNode from agent.models.db.tree_structure import TreeStructure,KgGraphCategory import string import json from agent.tree_utils import get_tree_dto router = APIRouter(prefix="/kb", tags=["knowledge build interface"]) logger = logging.getLogger(__name__) config = SiteConfig() LOG_DIR = config.get_config("TASK_LOG_DIR", current_path) # job_category = Column(String(64), nullable=False) # job_name = Column(String(64)) # job_details = Column(Text, nullable=False) # job_creator = Column(String(64), nullable=False) # job_logs = Column(Text, nullable=True) # job_files = Column(String(300), nullable=True) @router.post('/summary', response_model=StandardResponse) def summary_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse: if request.action != "get_summary": return StandardResponse(code=FAILED, message="invalid action") graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) summary = biz.get_graph_summary(graph_id=graph_id) if summary: logger.info(summary) return StandardResponse(code=SUCCESS, message="summary found", records=[summary]) else: return StandardResponse(code=FAILED, message="summary not found",records=[]) @router.post('/schemas', response_model=StandardResponse) def schemas_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse: if request.action== "get_nodes_schemas": graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) schemas = biz.get_nodes_categories(graph_id=graph_id) if schemas: return StandardResponse(code=SUCCESS, message="schemas found", records=schemas) if request.action== "get_edges_schemas": graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) schemas = biz.get_edges_categories(graph_id=graph_id) if schemas: return StandardResponse(code=SUCCESS, message="schemas found", records=schemas) return StandardResponse(code=FAILED, message="invalid action") @router.post('/nodes', response_model=StandardResponse) def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse: if (request.action == "search_nodes"): node_name = request.get_param("name","") category = request.get_param("category","") graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) if node_name == "": return StandardResponse(code=FAILED, message="node name is empty", records=[]) if category == "": return StandardResponse(code=FAILED, message="category is empty", records=[]) if graph_id == 0: return StandardResponse(code=FAILED, message="graph id is empty", records=[]) nodes = biz.search_like_node_by_name(graph_id=graph_id, category=category, name=node_name) if nodes: return StandardResponse(code=SUCCESS, message="nodes found", records=nodes) else: return StandardResponse(code=FAILED, message="search job failed") elif (request.action == "get_nodes"): graph_id = request.get_param("graph_id",0) page = request.get_param("page",1) page_size = request.get_param("page_size",1) biz = GraphBusiness(db) nodes = biz.get_nodes_by_page(graph_id=graph_id, page=page, page_size=page_size) if nodes: return StandardResponse(code=SUCCESS, message="nodes found", records=nodes) elif (request.action == "neighbors"): graph_id = request.get_param("graph_id",0) node_id = request.get_param("node_id",0) if node_id>0: biz = GraphBusiness(db) node = biz.get_node_by_id(graph_id=graph_id, node_id=node_id) if node is None: return StandardResponse(code=FAILED, message="node not found", records=[]) nodes_in = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="in") nodes_out = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="out") nodes_all = [] nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"self"}) for node in nodes_in: nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"in"}) for node in nodes_out: nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"out"}) return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all) return StandardResponse(code=FAILED, message="invalid action") async def get_node_properties(node_id: int, db: Session) -> dict: """ 查询节点的属性 :param node_id: 节点ID :param db: 数据库会话 :return: 属性字典 """ prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id") result = db.execute(prop_sql, {'node_id': node_id}).fetchall() properties = {} for row in result: properties[row._mapping['prop_title']] = row._mapping['prop_value'] return properties @router.get("/graph_data", response_model=StandardResponse) async def get_graph_data( label_name: str, user_id: int, graph_id: int, db: Session = Depends(get_db), input_str: Optional[str] = None ): """ 获取用户关联的图谱数据 - 从session_id获取user_id - 查询DbUserDataRelation获取用户关联的数据 - 返回与Java端一致的数据结构 """ try: # 1. 从session获取user_id # user_id = sess.user_id # if not user_id: # return StandardResponse(code=FAILED, message="user not found", records=[]) # 处理input_str为空的情况 if not input_str: # 根据user_id、graph_id和label_name从kg_nodes表中获取一个name get_name_sql = text(""" SELECT n.name FROM user_data_relations udr JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode' WHERE udr.user_id = :user_id AND n.category = :label_name AND n.status = '0' AND n.graph_id = :graph_id LIMIT 1 """) name_result = db.execute(get_name_sql, { 'user_id': user_id, 'label_name': label_name, 'graph_id': graph_id }).fetchone() if not name_result: return StandardResponse(code=FAILED, message="No node found for given parameters", records=[]) input_str = name_result._mapping['name'] # 2. 使用JOIN查询用户关联的图谱数据 sql = text(""" WITH RankedRelations AS ( SELECT e.name as rType, m.id as target_id, m.name as target_name, m.category as target_label, (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount, ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn FROM user_data_relations udr JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode' JOIN kg_edges e ON n.id = e.src_id JOIN kg_nodes m ON e.dest_id = m.id WHERE udr.user_id = :user_id AND n.category = :label_name AND n.name = :input_str AND n.status = '0' AND n.graph_id = :graph_id ) SELECT rType, target_id, target_name, target_label, pCount FROM RankedRelations WHERE rn <= 50 ORDER BY rType """) # 3. 组装返回数据结构 categories = [{"name": "中心词"}, {"name": "关系"}] nodes = [] links = [] # 查询中心节点 center_sql = text(""" SELECT n.id, n.name, n.category FROM user_data_relations udr JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode' WHERE udr.user_id = :user_id AND n.category = :label_name AND n.name = :input_str AND n.status = '0' AND n.graph_id= :graph_id limit 1 """) # 执行查询并处理结果 center_node = None c_map = {"中心词": 0, "关系": 1} node_id = 0 # 构建graph_dto数据结构 graph_dto = { "label": "", "name": "", "id": 0, "properties": {}, "ENodeRSDTOS": [] } # 1. 查询中心节点 center_result = db.execute(center_sql, { 'user_id': user_id, 'label_name': label_name, 'input_str': input_str, 'graph_id': graph_id }).fetchall() if center_result: for row in center_result: graph_dto["label"] = row._mapping['category'] graph_dto["name"] = row._mapping['name'] graph_dto["id"] = row._mapping['id'] graph_dto["properties"] = await get_node_properties(row._mapping['id'], db) break # 2. 查询关联的边和目标节点 relation_result = db.execute(sql, { 'user_id': user_id, 'label_name': label_name, 'input_str': input_str, 'graph_id': graph_id }).fetchall() if relation_result: rs_id = 2 for row in relation_result: r_type = row._mapping['rtype'] # 添加到graph_dto e_node_rs = { "RType": r_type, "ENodeDTOS": [{ "Label": row._mapping['target_label'], "Name": row._mapping['target_name'], "Id": row._mapping['target_id'], "PCount": row._mapping['pcount'], "properties": await get_node_properties(row._mapping['target_id'], db) }] } # 检查是否已有该关系类型 existing_rs = next((rs for rs in graph_dto["ENodeRSDTOS"] if rs["RType"] == r_type), None) if existing_rs: existing_rs["ENodeDTOS"].extend(e_node_rs["ENodeDTOS"]) else: graph_dto["ENodeRSDTOS"].append(e_node_rs) if r_type not in c_map: c_map[r_type] = rs_id categories.append({"name": r_type}) rs_id += 1 print("graph_dto:", graph_dto) # 打印graph_dto # 构建中心节点 center_node = { "label": graph_dto["name"], 'type': graph_dto["label"], "category": 0, "name": "0", # "id": graph_dto["id"], "symbol": "circle", "symbolSize": 50, "properties": graph_dto["properties"], "nodeId": graph_dto["id"], "itemStyle": {"display": True} } nodes.append(center_node) # 处理关系类型 rs_id = 2 for rs in graph_dto["ENodeRSDTOS"]: r_type = rs["RType"] if r_type not in c_map: c_map[r_type] = rs_id categories.append({"name": r_type}) rs_id += 1 # 关系节点 relation_node = { "label": "", 'type': graph_dto["label"], "category": 1, "name": str(len(nodes)), # "id": len(nodes), "symbol": "diamond", "symbolSize": 10, "properties": graph_dto["properties"], "nodeId": len(nodes), "itemStyle": {"display": True} } nodes.append(relation_node) # 添加链接 links.append({ "source": "0", "target": str(nodes.index(relation_node)), "value": r_type, "relationType": r_type }) # 处理子节点 for e_node in rs["ENodeDTOS"]: item_style = {"display": e_node["PCount"] > 0} child_node = { "label": e_node["Name"], "type": e_node["Label"], "category": c_map[r_type], "name": str(len(nodes)), # "id": e_node["Id"], "symbol": "circle", "symbolSize": 28, "properties": e_node["properties"], "nodeId": e_node["Id"], "itemStyle": item_style } nodes.append(child_node) links.append({ "source": str(nodes.index(relation_node)), "target": str(nodes.index(child_node)), "value": "", "relationType": r_type }) final_data = { "categories": categories, "node": nodes, "links": links } return StandardResponse( records=[{"records": final_data}], message="Graph data retrieved" ) except Exception as e: return StandardResponse( code=500, message=str(e) ) @router.get("/user_sub_graphs", response_model=StandardResponse) async def get_user_sub_graphs( user_id: int, pageNo: int = 1, pageSize: int = 10, db: Session = Depends(get_db) ): """ 获取用户关联的子图列表 - 根据user_id和data_category='sub_graph'查询user_data_relations表 - 关联jobs表获取job_name - 返回data_id和job_name列表 - 支持分页查询,参数pageNo(默认1)和pageSize(默认10) """ try: # 查询用户关联的子图 offset = (pageNo - 1) * pageSize sql = text(""" SELECT udr.data_id, j.job_name FROM user_data_relations udr LEFT JOIN jobs j ON udr.data_id = j.id WHERE udr.user_id = :user_id AND udr.data_category = 'sub_graph' order by udr.data_id desc LIMIT :pageSize OFFSET :offset """) result = db.execute(sql, {'user_id': user_id, 'pageSize': pageSize, 'offset': offset}).fetchall() records = [] for row in result: records.append({ "graph_id": row._mapping['data_id'], "graph_name": row._mapping['job_name'] }) return StandardResponse( records=records, message="User sub graphs retrieved" ) except Exception as e: return StandardResponse( code=500, message=str(e) ) def build_disease_tree(disease_nodes: list, root_name: str = "疾病") -> dict: """ 构建疾病树状结构的公共方法 :param disease_nodes: 疾病节点列表,每个节点需包含name属性 :param root_name: 根节点名称,默认为"疾病" :return: 树状结构字典 """ if not disease_nodes: return {"name": root_name, "sNode": []} # 按拼音首字母分类 letter_groups = {letter: [] for letter in string.ascii_uppercase} letter_groups['其他'] = [] for node in disease_nodes: name = node.name if hasattr(node, 'name') else str(node) first_letter = get_first_letter(name) letter_groups[first_letter].append(name) # 构建JSON结构 tree_structure = { "name": root_name, "sNode": [] } # 先添加A-Z的分类 for letter in string.ascii_uppercase: if letter_groups[letter]: letter_node = { "name": letter, "sNode": [{"name": disease, "sNode": []} for disease in sorted(letter_groups[letter])] } tree_structure["sNode"].append(letter_node) # 最后添加"其他"分类(如果有的话) if letter_groups['其他']: other_node = { "name": "其他", "sNode": [{"name": disease, "sNode": []} for disease in sorted(letter_groups['其他'])] } tree_structure["sNode"].append(other_node) # content=json.dumps(tree_structure, ensure_ascii=False) # print(content) # tree_dto=get_tree_dto(content) # print(tree_dto) return tree_structure def get_first_letter(word): """获取中文词语的拼音首字母""" if not word: return '其他' # 获取第一个汉字的拼音首字母 first_char = word[0] try: import pypinyin first_letter = pypinyin.pinyin(first_char, style=pypinyin.FIRST_LETTER)[0][0].upper() return first_letter if first_letter in string.ascii_uppercase else '其他' except Exception as e: print(str(e)) return '其他' @router.get('/disease_tree') async def get_disease_tree(graph_id: int, db: Session = Depends(get_db)): """ 根据graph_id查询kg_nodes表中category是疾病的数据并构建树状结构 严格按照字母A-Z顺序进行归类,中文首字母归类到对应拼音首字母 """ # 查询疾病节点 disease_nodes = db.query(DbKgNode).filter( DbKgNode.graph_id == graph_id, DbKgNode.category == '疾病' ).all() tree_structure = build_disease_tree(disease_nodes) return StandardResponse(records=[{"records": tree_structure}]) @router.get('/graph_categories') async def get_graph_categories( user_id: int, graph_id: int, db: Session = Depends(get_db) ): """ 根据user_id和graph_id查询kg_graph_category表中的category列表 返回category的字符串列表(按照id正序排列) """ try: # 查询category列表(按照id正序) categories = db.query(KgGraphCategory.category).filter( KgGraphCategory.user_id == user_id, KgGraphCategory.graph_id == graph_id ).order_by(KgGraphCategory.id).all() if not categories: return StandardResponse(code=FAILED, message="No categories found") # 转换为字符串列表 category_list = [category[0] for category in categories] return StandardResponse( records=[{"records": category_list}], message="Graph categories retrieved" ) except Exception as e: return StandardResponse( code=500, message=str(e) ) @router.get('/tree_structure') async def get_tree_structure( user_id: int, graph_id: int, db: Session = Depends(get_db) ): """ 根据user_id和graph_id获取树状结构数据 1. 查询kg_tree_structures表获取content 2. 调用get_tree_dto方法转换数据格式 3. 返回转换后的数据 """ try: # 查询树状结构数据 tree_structure = db.query(TreeStructure).filter( TreeStructure.user_id == user_id, TreeStructure.graph_id == graph_id ).first() if not tree_structure: return StandardResponse(code=FAILED, message="Tree structure not found") # 转换数据格式 tree_dto = get_tree_dto(tree_structure.content) return StandardResponse( records=[{"records": tree_dto}], message="Tree structure retrieved" ) except Exception as e: return StandardResponse( code=500, message=str(e) ) kb_router = router