Browse Source

1、session时间问题修复
2、自构建图谱查询接口

yuchengwei 1 month ago
parent
commit
a5587d2075

+ 18 - 18
agent/libs/auth.py

@@ -35,25 +35,25 @@ def verify_session_id(request: Request)-> SessionValues:
         # 提取 session_id
         session_user_id = auth_header.split(" ")[1]
         session_id = auth_header.split(" ")[2]
-        return SessionValues(session_id, '', session_user_id, '')
+        # return SessionValues(session_id, '', session_user_id, '')
         # 在这里添加你的 session_id 校验逻辑
         # 例如,检查 session_id 是否在数据库中存在
-        # if not session_business.validate_session(session_user_id, session_id):
-        #     print("Invalid session_id", session_user_id, session_id)
-        #     raise HTTPException(
-        #         status_code=status.HTTP_401_UNAUTHORIZED,
-        #         detail="Invalid session_id",
-        #         headers={"WWW-Authenticate": "Beaver"}
-        #     )
-        #
-        # user = user_business.get_user_by_username(session_user_id)
-        # if user is None:
-        #     print("Invalid user_id", session_user_id)
-        #     raise HTTPException(
-        #         status_code=status.HTTP_401_UNAUTHORIZED,
-        #         detail="Invalid username",
-        #         headers={"WWW-Authenticate": "Beaver"}
-        #     )
-        # return SessionValues(session_id, user.id, user.username, user.full_name)
+        if not session_business.validate_session(session_user_id, session_id):
+            print("Invalid session_id", session_user_id, session_id)
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail="Invalid session_id",
+                headers={"WWW-Authenticate": "Beaver"}
+            )
+
+        user = user_business.get_user_by_username(session_user_id)
+        if user is None:
+            print("Invalid user_id", session_user_id)
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail="Invalid username",
+                headers={"WWW-Authenticate": "Beaver"}
+            )
+        return SessionValues(session_id, user.id, user.username, user.full_name)
     # 如果校验通过,返回 session_id 或其他需要的信息
     return None

+ 1 - 1
agent/libs/user.py

@@ -66,7 +66,7 @@ class SessionBusiness:
         self.db = db
     def create_session(self, user:User):
         session_id = str(uuid.uuid4())
-        session = Session(session_id=session_id, user_id=user.id, username=user.username, full_name=user.full_name)
+        session = Session(session_id=session_id, user_id=user.id, username=user.username, full_name=user.full_name, created=datetime.now(), updated=datetime.now())
         self.db.add(session)
         self.db.commit()
         self.db.refresh(session)

+ 109 - 0
agent/libs/user_data_relation.py

