utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. import re
  2. import io
  3. import os
  4. import time
  5. import glob
  6. import shutil
  7. import logging
  8. import subprocess
  9. from datetime import datetime
  10. from typing import List, Optional
  11. from minio import Minio
  12. import urllib3
  13. from sqlalchemy.orm import Session
  14. from fastapi import HTTPException
  15. from agent.models.web.knowledge_base import KnowledgeBase, KnowledgeFile
  16. from config.site import settings
  17. # 配置Office文件转换日志
  18. office_logger = logging.getLogger('office_conversion')
  19. office_logger.setLevel(logging.INFO)
  20. if not office_logger.handlers:
  21. handler = logging.StreamHandler()
  22. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  23. handler.setFormatter(formatter)
  24. office_logger.addHandler(handler)
  25. class DatabaseUtils:
  26. @staticmethod
  27. def validate_knowledge_base_name(name: str) -> bool:
  28. pattern = r'^[a-zA-Z0-9\u4e00-\u9fa5_\-\.]+$'
  29. return bool(re.match(pattern, name))
  30. @staticmethod
  31. def create_knowledge_base(db: Session, name: str, description: Optional[str] = None, tags: Optional[str] = None) -> KnowledgeBase:
  32. if not DatabaseUtils.validate_knowledge_base_name(name):
  33. raise HTTPException(status_code=400, detail="知识库名称格式不正确")
  34. if description and len(description) > 400:
  35. raise HTTPException(status_code=400, detail="知识库备注不能超过400字")
  36. db_kb = KnowledgeBase(name=name, description=description, tags=tags, file_count=0)
  37. db.add(db_kb)
  38. db.commit()
  39. db.refresh(db_kb)
  40. return db_kb
  41. @staticmethod
  42. def update_knowledge_base(db: Session, kb_id: int, name: str, description: Optional[str] = None, tags: Optional[str] = None) -> KnowledgeBase:
  43. db_kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  44. if not db_kb:
  45. raise HTTPException(status_code=404, detail="知识库不存在")
  46. if name and not DatabaseUtils.validate_knowledge_base_name(name):
  47. raise HTTPException(status_code=400, detail="知识库名称格式不正确")
  48. if description and len(description) > 400:
  49. raise HTTPException(status_code=400, detail="知识库备注不能超过400字")
  50. db_kb.name = name
  51. db_kb.description = description
  52. db_kb.tags = tags
  53. db_kb.updated_at = datetime.utcnow()
  54. db.commit()
  55. db.refresh(db_kb)
  56. return db_kb
  57. @staticmethod
  58. def delete_knowledge_base(db: Session, kb_id: int) -> bool:
  59. db_kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  60. if not db_kb:
  61. raise HTTPException(status_code=404, detail="知识库不存在")
  62. # 删除知识库时将文件计数清零
  63. db_kb.file_count = 0
  64. db_kb.is_deleted = 1
  65. db_kb.updated_at = datetime.utcnow()
  66. db.commit()
  67. return True
  68. @staticmethod
  69. def get_knowledge_bases(db: Session, skip: int = 0, limit: int = 10, name: Optional[str] = None) -> tuple[List[KnowledgeBase], int]:
  70. query = db.query(KnowledgeBase).filter(KnowledgeBase.is_deleted == 0)
  71. if name:
  72. query = query.filter(KnowledgeBase.name.ilike(f"%{name}%"))
  73. total = query.count()
  74. knowledge_bases = query.offset(skip).limit(limit).all()
  75. return knowledge_bases, total
  76. @staticmethod
  77. def get_knowledge_base_by_name(db: Session, name: str) -> Optional[KnowledgeBase]:
  78. return db.query(KnowledgeBase).filter(KnowledgeBase.name == name, KnowledgeBase.is_deleted == 0).first()
  79. @staticmethod
  80. def increment_file_count(db: Session, kb_id: int) -> None:
  81. db_kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  82. if db_kb:
  83. db_kb.file_count += 1
  84. db.commit()
  85. @staticmethod
  86. def decrement_file_count(db: Session, kb_id: int) -> None:
  87. db_kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
  88. if db_kb and db_kb.file_count > 0:
  89. db_kb.file_count -= 1
  90. db.commit()
  91. class MinioUtils:
  92. def __init__(self):
  93. self.client = Minio(
  94. settings.MINIO_ENDPOINT,
  95. access_key=settings.MINIO_ACCESS_KEY,
  96. secret_key=settings.MINIO_SECRET_KEY,
  97. secure=settings.MINIO_SECURE,
  98. http_client=urllib3.PoolManager(
  99. timeout=urllib3.Timeout(connect=10, read=60),
  100. maxsize=50,
  101. retries=urllib3.Retry(
  102. total=5,
  103. backoff_factor=0.5,
  104. status_forcelist=[500, 502, 503, 504]
  105. )
  106. )
  107. )
  108. self._ensure_bucket_exists()
  109. def _ensure_bucket_exists(self):
  110. if not self.client.bucket_exists(settings.MINIO_BUCKET_NAME):
  111. self.client.make_bucket(settings.MINIO_BUCKET_NAME)
  112. def upload_file(self, file_data: bytes, file_name: str, content_type: str, part_size: int = 15 * 1024 * 1024) -> str:
  113. import tempfile
  114. import os
  115. object_name = file_name
  116. try:
  117. # 创建临时文件
  118. with tempfile.NamedTemporaryFile(delete=False) as temp_file:
  119. temp_file.write(file_data)
  120. temp_file_path = temp_file.name
  121. # 使用fput_object进行上传,内部已实现分片上传
  122. self.client.fput_object(
  123. bucket_name=settings.MINIO_BUCKET_NAME,
  124. object_name=object_name,
  125. file_path=temp_file_path,
  126. content_type=content_type,
  127. part_size=part_size # 使用更大的分片大小,提高上传效率
  128. )
  129. return f"http://{settings.MINIO_ENDPOINT}/{settings.MINIO_BUCKET_NAME}/{object_name}"
  130. except Exception as e:
  131. raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
  132. finally:
  133. # 清理临时文件
  134. try:
  135. os.unlink(temp_file_path)
  136. except:
  137. pass
  138. def download_file(self, object_name: str) -> bytes:
  139. try:
  140. response = self.client.get_object(settings.MINIO_BUCKET_NAME, object_name)
  141. return response.read()
  142. finally:
  143. response.close()
  144. response.release_conn()
  145. def delete_file(self, object_name: str) -> bool:
  146. try:
  147. self.client.remove_object(settings.MINIO_BUCKET_NAME, object_name)
  148. return True
  149. except:
  150. return False
  151. class FileUtils:
  152. @staticmethod
  153. def convert_office_file(input_path, output_dir, target_format):
  154. """使用LibreOffice转换Office文件格式
  155. Args:
  156. input_path (str): 输入文件路径
  157. output_dir (str): 输出目录
  158. target_format (str): 目标格式,如docx、pptx等
  159. Returns:
  160. str: 转换后的文件路径,转换失败则返回None
  161. """
  162. # 检查输入文件是否存在
  163. if not os.path.exists(input_path):
  164. office_logger.error(f"输入文件不存在: {input_path}")
  165. return None
  166. # 检查输出目录是否存在,不存在则创建
  167. if not os.path.exists(output_dir):
  168. try:
  169. os.makedirs(output_dir)
  170. office_logger.info(f"创建输出目录: {output_dir}")
  171. except OSError as e:
  172. office_logger.error(f"创建输出目录失败: {e}")
  173. return None
  174. # 检查输出目录权限
  175. if not os.access(output_dir, os.W_OK):
  176. office_logger.error(f"输出目录没有写入权限: {output_dir}")
  177. return None
  178. # 检查LibreOffice是否安装
  179. libreoffice_cmd = "soffice" # Linux/macOS
  180. if os.name == 'nt': # Windows
  181. libreoffice_cmd = r"C:\Program Files\LibreOffice\program\soffice.exe"
  182. # 检查LibreOffice命令是否可用
  183. try:
  184. version_cmd = [libreoffice_cmd, "--version"]
  185. version_result = subprocess.run(version_cmd, check=True, capture_output=True, text=True)
  186. office_logger.info(f"LibreOffice版本: {version_result.stdout.strip()}")
  187. except (subprocess.SubprocessError, FileNotFoundError) as e:
  188. office_logger.error(f"LibreOffice未安装或不可用: {e}")
  189. return None
  190. # 获取输入文件的文件名(不含路径和扩展名)
  191. filename = os.path.basename(input_path)
  192. base_name = os.path.splitext(filename)[0]
  193. input_ext = os.path.splitext(filename)[1][1:].lower()
  194. office_logger.info(f"原始文件名: {filename}, 基本名称: {base_name}, 扩展名: {input_ext}")
  195. # 如果输入文件扩展名与目标格式相同,直接复制文件
  196. if input_ext == target_format.lower():
  197. office_logger.info(f"输入文件已经是目标格式,直接复制文件")
  198. final_output_path = os.path.join(output_dir, f"{base_name}.{target_format}")
  199. try:
  200. shutil.copy2(input_path, final_output_path)
  201. office_logger.info(f"复制文件到最终位置: {final_output_path}")
  202. return final_output_path
  203. except (shutil.Error, IOError) as e:
  204. office_logger.error(f"复制文件失败: {e}")
  205. return None
  206. # 创建临时工作目录,避免中文路径问题
  207. temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"temp_convert_{int(time.time())}")
  208. try:
  209. os.makedirs(temp_dir)
  210. office_logger.info(f"创建临时工作目录: {temp_dir}")
  211. except OSError as e:
  212. office_logger.error(f"创建临时工作目录失败: {e}")
  213. return None
  214. # 复制原文件到临时目录,使用英文文件名
  215. temp_input_file = os.path.join(temp_dir, f"input.{input_ext}")
  216. try:
  217. shutil.copy2(input_path, temp_input_file)
  218. office_logger.info(f"复制文件到临时目录: {temp_input_file}")
  219. except (shutil.Error, IOError) as e:
  220. office_logger.error(f"复制文件失败: {e}")
  221. shutil.rmtree(temp_dir, ignore_errors=True)
  222. return None
  223. # 记录转换前输出目录中的文件
  224. before_files = set(os.listdir(temp_dir))
  225. office_logger.debug(f"转换前临时目录内容: {before_files}")
  226. # 构建转换命令
  227. cmd = [
  228. libreoffice_cmd,
  229. "--headless"
  230. ]
  231. # 根据文件类型选择合适的转换参数
  232. if input_ext == 'doc' and target_format.lower() == 'docx':
  233. cmd.extend(["--convert-to", "docx:MS Word 2007 XML"])
  234. elif input_ext == 'ppt' and target_format.lower() == 'pptx':
  235. cmd.extend(["--convert-to", "pptx:Impress MS PowerPoint 2007 XML"])
  236. else:
  237. cmd.extend(["--convert-to", target_format])
  238. # 添加输出目录和输入文件
  239. cmd.extend(["--outdir", temp_dir, temp_input_file])
  240. office_logger.info(f"开始转换文件: {temp_input_file} -> {target_format}")
  241. office_logger.info(f"执行命令: {' '.join(cmd)}")
  242. # 切换到临时目录执行命令,避免路径问题
  243. current_dir = os.getcwd()
  244. os.chdir(temp_dir)
  245. try:
  246. result = subprocess.run(cmd, check=True, capture_output=True, text=True)
  247. office_logger.info(f"转换命令输出: {result.stdout}")
  248. if result.stderr:
  249. office_logger.warning(f"转换命令错误输出: {result.stderr}")
  250. # 切回原目录
  251. os.chdir(current_dir)
  252. # 等待一小段时间确保文件写入完成
  253. time.sleep(1)
  254. # 记录转换后输出目录中的文件
  255. after_files = set(os.listdir(temp_dir))
  256. office_logger.debug(f"转换后临时目录内容: {after_files}")
  257. # 找出新增的文件
  258. new_files = after_files - before_files
  259. office_logger.info(f"新增文件: {new_files}")
  260. # 预期的输出文件名
  261. expected_output_filename = f"input.{target_format}"
  262. # 预期的输出文件路径(在临时目录中)
  263. expected_output_path = os.path.join(temp_dir, expected_output_filename)
  264. # 最终的输出文件路径(在目标目录中)
  265. final_output_path = os.path.join(output_dir, f"{base_name}.{target_format}")
  266. # 检查预期的输出文件是否存在
  267. if os.path.exists(expected_output_path):
  268. # 复制转换后的文件到最终目标位置
  269. try:
  270. shutil.copy2(expected_output_path, final_output_path)
  271. office_logger.info(f"复制转换后的文件到最终位置: {final_output_path}")
  272. # 清理临时目录
  273. shutil.rmtree(temp_dir, ignore_errors=True)
  274. return final_output_path
  275. except (shutil.Error, IOError) as e:
  276. office_logger.error(f"复制转换后的文件失败: {e}")
  277. elif new_files:
  278. # 如果有新文件生成,使用第一个新文件
  279. new_file_path = os.path.join(temp_dir, list(new_files)[0])
  280. try:
  281. shutil.copy2(new_file_path, final_output_path)
  282. office_logger.info(f"复制新生成的文件到最终位置: {final_output_path}")
  283. # 清理临时目录
  284. shutil.rmtree(temp_dir, ignore_errors=True)
  285. return final_output_path
  286. except (shutil.Error, IOError) as e:
  287. office_logger.error(f"复制新生成的文件失败: {e}")
  288. else:
  289. # 尝试在临时目录中查找匹配的文件
  290. pattern = os.path.join(temp_dir, f"*.{target_format}")
  291. matching_files = glob.glob(pattern)
  292. office_logger.info(f"匹配的文件列表: {matching_files}")
  293. if matching_files:
  294. # 按修改时间排序,获取最新的文件
  295. newest_file = max(matching_files, key=os.path.getmtime)
  296. try:
  297. shutil.copy2(newest_file, final_output_path)
  298. office_logger.info(f"复制匹配的文件到最终位置: {final_output_path}")
  299. # 清理临时目录
  300. shutil.rmtree(temp_dir, ignore_errors=True)
  301. return final_output_path
  302. except (shutil.Error, IOError) as e:
  303. office_logger.error(f"复制匹配的文件失败: {e}")
  304. # 如果所有尝试都失败,清理临时目录并返回None
  305. office_logger.error(f"转换后的文件不存在或无法复制")
  306. shutil.rmtree(temp_dir, ignore_errors=True)
  307. return None
  308. except subprocess.CalledProcessError as e:
  309. # 切回原目录
  310. os.chdir(current_dir)
  311. office_logger.error(f"转换失败: {e.stderr if hasattr(e, 'stderr') else str(e)}")
  312. # 清理临时目录
  313. shutil.rmtree(temp_dir, ignore_errors=True)
  314. return None
  315. except Exception as e:
  316. # 切回原目录
  317. os.chdir(current_dir)
  318. office_logger.error(f"转换过程中发生未知错误: {str(e)}")
  319. # 清理临时目录
  320. shutil.rmtree(temp_dir, ignore_errors=True)
  321. return None