Browse Source

查询/保存知识库及其文件时,增加用户关联信息。

chenbin 1 month ago
parent
commit
b2a7272680

+ 8 - 2
agent/libs/user_data_relation.py

@@ -6,13 +6,15 @@ class UserDataRelationBusiness:
     def __init__(self, db: Session):
         self.db = db
     
-    def create_relation(self, user_id: int, data_category: str, data_id: int, role_id: int):
+    def create_relation(self, user_id: int, data_category: str, data_id: int, role_id: int, user_name: str, role_name: str):
         """
         创建用户数据关联
         :param user_id: 用户ID
         :param data_category: 数据类别(表名)
         :param data_id: 数据ID
         :param role_id: 角色ID
+        :param user_name: 用户名称
+        :param role_name: 角色名称
         :return: DbUserDataRelation对象
         """
         relation = DbUserDataRelation(
@@ -20,6 +22,8 @@ class UserDataRelationBusiness:
             data_category=data_category,
             data_id=data_id,
             role_id=role_id,
+            user_name=user_name,
+            role_name=role_name,
             created=datetime.now(),
             updated=datetime.now()
         )
@@ -56,16 +60,18 @@ class UserDataRelationBusiness:
             DbUserDataRelation.data_id == data_id
         ).all()
     
-    def update_relation(self, relation_id: int, role_id: int):
+    def update_relation(self, relation_id: int, role_id: int, role_name: str):
         """
         更新关联关系的角色ID
         :param relation_id: 关联ID
         :param role_id: 新的角色ID
+        :param role_name: 新的角色名
         :return: 更新后的DbUserDataRelation对象
         """
         relation = self.get_relation(relation_id)
         if relation:
             relation.role_id = role_id
+            relation.role_name = role_name
             relation.updated = datetime.now()
             self.db.commit()
             self.db.refresh(relation)

+ 15 - 1
agent/models/db/graph.py

@@ -193,6 +193,20 @@ class DbKgDataset(Base):
     updated = Column(DateTime, nullable=False)
     status = Column(Integer, default=0)
 
+class DbUserDataRelation(Base):
+    __tablename__ = "user_data_relations"
+    
+    id = Column(Integer, primary_key=True, index=True)
+    user_id = Column(Integer, nullable=False)
+    data_category = Column(String(64), nullable=False)
+    data_id = Column(Integer, nullable=False)
+    role_id = Column(Integer)
+    user_name = Column(String(64))
+    role_name = Column(String(64))
+    created = Column(DateTime, nullable=False)
+    updated = Column(DateTime, nullable=False)
+
 __all__=['DbKgEdge','DbKgNode','DbKgProp','DbKgEdgeProp','DbDictICD','DbDictDRG',
          'DbDictDrug','DbKgSchemas','DbKgSubGraph','DbKgModels',
-         'DbKgGraphs', 'DbKgDataset']
+         'DbKgGraphs', 'DbKgDataset', 'DbUserDataRelation']
+

+ 81 - 14
agent/router/knowledge_base_router.py

@@ -17,10 +17,14 @@ from fastapi.openapi.docs import (
     get_swagger_ui_html,
     get_swagger_ui_oauth2_redirect_html,
 )
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, and_
 from sqlalchemy.orm import Session, sessionmaker
 from sqlalchemy.ext.declarative import declarative_base
+from agent.models.db.graph import DbUserDataRelation as UserDataRelation
 from pydantic import BaseModel, ConfigDict, Field, field_serializer
+
+from agent.libs.auth import SessionValues, verify_session_id
+from agent.libs.user_data_relation import UserDataRelationBusiness
 from agent.models.web.knowledge_base import Base, KnowledgeBase, KnowledgeFile
 from agent.utils import DatabaseUtils, MinioUtils, FileUtils
 from config.site import settings
@@ -41,6 +45,7 @@ class KnowledgeBaseResponse(BaseModel):
     description: Optional[str] = None
     tags: Optional[str] = None
     creator: Optional[str] = None
+    user_name: Optional[str] = None  # 新增字段
     file_count: int = 0
     created_at: datetime = Field(default_factory=datetime.utcnow)
     updated_at: datetime = Field(default_factory=datetime.utcnow)
