123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481 |
- import os
- import io
- import logging
- import urllib.parse
- import time
- import glob
- import shutil
- import subprocess
- from typing import List, Optional
- from datetime import datetime
- from fastapi import APIRouter, FastAPI, Depends, HTTPException, UploadFile, File, Form
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from fastapi.openapi.docs import (
- get_redoc_html,
- get_swagger_ui_html,
- get_swagger_ui_oauth2_redirect_html,
- )
- from sqlalchemy import create_engine
- from sqlalchemy.orm import Session, sessionmaker
- from sqlalchemy.ext.declarative import declarative_base
- from pydantic import BaseModel, ConfigDict, Field, field_serializer
- from agent.models.web.knowledge_base import Base, KnowledgeBase, KnowledgeFile
- from agent.utils import DatabaseUtils, MinioUtils, FileUtils
- from config.site import settings
- # 响应模型
- class ResponseModel(BaseModel):
- code: int
- message: str
- data: Optional[dict | list | bool | None]
- class KnowledgeBaseResponse(BaseModel):
- model_config = ConfigDict(from_attributes=True)
- id: int
- name: str
- description: Optional[str] = None
- tags: Optional[str] = None
- creator: Optional[str] = None
- file_count: int = 0
- created_at: datetime = Field(default_factory=datetime.utcnow)
- updated_at: datetime = Field(default_factory=datetime.utcnow)
- @field_serializer('created_at', 'updated_at')
- def serialize_datetime(self, dt: datetime) -> str:
- return dt.strftime('%Y-%m-%d')
- class KnowledgeFileResponse(BaseModel):
- model_config = ConfigDict(from_attributes=True)
- id: int
- knowledge_base_id: int
- file_name: str
- file_size: float
- file_type: str
- minio_url: str
- version: Optional[str] = None
- author: Optional[str] = None
- year: Optional[int] = None
- page_count: Optional[int] = None
- creator: Optional[str] = None
- knowledge_type: Optional[str] = None
- created_at: datetime = Field(default_factory=datetime.utcnow)
- updated_at: datetime = Field(default_factory=datetime.utcnow)
- @field_serializer('created_at', 'updated_at')
- def serialize_datetime(self, dt: datetime) -> str:
- return dt.strftime('%Y-%m-%d %H:%M')
- # 配置日志
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
- logger = logging.getLogger(__name__)
- # 创建数据库引擎
- engine = create_engine(settings.DATABASE_URL)
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
- # 创建数据库表
- Base.metadata.create_all(bind=engine)
- router = APIRouter(tags=["knowledge base interface"])
- # logger = logging.getLogger(__name__)
- # config = SiteConfig()
- # 初始化MinIO工具类
- minio_utils = MinioUtils()
- # 全局异常处理
- # @router.exception_handler(Exception)
- # async def global_exception_handler(request, exc):
- # logger.error(f"全局异常: {exc}", exc_info=True)
- # return {
- # "code": 500,
- # "message": "服务器内部错误",
- # "data": None
- # }
- # 依赖项:获取数据库会话
- def get_db():
- db = SessionLocal()
- try:
- yield db
- finally:
- db.close()
- # 请求模型
- class KnowledgeBaseCreate(BaseModel):
- name: str
- description: Optional[str] = None
- tags: Optional[str] = None
- class KnowledgeBaseUpdate(BaseModel):
- name: str
- description: Optional[str] = None
- tags: Optional[str] = None
- class FileUpdate(BaseModel):
- id: int
- file_name: Optional[str] = None
- version: Optional[str] = None
- author: Optional[str] = None
- year: Optional[int] = None
- page_count: Optional[int] = None
- creator: Optional[str] = None
- knowledge_type: Optional[str] = None
- class BatchFileUpdate(BaseModel):
- files: List[FileUpdate]
- # 使用utils.py中的FileUtils类进行文件转换
- @router.post("/knowledge-base/", response_model=ResponseModel)
- def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db)):
- kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, kb.description, kb.tags)
- return ResponseModel(
- code=200,
- message="创建成功",
- data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
- )
- @router.put("/knowledge-base/{kb_id}", response_model=ResponseModel)
- def update_knowledge_base(kb_id: int, kb: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
- kb_data = DatabaseUtils.update_knowledge_base(db, kb_id, kb.name, kb.description, kb.tags)
- return ResponseModel(
- code=200,
- message="更新成功",
- data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
- )
- @router.delete("/knowledge-base/{kb_id}", response_model=ResponseModel)
- def delete_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
- result = DatabaseUtils.delete_knowledge_base(db, kb_id)
- return ResponseModel(
- code=200,
- message="删除成功",
- data=result
- )
- @router.get("/knowledge-base/{kb_id}", response_model=ResponseModel)
- def get_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
- kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
- if not kb:
- raise HTTPException(status_code=404, detail="知识库不存在")
- return ResponseModel(
- code=200,
- message="查询成功",
- data=KnowledgeBaseResponse.model_validate(kb).model_dump()
- )
- @router.get("/knowledge-base/", response_model=ResponseModel)
- def list_knowledge_bases(pageNo: int = 1, pageSize: int = 10, name: Optional[str] = None,
- db: Session = Depends(get_db)):
- if pageNo < 1:
- raise HTTPException(status_code=400, detail="页码必须大于等于1")
- if pageSize < 1:
- raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
- skip = (pageNo - 1) * pageSize
- kb_list, total = DatabaseUtils.get_knowledge_bases(db, skip, pageSize, name)
- return ResponseModel(
- code=200,
- message="查询成功",
- data={
- "list": [KnowledgeBaseResponse.model_validate(kb).model_dump() for kb in kb_list],
- "total": total
- }
- )
- @router.get("/knowledge-base/name/{name}", response_model=ResponseModel)
- def get_knowledge_base_by_name(name: str, db: Session = Depends(get_db)):
- kb = DatabaseUtils.get_knowledge_base_by_name(db, name)
- if not kb:
- raise HTTPException(status_code=404, detail="知识库不存在")
- return ResponseModel(
- code=200,
- message="查询成功",
- data=KnowledgeBaseResponse.model_validate(kb).model_dump()
- )
- @router.post("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
- async def upload_files(
- kb_id: int,
- files: List[UploadFile] = File(...),
- db: Session = Depends(get_db)
- ):
- # 验证知识库是否存在
- kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
- if not kb:
- raise HTTPException(status_code=404, detail="知识库不存在")
- # 验证文件数量
- if len(files) > settings.MAX_FILE_COUNT:
- raise HTTPException(status_code=400, detail=f"单次上传文件数量不能超过{settings.MAX_FILE_COUNT}个")
- # 导入所需模块
- import tempfile
- uploaded_files = []
- for file in files:
- # 获取文件扩展名
- file_ext = os.path.splitext(file.filename)[1].lower().lstrip('.')
- original_filename = file.filename
- converted_content = None
- # 读取文件内容
- content = await file.read()
- # 处理需要转换的文件格式
- if file_ext in ["doc", "ppt"]:
- # 创建临时目录用于文件转换
- with tempfile.TemporaryDirectory() as temp_dir:
- # 创建临时文件
- temp_input_path = os.path.join(temp_dir, original_filename)
- with open(temp_input_path, "wb") as temp_file:
- temp_file.write(content)
- # 确定目标格式
- target_format = "docx" if file_ext == "doc" else "pptx"
- # 使用FileUtils中的文件转换方法
- converted_file_path = FileUtils.convert_office_file(temp_input_path, temp_dir, target_format)
- if converted_file_path and os.path.exists(converted_file_path):
- # 读取转换后的文件内容
- with open(converted_file_path, "rb") as converted_file:
- converted_content = converted_file.read()
- # 更新文件名和扩展名
- file_ext = target_format
- file.filename = os.path.splitext(original_filename)[0] + f".{target_format}"
- else:
- # 转换失败,使用原始文件
- raise HTTPException(status_code=500, detail=f"文件格式转换失败: {original_filename}")
- # 检查文件格式是否支持
- if file_ext not in settings.ALLOWED_EXTENSIONS:
- raise HTTPException(status_code=400, detail=f"不支持的文件格式:{file_ext}")
- # 使用转换后的内容或原始内容
- file_content = converted_content if converted_content else content
- file_size = len(file_content) / (1024 * 1024) # 转换为MB
- # 验证文件大小
- max_size = settings.ALLOWED_EXTENSIONS[file_ext]["max_size"]
- if file_size > max_size:
- raise HTTPException(status_code=400, detail=f"{original_filename}超过最大允许大小{max_size}MB")
- # 上传到MinIO
- minio_url = minio_utils.upload_file(file_content, file.filename, file.content_type)
- # 从文件名识别知识类型
- knowledge_type = None
- if '指南' in file.filename:
- knowledge_type = '指南'
- elif '教材' in file.filename:
- knowledge_type = '教材'
- # 创建文件记录
- db_file = KnowledgeFile(
- knowledge_base_id=kb_id,
- file_name=file.filename,
- file_size=file_size,
- file_type=file_ext,
- minio_url=minio_url,
- creator=kb.creator, # 使用知识库的创建人作为文件创建人
- knowledge_type=knowledge_type
- )
- db.add(db_file)
- uploaded_files.append(db_file)
- # 更新知识库文件计数
- DatabaseUtils.increment_file_count(db, kb_id)
- db.commit()
- return ResponseModel(
- code=200,
- message="上传成功",
- data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in uploaded_files]
- )
- @router.get("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
- def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optional[str] = None,
- db: Session = Depends(get_db)):
- if pageNo < 1:
- raise HTTPException(status_code=400, detail="页码必须大于等于1")
- if pageSize < 1:
- raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
- skip = (pageNo - 1) * pageSize
- query = db.query(KnowledgeFile).filter(
- KnowledgeFile.knowledge_base_id == kb_id,
- KnowledgeFile.is_deleted == 0
- )
- if file_name:
- query = query.filter(KnowledgeFile.file_name.ilike(f"%{file_name}%"))
- total = query.count()
- files = query.offset(skip).limit(pageSize).all()
- return ResponseModel(
- code=200,
- message="查询成功",
- data={
- "list": [KnowledgeFileResponse.model_validate(file).model_dump() for file in files],
- "total": total
- }
- )
- @router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel)
- def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)):
- files = db.query(KnowledgeFile).filter(
- KnowledgeFile.knowledge_base_id == kb_id,
- KnowledgeFile.file_name.ilike(f"%{file_name}%"),
- KnowledgeFile.is_deleted == 0
- ).all()
- return ResponseModel(
- code=200,
- message="查询成功",
- data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in files]
- )
- @router.get("/files/{file_id}/download")
- def download_file(file_id: int, db: Session = Depends(get_db)):
- # 获取文件信息
- file = db.query(KnowledgeFile).filter(
- KnowledgeFile.id == file_id,
- KnowledgeFile.is_deleted == 0
- ).first()
- if not file:
- raise HTTPException(status_code=404, detail="文件不存在")
- # 从MinIO下载文件
- object_name = file.minio_url.split("/")[-1]
- file_content = minio_utils.download_file(object_name)
- # 创建文件流
- file_stream = io.BytesIO(file_content)
- # 对文件名进行URL编码
- encoded_filename = urllib.parse.quote(file.file_name)
- return StreamingResponse(
- file_stream,
- media_type="application/octet-stream",
- headers={
- "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
- }
- )
- @router.delete("/files/{file_id}", response_model=dict)
- def delete_file(file_id: int, db: Session = Depends(get_db)):
- # 获取文件信息
- file = db.query(KnowledgeFile).filter(
- KnowledgeFile.id == file_id,
- KnowledgeFile.is_deleted == 0
- ).first()
- if not file:
- raise HTTPException(status_code=404, detail="文件不存在")
- # 从MinIO删除文件
- object_name = file.minio_url.split("/")[-1]
- minio_utils.delete_file(object_name)
- # 标记文件为已删除
- file.is_deleted = 1
- file.updated_at = datetime.utcnow()
- # 更新知识库文件计数
- DatabaseUtils.decrement_file_count(db, file.knowledge_base_id)
- db.commit()
- return {
- "code": 200,
- "message": "删除成功",
- "data": True
- }
- @router.put("/files/batch-update", response_model=ResponseModel)
- def batch_update_files(update_data: BatchFileUpdate, db: Session = Depends(get_db)):
- updated_files = []
- for file_update in update_data.files:
- # 获取文件信息
- file = db.query(KnowledgeFile).filter(
- KnowledgeFile.id == file_update.id,
- KnowledgeFile.is_deleted == 0
- ).first()
- if not file:
- raise HTTPException(status_code=404, detail=f"文件ID {file_update.id} 不存在")
- # 如果需要更新文件名,同时更新MinIO中的文件
- if file_update.file_name and file_update.file_name != file.file_name:
- old_object_name = file.minio_url.split("/")[-1]
- new_object_name = f"{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{file_update.file_name}"
- # 从MinIO下载文件
- file_content = minio_utils.download_file(old_object_name)
- # 上传到MinIO新的位置
- new_minio_url = minio_utils.upload_file(
- file_content,
- file_update.file_name,
- file.file_type
- )
- # 删除旧文件
- minio_utils.delete_file(old_object_name)
- # 更新数据库中的文件名和MinIO URL
- file.file_name = file_update.file_name
- file.minio_url = new_minio_url
- # 更新其他字段
- if file_update.version is not None:
- file.version = file_update.version
- if file_update.author is not None:
- file.author = file_update.author
- if file_update.year is not None:
- file.year = file_update.year
- if file_update.page_count is not None:
- file.page_count = file_update.page_count
- if file_update.creator is not None:
- file.creator = file_update.creator
- if file_update.knowledge_type is not None:
- file.knowledge_type = file_update.knowledge_type
- file.updated_at = datetime.utcnow()
- updated_files.append(file)
- db.commit()
- return ResponseModel(
- code=200,
- message="更新成功",
- data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in updated_files]
- )
- knowledge_base_router = router
|