浏览代码

1、session时间问题修复
2、自构建图谱查询接口

yuchengwei 1 月之前
父节点
当前提交
a5587d2075
共有 8 个文件被更改,包括 390 次插入23 次删除
  1. 18 18
      agent/libs/auth.py
  2. 1 1
      agent/libs/user.py
  3. 109 0
      agent/libs/user_data_relation.py
  4. 173 0
      agent/router/kb_router.py
  5. 31 2
      executor/job_script/standard_kb_build.py
  6. 1 1
      executor/main.py
  7. 17 1
      graph/db/models.py
  8. 40 0
      openapi.yaml

+ 18 - 18
agent/libs/auth.py

@@ -35,25 +35,25 @@ def verify_session_id(request: Request)-> SessionValues:
         # 提取 session_id
         # 提取 session_id
         session_user_id = auth_header.split(" ")[1]
         session_user_id = auth_header.split(" ")[1]
         session_id = auth_header.split(" ")[2]
         session_id = auth_header.split(" ")[2]
-        return SessionValues(session_id, '', session_user_id, '')
+        # return SessionValues(session_id, '', session_user_id, '')
         # 在这里添加你的 session_id 校验逻辑
         # 在这里添加你的 session_id 校验逻辑
         # 例如,检查 session_id 是否在数据库中存在
         # 例如,检查 session_id 是否在数据库中存在
-        # if not session_business.validate_session(session_user_id, session_id):
-        #     print("Invalid session_id", session_user_id, session_id)
-        #     raise HTTPException(
-        #         status_code=status.HTTP_401_UNAUTHORIZED,
-        #         detail="Invalid session_id",
-        #         headers={"WWW-Authenticate": "Beaver"}
-        #     )
-        #
-        # user = user_business.get_user_by_username(session_user_id)
-        # if user is None:
-        #     print("Invalid user_id", session_user_id)
-        #     raise HTTPException(
-        #         status_code=status.HTTP_401_UNAUTHORIZED,
-        #         detail="Invalid username",
-        #         headers={"WWW-Authenticate": "Beaver"}
-        #     )
-        # return SessionValues(session_id, user.id, user.username, user.full_name)
+        if not session_business.validate_session(session_user_id, session_id):
+            print("Invalid session_id", session_user_id, session_id)
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail="Invalid session_id",
+                headers={"WWW-Authenticate": "Beaver"}
+            )
+
+        user = user_business.get_user_by_username(session_user_id)
+        if user is None:
+            print("Invalid user_id", session_user_id)
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail="Invalid username",
+                headers={"WWW-Authenticate": "Beaver"}
+            )
+        return SessionValues(session_id, user.id, user.username, user.full_name)
     # 如果校验通过,返回 session_id 或其他需要的信息
     # 如果校验通过,返回 session_id 或其他需要的信息
     return None
     return None

+ 1 - 1
agent/libs/user.py

@@ -66,7 +66,7 @@ class SessionBusiness:
         self.db = db
         self.db = db
     def create_session(self, user:User):
     def create_session(self, user:User):
         session_id = str(uuid.uuid4())
         session_id = str(uuid.uuid4())
-        session = Session(session_id=session_id, user_id=user.id, username=user.username, full_name=user.full_name)
+        session = Session(session_id=session_id, user_id=user.id, username=user.username, full_name=user.full_name, created=datetime.now(), updated=datetime.now())
         self.db.add(session)
         self.db.add(session)
         self.db.commit()
         self.db.commit()
         self.db.refresh(session)
         self.db.refresh(session)

+ 109 - 0
agent/libs/user_data_relation.py

@@ -0,0 +1,109 @@
+from datetime import datetime
+from sqlalchemy.orm import Session
+from ..models.db.graph import DbUserDataRelation
+
+class UserDataRelationBusiness:
+    def __init__(self, db: Session):
+        self.db = db
+    
+    def create_relation(self, user_id: int, data_category: str, data_id: int, role_id: int):
+        """
+        创建用户数据关联
+        :param user_id: 用户ID
+        :param data_category: 数据类别(表名)
+        :param data_id: 数据ID
+        :param role_id: 角色ID
+        :return: DbUserDataRelation对象
+        """
+        relation = DbUserDataRelation(
+            user_id=user_id,
+            data_category=data_category,
+            data_id=data_id,
+            role_id=role_id,
+            created=datetime.now(),
+            updated=datetime.now()
+        )
+        self.db.add(relation)
+        self.db.commit()
+        self.db.refresh(relation)
+        return relation
+    
+    def get_relation(self, relation_id: int):
+        """
+        获取关联关系
+        :param relation_id: 关联ID
+        :return: DbUserDataRelation对象
+        """
+        return self.db.query(DbUserDataRelation).filter(DbUserDataRelation.id == relation_id).first()
+    
+    def get_user_relations(self, user_id: int):
+        """
+        获取用户的所有关联关系
+        :param user_id: 用户ID
+        :return: DbUserDataRelation对象列表
+        """
+        return self.db.query(DbUserDataRelation).filter(DbUserDataRelation.user_id == user_id).all()
+    
+    def get_data_relations(self, data_category: str, data_id: int):
+        """
+        获取特定数据的所有关联关系
+        :param data_category: 数据类别(表名)
+        :param data_id: 数据ID
+        :return: DbUserDataRelation对象列表
+        """
+        return self.db.query(DbUserDataRelation).filter(
+            DbUserDataRelation.data_category == data_category,
+            DbUserDataRelation.data_id == data_id
+        ).all()
+    
+    def update_relation(self, relation_id: int, role_id: int):
+        """
+        更新关联关系的角色ID
+        :param relation_id: 关联ID
+        :param role_id: 新的角色ID
+        :return: 更新后的DbUserDataRelation对象
+        """
+        relation = self.get_relation(relation_id)
+        if relation:
+            relation.role_id = role_id
+            relation.updated = datetime.now()
+            self.db.commit()
+            self.db.refresh(relation)
+        return relation
+    
+    def delete_relation(self, relation_id: int):
+        """
+        删除关联关系
+        :param relation_id: 关联ID
+        :return: 是否删除成功
+        """
+        relation = self.get_relation(relation_id)
+        if relation:
+            self.db.delete(relation)
+            self.db.commit()
+            return True
+        return False
+    
+    def delete_user_relations(self, user_id: int):
+        """
+        删除用户的所有关联关系
+        :param user_id: 用户ID
+        :return: 删除的记录数
+        """
+        count = self.db.query(DbUserDataRelation).filter(DbUserDataRelation.user_id == user_id).delete()
+        self.db.commit()
+        return count
+    
+    def delete_data_relations(self, data_category: str, data_id: int):
+        """
+        删除特定数据的所有关联关系
+        :param data_category: 数据类别(表名)
+        :param data_id: 数据ID
+        :return: 删除的记录数
+        """
+        count = self.db.query(DbUserDataRelation).filter(
+            DbUserDataRelation.data_category == data_category,
+            DbUserDataRelation.data_id == data_id
+        ).delete()
+        self.db.commit()
+        return count

+ 173 - 0
agent/router/kb_router.py

@@ -6,11 +6,13 @@ from config.site import SiteConfig
 from fastapi import APIRouter, Depends, Query
 from fastapi import APIRouter, Depends, Query
 from db.database import get_db
 from db.database import get_db
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
+from sqlalchemy import text
 from agent.models.web.response import StandardResponse,FAILED,SUCCESS
 from agent.models.web.response import StandardResponse,FAILED,SUCCESS
 from agent.models.web.request import BasicRequest
 from agent.models.web.request import BasicRequest
 from agent.libs.graph import GraphBusiness
 from agent.libs.graph import GraphBusiness
 from agent.libs.auth import verify_session_id, SessionValues
 from agent.libs.auth import verify_session_id, SessionValues
 import logging
 import logging