@@ -59,6 +64,7 @@ class KnowledgeFileResponse(BaseModel):
     file_size: float
     file_type: str
     minio_url: str
+    user_name: Optional[str] = None  # 用户名
     version: Optional[str] = None
     author: Optional[str] = None
     year: Optional[int] = None
@@ -146,8 +152,26 @@ class BatchFileUpdate(BaseModel):
 # 使用utils.py中的FileUtils类进行文件转换
 
 @router.post("/knowledge-base/", response_model=ResponseModel)
-def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db)):
-    kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, kb.description, kb.tags)
+def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db),
+    sess:SessionValues = Depends(verify_session_id)):
+    # 1. 从session获取user_id
+    user_id = sess.user_id
+    user_name = sess.username 
+    
+    # 2. 创建知识库
+    kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, user_id, kb.description, kb.tags)
+    
+    # 3. 创建用户数据关联
+    relation_business = UserDataRelationBusiness(db)
+    relation = relation_business.create_relation(
+        user_id=user_id,
+        data_category='KnowledgeBase',
+        data_id=kb_data.id,
+        user_name=user_name,
+        role_id=None,
+        role_name=None
+    )
+    
     return ResponseModel(
         code=200,
         message="创建成功",
@@ -180,10 +204,12 @@ def get_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
     kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
     if not kb:
         raise HTTPException(status_code=404, detail="知识库不存在")
+    
+    kb_data = KnowledgeBaseResponse.model_validate(kb).model_dump()
     return ResponseModel(
         code=200,
         message="查询成功",
-        data=KnowledgeBaseResponse.model_validate(kb).model_dump()
+        data=kb_data
     )
 
 
@@ -222,13 +248,18 @@ def get_knowledge_base_by_name(name: str, db: Session = Depends(get_db)):
 async def upload_files(
         kb_id: int,
         files: List[UploadFile] = File(...),
-        db: Session = Depends(get_db)
+        db: Session = Depends(get_db),
+        sess: SessionValues = Depends(verify_session_id)  # 添加session依赖
 ):
     # 验证知识库是否存在
     kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
     if not kb:
         raise HTTPException(status_code=404, detail="知识库不存在")
 
+    # 获取当前用户信息
+    user_id = sess.user_id
+    user_name = sess.username
+    
     # 验证文件数量
     if len(files) > settings.MAX_FILE_COUNT:
         raise HTTPException(status_code=400, detail=f"单次上传文件数量不能超过{settings.MAX_FILE_COUNT}个")
@@ -303,11 +334,23 @@ async def upload_files(
             file_size=file_size,
             file_type=file_ext,
             minio_url=minio_url,
-            creator=kb.creator,  # 使用知识库的创建人作为文件创建人
+            creator=user_id,
             knowledge_type=knowledge_type
         )
         db.add(db_file)
         uploaded_files.append(db_file)
+        
+        # 创建用户数据关联
+        relation_business = UserDataRelationBusiness(db)
+        relation = relation_business.create_relation(
+            user_id=user_id,
+            data_category='KnowledgeFile',
+            data_id=db_file.id,
+            user_name=user_name,
+            role_id=None,
+            role_name=None
+        )
+        
         # 更新知识库文件计数
         DatabaseUtils.increment_file_count(db, kb_id)
 
@@ -328,7 +371,13 @@ def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optio
         raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
 
     skip = (pageNo - 1) * pageSize
-    query = db.query(KnowledgeFile).filter(
+    query = db.query(KnowledgeFile,UserDataRelation.user_name).\
+        outerjoin(UserDataRelation,
+             and_(
+                 UserDataRelation.data_id == KnowledgeFile.id,
+                 UserDataRelation.data_category == 'KnowledgeFile'
+             )).\
+        filter(
         KnowledgeFile.knowledge_base_id == kb_id,
         KnowledgeFile.is_deleted == 0
     )
@@ -343,7 +392,12 @@ def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optio
         code=200,
         message="查询成功",
         data={
-            "list": [KnowledgeFileResponse.model_validate(file).model_dump() for file in files],
+            "list": [KnowledgeFileResponse.model_validate(
+                {
+                    **file[0].__dict__,
+                    "user_name": file[1]
+                }
+            ).model_dump() for file in files],
             "total": total
         }
     )
@@ -351,15 +405,28 @@ def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optio
 
 @router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel)
 def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)):
-    files = db.query(KnowledgeFile).filter(
-        KnowledgeFile.knowledge_base_id == kb_id,
-        KnowledgeFile.file_name.ilike(f"%{file_name}%"),
-        KnowledgeFile.is_deleted == 0
-    ).all()
+    files = db.query(KnowledgeFile, UserDataRelation.user_name).\
+        join(UserDataRelation, 
+             and_(
+                 UserDataRelation.data_id == KnowledgeFile.id,
+                 UserDataRelation.data_category == 'KnowledgeFile'
+             )).\
+        filter(
+            KnowledgeFile.knowledge_base_id == kb_id,
+            KnowledgeFile.file_name.ilike(f"%{file_name}%"),
+            KnowledgeFile.is_deleted == 0
+        ).all()
+    
+    result = []
+    for file, user_name in files:
+        file_response = KnowledgeFileResponse.model_validate(file)
+        file_response.user_name = user_name
+        result.append(file_response.model_dump())
+    
     return ResponseModel(
         code=200,
         message="查询成功",
-        data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in files]
+        data=result
     )
 
 

+ 22 - 2
agent/utils.py

@@ -11,8 +11,10 @@ from typing import List, Optional
 from minio import Minio
 import urllib3
 from sqlalchemy.orm import Session
+from sqlalchemy import and_
 from fastapi import HTTPException
 from agent.models.web.knowledge_base import KnowledgeBase, KnowledgeFile
+from agent.models.db.graph import DbUserDataRelation as UserDataRelation
 from config.site import settings
 
 # 配置Office文件转换日志
@@ -80,11 +82,29 @@ class DatabaseUtils:
 
     @staticmethod
     def get_knowledge_bases(db: Session, skip: int = 0, limit: int = 10, name: Optional[str] = None) -> tuple[List[KnowledgeBase], int]:
-        query = db.query(KnowledgeBase).filter(KnowledgeBase.is_deleted == 0)
+        query = db.query(
+            KnowledgeBase,
+            UserDataRelation.user_name
+        ).outerjoin(
+            UserDataRelation,
+            and_(
+                UserDataRelation.data_category == 'KnowledgeBase',
+                UserDataRelation.data_id == KnowledgeBase.id
+            )
+        ).filter(KnowledgeBase.is_deleted == 0)
+        
         if name:
             query = query.filter(KnowledgeBase.name.ilike(f"%{name}%"))
+            
         total = query.count()
-        knowledge_bases = query.offset(skip).limit(limit).all()
+        results = query.offset(skip).limit(limit).all()
+        
+        # 将user_name赋值给KnowledgeBase对象
+        knowledge_bases = []
+        for kb, user_name in results:
+            kb.user_name = user_name
+            knowledge_bases.append(kb)
+            
         return knowledge_bases, total
 
     @staticmethod

+ 4 - 2
executor/job_script/standard_kb_build.py

@@ -71,6 +71,8 @@ def import_entities(graph_id, entities_list, relations_list):
         logger.error(f"User {user_id} has no roles assigned")
         return entities
     role_id = user.roles[0].id
+    user_name = user.name
+    role_name = user.roles[0].name
     
     # 创建用户数据关系业务对象
     relation_biz = UserDataRelationBusiness(graphBiz.db)
@@ -87,7 +89,7 @@ def import_entities(graph_id, entities_list, relations_list):
         if node:
             ent["db_id"] = node.id
             # 创建节点数据关联
-            relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id)
+            relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id, user_name, role_name)
             
     for text, relations in relations_list.items():
         source_name = relations['source_name']
@@ -109,7 +111,7 @@ def import_entities(graph_id, entities_list, relations_list):
         logger.info(f"create edge: {source_db_id}->{target_db_id}")
         # 创建边数据关联
         if edge:
-            relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id)
+            relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id, user_name, role_name)
         
     return entities
 if __name__ == "__main__":