123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- import sys,os
- current_path = os.getcwd()
- sys.path.append(current_path)
- from config.site import SiteConfig
- from fastapi import APIRouter, Depends, Query
- from db.database import get_db
- from sqlalchemy.orm import Session
- from sqlalchemy import text
- from agent.models.web.response import StandardResponse,FAILED,SUCCESS
- from agent.models.web.request import BasicRequest
- from agent.libs.graph import GraphBusiness
- from agent.libs.auth import verify_session_id, SessionValues
- import logging
- from typing import Optional, List
- router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
- logger = logging.getLogger(__name__)
- config = SiteConfig()
- LOG_DIR = config.get_config("TASK_LOG_DIR", current_path)
- # job_category = Column(String(64), nullable=False)
- # job_name = Column(String(64))
- # job_details = Column(Text, nullable=False)
- # job_creator = Column(String(64), nullable=False)
- # job_logs = Column(Text, nullable=True)
- # job_files = Column(String(300), nullable=True)
- @router.post('/summary', response_model=StandardResponse)
- def summary_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
- if request.action != "get_summary":
- return StandardResponse(code=FAILED, message="invalid action")
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- summary = biz.get_graph_summary(graph_id=graph_id)
- if summary:
- logger.info(summary)
- return StandardResponse(code=SUCCESS, message="summary found", records=[summary])
- else:
- return StandardResponse(code=FAILED, message="summary not found",records=[])
- @router.post('/schemas', response_model=StandardResponse)
- def schemas_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
- if request.action== "get_nodes_schemas":
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- schemas = biz.get_nodes_categories(graph_id=graph_id)
- if schemas:
- return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
- if request.action== "get_edges_schemas":
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- schemas = biz.get_edges_categories(graph_id=graph_id)
- if schemas:
- return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
- return StandardResponse(code=FAILED, message="invalid action")
- @router.post('/nodes', response_model=StandardResponse)
- def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
- if (request.action == "search_nodes"):
- node_name = request.get_param("name","")
- category = request.get_param("category","")
- graph_id = request.get_param("graph_id",0)
- biz = GraphBusiness(db)
- if node_name == "":
- return StandardResponse(code=FAILED, message="node name is empty", records=[])
- if category == "":
- return StandardResponse(code=FAILED, message="category is empty", records=[])
- if graph_id == 0:
- return StandardResponse(code=FAILED, message="graph id is empty", records=[])
-
- nodes = biz.search_like_node_by_name(graph_id=graph_id, category=category, name=node_name)
- if nodes:
- return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
-
- else:
- return StandardResponse(code=FAILED, message="search job failed")
- elif (request.action == "get_nodes"):
- graph_id = request.get_param("graph_id",0)
- page = request.get_param("page",1)
- page_size = request.get_param("page_size",1)
- biz = GraphBusiness(db)
- nodes = biz.get_nodes_by_page(graph_id=graph_id, page=page, page_size=page_size)
- if nodes:
- return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
- elif (request.action == "neighbors"):
- graph_id = request.get_param("graph_id",0)
- node_id = request.get_param("node_id",0)
- if node_id>0:
- biz = GraphBusiness(db)
- node = biz.get_node_by_id(graph_id=graph_id, node_id=node_id)
- if node is None:
- return StandardResponse(code=FAILED, message="node not found", records=[])
- nodes_in = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="in")
- nodes_out = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="out")
- nodes_all = []
- nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"self"})
- for node in nodes_in:
- nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"in"})
- for node in nodes_out:
- nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"out"})
-
- return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
- return StandardResponse(code=FAILED, message="invalid action")
- async def get_node_properties(node_id: int, db: Session) -> dict:
- """
- 查询节点的属性
- :param node_id: 节点ID
- :param db: 数据库会话
- :return: 属性字典
- """
- prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id")
- result = db.execute(prop_sql, {'node_id': node_id}).fetchall()
- properties = {}
- for row in result:
- properties[row._mapping['prop_title']] = row._mapping['prop_value']
- return properties
- @router.get("/graph_data", response_model=StandardResponse)
- async def get_graph_data(
- label_name: Optional[str] = Query(None),
- input_str: Optional[str] = Query(None),
- db: Session = Depends(get_db),
- sess:SessionValues = Depends(verify_session_id)
- ):
- """
- 获取用户关联的图谱数据
- - 从session_id获取user_id
- - 查询DbUserDataRelation获取用户关联的数据
- - 返回与Java端一致的数据结构
- """
- try:
- # 1. 从session获取user_id
- user_id = sess.user_id
- if not user_id:
- return StandardResponse(code=FAILED, message="user not found", records=[])
-
- # 2. 使用JOIN查询用户关联的图谱数据
- sql = text("""
- WITH RankedRelations AS (
- SELECT
- e.name as rType,
- m.id as target_id,
- m.name as target_name,
- m.category as target_label,
- (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
- ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
- FROM user_data_relations udr
- JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
- JOIN kg_edges e ON n.id = e.src_id
- JOIN kg_nodes m ON e.dest_id = m.id
- WHERE udr.user_id = :user_id
- AND n.category = :label_name
- AND n.name = :input_str
- AND n.status = '0'
- )
- SELECT rType, target_id, target_name, target_label, pCount
- FROM RankedRelations
- WHERE rn <= 50
- ORDER BY rType
- """)
-
- # 3. 组装返回数据结构
- categories = ["中心词", "关系"]
- nodes = []
- links = []
-
- # 查询中心节点
- center_sql = text("""
- SELECT n.id, n.name, n.category
- FROM user_data_relations udr
- JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
- WHERE udr.user_id = :user_id
- AND n.category = :label_name
- AND n.name = :input_str
- AND n.status = '0'
- """)
-
- # 执行查询并处理结果
- center_node = None
- rtype_map = {}
-
- # 1. 查询中心节点
- center_result = db.execute(center_sql, {
- 'user_id': user_id,
- 'label_name': label_name,
- 'input_str': input_str
- }).fetchall()
-
- if center_result:
- for row in center_result:
- center_node = {
- "id": 0,
- "name": row._mapping['name'],
- "label": row._mapping['category'],
- "symbolSize": 50,
- "symbol": "circle",
- "properties": await get_node_properties(row._mapping['id'], db)
- }
- nodes.append(center_node)
- break
-
- # 2. 查询关联的边和目标节点
- relation_result = db.execute(sql, {
- 'user_id': user_id,
- 'label_name': label_name,
- 'input_str': input_str
- }).fetchall()
-
- if relation_result:
- for row in relation_result:
- r_type = row._mapping['rtype']
- target_node = {
- "id": row._mapping['target_id'],
- "name": row._mapping['target_name'],
- "label": row._mapping['target_label'],
- "symbolSize": 28,
- "symbol": "circle",
- "properties": await get_node_properties(row._mapping['target_id'], db)
- }
-
- if r_type not in rtype_map:
- rtype_map[r_type] = []
- rtype_map[r_type].append(target_node)
-
- if r_type not in categories:
- categories.append(r_type)
-
- # 3. 组装返回结果
- if center_node:
- for r_type, targets in rtype_map.items():
- # 添加关系节点
- relation_node = {
- "id": len(nodes),
- "name": "",
- "label": center_node['label'],
- "symbolSize": 10,
- "symbol": "diamond",
- "properties": {}
- }
- nodes.append(relation_node)
-
- # 添加中心节点到关系节点的链接
- links.append({
- "source": center_node['name'],
- "target": "",
- "name": r_type,
- "category": r_type
- })
-
- # 添加关系节点到目标节点的链接
- for target in targets:
- links.append({
- "source": "",
- "target": target['name'],
- "name": "",
- "category": r_type
- })
-
- final_data={
- "categories": categories,
- "nodes": nodes,
- "links": links
- }
- return StandardResponse(
- records=[{"records": final_data}],
- message="Graph data retrieved"
- )
- except Exception as e:
- return StandardResponse(
- code=500,
- message=str(e)
- )
-
-
- kb_router = router
|