@@ -0,0 +1,109 @@
+from datetime import datetime
+from sqlalchemy.orm import Session
+from ..models.db.graph import DbUserDataRelation
+
+class UserDataRelationBusiness:
+    def __init__(self, db: Session):
+        self.db = db
+    
+    def create_relation(self, user_id: int, data_category: str, data_id: int, role_id: int):
+        """
+        创建用户数据关联
+        :param user_id: 用户ID
+        :param data_category: 数据类别(表名)
+        :param data_id: 数据ID
+        :param role_id: 角色ID
+        :return: DbUserDataRelation对象
+        """
+        relation = DbUserDataRelation(
+            user_id=user_id,
+            data_category=data_category,
+            data_id=data_id,
+            role_id=role_id,
+            created=datetime.now(),
+            updated=datetime.now()
+        )
+        self.db.add(relation)
+        self.db.commit()
+        self.db.refresh(relation)
+        return relation
+    
+    def get_relation(self, relation_id: int):
+        """
+        获取关联关系
+        :param relation_id: 关联ID
+        :return: DbUserDataRelation对象
+        """
+        return self.db.query(DbUserDataRelation).filter(DbUserDataRelation.id == relation_id).first()
+    
+    def get_user_relations(self, user_id: int):
+        """
+        获取用户的所有关联关系
+        :param user_id: 用户ID
+        :return: DbUserDataRelation对象列表
+        """
+        return self.db.query(DbUserDataRelation).filter(DbUserDataRelation.user_id == user_id).all()
+    
+    def get_data_relations(self, data_category: str, data_id: int):
+        """
+        获取特定数据的所有关联关系
+        :param data_category: 数据类别(表名)
+        :param data_id: 数据ID
+        :return: DbUserDataRelation对象列表
+        """
+        return self.db.query(DbUserDataRelation).filter(
+            DbUserDataRelation.data_category == data_category,
+            DbUserDataRelation.data_id == data_id
+        ).all()
+    
+    def update_relation(self, relation_id: int, role_id: int):
+        """
+        更新关联关系的角色ID
+        :param relation_id: 关联ID
+        :param role_id: 新的角色ID
+        :return: 更新后的DbUserDataRelation对象
+        """
+        relation = self.get_relation(relation_id)
+        if relation:
+            relation.role_id = role_id
+            relation.updated = datetime.now()
+            self.db.commit()
+            self.db.refresh(relation)
+        return relation
+    
+    def delete_relation(self, relation_id: int):
+        """
+        删除关联关系
+        :param relation_id: 关联ID
+        :return: 是否删除成功
+        """
+        relation = self.get_relation(relation_id)
+        if relation:
+            self.db.delete(relation)
+            self.db.commit()
+            return True
+        return False
+    
+    def delete_user_relations(self, user_id: int):
+        """
+        删除用户的所有关联关系
+        :param user_id: 用户ID
+        :return: 删除的记录数
+        """
+        count = self.db.query(DbUserDataRelation).filter(DbUserDataRelation.user_id == user_id).delete()
+        self.db.commit()
+        return count
+    
+    def delete_data_relations(self, data_category: str, data_id: int):
+        """
+        删除特定数据的所有关联关系
+        :param data_category: 数据类别(表名)
+        :param data_id: 数据ID
+        :return: 删除的记录数
+        """
+        count = self.db.query(DbUserDataRelation).filter(
+            DbUserDataRelation.data_category == data_category,
+            DbUserDataRelation.data_id == data_id
+        ).delete()
+        self.db.commit()
+        return count

+ 173 - 0
agent/router/kb_router.py

@@ -6,11 +6,13 @@ from config.site import SiteConfig
 from fastapi import APIRouter, Depends, Query
 from db.database import get_db
 from sqlalchemy.orm import Session
+from sqlalchemy import text
 from agent.models.web.response import StandardResponse,FAILED,SUCCESS
 from agent.models.web.request import BasicRequest
 from agent.libs.graph import GraphBusiness
 from agent.libs.auth import verify_session_id, SessionValues
 import logging
+from typing import Optional, List
 
 router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
 logger = logging.getLogger(__name__)
@@ -100,7 +102,178 @@ def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:Session
             
             return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
     return StandardResponse(code=FAILED, message="invalid action")
+
+
+async def get_node_properties(node_id: int, db: Session) -> dict:
+    """
+    查询节点的属性
+    :param node_id: 节点ID
+    :param db: 数据库会话
+    :return: 属性字典
+    """
+    prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id")
+    result = db.execute(prop_sql, {'node_id': node_id}).fetchall()
+    properties = {}
+    for row in result:
+        properties[row._mapping['prop_title']] = row._mapping['prop_value']
+    return properties
+
+@router.get("/graph_data", response_model=StandardResponse)
+async def get_graph_data(
+    label_name: Optional[str] = Query(None),
+    input_str: Optional[str] = Query(None),
+    db: Session = Depends(get_db),
+    sess:SessionValues = Depends(verify_session_id)
+):
+    """
+    获取用户关联的图谱数据
+    - 从session_id获取user_id
+    - 查询DbUserDataRelation获取用户关联的数据
+    - 返回与Java端一致的数据结构
+    """
+    try:
+        # 1. 从session获取user_id
+        user_id = sess.user_id
+        if not user_id:
+            return StandardResponse(code=FAILED, message="user not found", records=[])
+        
+        # 2. 使用JOIN查询用户关联的图谱数据
+        sql = text("""
+        WITH RankedRelations AS (
+            SELECT 
+                e.name as rType, 
+                m.id as target_id, 
+                m.name as target_name,
+                m.category as target_label,
+                (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
+                ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
+            FROM user_data_relations udr
+            JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
+            JOIN kg_edges e ON n.id = e.src_id
+            JOIN kg_nodes m ON e.dest_id = m.id
+            WHERE udr.user_id = :user_id
+            AND n.category = :label_name
+            AND n.name = :input_str
+            AND n.status = '0'
+        )
+        SELECT rType, target_id, target_name, target_label, pCount
+        FROM RankedRelations
+        WHERE rn <= 50
+        ORDER BY rType
+        """)
         
+        # 3. 组装返回数据结构
+        categories = ["中心词", "关系"]
+        nodes = []
+        links = []
+        
+        # 查询中心节点
+        center_sql = text("""
+        SELECT n.id, n.name, n.category
+        FROM user_data_relations udr
+        JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
+        WHERE udr.user_id = :user_id
+        AND n.category = :label_name
+        AND n.name = :input_str
+        AND n.status = '0'
+        """)
+        
+        # 执行查询并处理结果
+        center_node = None
+        rtype_map = {}
+         
+        # 1. 查询中心节点
+        center_result = db.execute(center_sql, {
+             'user_id': user_id,
+             'label_name': label_name,
+             'input_str': input_str
+        }).fetchall()
+         
+        if center_result:
+             for row in center_result:
+                 center_node = {
+                     "id": 0,
+                     "name": row._mapping['name'],
+                     "label": row._mapping['category'],
+                     "symbolSize": 50,
+                     "symbol": "circle",
+                     "properties": await get_node_properties(row._mapping['id'], db)
+                 }
+                 nodes.append(center_node)
+                 break
+         
+        # 2. 查询关联的边和目标节点
+        relation_result = db.execute(sql, {
+             'user_id': user_id,
+             'label_name': label_name,
+             'input_str': input_str
+        }).fetchall()
+         
+        if relation_result:
+             for row in relation_result:
+                 r_type = row._mapping['rtype']
+                 target_node = {
+                     "id": row._mapping['target_id'],
+                     "name": row._mapping['target_name'],
+                     "label": row._mapping['target_label'],
+                     "symbolSize": 28,
+                     "symbol": "circle",
+                     "properties": await get_node_properties(row._mapping['target_id'], db)
+                 }
+                 
+                 if r_type not in rtype_map:
+                     rtype_map[r_type] = []
+                 rtype_map[r_type].append(target_node)
+                 
+                 if r_type not in categories:
+                     categories.append(r_type)
+         
+         # 3. 组装返回结果
+        if center_node:
+             for r_type, targets in rtype_map.items():
+                 # 添加关系节点
+                 relation_node = {
+                     "id": len(nodes),
+                     "name": "",
+                     "label": center_node['label'],
+                     "symbolSize": 10,
+                     "symbol": "diamond",
+                     "properties": {}
+                 }
+                 nodes.append(relation_node)
+                 
+                 # 添加中心节点到关系节点的链接
+                 links.append({
+                     "source": center_node['name'],
+                     "target": "",
+                     "name": r_type,
+                     "category": r_type
+                 })
+                 
+                 # 添加关系节点到目标节点的链接
+                 for target in targets:
+                     links.append({
+                         "source": "",
+                         "target": target['name'],
+                         "name": "",
+                         "category": r_type
+                     })
+         
+        final_data={
+                "categories": categories,
+                "nodes": nodes,
+                "links": links
+        }
+        return StandardResponse(
+            records=[{"records": final_data}],
+            message="Graph data retrieved"
+        )
+    except Exception as e:
+        return StandardResponse(
+            code=500,
+            message=str(e)
+        )
+
         
         
 kb_router = router

+ 31 - 2
executor/job_script/standard_kb_build.py

@@ -51,7 +51,30 @@ def parse_json(data):
                 entity.append("")
                 
 def import_entities(graph_id, entities_list, relations_list):
-
+    from agent.libs.user_data_relation import UserDataRelationBusiness
+    from agent.models.db.user import User, Role
+    from agent.libs.user import UserBusiness
+    
+    # 获取job信息
+    job = graphBiz.db.query(graphBiz.DbJob).filter(graphBiz.DbJob.id == graph_id).first()
+    if not job:
+        logger.error(f"Job not found with id: {graph_id}")
+        return entities
+        
+    # 从job_creator中提取user_id
+    user_id = int(job.job_creator.split('/')[1])
+    
+    # 获取用户角色
+    user_biz = UserBusiness(graphBiz.db)
+    user = user_biz.get_user(user_id)
+    if not user or not user.roles:
+        logger.error(f"User {user_id} has no roles assigned")
+        return entities
+    role_id = user.roles[0].id
+    
+    # 创建用户数据关系业务对象
+    relation_biz = UserDataRelationBusiness(graphBiz.db)
+    
     for text, ent in entities_list.items():
         id = ent['id']
         name = ent['name']
