knowledge_base_router.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. import os
  2. import io
  3. import logging
  4. import urllib.parse
  5. import time
  6. import glob
  7. import shutil
  8. import subprocess
  9. from typing import List, Optional
  10. from datetime import datetime
  11. from fastapi import APIRouter, FastAPI, Depends, HTTPException, UploadFile, File, Form
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from fastapi.responses import StreamingResponse
  14. from fastapi.staticfiles import StaticFiles
  15. from fastapi.openapi.docs import (
  16. get_redoc_html,
  17. get_swagger_ui_html,
  18. get_swagger_ui_oauth2_redirect_html,
  19. )
  20. from sqlalchemy import create_engine
  21. from sqlalchemy.orm import Session, sessionmaker
  22. from sqlalchemy.ext.declarative import declarative_base
  23. from pydantic import BaseModel, ConfigDict, Field, field_serializer
  24. from agent.models.web.knowledge_base import Base, KnowledgeBase, KnowledgeFile
  25. from agent.utils import DatabaseUtils, MinioUtils, FileUtils
  26. from config.site import settings
  27. # 响应模型
  28. class ResponseModel(BaseModel):
  29. code: int
  30. message: str
  31. data: Optional[dict | list | bool | None]
  32. class KnowledgeBaseResponse(BaseModel):
  33. model_config = ConfigDict(from_attributes=True)
  34. id: int
  35. name: str
  36. description: Optional[str] = None
  37. tags: Optional[str] = None
  38. creator: Optional[str] = None
  39. file_count: int = 0
  40. created_at: datetime = Field(default_factory=datetime.utcnow)
  41. updated_at: datetime = Field(default_factory=datetime.utcnow)
  42. @field_serializer('created_at', 'updated_at')
  43. def serialize_datetime(self, dt: datetime) -> str:
  44. return dt.strftime('%Y-%m-%d')
  45. class KnowledgeFileResponse(BaseModel):
  46. model_config = ConfigDict(from_attributes=True)
  47. id: int
  48. knowledge_base_id: int
  49. file_name: str
  50. file_size: float
  51. file_type: str
  52. minio_url: str
  53. version: Optional[str] = None
  54. author: Optional[str] = None
  55. year: Optional[int] = None
  56. page_count: Optional[int] = None
  57. creator: Optional[str] = None
  58. knowledge_type: Optional[str] = None
  59. created_at: datetime = Field(default_factory=datetime.utcnow)
  60. updated_at: datetime = Field(default_factory=datetime.utcnow)
  61. @field_serializer('created_at', 'updated_at')
  62. def serialize_datetime(self, dt: datetime) -> str:
  63. return dt.strftime('%Y-%m-%d %H:%M')
  64. # 配置日志
  65. logging.basicConfig(
  66. level=logging.INFO,
  67. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  68. )
  69. logger = logging.getLogger(__name__)
  70. # 创建数据库引擎
  71. engine = create_engine(settings.DATABASE_URL)
  72. SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
  73. # 创建数据库表
  74. Base.metadata.create_all(bind=engine)
  75. router = APIRouter(tags=["knowledge base interface"])
  76. # logger = logging.getLogger(__name__)
  77. # config = SiteConfig()
  78. # 初始化MinIO工具类
  79. minio_utils = MinioUtils()
  80. # 全局异常处理
  81. # @router.exception_handler(Exception)
  82. # async def global_exception_handler(request, exc):
  83. # logger.error(f"全局异常: {exc}", exc_info=True)
  84. # return {
  85. # "code": 500,
  86. # "message": "服务器内部错误",
  87. # "data": None
  88. # }
  89. # 依赖项:获取数据库会话
  90. def get_db():
  91. db = SessionLocal()
  92. try:
  93. yield db
  94. finally:
  95. db.close()
  96. # 请求模型
  97. class KnowledgeBaseCreate(BaseModel):
  98. name: str
  99. description: Optional[str] = None
  100. tags: Optional[str] = None
  101. class KnowledgeBaseUpdate(BaseModel):
  102. name: str
  103. description: Optional[str] = None
  104. tags: Optional[str] = None
  105. class FileUpdate(BaseModel):
  106. id: int
  107. file_name: Optional[str] = None
  108. version: Optional[str] = None
  109. author: Optional[str] = None
  110. year: Optional[int] = None
  111. page_count: Optional[int] = None
  112. creator: Optional[str] = None
  113. knowledge_type: Optional[str] = None
  114. class BatchFileUpdate(BaseModel):
  115. files: List[FileUpdate]
  116. # 使用utils.py中的FileUtils类进行文件转换
  117. @router.post("/knowledge-base/", response_model=ResponseModel)
  118. def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db)):
  119. kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, kb.description, kb.tags)
  120. return ResponseModel(
  121. code=200,
  122. message="创建成功",
  123. data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
  124. )
  125. @router.put("/knowledge-base/{kb_id}", response_model=ResponseModel)
  126. def update_knowledge_base(kb_id: int, kb: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
  127. kb_data = DatabaseUtils.update_knowledge_base(db, kb_id, kb.name, kb.description, kb.tags)
  128. return ResponseModel(
  129. code=200,
  130. message="更新成功",
  131. data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
  132. )
  133. @router.delete("/knowledge-base/{kb_id}", response_model=ResponseModel)
  134. def delete_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
  135. result = DatabaseUtils.delete_knowledge_base(db, kb_id)
  136. return ResponseModel(
  137. code=200,
  138. message="删除成功",
  139. data=result
  140. )
  141. @router.get("/knowledge-base/{kb_id}", response_model=ResponseModel)
  142. def get_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
  143. kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  144. if not kb:
  145. raise HTTPException(status_code=404, detail="知识库不存在")
  146. return ResponseModel(
  147. code=200,
  148. message="查询成功",
  149. data=KnowledgeBaseResponse.model_validate(kb).model_dump()
  150. )
  151. @router.get("/knowledge-base/", response_model=ResponseModel)
  152. def list_knowledge_bases(pageNo: int = 1, pageSize: int = 10, name: Optional[str] = None,
  153. db: Session = Depends(get_db)):
  154. if pageNo < 1:
  155. raise HTTPException(status_code=400, detail="页码必须大于等于1")
  156. if pageSize < 1:
  157. raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
  158. skip = (pageNo - 1) * pageSize
  159. kb_list, total = DatabaseUtils.get_knowledge_bases(db, skip, pageSize, name)
  160. return ResponseModel(
  161. code=200,
  162. message="查询成功",
  163. data={
  164. "list": [KnowledgeBaseResponse.model_validate(kb).model_dump() for kb in kb_list],
  165. "total": total
  166. }
  167. )
  168. @router.get("/knowledge-base/name/{name}", response_model=ResponseModel)
  169. def get_knowledge_base_by_name(name: str, db: Session = Depends(get_db)):
  170. kb = DatabaseUtils.get_knowledge_base_by_name(db, name)
  171. if not kb:
  172. raise HTTPException(status_code=404, detail="知识库不存在")
  173. return ResponseModel(
  174. code=200,
  175. message="查询成功",
  176. data=KnowledgeBaseResponse.model_validate(kb).model_dump()
  177. )
  178. @router.post("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
  179. async def upload_files(
  180. kb_id: int,
  181. files: List[UploadFile] = File(...),
  182. db: Session = Depends(get_db)
  183. ):
  184. # 验证知识库是否存在
  185. kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  186. if not kb:
  187. raise HTTPException(status_code=404, detail="知识库不存在")
  188. # 验证文件数量
  189. if len(files) > settings.MAX_FILE_COUNT:
  190. raise HTTPException(status_code=400, detail=f"单次上传文件数量不能超过{settings.MAX_FILE_COUNT}个")
  191. # 导入所需模块
  192. import tempfile
  193. uploaded_files = []
  194. for file in files:
  195. # 获取文件扩展名
  196. file_ext = os.path.splitext(file.filename)[1].lower().lstrip('.')
  197. original_filename = file.filename
  198. converted_content = None
  199. # 读取文件内容
  200. content = await file.read()
  201. # 处理需要转换的文件格式
  202. if file_ext in ["doc", "ppt"]:
  203. # 创建临时目录用于文件转换
  204. with tempfile.TemporaryDirectory() as temp_dir:
  205. # 创建临时文件
  206. temp_input_path = os.path.join(temp_dir, original_filename)
  207. with open(temp_input_path, "wb") as temp_file:
  208. temp_file.write(content)
  209. # 确定目标格式
  210. target_format = "docx" if file_ext == "doc" else "pptx"
  211. # 使用FileUtils中的文件转换方法
  212. converted_file_path = FileUtils.convert_office_file(temp_input_path, temp_dir, target_format)
  213. if converted_file_path and os.path.exists(converted_file_path):
  214. # 读取转换后的文件内容
  215. with open(converted_file_path, "rb") as converted_file:
  216. converted_content = converted_file.read()
  217. # 更新文件名和扩展名
  218. file_ext = target_format
  219. file.filename = os.path.splitext(original_filename)[0] + f".{target_format}"
  220. else:
  221. # 转换失败,使用原始文件
  222. raise HTTPException(status_code=500, detail=f"文件格式转换失败: {original_filename}")
  223. # 检查文件格式是否支持
  224. if file_ext not in settings.ALLOWED_EXTENSIONS:
  225. raise HTTPException(status_code=400, detail=f"不支持的文件格式:{file_ext}")
  226. # 使用转换后的内容或原始内容
  227. file_content = converted_content if converted_content else content
  228. file_size = len(file_content) / (1024 * 1024) # 转换为MB
  229. # 验证文件大小
  230. max_size = settings.ALLOWED_EXTENSIONS[file_ext]["max_size"]
  231. if file_size > max_size:
  232. raise HTTPException(status_code=400, detail=f"{original_filename}超过最大允许大小{max_size}MB")
  233. # 上传到MinIO
  234. minio_url = minio_utils.upload_file(file_content, file.filename, file.content_type)
  235. # 从文件名识别知识类型
  236. knowledge_type = None
  237. if '指南' in file.filename:
  238. knowledge_type = '指南'
  239. elif '教材' in file.filename:
  240. knowledge_type = '教材'
  241. # 创建文件记录
  242. db_file = KnowledgeFile(
  243. knowledge_base_id=kb_id,
  244. file_name=file.filename,
  245. file_size=file_size,
  246. file_type=file_ext,
  247. minio_url=minio_url,
  248. creator=kb.creator, # 使用知识库的创建人作为文件创建人
  249. knowledge_type=knowledge_type
  250. )
  251. db.add(db_file)
  252. uploaded_files.append(db_file)
  253. # 更新知识库文件计数
  254. DatabaseUtils.increment_file_count(db, kb_id)
  255. db.commit()
  256. return ResponseModel(
  257. code=200,
  258. message="上传成功",
  259. data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in uploaded_files]
  260. )
  261. @router.get("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
  262. def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optional[str] = None,
  263. db: Session = Depends(get_db)):
  264. if pageNo < 1:
  265. raise HTTPException(status_code=400, detail="页码必须大于等于1")
  266. if pageSize < 1:
  267. raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
  268. skip = (pageNo - 1) * pageSize
  269. query = db.query(KnowledgeFile).filter(
  270. KnowledgeFile.knowledge_base_id == kb_id,
  271. KnowledgeFile.is_deleted == 0
  272. )
  273. if file_name:
  274. query = query.filter(KnowledgeFile.file_name.ilike(f"%{file_name}%"))
  275. total = query.count()
  276. files = query.offset(skip).limit(pageSize).all()
  277. return ResponseModel(
  278. code=200,
  279. message="查询成功",
  280. data={
  281. "list": [KnowledgeFileResponse.model_validate(file).model_dump() for file in files],
  282. "total": total
  283. }
  284. )
  285. @router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel)
  286. def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)):
  287. files = db.query(KnowledgeFile).filter(
  288. KnowledgeFile.knowledge_base_id == kb_id,
  289. KnowledgeFile.file_name.ilike(f"%{file_name}%"),
  290. KnowledgeFile.is_deleted == 0
  291. ).all()
  292. return ResponseModel(
  293. code=200,
  294. message="查询成功",
  295. data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in files]
  296. )
  297. @router.get("/files/{file_id}/download")
  298. def download_file(file_id: int, db: Session = Depends(get_db)):
  299. # 获取文件信息
  300. file = db.query(KnowledgeFile).filter(
  301. KnowledgeFile.id == file_id,
  302. KnowledgeFile.is_deleted == 0
  303. ).first()
  304. if not file:
  305. raise HTTPException(status_code=404, detail="文件不存在")
  306. # 从MinIO下载文件
  307. object_name = file.minio_url.split("/")[-1]
  308. file_content = minio_utils.download_file(object_name)
  309. # 创建文件流
  310. file_stream = io.BytesIO(file_content)
  311. # 对文件名进行URL编码
  312. encoded_filename = urllib.parse.quote(file.file_name)
  313. return StreamingResponse(
  314. file_stream,
  315. media_type="application/octet-stream",
  316. headers={
  317. "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
  318. }
  319. )
  320. @router.delete("/files/{file_id}", response_model=dict)
  321. def delete_file(file_id: int, db: Session = Depends(get_db)):
  322. # 获取文件信息
  323. file = db.query(KnowledgeFile).filter(
  324. KnowledgeFile.id == file_id,
  325. KnowledgeFile.is_deleted == 0
  326. ).first()
  327. if not file:
  328. raise HTTPException(status_code=404, detail="文件不存在")
  329. # 从MinIO删除文件
  330. object_name = file.minio_url.split("/")[-1]
  331. minio_utils.delete_file(object_name)
  332. # 标记文件为已删除
  333. file.is_deleted = 1
  334. file.updated_at = datetime.utcnow()
  335. # 更新知识库文件计数
  336. DatabaseUtils.decrement_file_count(db, file.knowledge_base_id)
  337. db.commit()
  338. return {
  339. "code": 200,
  340. "message": "删除成功",
  341. "data": True
  342. }
  343. @router.put("/files/batch-update", response_model=ResponseModel)
  344. def batch_update_files(update_data: BatchFileUpdate, db: Session = Depends(get_db)):
  345. updated_files = []
  346. for file_update in update_data.files:
  347. # 获取文件信息
  348. file = db.query(KnowledgeFile).filter(
  349. KnowledgeFile.id == file_update.id,
  350. KnowledgeFile.is_deleted == 0
  351. ).first()
  352. if not file:
  353. raise HTTPException(status_code=404, detail=f"文件ID {file_update.id} 不存在")
  354. # 如果需要更新文件名,同时更新MinIO中的文件
  355. if file_update.file_name and file_update.file_name != file.file_name:
  356. old_object_name = file.minio_url.split("/")[-1]
  357. new_object_name = f"{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{file_update.file_name}"
  358. # 从MinIO下载文件
  359. file_content = minio_utils.download_file(old_object_name)
  360. # 上传到MinIO新的位置
  361. new_minio_url = minio_utils.upload_file(
  362. file_content,
  363. file_update.file_name,
  364. file.file_type
  365. )
  366. # 删除旧文件
  367. minio_utils.delete_file(old_object_name)
  368. # 更新数据库中的文件名和MinIO URL
  369. file.file_name = file_update.file_name
  370. file.minio_url = new_minio_url
  371. # 更新其他字段
  372. if file_update.version is not None:
  373. file.version = file_update.version
  374. if file_update.author is not None:
  375. file.author = file_update.author
  376. if file_update.year is not None:
  377. file.year = file_update.year
  378. if file_update.page_count is not None:
  379. file.page_count = file_update.page_count
  380. if file_update.creator is not None:
  381. file.creator = file_update.creator
  382. if file_update.knowledge_type is not None:
  383. file.knowledge_type = file_update.knowledge_type
  384. file.updated_at = datetime.utcnow()
  385. updated_files.append(file)
  386. db.commit()
  387. return ResponseModel(
  388. code=200,
  389. message="更新成功",
  390. data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in updated_files]
  391. )
  392. knowledge_base_router = router