+from typing import Optional, List
 
 
 router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
 router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -100,7 +102,178 @@ def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:Session
             
             
             return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
             return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
     return StandardResponse(code=FAILED, message="invalid action")
     return StandardResponse(code=FAILED, message="invalid action")
+
+
+async def get_node_properties(node_id: int, db: Session) -> dict:
+    """
+    查询节点的属性
+    :param node_id: 节点ID
+    :param db: 数据库会话
+    :return: 属性字典
+    """
+    prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id")
+    result = db.execute(prop_sql, {'node_id': node_id}).fetchall()
+    properties = {}
+    for row in result:
+        properties[row._mapping['prop_title']] = row._mapping['prop_value']
+    return properties
+
+@router.get("/graph_data", response_model=StandardResponse)
+async def get_graph_data(
+    label_name: Optional[str] = Query(None),
+    input_str: Optional[str] = Query(None),
+    db: Session = Depends(get_db),
+    sess:SessionValues = Depends(verify_session_id)
+):
+    """
+    获取用户关联的图谱数据
+    - 从session_id获取user_id
+    - 查询DbUserDataRelation获取用户关联的数据
+    - 返回与Java端一致的数据结构
+    """
+    try:
+        # 1. 从session获取user_id
+        user_id = sess.user_id
+        if not user_id:
+            return StandardResponse(code=FAILED, message="user not found", records=[])
+        
+        # 2. 使用JOIN查询用户关联的图谱数据
+        sql = text("""
+        WITH RankedRelations AS (
+            SELECT 
+                e.name as rType, 
+                m.id as target_id, 
+                m.name as target_name,
+                m.category as target_label,
+                (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
+                ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
+            FROM user_data_relations udr
+            JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
+            JOIN kg_edges e ON n.id = e.src_id
+            JOIN kg_nodes m ON e.dest_id = m.id
+            WHERE udr.user_id = :user_id
+            AND n.category = :label_name
+            AND n.name = :input_str
+            AND n.status = '0'
+        )
+        SELECT rType, target_id, target_name, target_label, pCount
+        FROM RankedRelations
+        WHERE rn <= 50
+        ORDER BY rType
+        """)
         
         
+        # 3. 组装返回数据结构
+        categories = ["中心词", "关系"]
+        nodes = []
+        links = []
+        
+        # 查询中心节点
+        center_sql = text("""
+        SELECT n.id, n.name, n.category
+        FROM user_data_relations udr
+        JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
+        WHERE udr.user_id = :user_id
+        AND n.category = :label_name
+        AND n.name = :input_str
+        AND n.status = '0'
+        """)
+        
+        # 执行查询并处理结果
+        center_node = None
+        rtype_map = {}
+         
+        # 1. 查询中心节点
+        center_result = db.execute(center_sql, {
+             'user_id': user_id,
+             'label_name': label_name,
+             'input_str': input_str
+        }).fetchall()
+         
+        if center_result:
+             for row in center_result:
+                 center_node = {
+                     "id": 0,
+                     "name": row._mapping['name'],
+                     "label": row._mapping['category'],
+                     "symbolSize": 50,
+                     "symbol": "circle",
+                     "properties": await get_node_properties(row._mapping['id'], db)
+                 }
+                 nodes.append(center_node)
+                 break
+         
+        # 2. 查询关联的边和目标节点
+        relation_result = db.execute(sql, {
+             'user_id': user_id,
+             'label_name': label_name,
+             'input_str': input_str
+        }).fetchall()
+         
+        if relation_result:
+             for row in relation_result:
+                 r_type = row._mapping['rtype']
+                 target_node = {
+                     "id": row._mapping['target_id'],
+                     "name": row._mapping['target_name'],
+                     "label": row._mapping['target_label'],
+                     "symbolSize": 28,
+                     "symbol": "circle",
+                     "properties": await get_node_properties(row._mapping['target_id'], db)
+                 }
+                 
+                 if r_type not in rtype_map:
+                     rtype_map[r_type] = []
+                 rtype_map[r_type].append(target_node)
+                 
+                 if r_type not in categories:
+                     categories.append(r_type)
+         
+         # 3. 组装返回结果
+        if center_node:
+             for r_type, targets in rtype_map.items():
+                 # 添加关系节点
+                 relation_node = {
+                     "id": len(nodes),
+                     "name": "",
+                     "label": center_node['label'],
+                     "symbolSize": 10,
+                     "symbol": "diamond",
+                     "properties": {}
+                 }
+                 nodes.append(relation_node)
+                 
+                 # 添加中心节点到关系节点的链接
+                 links.append({
+                     "source": center_node['name'],
+                     "target": "",
+                     "name": r_type,
+                     "category": r_type
+                 })
+                 
+                 # 添加关系节点到目标节点的链接
+                 for target in targets:
+                     links.append({
+                         "source": "",
+                         "target": target['name'],
+                         "name": "",
+                         "category": r_type
+                     })
+         
+        final_data={
+                "categories": categories,
+                "nodes": nodes,
+                "links": links
+        }
+        return StandardResponse(
+            records=[{"records": final_data}],
+            message="Graph data retrieved"
+        )
+    except Exception as e:
+        return StandardResponse(
+            code=500,
+            message=str(e)
+        )
+
         
         
         
         
 kb_router = router
 kb_router = router

+ 31 - 2
executor/job_script/standard_kb_build.py

@@ -51,7 +51,30 @@ def parse_json(data):
                 entity.append("")
                 entity.append("")
                 
                 
 def import_entities(graph_id, entities_list, relations_list):
 def import_entities(graph_id, entities_list, relations_list):
-
+    from agent.libs.user_data_relation import UserDataRelationBusiness
+    from agent.models.db.user import User, Role
+    from agent.libs.user import UserBusiness
+    
+    # 获取job信息
+    job = graphBiz.db.query(graphBiz.DbJob).filter(graphBiz.DbJob.id == graph_id).first()
+    if not job:
+        logger.error(f"Job not found with id: {graph_id}")
+        return entities
+        
+    # 从job_creator中提取user_id
+    user_id = int(job.job_creator.split('/')[1])
+    
+    # 获取用户角色
+    user_biz = UserBusiness(graphBiz.db)
+    user = user_biz.get_user(user_id)
+    if not user or not user.roles:
+        logger.error(f"User {user_id} has no roles assigned")
+        return entities
+    role_id = user.roles[0].id
+    
+    # 创建用户数据关系业务对象
+    relation_biz = UserDataRelationBusiness(graphBiz.db)
+    
     for text, ent in entities_list.items():
     for text, ent in entities_list.items():
         id = ent['id']
         id = ent['id']
         name = ent['name']
         name = ent['name']
@@ -63,6 +86,9 @@ def import_entities(graph_id, entities_list, relations_list):
         node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
         node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
         if node:
         if node:
             ent["db_id"] = node.id
             ent["db_id"] = node.id
+            # 创建节点数据关联
+            relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id)
+            
     for text, relations in relations_list.items():
     for text, relations in relations_list.items():
         source_name = relations['source_name']
         source_name = relations['source_name']
         source_type = relations['source_type']
         source_type = relations['source_type']
@@ -71,7 +97,7 @@ def import_entities(graph_id, entities_list, relations_list):
         relation_type = relations['type']
         relation_type = relations['type']
         source_db_id = entities_list[source_name]['db_id']
         source_db_id = entities_list[source_name]['db_id']
         target_db_id = entities_list[target_name]['db_id']
         target_db_id = entities_list[target_name]['db_id']
-        graphBiz.create_edge(graph_id=graph_id, 
+        edge = graphBiz.create_edge(graph_id=graph_id, 
                              src_id=source_db_id, 
                              src_id=source_db_id, 
                              dest_id=target_db_id, 
                              dest_id=target_db_id, 
                              name=relation_type, 
                              name=relation_type, 
@@ -81,6 +107,9 @@ def import_entities(graph_id, entities_list, relations_list):
                                  "dest_type":target_type,
                                  "dest_type":target_type,
                              })
                              })
         logger.info(f"create edge: {source_db_id}->{target_db_id}")
         logger.info(f"create edge: {source_db_id}->{target_db_id}")
+        # 创建边数据关联
+        if edge:
+            relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id)
         
         
     return entities
     return entities
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 1 - 1
executor/main.py

@@ -46,7 +46,7 @@ SCRIPT_CONFIG = {
         'command': "python",  # 脚本路径
         'command': "python",  # 脚本路径
        'script':'standard_txt_chunk.py',
        'script':'standard_txt_chunk.py',
         'args': [],  # 脚本参数
         'args': [],  # 脚本参数
-       'success': { 'queue_category': 'SYSTEM', 'queue_name':'CHUNKS'},
+       'success': { 'queue_category': 'SYSTEM', 'queue_name':'KB_EXTRACT'},
         'failed': { 'queue_category': 'SYSTEM', 'queue_name': 'CHUNKS'},
         'failed': { 'queue_category': 'SYSTEM', 'queue_name': 'CHUNKS'},
         'error': { 'queue_category': 'SYSTEM', 'queue_name': 'CHUNKS'}
         'error': { 'queue_category': 'SYSTEM', 'queue_name': 'CHUNKS'}
         },
         },

+ 17 - 1
graph/db/models.py

@@ -177,6 +177,22 @@ class DbKgDataset(Base):
     updated = Column(DateTime, nullable=False)
     updated = Column(DateTime, nullable=False)
     status = Column(Integer, default=0)
     status = Column(Integer, default=0)
 
 
+class DbUserDataRelation(Base):
+    __tablename__ = "user_data_relations"
+    
+    id = Column(Integer, primary_key=True, index=True)
+    user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
+    data_category = Column(String(255), nullable=False)  # 表名
+    data_id = Column(Integer, nullable=False)  # 表中数据ID
+    role_id = Column(Integer, nullable=False)  # 用户角色ID
+    created = Column(DateTime, nullable=False)
+    updated = Column(DateTime, nullable=False)
+    
+    user = relationship("DbUsers", back_populates="data_relations")
+
+# 在DbUsers中添加反向关系
+DbUsers.data_relations = relationship("DbUserDataRelation", back_populates="user")
+
 __all__=['DbKgEdge','DbKgNode','DbKgProp','DbDictICD','DbDictDRG',
 __all__=['DbKgEdge','DbKgNode','DbKgProp','DbDictICD','DbDictDRG',
          'DbDictDrug','DbKgSchemas','DbKgSubGraph','DbKgModels',
          'DbDictDrug','DbKgSchemas','DbKgSubGraph','DbKgModels',
-         'DbKgGraphs', 'DbKgDataset']
+         'DbKgGraphs', 'DbKgDataset', 'DbUserDataRelation']

+ 40 - 0
openapi.yaml

@@ -482,6 +482,46 @@ paths:
             application/json:
             application/json:
               schema:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
                 $ref: '#/components/schemas/HTTPValidationError'
+  /kb/graph_data:
+    get:
+      tags:
+      - knowledge build interface
+      summary: Get Graph Data
+      description: "\u83B7\u53D6\u7528\u6237\u5173\u8054\u7684\u56FE\u8C31\u6570\u636E\
+        \n- \u4ECEsession_id\u83B7\u53D6user_id\n- \u67E5\u8BE2DbUserDataRelation\u83B7\
+        \u53D6\u7528\u6237\u5173\u8054\u7684\u6570\u636E\n- \u8FD4\u56DE\u4E0EJava\u7AEF\
+        \u4E00\u81F4\u7684\u6570\u636E\u7ED3\u6784"
+      operationId: get_graph_data_kb_graph_data_get
+      parameters:
+      - name: label_name
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: string
+          - type: 'null'
+          title: Label Name
+      - name: input_str
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: string
+          - type: 'null'
+          title: Input Str
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/StandardResponse'
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
   /knowledge-base/:
   /knowledge-base/:
     post:
     post:
       tags:
       tags: