Browse Source

增加压缩包上传文件功能

chenbin 1 month ago
parent
commit
d865f2b0d0
1 changed files with 145 additions and 84 deletions
  1. 145 84
      agent/router/knowledge_base_router.py

+ 145 - 84
agent/router/knowledge_base_router.py

@@ -2,6 +2,7 @@ import json
 import os
 import io
 import logging
+import tempfile
 import urllib.parse
 import time
 import glob
@@ -261,6 +262,14 @@ async def upload_files(
         db: Session = Depends(get_db),
         sess: SessionValues = Depends(verify_session_id)  # 添加session依赖
 ):
+    """
+    支持多文件上传和压缩文件解析
+    :param kb_id: 知识库ID
+    :param files: 上传文件列表(支持压缩文件)
+    :param db: 数据库会话
+    :param sess: 用户会话
+    :return: ResponseModel
+    """
     # 验证知识库是否存在
     kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0).first()
     if not kb:
@@ -275,102 +284,154 @@ async def upload_files(
         raise HTTPException(status_code=400, detail=f"单次上传文件数量不能超过{settings.MAX_FILE_COUNT}个")
 
     # 导入所需模块
-    import tempfile
+    # import tempfile
+    import zipfile
+    import shutil
 
     uploaded_files = []
     for file in files:
-        # 获取文件扩展名
-        file_ext = os.path.splitext(file.filename)[1].lower().lstrip('.')
-        original_filename = file.filename
-        converted_content = None
-
-        # 读取文件内容
-        content = await file.read()
-
-        # 处理需要转换的文件格式
-        if file_ext in ["doc", "ppt"]:
-            # 创建临时目录用于文件转换
+        # 处理压缩文件
+        if file.filename.lower().endswith(('.zip', '.rar', '.7z')):
+            # 创建临时目录用于解压
             with tempfile.TemporaryDirectory() as temp_dir:
-                # 创建临时文件
-                temp_input_path = os.path.join(temp_dir, original_filename)
-                with open(temp_input_path, "wb") as temp_file:
-                    temp_file.write(content)
-
-                # 确定目标格式
-                target_format = "docx" if file_ext == "doc" else "pptx"
-
-                # 使用FileUtils中的文件转换方法
-                converted_file_path = FileUtils.convert_office_file(temp_input_path, temp_dir, target_format)
-
-                if converted_file_path and os.path.exists(converted_file_path):
-                    # 读取转换后的文件内容
-                    with open(converted_file_path, "rb") as converted_file:
-                        converted_content = converted_file.read()
-
-                    # 更新文件名和扩展名
-                    file_ext = target_format
-                    file.filename = os.path.splitext(original_filename)[0] + f".{target_format}"
-                else:
-                    # 转换失败,使用原始文件
-                    raise HTTPException(status_code=500, detail=f"文件格式转换失败: {original_filename}")
-
-        # 检查文件格式是否支持
-        if file_ext not in settings.ALLOWED_EXTENSIONS:
-            raise HTTPException(status_code=400, detail=f"不支持的文件格式:{file_ext}")
-
-        # 使用转换后的内容或原始内容
-        file_content = converted_content if converted_content else content
-        file_size = len(file_content) / (1024 * 1024)  # 转换为MB
-
-        # 验证文件大小
-        max_size = settings.ALLOWED_EXTENSIONS[file_ext]["max_size"]
-        if file_size > max_size:
-            raise HTTPException(status_code=400, detail=f"{original_filename}超过最大允许大小{max_size}MB")
-
-        # 上传到MinIO
-        minio_url = minio_utils.upload_file(file_content, file.filename, file.content_type)
-
-        # 从文件名识别知识类型
-        knowledge_type = None
-        if '指南' in file.filename:
-            knowledge_type = '指南'
-        elif '教材' in file.filename:
-            knowledge_type = '教材'
-
-        # 创建文件记录
-        db_file = KnowledgeFile(
-            knowledge_base_id=kb_id,
-            file_name=file.filename,
-            file_size=file_size,
-            file_type=file_ext,
-            minio_url=minio_url,
-            creator=user_id,
-            knowledge_type=knowledge_type
-        )
-        db.add(db_file)
-        uploaded_files.append(db_file)
+                file_content = await file.read()
+                file_path = os.path.join(temp_dir, file.filename)
+                
+                with open(file_path, "wb") as f:
+                    f.write(file_content)
+                
+                extract_path = os.path.join(temp_dir, "extracted")
+                os.makedirs(extract_path, exist_ok=True)
+                
+                if file.filename.lower().endswith('.zip'):
+                    with zipfile.ZipFile(file_path, 'r') as zip_ref:
+                        zip_ref.extractall(extract_path)
+                    
+                    # 处理解压后的文件
+                    for extracted_file in os.listdir(extract_path):
+                        extracted_file_path = os.path.join(extract_path, extracted_file)
+                        if os.path.isfile(extracted_file_path):
+                            # 为每个解压文件创建新的UploadFile对象
+                            with open(extracted_file_path, "rb") as f:
+                                content = f.read()
+                            
+                            # 处理中文文件名编码问题
+                            try:
+                                decoded_filename = extracted_file.encode('cp437').decode('gbk')
+                            except:
+                                decoded_filename = extracted_file
+                            
+                            extracted_file_obj = UploadFile(
+                                filename=decoded_filename,
+                                file=io.BytesIO(content),
+                                size=len(content)
+                            )
+                            
+                            # 递归处理解压后的文件
+                            result = await process_single_file(extracted_file_obj, kb_id, db, user_id, user_name)
+                            uploaded_files.extend(result)
+                
+                continue
         
-        # 创建用户数据关联
-        relation_business = UserDataRelationBusiness(db)
-        relation = relation_business.create_relation(
-            user_id=user_id,
-            data_category='KnowledgeFile',
-            data_id=db_file.id,
-            user_name=user_name,
-            role_id=None,
-            role_name=None
-        )
-        
-        # 更新知识库文件计数
-        DatabaseUtils.increment_file_count(db, kb_id)
+        # 处理普通文件
+        result = await process_single_file(file, kb_id, db, user_id, user_name)
+        uploaded_files.extend(result)
 
     db.commit()
     return ResponseModel(
         code=200,
-        message="上传成功",
+        message=f"成功上传{len(uploaded_files)}个文件",
         data=[KnowledgeFileResponse.model_validate(file).model_dump() for file in uploaded_files]
     )
 
+async def process_single_file(file: UploadFile, kb_id: int, db: Session, user_id: int, user_name: str):
+    """处理单个文件的上传逻辑"""
+    # 获取文件扩展名
+    file_ext = os.path.splitext(file.filename)[1].lower().lstrip('.')
+    original_filename = file.filename
+    converted_content = None
+
+    # 读取文件内容
+    content = await file.read()
+
+    # 处理需要转换的文件格式
+    if file_ext in ["doc", "ppt"]:
+        # 创建临时目录用于文件转换
+        with tempfile.TemporaryDirectory() as temp_dir:
+            # 创建临时文件
+            temp_input_path = os.path.join(temp_dir, original_filename)
+            with open(temp_input_path, "wb") as temp_file:
+                temp_file.write(content)
+
+            # 确定目标格式
+            target_format = "docx" if file_ext == "doc" else "pptx"
+
+            # 使用FileUtils中的文件转换方法
+            converted_file_path = FileUtils.convert_office_file(temp_input_path, temp_dir, target_format)
+
+            if converted_file_path and os.path.exists(converted_file_path):
+                # 读取转换后的文件内容
+                with open(converted_file_path, "rb") as converted_file:
+                    converted_content = converted_file.read()
+
+                # 更新文件名和扩展名
+                file_ext = target_format
+                file.filename = os.path.splitext(original_filename)[0] + f".{target_format}"
+            else:
+                # 转换失败,使用原始文件
+                raise HTTPException(status_code=500, detail=f"文件格式转换失败: {original_filename}")
+
+    # 检查文件格式是否支持
+    if file_ext not in settings.ALLOWED_EXTENSIONS:
+        raise HTTPException(status_code=400, detail=f"不支持的文件格式:{file_ext}")
+
+    # 使用转换后的内容或原始内容
+    file_content = converted_content if converted_content else content
+    file_size = len(file_content) / (1024 * 1024)  # 转换为MB
+
+    # 验证文件大小
+    max_size = settings.ALLOWED_EXTENSIONS[file_ext]["max_size"]
+    if file_size > max_size:
+        raise HTTPException(status_code=400, detail=f"{original_filename}超过最大允许大小{max_size}MB")
+
+    # 上传到MinIO
+    minio_url = minio_utils.upload_file(file_content, file.filename, file.content_type)
+
+    # 从文件名识别知识类型
+    knowledge_type = None
+    if '指南' in file.filename:
+        knowledge_type = '指南'
+    elif '教材' in file.filename:
+        knowledge_type = '教材'
+
+    # 创建文件记录
+    db_file = KnowledgeFile(
+        knowledge_base_id=kb_id,
+        file_name=file.filename,
+        file_size=file_size,
+        file_type=file_ext,
+        minio_url=minio_url,
+        creator=user_id,
+        knowledge_type=knowledge_type
+    )
+    db.add(db_file)
+    
+    # 创建用户数据关联
+    relation_business = UserDataRelationBusiness(db)
+    relation = relation_business.create_relation(
+        user_id=user_id,
+        data_category='KnowledgeFile',
+        data_id=db_file.id,
+        user_name=user_name,
+        role_id=None,
+        role_name=None
+    )
+    
+    # 更新知识库文件计数
+    DatabaseUtils.increment_file_count(db, kb_id)
+    
+    return [db_file]
+
 
 @router.get("/knowledge-base/{kb_id}/files/", response_model=ResponseModel)
 def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optional[str] = None,