@@ -63,6 +86,9 @@ def import_entities(graph_id, entities_list, relations_list):
         node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
         if node:
             ent["db_id"] = node.id
+            # 创建节点数据关联
+            relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id)
+            
     for text, relations in relations_list.items():
         source_name = relations['source_name']
         source_type = relations['source_type']
@@ -71,7 +97,7 @@ def import_entities(graph_id, entities_list, relations_list):
         relation_type = relations['type']
         source_db_id = entities_list[source_name]['db_id']
         target_db_id = entities_list[target_name]['db_id']
-        graphBiz.create_edge(graph_id=graph_id, 
+        edge = graphBiz.create_edge(graph_id=graph_id, 
                              src_id=source_db_id, 
                              dest_id=target_db_id, 
                              name=relation_type, 
@@ -81,6 +107,9 @@ def import_entities(graph_id, entities_list, relations_list):
                                  "dest_type":target_type,
                              })
         logger.info(f"create edge: {source_db_id}->{target_db_id}")
+        # 创建边数据关联
+        if edge:
+            relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id)
         
     return entities
 if __name__ == "__main__":

+ 1 - 1
executor/main.py

@@ -46,7 +46,7 @@ SCRIPT_CONFIG = {
         'command': "python",  # 脚本路径
        'script':'standard_txt_chunk.py',
         'args': [],  # 脚本参数
-       'success': { 'queue_category': 'SYSTEM', 'queue_name':'CHUNKS'},
+       'success': { 'queue_category': 'SYSTEM', 'queue_name':'KB_EXTRACT'},
         'failed': { 'queue_category': 'SYSTEM', 'queue_name': 'CHUNKS'},
         'error': { 'queue_category': 'SYSTEM', 'queue_name': 'CHUNKS'}
         },

+ 17 - 1
graph/db/models.py

@@ -177,6 +177,22 @@ class DbKgDataset(Base):
     updated = Column(DateTime, nullable=False)
     status = Column(Integer, default=0)
 
+class DbUserDataRelation(Base):
+    __tablename__ = "user_data_relations"
+    
+    id = Column(Integer, primary_key=True, index=True)
+    user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
+    data_category = Column(String(255), nullable=False)  # 表名
+    data_id = Column(Integer, nullable=False)  # 表中数据ID
+    role_id = Column(Integer, nullable=False)  # 用户角色ID
+    created = Column(DateTime, nullable=False)
+    updated = Column(DateTime, nullable=False)
+    
+    user = relationship("DbUsers", back_populates="data_relations")
+
+# 在DbUsers中添加反向关系
+DbUsers.data_relations = relationship("DbUserDataRelation", back_populates="user")
+
 __all__=['DbKgEdge','DbKgNode','DbKgProp','DbDictICD','DbDictDRG',
          'DbDictDrug','DbKgSchemas','DbKgSubGraph','DbKgModels',
-         'DbKgGraphs', 'DbKgDataset']
+         'DbKgGraphs', 'DbKgDataset', 'DbUserDataRelation']

+ 40 - 0
openapi.yaml

@@ -482,6 +482,46 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+  /kb/graph_data:
+    get:
+      tags:
+      - knowledge build interface
+      summary: Get Graph Data
+      description: "\u83B7\u53D6\u7528\u6237\u5173\u8054\u7684\u56FE\u8C31\u6570\u636E\
+        \n- \u4ECEsession_id\u83B7\u53D6user_id\n- \u67E5\u8BE2DbUserDataRelation\u83B7\
+        \u53D6\u7528\u6237\u5173\u8054\u7684\u6570\u636E\n- \u8FD4\u56DE\u4E0EJava\u7AEF\
+        \u4E00\u81F4\u7684\u6570\u636E\u7ED3\u6784"
+      operationId: get_graph_data_kb_graph_data_get
+      parameters:
+      - name: label_name
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: string
+          - type: 'null'
+          title: Label Name
+      - name: input_str
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: string
+          - type: 'null'
+          title: Input Str
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/StandardResponse'
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
   /knowledge-base/:
     post:
       tags: