|
@@ -0,0 +1,481 @@
|
|
|
+import os
|
|
|
+import io
|
|
|
+import logging
|
|
|
+import urllib.parse
|
|
|
+import time
|
|
|
+import glob
|
|
|
+import shutil
|
|
|
+import subprocess
|
|
|
+from typing import List, Optional
|
|
|
+from datetime import datetime
|
|
|
+from fastapi import APIRouter, FastAPI, Depends, HTTPException, UploadFile, File, Form
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
+from fastapi.staticfiles import StaticFiles
|
|
|
+from fastapi.openapi.docs import (
|
|
|
+ get_redoc_html,
|
|
|
+ get_swagger_ui_html,
|
|
|
+ get_swagger_ui_oauth2_redirect_html,
|
|
|
+)
|
|
|
+from sqlalchemy import create_engine
|
|
|
+from sqlalchemy.orm import Session, sessionmaker
|
|
|
+from sqlalchemy.ext.declarative import declarative_base
|
|
|
+from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
|
|
+from agent.models.web.knowledge_base import Base, KnowledgeBase, KnowledgeFile
|
|
|
+from agent.utils import DatabaseUtils, MinioUtils, FileUtils
|
|
|
+from config.site import settings
|
|
|
+
|
|
|
+
|
|
|
+# 响应模型
|
|
|
+class ResponseModel(BaseModel):
|
|
|
+ code: int
|
|
|
+ message: str
|
|
|
+ data: Optional[dict | list | bool | None]
|
|
|
+
|
|
|
+
|
|
|
+class KnowledgeBaseResponse(BaseModel):
|
|
|
+ model_config = ConfigDict(from_attributes=True)
|
|
|
+
|
|
|
+ id: int
|
|
|
+ name: str
|
|
|
+ description: Optional[str] = None
|
|
|
+ tags: Optional[str] = None
|
|
|
+ creator: Optional[str] = None
|
|
|
+ file_count: int = 0
|
|
|
+ created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
+
|
|
|
+ @field_serializer('created_at', 'updated_at')
|
|
|
+ def serialize_datetime(self, dt: datetime) -> str:
|
|
|
+ return dt.strftime('%Y-%m-%d')
|
|
|
+
|
|
|
+
|
|
|
+class KnowledgeFileResponse(BaseModel):
|
|
|
+ model_config = ConfigDict(from_attributes=True)
|
|
|
+
|
|
|
+ id: int
|
|
|
+ knowledge_base_id: int
|
|
|
+ file_name: str
|
|
|
+ file_size: float
|
|
|
+ file_type: str
|
|
|
+ minio_url: str
|
|
|
+ version: Optional[str] = None
|
|
|
+ author: Optional[str] = None
|
|
|
+ year: Optional[int] = None
|
|
|
+ page_count: Optional[int] = None
|
|
|
+ creator: Optional[str] = None
|
|
|
+ knowledge_type: Optional[str] = None
|
|
|
+ created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
+
|
|
|
+ @field_serializer('created_at', 'updated_at')
|
|
|
+ def serialize_datetime(self, dt: datetime) -> str:
|
|
|
+ return dt.strftime('%Y-%m-%d %H:%M')
|
|
|
+
|
|
|
+
|
|
|
+# 配置日志
|
|
|
+logging.basicConfig(
|
|
|
+ level=logging.INFO,
|
|
|
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
+)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+# 创建数据库引擎
|
|
|
+engine = create_engine(settings.DATABASE_URL)
|
|
|
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
+
|
|
|
+# 创建数据库表
|
|
|
+Base.metadata.create_all(bind=engine)
|
|
|
+
|
|
|
+router = APIRouter(tags=["knowledge base interface"])
|
|
|
+# logger = logging.getLogger(__name__)
|
|
|
+# config = SiteConfig()
|
|
|
+
|
|
|
+# 初始化MinIO工具类
|
|
|
+minio_utils = MinioUtils()
|
|
|
+
|
|
|
+
|
|
|
+# 全局异常处理
|
|
|
+# @router.exception_handler(Exception)
|
|
|
+# async def global_exception_handler(request, exc):
|
|
|
+# logger.error(f"全局异常: {exc}", exc_info=True)
|
|
|
+# return {
|
|
|
+# "code": 500,
|
|
|
+# "message": "服务器内部错误",
|
|
|
+# "data": None
|
|
|
+# }
|
|
|
+
|
|
|
+
|
|
|
+# 依赖项:获取数据库会话
|
|
|
+def get_db():
|
|
|
+ db = SessionLocal()
|
|
|
+ try:
|
|
|
+ yield db
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+
|
|
|
+# 请求模型
|
|
|
+class KnowledgeBaseCreate(BaseModel):
|
|
|
+ name: str
|
|
|
+ description: Optional[str] = None
|
|
|
+ tags: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+class KnowledgeBaseUpdate(BaseModel):
|
|
|
+ name: str
|
|
|
+ description: Optional[str] = None
|
|
|
+ tags: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+class FileUpdate(BaseModel):
|
|
|
+ id: int
|
|
|
+ file_name: Optional[str] = None
|
|
|
+ version: Optional[str] = None
|
|
|
+ author: Optional[str] = None
|
|
|
+ year: Optional[int] = None
|
|
|
+ page_count: Optional[int] = None
|
|
|
+ creator: Optional[str] = None
|
|
|
+ knowledge_type: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+class BatchFileUpdate(BaseModel):
|
|
|
+ files: List[FileUpdate]
|
|
|
+
|
|
|
+
|
|
|
+# 使用utils.py中的FileUtils类进行文件转换
|
|
|
+
|
|
|
+@router.post("/knowledge-base/", response_model=ResponseModel)
|
|
|
+def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db)):
|
|
|
+ kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, kb.description, kb.tags)
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="创建成功",
|
|
|
+ data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.put("/knowledge-base/{kb_id}", response_model=ResponseModel)
|
|
|
+def update_knowledge_base(kb_id: int, kb: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
|
|
|
+ kb_data = DatabaseUtils.update_knowledge_base(db, kb_id, kb.name, kb.description, kb.tags)
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="更新成功",
|
|
|
+ data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.delete("/knowledge-base/{kb_id}", response_model=ResponseModel)
|
|
|
+def delete_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
|
|
|
+ result = DatabaseUtils.delete_knowledge_base(db, kb_id)
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="删除成功",
|
|
|
+ data=result
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/knowledge-base/{kb_id}", response_model=ResponseModel)
|
|
|
+def get_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
|
|
|
+ kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
|
|
|
+ if not kb:
|
|
|
+ raise HTTPException(status_code=404, detail="知识库不存在")
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="查询成功",
|
|
|
+ data=KnowledgeBaseResponse.model_validate(kb).model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/knowledge-base/", response_model=ResponseModel)
|
|
|
+def list_knowledge_bases(pageNo: int = 1, pageSize: int = 10, name: Optional[str] = None,
|
|
|
+ db: Session = Depends(get_db)):
|
|
|
+ if pageNo < 1:
|
|
|
+ raise HTTPException(status_code=400, detail="页码必须大于等于1")
|
|
|
+ if pageSize < 1:
|
|
|
+ raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
|
|
|
+ skip = (pageNo - 1) * pageSize
|
|
|
+ kb_list, total = DatabaseUtils.get_knowledge_bases(db, skip, pageSize, name)
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="查询成功",
|
|
|
+ data={
|
|
|
+ "list": [KnowledgeBaseResponse.model_validate(kb).model_dump() for kb in kb_list],
|
|
|
+ "total": total
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/knowledge-base/name/{name}", response_model=ResponseModel)
|
|
|
+def get_knowledge_base_by_name(name: str, db: Session = Depends(get_db)):
|
|
|
+ kb = DatabaseUtils.get_knowledge_base_by_name(db, name)
|
|
|
+ if not kb:
|
|
|
+ raise HTTPException(status_code=404, detail="知识库不存在")
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="查询成功",
|
|
|
+ data=KnowledgeBaseResponse.model_validate(kb).model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
|
|
|
+async def upload_files(
|
|
|
+ kb_id: int,
|
|
|
+ files: List[UploadFile] = File(...),
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ # 验证知识库是否存在
|
|
|
+ kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
|
|
|
+ if not kb:
|
|
|
+ raise HTTPException(status_code=404, detail="知识库不存在")
|
|
|
+
|
|
|
+ # 验证文件数量
|
|
|
+ if len(files) > settings.MAX_FILE_COUNT:
|
|
|
+ raise HTTPException(status_code=400, detail=f"单次上传文件数量不能超过{settings.MAX_FILE_COUNT}个")
|
|
|
+
|
|
|
+ # 导入所需模块
|
|
|
+ import tempfile
|
|
|
+
|
|
|
+ uploaded_files = []
|
|
|
+ for file in files:
|
|
|
+ # 获取文件扩展名
|
|
|
+ file_ext = os.path.splitext(file.filename)[1].lower().lstrip('.')
|
|
|
+ original_filename = file.filename
|
|
|
+ converted_content = None
|
|
|
+
|
|
|
+ # 读取文件内容
|
|
|
+ content = await file.read()
|
|
|
+
|
|
|
+ # 处理需要转换的文件格式
|
|
|
+ if file_ext in ["doc", "ppt"]:
|
|
|
+ # 创建临时目录用于文件转换
|
|
|
+ with tempfile.TemporaryDirectory() as temp_dir:
|
|
|
+ # 创建临时文件
|
|
|
+ temp_input_path = os.path.join(temp_dir, original_filename)
|
|
|
+ with open(temp_input_path, "wb") as temp_file:
|
|
|
+ temp_file.write(content)
|
|
|
+
|
|
|
+ # 确定目标格式
|
|
|
+ target_format = "docx" if file_ext == "doc" else "pptx"
|
|
|
+
|
|
|
+ # 使用FileUtils中的文件转换方法
|
|
|
+ converted_file_path = FileUtils.convert_office_file(temp_input_path, temp_dir, target_format)
|
|
|
+
|
|
|
+ if converted_file_path and os.path.exists(converted_file_path):
|
|
|
+ # 读取转换后的文件内容
|
|
|
+ with open(converted_file_path, "rb") as converted_file:
|
|
|
+ converted_content = converted_file.read()
|
|
|
+
|
|
|
+ # 更新文件名和扩展名
|
|
|
+ file_ext = target_format
|
|
|
+ file.filename = os.path.splitext(original_filename)[0] + f".{target_format}"
|
|
|
+ else:
|
|
|
+ # 转换失败,使用原始文件
|
|
|
+ raise HTTPException(status_code=500, detail=f"文件格式转换失败: {original_filename}")
|
|
|
+
|
|
|
+ # 检查文件格式是否支持
|
|
|
+ if file_ext not in settings.ALLOWED_EXTENSIONS:
|
|
|
+ raise HTTPException(status_code=400, detail=f"不支持的文件格式:{file_ext}")
|
|
|
+
|
|
|
+ # 使用转换后的内容或原始内容
|
|
|
+ file_content = converted_content if converted_content else content
|
|
|
+ file_size = len(file_content) / (1024 * 1024) # 转换为MB
|
|
|
+
|
|
|
+ # 验证文件大小
|
|
|
+ max_size = settings.ALLOWED_EXTENSIONS[file_ext]["max_size"]
|
|
|
+ if file_size > max_size:
|
|
|
+ raise HTTPException(status_code=400, detail=f"{original_filename}超过最大允许大小{max_size}MB")
|
|
|
+
|
|
|
+ # 上传到MinIO
|
|
|
+ minio_url = minio_utils.upload_file(file_content, file.filename, file.content_type)
|
|
|
+
|
|
|
+ # 从文件名识别知识类型
|
|
|
+ knowledge_type = None
|
|
|
+ if '指南' in file.filename:
|
|
|
+ knowledge_type = '指南'
|
|
|
+ elif '教材' in file.filename:
|
|
|
+ knowledge_type = '教材'
|
|
|
+
|
|
|
+ # 创建文件记录
|
|
|
+ db_file = KnowledgeFile(
|
|
|
+ knowledge_base_id=kb_id,
|
|
|
+ file_name=file.filename,
|
|
|
+ file_size=file_size,
|
|
|
+ file_type=file_ext,
|
|
|
+ minio_url=minio_url,
|
|
|
+ creator=kb.creator, # 使用知识库的创建人作为文件创建人
|
|
|
+ knowledge_type=knowledge_type
|
|
|
+ )
|
|
|
+ db.add(db_file)
|
|
|
+ uploaded_files.append(db_file)
|
|
|
+ # 更新知识库文件计数
|
|
|
+ DatabaseUtils.increment_file_count(db, kb_id)
|
|
|
+
|
|
|
+ db.commit()
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="上传成功",
|
|
|
+ data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in uploaded_files]
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
|
|
|
+def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optional[str] = None,
|
|
|
+ db: Session = Depends(get_db)):
|
|
|
+ if pageNo < 1:
|
|
|
+ raise HTTPException(status_code=400, detail="页码必须大于等于1")
|
|
|
+ if pageSize < 1:
|
|
|
+ raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
|
|
|
+
|
|
|
+ skip = (pageNo - 1) * pageSize
|
|
|
+ query = db.query(KnowledgeFile).filter(
|
|
|
+ KnowledgeFile.knowledge_base_id == kb_id,
|
|
|
+ KnowledgeFile.is_deleted == 0
|
|
|
+ )
|
|
|
+
|
|
|
+ if file_name:
|
|
|
+ query = query.filter(KnowledgeFile.file_name.ilike(f"%{file_name}%"))
|
|
|
+
|
|
|
+ total = query.count()
|
|
|
+ files = query.offset(skip).limit(pageSize).all()
|
|
|
+
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="查询成功",
|
|
|
+ data={
|
|
|
+ "list": [KnowledgeFileResponse.model_validate(file).model_dump() for file in files],
|
|
|
+ "total": total
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel)
|
|
|
+def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)):
|
|
|
+ files = db.query(KnowledgeFile).filter(
|
|
|
+ KnowledgeFile.knowledge_base_id == kb_id,
|
|
|
+ KnowledgeFile.file_name.ilike(f"%{file_name}%"),
|
|
|
+ KnowledgeFile.is_deleted == 0
|
|
|
+ ).all()
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="查询成功",
|
|
|
+ data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in files]
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/files/{file_id}/download")
|
|
|
+def download_file(file_id: int, db: Session = Depends(get_db)):
|
|
|
+ # 获取文件信息
|
|
|
+ file = db.query(KnowledgeFile).filter(
|
|
|
+ KnowledgeFile.id == file_id,
|
|
|
+ KnowledgeFile.is_deleted == 0
|
|
|
+ ).first()
|
|
|
+ if not file:
|
|
|
+ raise HTTPException(status_code=404, detail="文件不存在")
|
|
|
+
|
|
|
+ # 从MinIO下载文件
|
|
|
+ object_name = file.minio_url.split("/")[-1]
|
|
|
+ file_content = minio_utils.download_file(object_name)
|
|
|
+
|
|
|
+ # 创建文件流
|
|
|
+ file_stream = io.BytesIO(file_content)
|
|
|
+
|
|
|
+ # 对文件名进行URL编码
|
|
|
+ encoded_filename = urllib.parse.quote(file.file_name)
|
|
|
+
|
|
|
+ return StreamingResponse(
|
|
|
+ file_stream,
|
|
|
+ media_type="application/octet-stream",
|
|
|
+ headers={
|
|
|
+ "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.delete("/files/{file_id}", response_model=dict)
|
|
|
+def delete_file(file_id: int, db: Session = Depends(get_db)):
|
|
|
+ # 获取文件信息
|
|
|
+ file = db.query(KnowledgeFile).filter(
|
|
|
+ KnowledgeFile.id == file_id,
|
|
|
+ KnowledgeFile.is_deleted == 0
|
|
|
+ ).first()
|
|
|
+ if not file:
|
|
|
+ raise HTTPException(status_code=404, detail="文件不存在")
|
|
|
+
|
|
|
+ # 从MinIO删除文件
|
|
|
+ object_name = file.minio_url.split("/")[-1]
|
|
|
+ minio_utils.delete_file(object_name)
|
|
|
+
|
|
|
+ # 标记文件为已删除
|
|
|
+ file.is_deleted = 1
|
|
|
+ file.updated_at = datetime.utcnow()
|
|
|
+ # 更新知识库文件计数
|
|
|
+ DatabaseUtils.decrement_file_count(db, file.knowledge_base_id)
|
|
|
+ db.commit()
|
|
|
+
|
|
|
+ return {
|
|
|
+ "code": 200,
|
|
|
+ "message": "删除成功",
|
|
|
+ "data": True
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@router.put("/files/batch-update", response_model=ResponseModel)
|
|
|
+def batch_update_files(update_data: BatchFileUpdate, db: Session = Depends(get_db)):
|
|
|
+ updated_files = []
|
|
|
+ for file_update in update_data.files:
|
|
|
+ # 获取文件信息
|
|
|
+ file = db.query(KnowledgeFile).filter(
|
|
|
+ KnowledgeFile.id == file_update.id,
|
|
|
+ KnowledgeFile.is_deleted == 0
|
|
|
+ ).first()
|
|
|
+ if not file:
|
|
|
+ raise HTTPException(status_code=404, detail=f"文件ID {file_update.id} 不存在")
|
|
|
+
|
|
|
+ # 如果需要更新文件名,同时更新MinIO中的文件
|
|
|
+ if file_update.file_name and file_update.file_name != file.file_name:
|
|
|
+ old_object_name = file.minio_url.split("/")[-1]
|
|
|
+ new_object_name = f"{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{file_update.file_name}"
|
|
|
+
|
|
|
+ # 从MinIO下载文件
|
|
|
+ file_content = minio_utils.download_file(old_object_name)
|
|
|
+
|
|
|
+ # 上传到MinIO新的位置
|
|
|
+ new_minio_url = minio_utils.upload_file(
|
|
|
+ file_content,
|
|
|
+ file_update.file_name,
|
|
|
+ file.file_type
|
|
|
+ )
|
|
|
+
|
|
|
+ # 删除旧文件
|
|
|
+ minio_utils.delete_file(old_object_name)
|
|
|
+
|
|
|
+ # 更新数据库中的文件名和MinIO URL
|
|
|
+ file.file_name = file_update.file_name
|
|
|
+ file.minio_url = new_minio_url
|
|
|
+
|
|
|
+ # 更新其他字段
|
|
|
+ if file_update.version is not None:
|
|
|
+ file.version = file_update.version
|
|
|
+ if file_update.author is not None:
|
|
|
+ file.author = file_update.author
|
|
|
+ if file_update.year is not None:
|
|
|
+ file.year = file_update.year
|
|
|
+ if file_update.page_count is not None:
|
|
|
+ file.page_count = file_update.page_count
|
|
|
+ if file_update.creator is not None:
|
|
|
+ file.creator = file_update.creator
|
|
|
+ if file_update.knowledge_type is not None:
|
|
|
+ file.knowledge_type = file_update.knowledge_type
|
|
|
+
|
|
|
+ file.updated_at = datetime.utcnow()
|
|
|
+ updated_files.append(file)
|
|
|
+
|
|
|
+ db.commit()
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="更新成功",
|
|
|
+ data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in updated_files]
|
|
|
+ )
|
|
|
+
|
|
|
+knowledge_base_router = router
|