knowledge_base_router.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. import os
  2. import io
  3. import logging
  4. import urllib.parse
  5. import time
  6. import glob
  7. import shutil
  8. import subprocess
  9. from typing import List, Optional
  10. from datetime import datetime
  11. from fastapi import APIRouter, FastAPI, Depends, HTTPException, UploadFile, File, Form
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from fastapi.responses import StreamingResponse
  14. from fastapi.staticfiles import StaticFiles
  15. from fastapi.openapi.docs import (
  16. get_redoc_html,
  17. get_swagger_ui_html,
  18. get_swagger_ui_oauth2_redirect_html,
  19. )
  20. from sqlalchemy import create_engine, and_
  21. from sqlalchemy.orm import Session, sessionmaker
  22. from sqlalchemy.ext.declarative import declarative_base
  23. from agent.models.db.graph import DbUserDataRelation as UserDataRelation
  24. from pydantic import BaseModel, ConfigDict, Field, field_serializer
  25. from agent.libs.auth import SessionValues, verify_session_id
  26. from agent.libs.user_data_relation import UserDataRelationBusiness
  27. from agent.models.web.knowledge_base import Base, KnowledgeBase, KnowledgeFile
  28. from agent.utils import DatabaseUtils, MinioUtils, FileUtils
  29. from config.site import settings
  30. # 响应模型
  31. class ResponseModel(BaseModel):
  32. code: int
  33. message: str
  34. data: Optional[dict | list | bool | None]
  35. class KnowledgeBaseResponse(BaseModel):
  36. model_config = ConfigDict(from_attributes=True)
  37. id: int
  38. name: str
  39. description: Optional[str] = None
  40. tags: Optional[str] = None
  41. creator: Optional[str] = None
  42. user_name: Optional[str] = None # 新增字段
  43. file_count: int = 0
  44. created_at: datetime = Field(default_factory=datetime.utcnow)
  45. updated_at: datetime = Field(default_factory=datetime.utcnow)
  46. @field_serializer('created_at', 'updated_at')
  47. def serialize_datetime(self, dt: datetime) -> str:
  48. return dt.strftime('%Y-%m-%d')
  49. class KnowledgeFileResponse(BaseModel):
  50. model_config = ConfigDict(from_attributes=True)
  51. id: int
  52. knowledge_base_id: int
  53. file_name: str
  54. file_size: float
  55. file_type: str
  56. minio_url: str
  57. user_name: Optional[str] = None # 用户名
  58. version: Optional[str] = None
  59. author: Optional[str] = None
  60. year: Optional[int] = None
  61. page_count: Optional[int] = None
  62. creator: Optional[str] = None
  63. knowledge_type: Optional[str] = None
  64. created_at: datetime = Field(default_factory=datetime.utcnow)
  65. updated_at: datetime = Field(default_factory=datetime.utcnow)
  66. @field_serializer('created_at', 'updated_at')
  67. def serialize_datetime(self, dt: datetime) -> str:
  68. return dt.strftime('%Y-%m-%d %H:%M')
  69. # 配置日志
  70. logging.basicConfig(
  71. level=logging.INFO,
  72. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  73. )
  74. logger = logging.getLogger(__name__)
  75. # 创建数据库引擎
  76. engine = create_engine(settings.DATABASE_URL)
  77. SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
  78. # 创建数据库表
  79. Base.metadata.create_all(bind=engine)
  80. router = APIRouter(tags=["knowledge base interface"])
  81. # logger = logging.getLogger(__name__)
  82. # config = SiteConfig()
  83. # 初始化MinIO工具类
  84. minio_utils = MinioUtils()
  85. # 全局异常处理
  86. # @router.exception_handler(Exception)
  87. # async def global_exception_handler(request, exc):
  88. # logger.error(f"全局异常: {exc}", exc_info=True)
  89. # return {
  90. # "code": 500,
  91. # "message": "服务器内部错误",
  92. # "data": None
  93. # }
  94. # 依赖项:获取数据库会话
  95. def get_db():
  96. db = SessionLocal()
  97. try:
  98. yield db
  99. finally:
  100. db.close()
  101. # 请求模型
  102. class KnowledgeBaseCreate(BaseModel):
  103. name: str
  104. description: Optional[str] = None
  105. tags: Optional[str] = None
  106. class KnowledgeBaseUpdate(BaseModel):
  107. name: str
  108. description: Optional[str] = None
  109. tags: Optional[str] = None
  110. class FileUpdate(BaseModel):
  111. id: int
  112. file_name: Optional[str] = None
  113. version: Optional[str] = None
  114. author: Optional[str] = None
  115. year: Optional[int] = None
  116. page_count: Optional[int] = None
  117. creator: Optional[str] = None
  118. knowledge_type: Optional[str] = None
  119. class BatchFileUpdate(BaseModel):
  120. files: List[FileUpdate]
  121. # 使用utils.py中的FileUtils类进行文件转换
  122. @router.post("/knowledge-base/", response_model=ResponseModel)
  123. def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db),
  124. sess:SessionValues = Depends(verify_session_id)):
  125. # 1. 从session获取user_id
  126. user_id = sess.user_id
  127. user_name = sess.username
  128. # 2. 创建知识库
  129. kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, user_id, kb.description, kb.tags)
  130. # 3. 创建用户数据关联
  131. relation_business = UserDataRelationBusiness(db)
  132. relation = relation_business.create_relation(
  133. user_id=user_id,
  134. data_category='KnowledgeBase',
  135. data_id=kb_data.id,
  136. user_name=user_name,
  137. role_id=None,
  138. role_name=None
  139. )
  140. return ResponseModel(
  141. code=200,
  142. message="创建成功",
  143. data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
  144. )
  145. @router.put("/knowledge-base/{kb_id}", response_model=ResponseModel)
  146. def update_knowledge_base(kb_id: int, kb: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
  147. kb_data = DatabaseUtils.update_knowledge_base(db, kb_id, kb.name, kb.description, kb.tags)
  148. return ResponseModel(
  149. code=200,
  150. message="更新成功",
  151. data=KnowledgeBaseResponse.model_validate(kb_data).model_dump()
  152. )
  153. @router.delete("/knowledge-base/{kb_id}", response_model=ResponseModel)
  154. def delete_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
  155. result = DatabaseUtils.delete_knowledge_base(db, kb_id)
  156. return ResponseModel(
  157. code=200,
  158. message="删除成功",
  159. data=result
  160. )
  161. @router.get("/knowledge-base/{kb_id}", response_model=ResponseModel)
  162. def get_knowledge_base(kb_id: int, db: Session = Depends(get_db)):
  163. kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  164. if not kb:
  165. raise HTTPException(status_code=404, detail="知识库不存在")
  166. kb_data = KnowledgeBaseResponse.model_validate(kb).model_dump()
  167. return ResponseModel(
  168. code=200,
  169. message="查询成功",
  170. data=kb_data
  171. )
  172. @router.get("/knowledge-base/", response_model=ResponseModel)
  173. def list_knowledge_bases(pageNo: int = 1, pageSize: int = 10, name: Optional[str] = None,
  174. db: Session = Depends(get_db)):
  175. if pageNo < 1:
  176. raise HTTPException(status_code=400, detail="页码必须大于等于1")
  177. if pageSize < 1:
  178. raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
  179. skip = (pageNo - 1) * pageSize
  180. kb_list, total = DatabaseUtils.get_knowledge_bases(db, skip, pageSize, name)
  181. return ResponseModel(
  182. code=200,
  183. message="查询成功",
  184. data={
  185. "list": [KnowledgeBaseResponse.model_validate(kb).model_dump() for kb in kb_list],
  186. "total": total
  187. }
  188. )
  189. @router.get("/knowledge-base/name/{name}", response_model=ResponseModel)
  190. def get_knowledge_base_by_name(name: str, db: Session = Depends(get_db)):
  191. kb = DatabaseUtils.get_knowledge_base_by_name(db, name)
  192. if not kb:
  193. raise HTTPException(status_code=404, detail="知识库不存在")
  194. return ResponseModel(
  195. code=200,
  196. message="查询成功",
  197. data=KnowledgeBaseResponse.model_validate(kb).model_dump()
  198. )
  199. @router.post("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
  200. async def upload_files(
  201. kb_id: int,
  202. files: List[UploadFile] = File(...),
  203. db: Session = Depends(get_db),
  204. sess: SessionValues = Depends(verify_session_id) # 添加session依赖
  205. ):
  206. # 验证知识库是否存在
  207. kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  208. if not kb:
  209. raise HTTPException(status_code=404, detail="知识库不存在")
  210. # 获取当前用户信息
  211. user_id = sess.user_id
  212. user_name = sess.username
  213. # 验证文件数量
  214. if len(files) > settings.MAX_FILE_COUNT:
  215. raise HTTPException(status_code=400, detail=f"单次上传文件数量不能超过{settings.MAX_FILE_COUNT}个")
  216. # 导入所需模块
  217. import tempfile
  218. uploaded_files = []
  219. for file in files:
  220. # 获取文件扩展名
  221. file_ext = os.path.splitext(file.filename)[1].lower().lstrip('.')
  222. original_filename = file.filename
  223. converted_content = None
  224. # 读取文件内容
  225. content = await file.read()
  226. # 处理需要转换的文件格式
  227. if file_ext in ["doc", "ppt"]:
  228. # 创建临时目录用于文件转换
  229. with tempfile.TemporaryDirectory() as temp_dir:
  230. # 创建临时文件
  231. temp_input_path = os.path.join(temp_dir, original_filename)
  232. with open(temp_input_path, "wb") as temp_file:
  233. temp_file.write(content)
  234. # 确定目标格式
  235. target_format = "docx" if file_ext == "doc" else "pptx"
  236. # 使用FileUtils中的文件转换方法
  237. converted_file_path = FileUtils.convert_office_file(temp_input_path, temp_dir, target_format)
  238. if converted_file_path and os.path.exists(converted_file_path):
  239. # 读取转换后的文件内容
  240. with open(converted_file_path, "rb") as converted_file:
  241. converted_content = converted_file.read()
  242. # 更新文件名和扩展名
  243. file_ext = target_format
  244. file.filename = os.path.splitext(original_filename)[0] + f".{target_format}"
  245. else:
  246. # 转换失败,使用原始文件
  247. raise HTTPException(status_code=500, detail=f"文件格式转换失败: {original_filename}")
  248. # 检查文件格式是否支持
  249. if file_ext not in settings.ALLOWED_EXTENSIONS:
  250. raise HTTPException(status_code=400, detail=f"不支持的文件格式:{file_ext}")
  251. # 使用转换后的内容或原始内容
  252. file_content = converted_content if converted_content else content
  253. file_size = len(file_content) / (1024 * 1024) # 转换为MB
  254. # 验证文件大小
  255. max_size = settings.ALLOWED_EXTENSIONS[file_ext]["max_size"]
  256. if file_size > max_size:
  257. raise HTTPException(status_code=400, detail=f"{original_filename}超过最大允许大小{max_size}MB")
  258. # 上传到MinIO
  259. minio_url = minio_utils.upload_file(file_content, file.filename, file.content_type)
  260. # 从文件名识别知识类型
  261. knowledge_type = None
  262. if '指南' in file.filename:
  263. knowledge_type = '指南'
  264. elif '教材' in file.filename:
  265. knowledge_type = '教材'
  266. # 创建文件记录
  267. db_file = KnowledgeFile(
  268. knowledge_base_id=kb_id,
  269. file_name=file.filename,
  270. file_size=file_size,
  271. file_type=file_ext,
  272. minio_url=minio_url,
  273. creator=user_id,
  274. knowledge_type=knowledge_type
  275. )
  276. db.add(db_file)
  277. uploaded_files.append(db_file)
  278. # 创建用户数据关联
  279. relation_business = UserDataRelationBusiness(db)
  280. relation = relation_business.create_relation(
  281. user_id=user_id,
  282. data_category='KnowledgeFile',
  283. data_id=db_file.id,
  284. user_name=user_name,
  285. role_id=None,
  286. role_name=None
  287. )
  288. # 更新知识库文件计数
  289. DatabaseUtils.increment_file_count(db, kb_id)
  290. db.commit()
  291. return ResponseModel(
  292. code=200,
  293. message="上传成功",
  294. data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in uploaded_files]
  295. )
  296. @router.get("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
  297. def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optional[str] = None,
  298. db: Session = Depends(get_db)):
  299. if pageNo < 1:
  300. raise HTTPException(status_code=400, detail="页码必须大于等于1")
  301. if pageSize < 1:
  302. raise HTTPException(status_code=400, detail="每页条数必须大于等于1")
  303. skip = (pageNo - 1) * pageSize
  304. query = db.query(KnowledgeFile,UserDataRelation.user_name).\
  305. outerjoin(UserDataRelation,
  306. and_(
  307. UserDataRelation.data_id == KnowledgeFile.id,
  308. UserDataRelation.data_category == 'KnowledgeFile'
  309. )).\
  310. filter(
  311. KnowledgeFile.knowledge_base_id == kb_id,
  312. KnowledgeFile.is_deleted == 0
  313. )
  314. if file_name:
  315. query = query.filter(KnowledgeFile.file_name.ilike(f"%{file_name}%"))
  316. total = query.count()
  317. files = query.offset(skip).limit(pageSize).all()
  318. return ResponseModel(
  319. code=200,
  320. message="查询成功",
  321. data={
  322. "list": [KnowledgeFileResponse.model_validate(
  323. {
  324. **file[0].__dict__,
  325. "user_name": file[1]
  326. }
  327. ).model_dump() for file in files],
  328. "total": total
  329. }
  330. )
  331. @router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel)
  332. def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)):
  333. files = db.query(KnowledgeFile, UserDataRelation.user_name).\
  334. join(UserDataRelation,
  335. and_(
  336. UserDataRelation.data_id == KnowledgeFile.id,
  337. UserDataRelation.data_category == 'KnowledgeFile'
  338. )).\
  339. filter(
  340. KnowledgeFile.knowledge_base_id == kb_id,
  341. KnowledgeFile.file_name.ilike(f"%{file_name}%"),
  342. KnowledgeFile.is_deleted == 0
  343. ).all()
  344. result = []
  345. for file, user_name in files:
  346. file_response = KnowledgeFileResponse.model_validate(file)
  347. file_response.user_name = user_name
  348. result.append(file_response.model_dump())
  349. return ResponseModel(
  350. code=200,
  351. message="查询成功",
  352. data=result
  353. )
  354. @router.get("/files/{file_id}/download")
  355. def download_file(file_id: int, db: Session = Depends(get_db)):
  356. # 获取文件信息
  357. file = db.query(KnowledgeFile).filter(
  358. KnowledgeFile.id == file_id,
  359. KnowledgeFile.is_deleted == 0
  360. ).first()
  361. if not file:
  362. raise HTTPException(status_code=404, detail="文件不存在")
  363. # 从MinIO下载文件
  364. object_name = file.minio_url.split("/")[-1]
  365. file_content = minio_utils.download_file(object_name)
  366. # 创建文件流
  367. file_stream = io.BytesIO(file_content)
  368. # 对文件名进行URL编码
  369. encoded_filename = urllib.parse.quote(file.file_name)
  370. return StreamingResponse(
  371. file_stream,
  372. media_type="application/octet-stream",
  373. headers={
  374. "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
  375. }
  376. )
  377. @router.delete("/files/{file_id}", response_model=dict)
  378. def delete_file(file_id: int, db: Session = Depends(get_db)):
  379. # 获取文件信息
  380. file = db.query(KnowledgeFile).filter(
  381. KnowledgeFile.id == file_id,
  382. KnowledgeFile.is_deleted == 0
  383. ).first()
  384. if not file:
  385. raise HTTPException(status_code=404, detail="文件不存在")
  386. # 从MinIO删除文件
  387. object_name = file.minio_url.split("/")[-1]
  388. minio_utils.delete_file(object_name)
  389. # 标记文件为已删除
  390. file.is_deleted = 1
  391. file.updated_at = datetime.utcnow()
  392. # 更新知识库文件计数
  393. DatabaseUtils.decrement_file_count(db, file.knowledge_base_id)
  394. db.commit()
  395. return {
  396. "code": 200,
  397. "message": "删除成功",
  398. "data": True
  399. }
  400. @router.put("/files/batch-update", response_model=ResponseModel)
  401. def batch_update_files(update_data: BatchFileUpdate, db: Session = Depends(get_db)):
  402. updated_files = []
  403. for file_update in update_data.files:
  404. # 获取文件信息
  405. file = db.query(KnowledgeFile).filter(
  406. KnowledgeFile.id == file_update.id,
  407. KnowledgeFile.is_deleted == 0
  408. ).first()
  409. if not file:
  410. raise HTTPException(status_code=404, detail=f"文件ID {file_update.id} 不存在")
  411. # 如果需要更新文件名,同时更新MinIO中的文件
  412. if file_update.file_name and file_update.file_name != file.file_name:
  413. old_object_name = file.minio_url.split("/")[-1]
  414. new_object_name = f"{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{file_update.file_name}"
  415. # 从MinIO下载文件
  416. file_content = minio_utils.download_file(old_object_name)
  417. # 上传到MinIO新的位置
  418. new_minio_url = minio_utils.upload_file(
  419. file_content,
  420. file_update.file_name,
  421. file.file_type
  422. )
  423. # 删除旧文件
  424. minio_utils.delete_file(old_object_name)
  425. # 更新数据库中的文件名和MinIO URL
  426. file.file_name = file_update.file_name
  427. file.minio_url = new_minio_url
  428. # 更新其他字段
  429. if file_update.version is not None:
  430. file.version = file_update.version
  431. if file_update.author is not None:
  432. file.author = file_update.author
  433. if file_update.year is not None:
  434. file.year = file_update.year
  435. if file_update.page_count is not None:
  436. file.page_count = file_update.page_count
  437. if file_update.creator is not None:
  438. file.creator = file_update.creator
  439. if file_update.knowledge_type is not None:
  440. file.knowledge_type = file_update.knowledge_type
  441. file.updated_at = datetime.utcnow()
  442. updated_files.append(file)
  443. db.commit()
  444. return ResponseModel(
  445. code=200,
  446. message="更新成功",
  447. data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in updated_files]
  448. )
  449. knowledge_base_router = router