|
@@ -6,11 +6,13 @@ from config.site import SiteConfig
|
|
|
from fastapi import APIRouter, Depends, Query
|
|
|
from db.database import get_db
|
|
|
from sqlalchemy.orm import Session
|
|
|
+from sqlalchemy import text
|
|
|
from agent.models.web.response import StandardResponse,FAILED,SUCCESS
|
|
|
from agent.models.web.request import BasicRequest
|
|
|
from agent.libs.graph import GraphBusiness
|
|
|
from agent.libs.auth import verify_session_id, SessionValues
|
|
|
import logging
|
|
|
+from typing import Optional, List
|
|
|
|
|
|
router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -100,7 +102,178 @@ def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:Session
|
|
|
|
|
|
return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
|
|
|
return StandardResponse(code=FAILED, message="invalid action")
|
|
|
+
|
|
|
+
|
|
|
+async def get_node_properties(node_id: int, db: Session) -> dict:
|
|
|
+ """
|
|
|
+ 查询节点的属性
|
|
|
+ :param node_id: 节点ID
|
|
|
+ :param db: 数据库会话
|
|
|
+ :return: 属性字典
|
|
|
+ """
|
|
|
+ prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id")
|
|
|
+ result = db.execute(prop_sql, {'node_id': node_id}).fetchall()
|
|
|
+ properties = {}
|
|
|
+ for row in result:
|
|
|
+ properties[row._mapping['prop_title']] = row._mapping['prop_value']
|
|
|
+ return properties
|
|
|
+
|
|
|
+@router.get("/graph_data", response_model=StandardResponse)
|
|
|
+async def get_graph_data(
|
|
|
+ label_name: Optional[str] = Query(None),
|
|
|
+ input_str: Optional[str] = Query(None),
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+ sess:SessionValues = Depends(verify_session_id)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 获取用户关联的图谱数据
|
|
|
+ - 从session_id获取user_id
|
|
|
+ - 查询DbUserDataRelation获取用户关联的数据
|
|
|
+ - 返回与Java端一致的数据结构
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 1. 从session获取user_id
|
|
|
+ user_id = sess.user_id
|
|
|
+ if not user_id:
|
|
|
+ return StandardResponse(code=FAILED, message="user not found", records=[])
|
|
|
+
|
|
|
+ # 2. 使用JOIN查询用户关联的图谱数据
|
|
|
+ sql = text("""
|
|
|
+ WITH RankedRelations AS (
|
|
|
+ SELECT
|
|
|
+ e.name as rType,
|
|
|
+ m.id as target_id,
|
|
|
+ m.name as target_name,
|
|
|
+ m.category as target_label,
|
|
|
+ (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
|
|
|
+ ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
|
|
|
+ FROM user_data_relations udr
|
|
|
+ JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
|
|
|
+ JOIN kg_edges e ON n.id = e.src_id
|
|
|
+ JOIN kg_nodes m ON e.dest_id = m.id
|
|
|
+ WHERE udr.user_id = :user_id
|
|
|
+ AND n.category = :label_name
|
|
|
+ AND n.name = :input_str
|
|
|
+ AND n.status = '0'
|
|
|
+ )
|
|
|
+ SELECT rType, target_id, target_name, target_label, pCount
|
|
|
+ FROM RankedRelations
|
|
|
+ WHERE rn <= 50
|
|
|
+ ORDER BY rType
|
|
|
+ """)
|
|
|
|
|
|
+ # 3. 组装返回数据结构
|
|
|
+ categories = ["中心词", "关系"]
|
|
|
+ nodes = []
|
|
|
+ links = []
|
|
|
+
|
|
|
+ # 查询中心节点
|
|
|
+ center_sql = text("""
|
|
|
+ SELECT n.id, n.name, n.category
|
|
|
+ FROM user_data_relations udr
|
|
|
+ JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
|
|
|
+ WHERE udr.user_id = :user_id
|
|
|
+ AND n.category = :label_name
|
|
|
+ AND n.name = :input_str
|
|
|
+ AND n.status = '0'
|
|
|
+ """)
|
|
|
+
|
|
|
+ # 执行查询并处理结果
|
|
|
+ center_node = None
|
|
|
+ rtype_map = {}
|
|
|
+
|
|
|
+ # 1. 查询中心节点
|
|
|
+ center_result = db.execute(center_sql, {
|
|
|
+ 'user_id': user_id,
|
|
|
+ 'label_name': label_name,
|
|
|
+ 'input_str': input_str
|
|
|
+ }).fetchall()
|
|
|
+
|
|
|
+ if center_result:
|
|
|
+ for row in center_result:
|
|
|
+ center_node = {
|
|
|
+ "id": 0,
|
|
|
+ "name": row._mapping['name'],
|
|
|
+ "label": row._mapping['category'],
|
|
|
+ "symbolSize": 50,
|
|
|
+ "symbol": "circle",
|
|
|
+ "properties": await get_node_properties(row._mapping['id'], db)
|
|
|
+ }
|
|
|
+ nodes.append(center_node)
|
|
|
+ break
|
|
|
+
|
|
|
+ # 2. 查询关联的边和目标节点
|
|
|
+ relation_result = db.execute(sql, {
|
|
|
+ 'user_id': user_id,
|
|
|
+ 'label_name': label_name,
|
|
|
+ 'input_str': input_str
|
|
|
+ }).fetchall()
|
|
|
+
|
|
|
+ if relation_result:
|
|
|
+ for row in relation_result:
|
|
|
+ r_type = row._mapping['rtype']
|
|
|
+ target_node = {
|
|
|
+ "id": row._mapping['target_id'],
|
|
|
+ "name": row._mapping['target_name'],
|
|
|
+ "label": row._mapping['target_label'],
|
|
|
+ "symbolSize": 28,
|
|
|
+ "symbol": "circle",
|
|
|
+ "properties": await get_node_properties(row._mapping['target_id'], db)
|
|
|
+ }
|
|
|
+
|
|
|
+ if r_type not in rtype_map:
|
|
|
+ rtype_map[r_type] = []
|
|
|
+ rtype_map[r_type].append(target_node)
|
|
|
+
|
|
|
+ if r_type not in categories:
|
|
|
+ categories.append(r_type)
|
|
|
+
|
|
|
+ # 3. 组装返回结果
|
|
|
+ if center_node:
|
|
|
+ for r_type, targets in rtype_map.items():
|
|
|
+ # 添加关系节点
|
|
|
+ relation_node = {
|
|
|
+ "id": len(nodes),
|
|
|
+ "name": "",
|
|
|
+ "label": center_node['label'],
|
|
|
+ "symbolSize": 10,
|
|
|
+ "symbol": "diamond",
|
|
|
+ "properties": {}
|
|
|
+ }
|
|
|
+ nodes.append(relation_node)
|
|
|
+
|
|
|
+ # 添加中心节点到关系节点的链接
|
|
|
+ links.append({
|
|
|
+ "source": center_node['name'],
|
|
|
+ "target": "",
|
|
|
+ "name": r_type,
|
|
|
+ "category": r_type
|
|
|
+ })
|
|
|
+
|
|
|
+ # 添加关系节点到目标节点的链接
|
|
|
+ for target in targets:
|
|
|
+ links.append({
|
|
|
+ "source": "",
|
|
|
+ "target": target['name'],
|
|
|
+ "name": "",
|
|
|
+ "category": r_type
|
|
|
+ })
|
|
|
+
|
|
|
+ final_data={
|
|
|
+ "categories": categories,
|
|
|
+ "nodes": nodes,
|
|
|
+ "links": links
|
|
|
+ }
|
|
|
+ return StandardResponse(
|
|
|
+ records=[{"records": final_data}],
|
|
|
+ message="Graph data retrieved"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ return StandardResponse(
|
|
|
+ code=500,
|
|
|
+ message=str(e)
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
|
|
|
kb_router = router
|