|
@@ -13,6 +13,11 @@ from agent.libs.graph import GraphBusiness
|
|
|
from agent.libs.auth import verify_session_id, SessionValues
|
|
|
import logging
|
|
|
from typing import Optional, List
|
|
|
+from agent.models.db.graph import DbKgNode
|
|
|
+from agent.models.db.tree_structure import TreeStructure,KgGraphCategory
|
|
|
+import string
|
|
|
+import json
|
|
|
+from agent.tree_utils import get_tree_dto
|
|
|
|
|
|
router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -120,10 +125,11 @@ async def get_node_properties(node_id: int, db: Session) -> dict:
|
|
|
|
|
|
@router.get("/graph_data", response_model=StandardResponse)
|
|
|
async def get_graph_data(
|
|
|
- label_name: Optional[str] = Query(None),
|
|
|
- input_str: Optional[str] = Query(None),
|
|
|
+ label_name: str,
|
|
|
+ user_id: int,
|
|
|
+ graph_id: int,
|
|
|
db: Session = Depends(get_db),
|
|
|
- sess:SessionValues = Depends(verify_session_id)
|
|
|
+ input_str: Optional[str] = None
|
|
|
):
|
|
|
"""
|
|
|
获取用户关联的图谱数据
|
|
@@ -133,9 +139,34 @@ async def get_graph_data(
|
|
|
"""
|
|
|
try:
|
|
|
# 1. 从session获取user_id
|
|
|
- user_id = sess.user_id
|
|
|
- if not user_id:
|
|
|
- return StandardResponse(code=FAILED, message="user not found", records=[])
|
|
|
+ # user_id = sess.user_id
|
|
|
+ # if not user_id:
|
|
|
+ # return StandardResponse(code=FAILED, message="user not found", records=[])
|
|
|
+
|
|
|
+ # 处理input_str为空的情况
|
|
|
+ if not input_str:
|
|
|
+ # 根据user_id、graph_id和label_name从kg_nodes表中获取一个name
|
|
|
+ get_name_sql = text("""
|
|
|
+ SELECT n.name
|
|
|
+ FROM user_data_relations udr
|
|
|
+ JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
|
|
|
+ WHERE udr.user_id = :user_id
|
|
|
+ AND n.category = :label_name
|
|
|
+ AND n.status = '0'
|
|
|
+ AND n.graph_id = :graph_id
|
|
|
+ LIMIT 1
|
|
|
+ """)
|
|
|
+
|
|
|
+ name_result = db.execute(get_name_sql, {
|
|
|
+ 'user_id': user_id,
|
|
|
+ 'label_name': label_name,
|
|
|
+ 'graph_id': graph_id
|
|
|
+ }).fetchone()
|
|
|
+
|
|
|
+ if not name_result:
|
|
|
+ return StandardResponse(code=FAILED, message="No node found for given parameters", records=[])
|
|
|
+
|
|
|
+ input_str = name_result._mapping['name']
|
|
|
|
|
|
# 2. 使用JOIN查询用户关联的图谱数据
|
|
|
sql = text("""
|
|
@@ -155,6 +186,7 @@ async def get_graph_data(
|
|
|
AND n.category = :label_name
|
|
|
AND n.name = :input_str
|
|
|
AND n.status = '0'
|
|
|
+ AND n.graph_id = :graph_id
|
|
|
)
|
|
|
SELECT rType, target_id, target_name, target_label, pCount
|
|
|
FROM RankedRelations
|
|
@@ -163,7 +195,7 @@ async def get_graph_data(
|
|
|
""")
|
|
|
|
|
|
# 3. 组装返回数据结构
|
|
|
- categories = ["中心词", "关系"]
|
|
|
+ categories = [{"name": "中心词"}, {"name": "关系"}]
|
|
|
nodes = []
|
|
|
links = []
|
|
|
|
|
@@ -175,94 +207,155 @@ async def get_graph_data(
|
|
|
WHERE udr.user_id = :user_id
|
|
|
AND n.category = :label_name
|
|
|
AND n.name = :input_str
|
|
|
- AND n.status = '0'
|
|
|
+ AND n.status = '0' AND n.graph_id= :graph_id limit 1
|
|
|
""")
|
|
|
|
|
|
# 执行查询并处理结果
|
|
|
center_node = None
|
|
|
- rtype_map = {}
|
|
|
-
|
|
|
+ c_map = {"中心词": 0, "关系": 1}
|
|
|
+ node_id = 0
|
|
|
+
|
|
|
+ # 构建graph_dto数据结构
|
|
|
+ graph_dto = {
|
|
|
+ "label": "",
|
|
|
+ "name": "",
|
|
|
+ "id": 0,
|
|
|
+ "properties": {},
|
|
|
+ "ENodeRSDTOS": []
|
|
|
+ }
|
|
|
+
|
|
|
# 1. 查询中心节点
|
|
|
center_result = db.execute(center_sql, {
|
|
|
'user_id': user_id,
|
|
|
'label_name': label_name,
|
|
|
- 'input_str': input_str
|
|
|
+ 'input_str': input_str,
|
|
|
+ 'graph_id': graph_id
|
|
|
}).fetchall()
|
|
|
|
|
|
if center_result:
|
|
|
for row in center_result:
|
|
|
- center_node = {
|
|
|
- "id": 0,
|
|
|
- "name": row._mapping['name'],
|
|
|
- "label": row._mapping['category'],
|
|
|
- "symbolSize": 50,
|
|
|
- "symbol": "circle",
|
|
|
- "properties": await get_node_properties(row._mapping['id'], db)
|
|
|
- }
|
|
|
- nodes.append(center_node)
|
|
|
+ graph_dto["label"] = row._mapping['category']
|
|
|
+ graph_dto["name"] = row._mapping['name']
|
|
|
+ graph_dto["id"] = row._mapping['id']
|
|
|
+ graph_dto["properties"] = await get_node_properties(row._mapping['id'], db)
|
|
|
break
|
|
|
-
|
|
|
+
|
|
|
# 2. 查询关联的边和目标节点
|
|
|
relation_result = db.execute(sql, {
|
|
|
'user_id': user_id,
|
|
|
'label_name': label_name,
|
|
|
- 'input_str': input_str
|
|
|
+ 'input_str': input_str,
|
|
|
+ 'graph_id': graph_id
|
|
|
}).fetchall()
|
|
|
|
|
|
if relation_result:
|
|
|
+ rs_id = 2
|
|
|
for row in relation_result:
|
|
|
r_type = row._mapping['rtype']
|
|
|
- target_node = {
|
|
|
- "id": row._mapping['target_id'],
|
|
|
- "name": row._mapping['target_name'],
|
|
|
- "label": row._mapping['target_label'],
|
|
|
- "symbolSize": 28,
|
|
|
- "symbol": "circle",
|
|
|
- "properties": await get_node_properties(row._mapping['target_id'], db)
|
|
|
- }
|
|
|
-
|
|
|
- if r_type not in rtype_map:
|
|
|
- rtype_map[r_type] = []
|
|
|
- rtype_map[r_type].append(target_node)
|
|
|
|
|
|
- if r_type not in categories:
|
|
|
- categories.append(r_type)
|
|
|
-
|
|
|
- # 3. 组装返回结果
|
|
|
- if center_node:
|
|
|
- for r_type, targets in rtype_map.items():
|
|
|
- # 添加关系节点
|
|
|
- relation_node = {
|
|
|
- "id": len(nodes),
|
|
|
- "name": "",
|
|
|
- "label": center_node['label'],
|
|
|
- "symbolSize": 10,
|
|
|
- "symbol": "diamond",
|
|
|
- "properties": {}
|
|
|
+ # 添加到graph_dto
|
|
|
+ e_node_rs = {
|
|
|
+ "RType": r_type,
|
|
|
+ "ENodeDTOS": [{
|
|
|
+ "Label": row._mapping['target_label'],
|
|
|
+ "Name": row._mapping['target_name'],
|
|
|
+ "Id": row._mapping['target_id'],
|
|
|
+ "PCount": row._mapping['pcount'],
|
|
|
+ "properties": await get_node_properties(row._mapping['target_id'], db)
|
|
|
+ }]
|
|
|
}
|
|
|
- nodes.append(relation_node)
|
|
|
|
|
|
- # 添加中心节点到关系节点的链接
|
|
|
- links.append({
|
|
|
- "source": center_node['name'],
|
|
|
- "target": "",
|
|
|
- "name": r_type,
|
|
|
- "category": r_type
|
|
|
- })
|
|
|
+ # 检查是否已有该关系类型
|
|
|
+ existing_rs = next((rs for rs in graph_dto["ENodeRSDTOS"] if rs["RType"] == r_type), None)
|
|
|
+ if existing_rs:
|
|
|
+ existing_rs["ENodeDTOS"].extend(e_node_rs["ENodeDTOS"])
|
|
|
+ else:
|
|
|
+ graph_dto["ENodeRSDTOS"].append(e_node_rs)
|
|
|
|
|
|
- # 添加关系节点到目标节点的链接
|
|
|
- for target in targets:
|
|
|
- links.append({
|
|
|
- "source": "",
|
|
|
- "target": target['name'],
|
|
|
- "name": "",
|
|
|
- "category": r_type
|
|
|
- })
|
|
|
+ if r_type not in c_map:
|
|
|
+ c_map[r_type] = rs_id
|
|
|
+ categories.append({"name": r_type})
|
|
|
+ rs_id += 1
|
|
|
+
|
|
|
+ print("graph_dto:", graph_dto) # 打印graph_dto
|
|
|
|
|
|
- final_data={
|
|
|
- "categories": categories,
|
|
|
- "nodes": nodes,
|
|
|
- "links": links
|
|
|
+ # 构建中心节点
|
|
|
+ center_node = {
|
|
|
+ "label": graph_dto["name"],
|
|
|
+ 'type': graph_dto["label"],
|
|
|
+ "category": 0,
|
|
|
+ "name": "0",
|
|
|
+ # "id": graph_dto["id"],
|
|
|
+ "symbol": "circle",
|
|
|
+ "symbolSize": 50,
|
|
|
+ "properties": graph_dto["properties"],
|
|
|
+ "nodeId": graph_dto["id"],
|
|
|
+ "itemStyle": {"display": True}
|
|
|
+ }
|
|
|
+ nodes.append(center_node)
|
|
|
+
|
|
|
+ # 处理关系类型
|
|
|
+ rs_id = 2
|
|
|
+
|
|
|
+ for rs in graph_dto["ENodeRSDTOS"]:
|
|
|
+ r_type = rs["RType"]
|
|
|
+
|
|
|
+ if r_type not in c_map:
|
|
|
+ c_map[r_type] = rs_id
|
|
|
+ categories.append({"name": r_type})
|
|
|
+ rs_id += 1
|
|
|
+
|
|
|
+ # 关系节点
|
|
|
+ relation_node = {
|
|
|
+ "label": "",
|
|
|
+ 'type': graph_dto["label"],
|
|
|
+ "category": 1,
|
|
|
+ "name": str(len(nodes)),
|
|
|
+ # "id": len(nodes),
|
|
|
+ "symbol": "diamond",
|
|
|
+ "symbolSize": 10,
|
|
|
+ "properties": graph_dto["properties"],
|
|
|
+ "nodeId": len(nodes),
|
|
|
+ "itemStyle": {"display": True}
|
|
|
+ }
|
|
|
+ nodes.append(relation_node)
|
|
|
+
|
|
|
+ # 添加链接
|
|
|
+ links.append({
|
|
|
+ "source": "0",
|
|
|
+ "target": str(nodes.index(relation_node)),
|
|
|
+ "value": r_type,
|
|
|
+ "relationType": r_type
|
|
|
+ })
|
|
|
+
|
|
|
+ # 处理子节点
|
|
|
+ for e_node in rs["ENodeDTOS"]:
|
|
|
+ item_style = {"display": e_node["PCount"] > 0}
|
|
|
+ child_node = {
|
|
|
+ "label": e_node["Name"],
|
|
|
+ "type": e_node["Label"],
|
|
|
+ "category": c_map[r_type],
|
|
|
+ "name": str(len(nodes)),
|
|
|
+ # "id": e_node["Id"],
|
|
|
+ "symbol": "circle",
|
|
|
+ "symbolSize": 28,
|
|
|
+ "properties": e_node["properties"],
|
|
|
+ "nodeId": e_node["Id"],
|
|
|
+ "itemStyle": item_style
|
|
|
+ }
|
|
|
+ nodes.append(child_node)
|
|
|
+
|
|
|
+ links.append({
|
|
|
+ "source": str(nodes.index(relation_node)),
|
|
|
+ "target": str(nodes.index(child_node)),
|
|
|
+ "value": "",
|
|
|
+ "relationType": r_type
|
|
|
+ })
|
|
|
+
|
|
|
+ final_data = {
|
|
|
+ "categories": categories,
|
|
|
+ "node": nodes,
|
|
|
+ "links": links
|
|
|
}
|
|
|
return StandardResponse(
|
|
|
records=[{"records": final_data}],
|
|
@@ -274,6 +367,199 @@ async def get_graph_data(
|
|
|
message=str(e)
|
|
|
)
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/user_sub_graphs", response_model=StandardResponse)
|
|
|
+async def get_user_sub_graphs(
|
|
|
+ user_id: int,
|
|
|
+ pageNo: int = 1,
|
|
|
+ pageSize: int = 10,
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 获取用户关联的子图列表
|
|
|
+ - 根据user_id和data_category='sub_graph'查询user_data_relations表
|
|
|
+ - 关联jobs表获取job_name
|
|
|
+ - 返回data_id和job_name列表
|
|
|
+ - 支持分页查询,参数pageNo(默认1)和pageSize(默认10)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 查询用户关联的子图
|
|
|
+ offset = (pageNo - 1) * pageSize
|
|
|
+ sql = text("""
|
|
|
+ SELECT udr.data_id, j.job_name
|
|
|
+ FROM user_data_relations udr
|
|
|
+ LEFT JOIN jobs j ON udr.data_id = j.id
|
|
|
+ WHERE udr.user_id = :user_id
|
|
|
+ AND udr.data_category = 'sub_graph' order by udr.data_id desc
|
|
|
+ LIMIT :pageSize OFFSET :offset
|
|
|
+ """)
|
|
|
+
|
|
|
+ result = db.execute(sql, {'user_id': user_id, 'pageSize': pageSize, 'offset': offset}).fetchall()
|
|
|
+
|
|
|
+ records = []
|
|
|
+ for row in result:
|
|
|
+ records.append({
|
|
|
+ "graph_id": row._mapping['data_id'],
|
|
|
+ "graph_name": row._mapping['job_name']
|
|
|
+ })
|
|
|
+
|
|
|
+ return StandardResponse(
|
|
|
+ records=records,
|
|
|
+ message="User sub graphs retrieved"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ return StandardResponse(
|
|
|
+ code=500,
|
|
|
+ message=str(e)
|
|
|
+ )
|
|
|
+
|
|
|
+def build_disease_tree(disease_nodes: list, root_name: str = "疾病") -> dict:
|
|
|
+ """
|
|
|
+ 构建疾病树状结构的公共方法
|
|
|
+ :param disease_nodes: 疾病节点列表,每个节点需包含name属性
|
|
|
+ :param root_name: 根节点名称,默认为"疾病"
|
|
|
+ :return: 树状结构字典
|
|
|
+ """
|
|
|
+ if not disease_nodes:
|
|
|
+ return {"name": root_name, "sNode": []}
|
|
|
+
|
|
|
+ # 按拼音首字母分类
|
|
|
+ letter_groups = {letter: [] for letter in string.ascii_uppercase}
|
|
|
+ letter_groups['其他'] = []
|
|
|
+
|
|
|
+ for node in disease_nodes:
|
|
|
+ name = node.name if hasattr(node, 'name') else str(node)
|
|
|
+ first_letter = get_first_letter(name)
|
|
|
+ letter_groups[first_letter].append(name)
|
|
|
+
|
|
|
+ # 构建JSON结构
|
|
|
+ tree_structure = {
|
|
|
+ "name": root_name,
|
|
|
+ "sNode": []
|
|
|
+ }
|
|
|
+
|
|
|
+ # 先添加A-Z的分类
|
|
|
+ for letter in string.ascii_uppercase:
|
|
|
+ if letter_groups[letter]:
|
|
|
+ letter_node = {
|
|
|
+ "name": letter,
|
|
|
+ "sNode": [{"name": disease, "sNode": []} for disease in sorted(letter_groups[letter])]
|
|
|
+ }
|
|
|
+ tree_structure["sNode"].append(letter_node)
|
|
|
+
|
|
|
+ # 最后添加"其他"分类(如果有的话)
|
|
|
+ if letter_groups['其他']:
|
|
|
+ other_node = {
|
|
|
+ "name": "其他",
|
|
|
+ "sNode": [{"name": disease, "sNode": []} for disease in sorted(letter_groups['其他'])]
|
|
|
+ }
|
|
|
+ tree_structure["sNode"].append(other_node)
|
|
|
+ # content=json.dumps(tree_structure, ensure_ascii=False)
|
|
|
+ # print(content)
|
|
|
+ # tree_dto=get_tree_dto(content)
|
|
|
+ # print(tree_dto)
|
|
|
+ return tree_structure
|
|
|
+
|
|
|
+def get_first_letter(word):
|
|
|
+ """获取中文词语的拼音首字母"""
|
|
|
+ if not word:
|
|
|
+ return '其他'
|
|
|
+
|
|
|
+ # 获取第一个汉字的拼音首字母
|
|
|
+ first_char = word[0]
|
|
|
+ try:
|
|
|
+ import pypinyin
|
|
|
+ first_letter = pypinyin.pinyin(first_char, style=pypinyin.FIRST_LETTER)[0][0].upper()
|
|
|
+ return first_letter if first_letter in string.ascii_uppercase else '其他'
|
|
|
+ except Exception as e:
|
|
|
+ print(str(e))
|
|
|
+ return '其他'
|
|
|
+
|
|
|
+@router.get('/disease_tree')
|
|
|
+async def get_disease_tree(graph_id: int, db: Session = Depends(get_db)):
|
|
|
+ """
|
|
|
+ 根据graph_id查询kg_nodes表中category是疾病的数据并构建树状结构
|
|
|
+ 严格按照字母A-Z顺序进行归类,中文首字母归类到对应拼音首字母
|
|
|
+ """
|
|
|
+ # 查询疾病节点
|
|
|
+ disease_nodes = db.query(DbKgNode).filter(
|
|
|
+ DbKgNode.graph_id == graph_id,
|
|
|
+ DbKgNode.category == '疾病'
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ tree_structure = build_disease_tree(disease_nodes)
|
|
|
+ return StandardResponse(records=[{"records": tree_structure}])
|
|
|
+
|
|
|
+
|
|
|
+@router.get('/graph_categories')
|
|
|
+async def get_graph_categories(
|
|
|
+ user_id: int,
|
|
|
+ graph_id: int,
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 根据user_id和graph_id查询kg_graph_category表中的category列表
|
|
|
+ 返回category的字符串列表(按照id正序排列)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 查询category列表(按照id正序)
|
|
|
+ categories = db.query(KgGraphCategory.category).filter(
|
|
|
+ KgGraphCategory.user_id == user_id,
|
|
|
+ KgGraphCategory.graph_id == graph_id
|
|
|
+ ).order_by(KgGraphCategory.id).all()
|
|
|
+
|
|
|
+ if not categories:
|
|
|
+ return StandardResponse(code=FAILED, message="No categories found")
|
|
|
+
|
|
|
+ # 转换为字符串列表
|
|
|
+ category_list = [category[0] for category in categories]
|
|
|
|
|
|
+ return StandardResponse(
|
|
|
+ records=[{"records": category_list}],
|
|
|
+ message="Graph categories retrieved"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ return StandardResponse(
|
|
|
+ code=500,
|
|
|
+ message=str(e)
|
|
|
+ )
|
|
|
+
|
|
|
+@router.get('/tree_structure')
|
|
|
+async def get_tree_structure(
|
|
|
+ user_id: int,
|
|
|
+ graph_id: int,
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 根据user_id和graph_id获取树状结构数据
|
|
|
+ 1. 查询kg_tree_structures表获取content
|
|
|
+ 2. 调用get_tree_dto方法转换数据格式
|
|
|
+ 3. 返回转换后的数据
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 查询树状结构数据
|
|
|
+ tree_structure = db.query(TreeStructure).filter(
|
|
|
+ TreeStructure.user_id == user_id,
|
|
|
+ TreeStructure.graph_id == graph_id
|
|
|
+ ).first()
|
|
|
+
|
|
|
+ if not tree_structure:
|
|
|
+ return StandardResponse(code=FAILED, message="Tree structure not found")
|
|
|
+
|
|
|
+ # 转换数据格式
|
|
|
+ tree_dto = get_tree_dto(tree_structure.content)
|
|
|
|
|
|
+ return StandardResponse(
|
|
|
+ records=[{"records": tree_dto}],
|
|
|
+ message="Tree structure retrieved"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ return StandardResponse(
|
|
|
+ code=500,
|
|
|
+ message=str(e)
|
|
|
+ )
|
|
|
+
|
|
|
kb_router = router
|