import sys,os current_path = os.getcwd() sys.path.append(current_path) from config.site import SiteConfig from fastapi import APIRouter, Depends, Query from db.database import get_db from sqlalchemy.orm import Session from sqlalchemy import text from agent.models.web.response import StandardResponse,FAILED,SUCCESS from agent.models.web.request import BasicRequest from agent.libs.graph import GraphBusiness from agent.libs.auth import verify_session_id, SessionValues import logging from typing import Optional, List router = APIRouter(prefix="/kb", tags=["knowledge build interface"]) logger = logging.getLogger(__name__) config = SiteConfig() LOG_DIR = config.get_config("TASK_LOG_DIR", current_path) # job_category = Column(String(64), nullable=False) # job_name = Column(String(64)) # job_details = Column(Text, nullable=False) # job_creator = Column(String(64), nullable=False) # job_logs = Column(Text, nullable=True) # job_files = Column(String(300), nullable=True) @router.post('/summary', response_model=StandardResponse) def summary_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse: if request.action != "get_summary": return StandardResponse(code=FAILED, message="invalid action") graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) summary = biz.get_graph_summary(graph_id=graph_id) if summary: logger.info(summary) return StandardResponse(code=SUCCESS, message="summary found", records=[summary]) else: return StandardResponse(code=FAILED, message="summary not found",records=[]) @router.post('/schemas', response_model=StandardResponse) def schemas_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse: if request.action== "get_nodes_schemas": graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) schemas = biz.get_nodes_categories(graph_id=graph_id) if schemas: return StandardResponse(code=SUCCESS, message="schemas found", records=schemas) if request.action== "get_edges_schemas": graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) schemas = biz.get_edges_categories(graph_id=graph_id) if schemas: return StandardResponse(code=SUCCESS, message="schemas found", records=schemas) return StandardResponse(code=FAILED, message="invalid action") @router.post('/nodes', response_model=StandardResponse) def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse: if (request.action == "search_nodes"): node_name = request.get_param("name","") category = request.get_param("category","") graph_id = request.get_param("graph_id",0) biz = GraphBusiness(db) if node_name == "": return StandardResponse(code=FAILED, message="node name is empty", records=[]) if category == "": return StandardResponse(code=FAILED, message="category is empty", records=[]) if graph_id == 0: return StandardResponse(code=FAILED, message="graph id is empty", records=[]) nodes = biz.search_like_node_by_name(graph_id=graph_id, category=category, name=node_name) if nodes: return StandardResponse(code=SUCCESS, message="nodes found", records=nodes) else: return StandardResponse(code=FAILED, message="search job failed") elif (request.action == "get_nodes"): graph_id = request.get_param("graph_id",0) page = request.get_param("page",1) page_size = request.get_param("page_size",1) biz = GraphBusiness(db) nodes = biz.get_nodes_by_page(graph_id=graph_id, page=page, page_size=page_size) if nodes: return StandardResponse(code=SUCCESS, message="nodes found", records=nodes) elif (request.action == "neighbors"): graph_id = request.get_param("graph_id",0) node_id = request.get_param("node_id",0) if node_id>0: biz = GraphBusiness(db) node = biz.get_node_by_id(graph_id=graph_id, node_id=node_id) if node is None: return StandardResponse(code=FAILED, message="node not found", records=[]) nodes_in = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="in") nodes_out = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="out") nodes_all = [] nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"self"}) for node in nodes_in: nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"in"}) for node in nodes_out: nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"out"}) return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all) return StandardResponse(code=FAILED, message="invalid action") async def get_node_properties(node_id: int, db: Session) -> dict: """ 查询节点的属性 :param node_id: 节点ID :param db: 数据库会话 :return: 属性字典 """ prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id") result = db.execute(prop_sql, {'node_id': node_id}).fetchall() properties = {} for row in result: properties[row._mapping['prop_title']] = row._mapping['prop_value'] return properties @router.get("/graph_data", response_model=StandardResponse) async def get_graph_data( label_name: Optional[str] = Query(None), input_str: Optional[str] = Query(None), db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id) ): """ 获取用户关联的图谱数据 - 从session_id获取user_id - 查询DbUserDataRelation获取用户关联的数据 - 返回与Java端一致的数据结构 """ try: # 1. 从session获取user_id user_id = sess.user_id if not user_id: return StandardResponse(code=FAILED, message="user not found", records=[]) # 2. 使用JOIN查询用户关联的图谱数据 sql = text(""" WITH RankedRelations AS ( SELECT e.name as rType, m.id as target_id, m.name as target_name, m.category as target_label, (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount, ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn FROM user_data_relations udr JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode' JOIN kg_edges e ON n.id = e.src_id JOIN kg_nodes m ON e.dest_id = m.id WHERE udr.user_id = :user_id AND n.category = :label_name AND n.name = :input_str AND n.status = '0' ) SELECT rType, target_id, target_name, target_label, pCount FROM RankedRelations WHERE rn <= 50 ORDER BY rType """) # 3. 组装返回数据结构 categories = ["中心词", "关系"] nodes = [] links = [] # 查询中心节点 center_sql = text(""" SELECT n.id, n.name, n.category FROM user_data_relations udr JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode' WHERE udr.user_id = :user_id AND n.category = :label_name AND n.name = :input_str AND n.status = '0' """) # 执行查询并处理结果 center_node = None rtype_map = {} # 1. 查询中心节点 center_result = db.execute(center_sql, { 'user_id': user_id, 'label_name': label_name, 'input_str': input_str }).fetchall() if center_result: for row in center_result: center_node = { "id": 0, "name": row._mapping['name'], "label": row._mapping['category'], "symbolSize": 50, "symbol": "circle", "properties": await get_node_properties(row._mapping['id'], db) } nodes.append(center_node) break # 2. 查询关联的边和目标节点 relation_result = db.execute(sql, { 'user_id': user_id, 'label_name': label_name, 'input_str': input_str }).fetchall() if relation_result: for row in relation_result: r_type = row._mapping['rtype'] target_node = { "id": row._mapping['target_id'], "name": row._mapping['target_name'], "label": row._mapping['target_label'], "symbolSize": 28, "symbol": "circle", "properties": await get_node_properties(row._mapping['target_id'], db) } if r_type not in rtype_map: rtype_map[r_type] = [] rtype_map[r_type].append(target_node) if r_type not in categories: categories.append(r_type) # 3. 组装返回结果 if center_node: for r_type, targets in rtype_map.items(): # 添加关系节点 relation_node = { "id": len(nodes), "name": "", "label": center_node['label'], "symbolSize": 10, "symbol": "diamond", "properties": {} } nodes.append(relation_node) # 添加中心节点到关系节点的链接 links.append({ "source": center_node['name'], "target": "", "name": r_type, "category": r_type }) # 添加关系节点到目标节点的链接 for target in targets: links.append({ "source": "", "target": target['name'], "name": "", "category": r_type }) final_data={ "categories": categories, "nodes": nodes, "links": links } return StandardResponse( records=[{"records": final_data}], message="Graph data retrieved" ) except Exception as e: return StandardResponse( code=500, message=str(e) ) kb_router = router