|
- import sys,os
- current_path = os.getcwd()
- sys.path.append(current_path)
- from config.site import SiteConfig
- from fastapi import APIRouter, Depends, Query
- from db.database import get_db
- from sqlalchemy.orm import Session
- from sqlalchemy import text
- from agent.models.web.response import StandardResponse,FAILED,SUCCESS
- from agent.models.web.request import BasicRequest
- from agent.libs.graph import GraphBusiness
- from agent.libs.auth import verify_session_id, SessionValues
- import logging
- from typing import Optional, List
- from agent.models.db.graph import DbKgNode
- from agent.models.db.tree_structure import TreeStructure,KgGraphCategory
- import string
- import json
- from agent.tree_utils import get_tree_dto
- router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
- logger = logging.getLogger(__name__)
- config = SiteConfig()
- LOG_DIR = config.get_config("TASK_LOG_DIR", current_path)
- # job_category = Column(String(64), nullable=False)
- # job_name = Column(String(64))
- # job_details = Column(Text, nullable=False)
- # job_creator = Column(String(64), nullable=False)
- # job_logs = Column(Text, nullable=True)
- # job_files = Column(String(300), nullable=True)
- @router.post('/summary', response_model=StandardResponse)
- def summary_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
- if request.action != "get_summary":
- return StandardResponse(code=FAILED, message="invalid action")
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- summary = biz.get_graph_summary(graph_id=graph_id)
- if summary:
- logger.info(summary)
- return StandardResponse(code=SUCCESS, message="summary found", records=[summary])
- else:
- return StandardResponse(code=FAILED, message="summary not found",records=[])
- @router.post('/schemas', response_model=StandardResponse)
- def schemas_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
- if request.action== "get_nodes_schemas":
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- schemas = biz.get_nodes_categories(graph_id=graph_id)
- if schemas:
- return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
- if request.action== "get_edges_schemas":
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- schemas = biz.get_edges_categories(graph_id=graph_id)
- if schemas:
- return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
- return StandardResponse(code=FAILED, message="invalid action")
- @router.post('/nodes', response_model=StandardResponse)
- def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
- if (request.action == "search_nodes"):
- node_name = request.get_param("name","")
- category = request.get_param("category","")
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- if node_name == "":
- return StandardResponse(code=FAILED, message="node name is empty", records=[])
- if category == "":
- return StandardResponse(code=FAILED, message="category is empty", records=[])
- if graph_id == 0:
- return StandardResponse(code=FAILED, message="graph id is empty", records=[])
-
- nodes = biz.search_like_node_by_name(graph_id=graph_id, category=category, name=node_name)
- if nodes:
- return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
-
- else:
- return StandardResponse(code=FAILED, message="search job failed")
- elif (request.action == "get_nodes"):
- graph_id = request.get_param("graph_id",0)
- page = request.get_param("page",1)
- page_size = request.get_param("page_size",1)
- biz = GraphBusiness(db)
- nodes = biz.get_nodes_by_page(graph_id=graph_id, page=page, page_size=page_size)
- if nodes:
- return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
- elif (request.action == "neighbors"):
- graph_id = request.get_param("graph_id",0)
- node_id = request.get_param("node_id",0)
- if node_id>0:
- biz = GraphBusiness(db)
- node = biz.get_node_by_id(graph_id=graph_id, node_id=node_id)
- if node is None:
- return StandardResponse(code=FAILED, message="node not found", records=[])
- nodes_in = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="in")
- nodes_out = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="out")
- nodes_all = []
- nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"self"})
- for node in nodes_in:
- nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"in"})
- for node in nodes_out:
- nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"out"})
-
- return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
- return StandardResponse(code=FAILED, message="invalid action")
- async def get_node_properties(node_id: int, db: Session) -> dict:
- """
- 查询节点的属性
- :param node_id: 节点ID
- :param db: 数据库会话
- :return: 属性字典
- """
- prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id")
- result = db.execute(prop_sql, {'node_id': node_id}).fetchall()
- properties = {}
- for row in result:
- properties[row._mapping['prop_title']] = row._mapping['prop_value']
- return properties
- @router.get("/graph_data", response_model=StandardResponse)
- async def get_graph_data(
- label_name: str,
- user_id: int,
- graph_id: int,
- db: Session = Depends(get_db),
- input_str: Optional[str] = None
- ):
- """
- 获取用户关联的图谱数据
- - 从session_id获取user_id
- - 查询DbUserDataRelation获取用户关联的数据
- - 返回与Java端一致的数据结构
- """
- try:
- # 1. 从session获取user_id
- # user_id = sess.user_id
- # if not user_id:
- # return StandardResponse(code=FAILED, message="user not found", records=[])
-
- # 处理input_str为空的情况
- if not input_str:
- # 根据user_id、graph_id和label_name从kg_nodes表中获取一个name
- get_name_sql = text("""
- SELECT n.name
- FROM user_data_relations udr
- JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
- WHERE udr.user_id = :user_id
- AND n.category = :label_name
- AND n.status = '0'
- AND n.graph_id = :graph_id
- LIMIT 1
- """)
-
- name_result = db.execute(get_name_sql, {
- 'user_id': user_id,
- 'label_name': label_name,
- 'graph_id': graph_id
- }).fetchone()
-
- if not name_result:
- return StandardResponse(code=FAILED, message="No node found for given parameters", records=[])
-
- input_str = name_result._mapping['name']
-
- # 2. 使用JOIN查询用户关联的图谱数据
- sql = text("""
- WITH RankedRelations AS (
- SELECT
- e.name as rType,
- m.id as target_id,
- m.name as target_name,
- m.category as target_label,
- (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
- ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
- FROM user_data_relations udr
- JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
- JOIN kg_edges e ON n.id = e.src_id
- JOIN kg_nodes m ON e.dest_id = m.id
- WHERE udr.user_id = :user_id
- AND n.category = :label_name
- AND n.name = :input_str
- AND n.status = '0'
- AND n.graph_id = :graph_id
- )
- SELECT rType, target_id, target_name, target_label, pCount
- FROM RankedRelations
- WHERE rn <= 50
- ORDER BY rType
- """)
-
- # 3. 组装返回数据结构
- categories = [{"name": "中心词"}, {"name": "关系"}]
- nodes = []
- links = []
-
- # 查询中心节点
- center_sql = text("""
- SELECT n.id, n.name, n.category
- FROM user_data_relations udr
- JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
- WHERE udr.user_id = :user_id
- AND n.category = :label_name
- AND n.name = :input_str
- AND n.status = '0' AND n.graph_id= :graph_id limit 1
- """)
-
- # 执行查询并处理结果
- center_node = None
- c_map = {"中心词": 0, "关系": 1}
- node_id = 0
-
- # 构建graph_dto数据结构
- graph_dto = {
- "label": "",
- "name": "",
- "id": 0,
- "properties": {},
- "ENodeRSDTOS": []
- }
-
- # 1. 查询中心节点
- center_result = db.execute(center_sql, {
- 'user_id': user_id,
- 'label_name': label_name,
- 'input_str': input_str,
- 'graph_id': graph_id
- }).fetchall()
-
- if center_result:
- for row in center_result:
- graph_dto["label"] = row._mapping['category']
- graph_dto["name"] = row._mapping['name']
- graph_dto["id"] = row._mapping['id']
- graph_dto["properties"] = await get_node_properties(row._mapping['id'], db)
- break
-
- # 2. 查询关联的边和目标节点
- relation_result = db.execute(sql, {
- 'user_id': user_id,
- 'label_name': label_name,
- 'input_str': input_str,
- 'graph_id': graph_id
- }).fetchall()
-
- if relation_result:
- rs_id = 2
- for row in relation_result:
- r_type = row._mapping['rtype']
-
- # 添加到graph_dto
- e_node_rs = {
- "RType": r_type,
- "ENodeDTOS": [{
- "Label": row._mapping['target_label'],
- "Name": row._mapping['target_name'],
- "Id": row._mapping['target_id'],
- "PCount": row._mapping['pcount'],
- "properties": await get_node_properties(row._mapping['target_id'], db)
- }]
- }
-
- # 检查是否已有该关系类型
- existing_rs = next((rs for rs in graph_dto["ENodeRSDTOS"] if rs["RType"] == r_type), None)
- if existing_rs:
- existing_rs["ENodeDTOS"].extend(e_node_rs["ENodeDTOS"])
- else:
- graph_dto["ENodeRSDTOS"].append(e_node_rs)
-
- if r_type not in c_map:
- c_map[r_type] = rs_id
- categories.append({"name": r_type})
- rs_id += 1
-
- print("graph_dto:", graph_dto) # 打印graph_dto
-
- # 构建中心节点
- center_node = {
- "label": graph_dto["name"],
- 'type': graph_dto["label"],
- "category": 0,
- "name": "0",
- # "id": graph_dto["id"],
- "symbol": "circle",
- "symbolSize": 50,
- "properties": graph_dto["properties"],
- "nodeId": graph_dto["id"],
- "itemStyle": {"display": True}
- }
- nodes.append(center_node)
- # 处理关系类型
- rs_id = 2
-
- for rs in graph_dto["ENodeRSDTOS"]:
- r_type = rs["RType"]
-
- if r_type not in c_map:
- c_map[r_type] = rs_id
- categories.append({"name": r_type})
- rs_id += 1
-
- # 关系节点
- relation_node = {
- "label": "",
- 'type': graph_dto["label"],
- "category": 1,
- "name": str(len(nodes)),
- # "id": len(nodes),
- "symbol": "diamond",
- "symbolSize": 10,
- "properties": graph_dto["properties"],
- "nodeId": len(nodes),
- "itemStyle": {"display": True}
- }
- nodes.append(relation_node)
-
- # 添加链接
- links.append({
- "source": "0",
- "target": str(nodes.index(relation_node)),
- "value": r_type,
- "relationType": r_type
- })
- # 处理子节点
- for e_node in rs["ENodeDTOS"]:
- item_style = {"display": e_node["PCount"] > 0}
- child_node = {
- "label": e_node["Name"],
- "type": e_node["Label"],
- "category": c_map[r_type],
- "name": str(len(nodes)),
- # "id": e_node["Id"],
- "symbol": "circle",
- "symbolSize": 28,
- "properties": e_node["properties"],
- "nodeId": e_node["Id"],
- "itemStyle": item_style
- }
- nodes.append(child_node)
-
- links.append({
- "source": str(nodes.index(relation_node)),
- "target": str(nodes.index(child_node)),
- "value": "",
- "relationType": r_type
- })
-
- final_data = {
- "categories": categories,
- "node": nodes,
- "links": links
- }
- return StandardResponse(
- records=[{"records": final_data}],
- message="Graph data retrieved"
- )
- except Exception as e:
- return StandardResponse(
- code=500,
- message=str(e)
- )
-
-
- @router.get("/user_sub_graphs", response_model=StandardResponse)
- async def get_user_sub_graphs(
- user_id: int,
- pageNo: int = 1,
- pageSize: int = 10,
- db: Session = Depends(get_db)
- ):
- """
- 获取用户关联的子图列表
- - 根据user_id和data_category='sub_graph'查询user_data_relations表
- - 关联jobs表获取job_name
- - 返回data_id和job_name列表
- - 支持分页查询,参数pageNo(默认1)和pageSize(默认10)
- """
- try:
- # 查询用户关联的子图
- offset = (pageNo - 1) * pageSize
- sql = text("""
- SELECT udr.data_id, j.job_name
- FROM user_data_relations udr
- LEFT JOIN jobs j ON udr.data_id = j.id
- WHERE udr.user_id = :user_id
- AND udr.data_category = 'sub_graph' order by udr.data_id desc
- LIMIT :pageSize OFFSET :offset
- """)
-
- result = db.execute(sql, {'user_id': user_id, 'pageSize': pageSize, 'offset': offset}).fetchall()
-
- records = []
- for row in result:
- records.append({
- "graph_id": row._mapping['data_id'],
- "graph_name": row._mapping['job_name']
- })
-
- return StandardResponse(
- records=records,
- message="User sub graphs retrieved"
- )
- except Exception as e:
- return StandardResponse(
- code=500,
- message=str(e)
- )
- def build_disease_tree(disease_nodes: list, root_name: str = "疾病") -> dict:
- """
- 构建疾病树状结构的公共方法
- :param disease_nodes: 疾病节点列表,每个节点需包含name属性
- :param root_name: 根节点名称,默认为"疾病"
- :return: 树状结构字典
- """
- if not disease_nodes:
- return {"name": root_name, "sNode": []}
-
- # 按拼音首字母分类
- letter_groups = {letter: [] for letter in string.ascii_uppercase}
- letter_groups['其他'] = []
-
- for node in disease_nodes:
- name = node.name if hasattr(node, 'name') else str(node)
- first_letter = get_first_letter(name)
- letter_groups[first_letter].append(name)
-
- # 构建JSON结构
- tree_structure = {
- "name": root_name,
- "sNode": []
- }
-
- # 先添加A-Z的分类
- for letter in string.ascii_uppercase:
- if letter_groups[letter]:
- letter_node = {
- "name": letter,
- "sNode": [{"name": disease, "sNode": []} for disease in sorted(letter_groups[letter])]
- }
- tree_structure["sNode"].append(letter_node)
-
- # 最后添加"其他"分类(如果有的话)
- if letter_groups['其他']:
- other_node = {
- "name": "其他",
- "sNode": [{"name": disease, "sNode": []} for disease in sorted(letter_groups['其他'])]
- }
- tree_structure["sNode"].append(other_node)
- # content=json.dumps(tree_structure, ensure_ascii=False)
- # print(content)
- # tree_dto=get_tree_dto(content)
- # print(tree_dto)
- return tree_structure
- def get_first_letter(word):
- """获取中文词语的拼音首字母"""
- if not word:
- return '其他'
-
- # 获取第一个汉字的拼音首字母
- first_char = word[0]
- try:
- import pypinyin
- first_letter = pypinyin.pinyin(first_char, style=pypinyin.FIRST_LETTER)[0][0].upper()
- return first_letter if first_letter in string.ascii_uppercase else '其他'
- except Exception as e:
- print(str(e))
- return '其他'
- @router.get('/disease_tree')
- async def get_disease_tree(graph_id: int, db: Session = Depends(get_db)):
- """
- 根据graph_id查询kg_nodes表中category是疾病的数据并构建树状结构
- 严格按照字母A-Z顺序进行归类,中文首字母归类到对应拼音首字母
- """
- # 查询疾病节点
- disease_nodes = db.query(DbKgNode).filter(
- DbKgNode.graph_id == graph_id,
- DbKgNode.category == '疾病'
- ).all()
-
- tree_structure = build_disease_tree(disease_nodes)
- return StandardResponse(records=[{"records": tree_structure}])
-
- @router.get('/graph_categories')
- async def get_graph_categories(
- user_id: int,
- graph_id: int,
- db: Session = Depends(get_db)
- ):
- """
- 根据user_id和graph_id查询kg_graph_category表中的category列表
- 返回category的字符串列表(按照id正序排列)
- """
- try:
- # 查询category列表(按照id正序)
- categories = db.query(KgGraphCategory.category).filter(
- KgGraphCategory.user_id == user_id,
- KgGraphCategory.graph_id == graph_id
- ).order_by(KgGraphCategory.id).all()
-
- if not categories:
- return StandardResponse(code=FAILED, message="No categories found")
-
- # 转换为字符串列表
- category_list = [category[0] for category in categories]
-
- return StandardResponse(
- records=[{"records": category_list}],
- message="Graph categories retrieved"
- )
- except Exception as e:
- return StandardResponse(
- code=500,
- message=str(e)
- )
-
- @router.get('/tree_structure')
- async def get_tree_structure(
- user_id: int,
- graph_id: int,
- db: Session = Depends(get_db)
- ):
- """
- 根据user_id和graph_id获取树状结构数据
- 1. 查询kg_tree_structures表获取content
- 2. 调用get_tree_dto方法转换数据格式
- 3. 返回转换后的数据
- """
- try:
- # 查询树状结构数据
- tree_structure = db.query(TreeStructure).filter(
- TreeStructure.user_id == user_id,
- TreeStructure.graph_id == graph_id
- ).first()
-
- if not tree_structure:
- return StandardResponse(code=FAILED, message="Tree structure not found")
-
- # 转换数据格式
- tree_dto = get_tree_dto(tree_structure.content)
-
- return StandardResponse(
- records=[{"records": tree_dto}],
- message="Tree structure retrieved"
- )
- except Exception as e:
- return StandardResponse(
- code=500,
- message=str(e)
- )
- kb_router = router
|