import json import os import io import logging import tempfile import urllib.parse import time import glob import shutil import subprocess from typing import List, Optional from datetime import datetime from fastapi import APIRouter, FastAPI, Depends, HTTPException, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.openapi.docs import ( get_redoc_html, get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html, ) from sqlalchemy import create_engine, and_ from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.ext.declarative import declarative_base from agent.models.db.graph import DbUserDataRelation as UserDataRelation from pydantic import BaseModel, ConfigDict, Field, field_serializer from agent.libs.auth import SessionValues, verify_session_id from agent.libs.user_data_relation import UserDataRelationBusiness from agent.models.web.knowledge_base import Base, KnowledgeBase, KnowledgeFile from agent.utils import DatabaseUtils, MinioUtils, FileUtils from config.site import settings # 响应模型 class ResponseModel(BaseModel): code: int message: str data: Optional[dict | list | bool | None] class KnowledgeBaseResponse(BaseModel): model_config = ConfigDict(from_attributes=True) id: int name: str description: Optional[str] = None tags: Optional[str] = None creator: Optional[str] = None user_name: Optional[str] = None # 新增字段 file_count: int = 0 created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) @field_serializer('created_at', 'updated_at') def serialize_datetime(self, dt: datetime) -> str: return dt.strftime('%Y-%m-%d') @field_serializer('tags') def serialize_tags(self, tags: str) -> Optional[List[str]]: if tags: return json.loads(tags) return None class KnowledgeFileResponse(BaseModel): model_config = ConfigDict(from_attributes=True) id: int knowledge_base_id: int file_name: str file_size: float file_type: str minio_url: str status: bool = False user_name: Optional[str] = None # 用户名 version: Optional[str] = None author: Optional[str] = None year: Optional[int] = None page_count: Optional[int] = None creator: Optional[str] = None knowledge_type: Optional[str] = None created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) @field_serializer('created_at', 'updated_at') def serialize_datetime(self, dt: datetime) -> str: return dt.strftime('%Y-%m-%d %H:%M') # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 创建数据库引擎 engine = create_engine(settings.DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # 创建数据库表 Base.metadata.create_all(bind=engine) router = APIRouter(tags=["knowledge base interface"]) # logger = logging.getLogger(__name__) # config = SiteConfig() # 初始化MinIO工具类 minio_utils = MinioUtils() # 全局异常处理 # @router.exception_handler(Exception) # async def global_exception_handler(request, exc): # logger.error(f"全局异常: {exc}", exc_info=True) # return { # "code": 500, # "message": "服务器内部错误", # "data": None # } # 依赖项:获取数据库会话 def get_db(): db = SessionLocal() try: yield db finally: db.close() # 请求模型 class KnowledgeBaseCreate(BaseModel): name: str description: Optional[str] = None tags: Optional[List[str]] = Field(default_factory=list) class KnowledgeBaseUpdate(BaseModel): name: str description: Optional[str] = None tags: Optional[List[str]] = Field(default_factory=list) class FileUpdate(BaseModel): id: int file_name: Optional[str] = None version: Optional[str] = None author: Optional[str] = None year: Optional[int] = None page_count: Optional[int] = None creator: Optional[str] = None knowledge_type: Optional[str] = None class BatchFileUpdate(BaseModel): files: List[FileUpdate] # 使用utils.py中的FileUtils类进行文件转换 @router.post("/knowledge-base/", response_model=ResponseModel) def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id)): # 1. 从session获取user_id user_id = sess.user_id user_name = sess.username tags = json.dumps(kb.tags, ensure_ascii=False) # 2. 创建知识库 kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, user_id, kb.description, tags) # 3. 创建用户数据关联 relation_business = UserDataRelationBusiness(db) relation = relation_business.create_relation( user_id=user_id, data_category='KnowledgeBase', data_id=kb_data.id, user_name=user_name, role_id=None, role_name=None ) return ResponseModel( code=200, message="创建成功", data=KnowledgeBaseResponse.model_validate(kb_data).model_dump() ) @router.put("/knowledge-base/{kb_id}", response_model=ResponseModel) def update_knowledge_base(kb_id: int, kb: KnowledgeBaseUpdate, db: Session = Depends(get_db)): tags = json.dumps(kb.tags, ensure_ascii=False) kb_data = DatabaseUtils.update_knowledge_base(db, kb_id, kb.name, kb.description, tags) return ResponseModel( code=200, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb_data).model_dump() ) @router.delete("/knowledge-base/{kb_id}", response_model=ResponseModel) def delete_knowledge_base(kb_id: int, db: Session = Depends(get_db)): result = DatabaseUtils.delete_knowledge_base(db, kb_id) return ResponseModel( code=200, message="删除成功", data=result ) @router.get("/knowledge-base/{kb_id}", response_model=ResponseModel) def get_knowledge_base(kb_id: int, db: Session = Depends(get_db)): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first() if not kb: raise HTTPException(status_code=404, detail="知识库不存在") kb_data = KnowledgeBaseResponse.model_validate(kb).model_dump() return ResponseModel( code=200, message="查询成功", data=kb_data ) @router.get("/knowledge-base/", response_model=ResponseModel) def list_knowledge_bases(pageNo: int = 1, pageSize: int = 10, name: Optional[str] = None, db: Session = Depends(get_db)): if pageNo < 1: raise HTTPException(status_code=400, detail="页码必须大于等于1") if pageSize < 1: raise HTTPException(status_code=400, detail="每页条数必须大于等于1") skip = (pageNo - 1) * pageSize kb_list, total = DatabaseUtils.get_knowledge_bases(db, skip, pageSize, name) return ResponseModel( code=200, message="查询成功", data={ "list": [KnowledgeBaseResponse.model_validate(kb).model_dump() for kb in kb_list], "total": total } ) @router.get("/knowledge-base/name/{name}", response_model=ResponseModel) def get_knowledge_base_by_name(name: str, db: Session = Depends(get_db)): kb = DatabaseUtils.get_knowledge_base_by_name(db, name) if not kb: raise HTTPException(status_code=404, detail="知识库不存在") return ResponseModel( code=200, message="查询成功", data=KnowledgeBaseResponse.model_validate(kb).model_dump() ) @router.post("/knowledge-base/{kb_id}/files/", response_model=ResponseModel) async def upload_files( kb_id: int, files: List[UploadFile] = File(...), db: Session = Depends(get_db), sess: SessionValues = Depends(verify_session_id) # 添加session依赖 ): """ 支持多文件上传和压缩文件解析 :param kb_id: 知识库ID :param files: 上传文件列表(支持压缩文件) :param db: 数据库会话 :param sess: 用户会话 :return: ResponseModel """ # 验证知识库是否存在 kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first() if not kb: raise HTTPException(status_code=404, detail="知识库不存在") # 获取当前用户信息 user_id = sess.user_id user_name = sess.username # 验证文件数量 if len(files) > settings.MAX_FILE_COUNT: raise HTTPException(status_code=400, detail=f"单次上传文件数量不能超过{settings.MAX_FILE_COUNT}个") # 导入所需模块 # import tempfile import zipfile import py7zr import rarfile import tarfile import shutil uploaded_files = [] for file in files: # 处理压缩文件 if file.filename.lower().endswith(('.zip', '.rar', '.tar', '.7z')): # 创建临时目录用于解压 with tempfile.TemporaryDirectory() as temp_dir: file_content = await file.read() file_path = os.path.join(temp_dir, file.filename) with open(file_path, "wb") as f: f.write(file_content) extract_path = os.path.join(temp_dir, "extracted") os.makedirs(extract_path, exist_ok=True) if file.filename.lower().endswith('.zip'): with zipfile.ZipFile(file_path, 'r') as zip_ref: zip_ref.extractall(extract_path) elif file.filename.lower().endswith('.7z'): with py7zr.SevenZipFile(file_path, mode='r') as zip_ref: zip_ref.extractall(extract_path) elif file.filename.lower().endswith('.rar'): with rarfile.RarFile(file_path) as zip_ref: zip_ref.extractall(extract_path) elif file.filename.lower().endswith('.tar'): with tarfile.open(file_path, 'r') as tar: tar.extractall(path=extract_path) # 处理解压后的文件 for extracted_file in os.listdir(extract_path): extracted_file_path = os.path.join(extract_path, extracted_file) if os.path.isfile(extracted_file_path): # 为每个解压文件创建新的UploadFile对象 with open(extracted_file_path, "rb") as f: content = f.read() # 处理中文文件名编码问题 try: decoded_filename = extracted_file.encode('cp437').decode('gbk') except: decoded_filename = extracted_file extracted_file_obj = UploadFile( filename=decoded_filename, file=io.BytesIO(content), size=len(content) ) # 递归处理解压后的文件 result = await process_single_file(extracted_file_obj, kb_id, db, user_id, user_name) uploaded_files.extend(result) continue # 处理普通文件 result = await process_single_file(file, kb_id, db, user_id, user_name) uploaded_files.extend(result) db.commit() return ResponseModel( code=200, message=f"成功上传{len(uploaded_files)}个文件", data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in uploaded_files] ) async def process_single_file(file: UploadFile, kb_id: int, db: Session, user_id: int, user_name: str): """处理单个文件的上传逻辑""" # 获取文件扩展名 file_ext = os.path.splitext(file.filename)[1].lower().lstrip('.') original_filename = file.filename converted_content = None # 读取文件内容 content = await file.read() # 处理需要转换的文件格式 if file_ext in ["doc", "ppt"]: # 创建临时目录用于文件转换 with tempfile.TemporaryDirectory() as temp_dir: # 创建临时文件 temp_input_path = os.path.join(temp_dir, original_filename) with open(temp_input_path, "wb") as temp_file: temp_file.write(content) # 确定目标格式 target_format = "docx" if file_ext == "doc" else "pptx" # 使用FileUtils中的文件转换方法 converted_file_path = FileUtils.convert_office_file(temp_input_path, temp_dir, target_format) if converted_file_path and os.path.exists(converted_file_path): # 读取转换后的文件内容 with open(converted_file_path, "rb") as converted_file: converted_content = converted_file.read() # 更新文件名和扩展名 file_ext = target_format file.filename = os.path.splitext(original_filename)[0] + f".{target_format}" else: # 转换失败,使用原始文件 raise HTTPException(status_code=500, detail=f"文件格式转换失败: {original_filename}") # 检查文件格式是否支持 if file_ext not in settings.ALLOWED_EXTENSIONS: raise HTTPException(status_code=400, detail=f"不支持的文件格式:{file_ext}") # 使用转换后的内容或原始内容 file_content = converted_content if converted_content else content file_size = len(file_content) / (1024 * 1024) # 转换为MB # 验证文件大小 max_size = settings.ALLOWED_EXTENSIONS[file_ext]["max_size"] if file_size > max_size: raise HTTPException(status_code=400, detail=f"{original_filename}超过最大允许大小{max_size}MB") # 上传到MinIO minio_url = minio_utils.upload_file(file_content, file.filename, file.content_type) # 从文件名识别知识类型 knowledge_type = None if '指南' in file.filename: knowledge_type = '指南' elif '教材' in file.filename: knowledge_type = '教材' # 创建文件记录 db_file = KnowledgeFile( knowledge_base_id=kb_id, file_name=file.filename, file_size=file_size, file_type=file_ext, minio_url=minio_url, creator=user_id, knowledge_type=knowledge_type ) db.add(db_file) # 创建用户数据关联 relation_business = UserDataRelationBusiness(db) relation = relation_business.create_relation( user_id=user_id, data_category='KnowledgeFile', data_id=db_file.id, user_name=user_name, role_id=None, role_name=None ) # 更新知识库文件计数 DatabaseUtils.increment_file_count(db, kb_id) return [db_file] @router.get("/knowledge-base/{kb_id}/files/", response_model=ResponseModel) def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optional[str] = None, db: Session = Depends(get_db)): if pageNo < 1: raise HTTPException(status_code=400, detail="页码必须大于等于1") if pageSize < 1: raise HTTPException(status_code=400, detail="每页条数必须大于等于1") skip = (pageNo - 1) * pageSize query = db.query(KnowledgeFile,UserDataRelation.user_name).\ outerjoin(UserDataRelation, and_( UserDataRelation.data_id == KnowledgeFile.id, UserDataRelation.data_category == 'KnowledgeFile' )).\ filter( KnowledgeFile.knowledge_base_id == kb_id, KnowledgeFile.is_deleted == 0 ).order_by(KnowledgeFile.status.desc()) if file_name: query = query.filter(KnowledgeFile.file_name.ilike(f"%{file_name}%")) total = query.count() files = query.offset(skip).limit(pageSize).all() return ResponseModel( code=200, message="查询成功", data={ "list": [KnowledgeFileResponse.model_validate( { **file[0].__dict__, "user_name": file[1] } ).model_dump() for file in files], "total": total } ) @router.get("/knowledge-base/{kb_id}/files/enable", response_model=ResponseModel) def list_files_enable(kb_id: int, status: int = 0, db: Session = Depends(get_db)): query = db.query(KnowledgeFile,UserDataRelation.user_name).\ outerjoin(UserDataRelation, and_( UserDataRelation.data_id == KnowledgeFile.id, UserDataRelation.data_category == 'KnowledgeFile' )).\ filter( KnowledgeFile.knowledge_base_id == kb_id, KnowledgeFile.status == status, KnowledgeFile.is_deleted == 0 ).order_by(KnowledgeFile.status.desc()) total = query.count() files = query.all() return ResponseModel( code=200, message="查询成功", data={ "list": [KnowledgeFileResponse.model_validate( { **file[0].__dict__, "user_name": file[1] } ).model_dump() for file in files], "total": total } ) @router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel) def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)): files = db.query(KnowledgeFile, UserDataRelation.user_name).\ join(UserDataRelation, and_( UserDataRelation.data_id == KnowledgeFile.id, UserDataRelation.data_category == 'KnowledgeFile' )).\ filter( KnowledgeFile.knowledge_base_id == kb_id, KnowledgeFile.file_name.ilike(f"%{file_name}%"), KnowledgeFile.is_deleted == 0 ).all() result = [] for file, user_name in files: file_response = KnowledgeFileResponse.model_validate(file) file_response.user_name = user_name result.append(file_response.model_dump()) return ResponseModel( code=200, message="查询成功", data=result ) @router.get("/files/{file_id}/download") def download_file(file_id: int, db: Session = Depends(get_db)): # 获取文件信息 file = db.query(KnowledgeFile).filter( KnowledgeFile.id == file_id, KnowledgeFile.is_deleted == 0 ).first() if not file: raise HTTPException(status_code=404, detail="文件不存在") # 从MinIO下载文件 object_name = file.minio_url.split("/")[-1] file_content = minio_utils.download_file(object_name) # 创建文件流 file_stream = io.BytesIO(file_content) # 对文件名进行URL编码 encoded_filename = urllib.parse.quote(file.file_name) return StreamingResponse( file_stream, media_type="application/octet-stream", headers={ "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" } ) @router.delete("/files/{file_id}", response_model=dict) def delete_file(file_id: int, db: Session = Depends(get_db)): # 获取文件信息 file = db.query(KnowledgeFile).filter( KnowledgeFile.id == file_id, KnowledgeFile.is_deleted == 0 ).first() if not file: raise HTTPException(status_code=404, detail="文件不存在") # 从MinIO删除文件 object_name = file.minio_url.split("/")[-1] minio_utils.delete_file(object_name) # 标记文件为已删除 file.is_deleted = 1 file.updated_at = datetime.utcnow() # 更新知识库文件计数 DatabaseUtils.decrement_file_count(db, file.knowledge_base_id) db.commit() return { "code": 200, "message": "删除成功", "data": True } @router.get("/files/{file_id}/changeStatus", response_model=dict) def change_file_status(file_id: int, status: bool, db: Session = Depends(get_db)): # 获取文件信息 file = db.query(KnowledgeFile).filter( KnowledgeFile.id == file_id, KnowledgeFile.is_deleted == 0 ).first() if not file: raise HTTPException(status_code=404, detail="文件不存在") # 标记文件停用状态 if status: file.status = 1 else: file.status = 0 file.updated_at = datetime.utcnow() db.commit() return { "code": 200, "message": "修改成功", "data": True } @router.put("/files/batch-update", response_model=ResponseModel) def batch_update_files(update_data: BatchFileUpdate, db: Session = Depends(get_db)): updated_files = [] for file_update in update_data.files: # 获取文件信息 file = db.query(KnowledgeFile).filter( KnowledgeFile.id == file_update.id, KnowledgeFile.is_deleted == 0 ).first() if not file: raise HTTPException(status_code=404, detail=f"文件ID {file_update.id} 不存在") # 如果需要更新文件名,同时更新MinIO中的文件 if file_update.file_name and file_update.file_name != file.file_name: old_object_name = file.minio_url.split("/")[-1] new_object_name = f"{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{file_update.file_name}" # 从MinIO下载文件 file_content = minio_utils.download_file(old_object_name) # 上传到MinIO新的位置 new_minio_url = minio_utils.upload_file( file_content, file_update.file_name, file.file_type ) # 删除旧文件 minio_utils.delete_file(old_object_name) # 更新数据库中的文件名和MinIO URL file.file_name = file_update.file_name file.minio_url = new_minio_url # 更新其他字段 if file_update.version is not None: file.version = file_update.version if file_update.author is not None: file.author = file_update.author if file_update.year is not None: file.year = file_update.year if file_update.page_count is not None: file.page_count = file_update.page_count if file_update.creator is not None: file.creator = file_update.creator if file_update.knowledge_type is not None: file.knowledge_type = file_update.knowledge_type file.updated_at = datetime.utcnow() updated_files.append(file) db.commit() return ResponseModel( code=200, message="更新成功", data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in updated_files] ) knowledge_base_router = router