Browse Source

增加压缩包上传文件功能(增加适配压缩格式) && 新增获取知识库可下载文件列表接口

luhaonan 4 weeks ago
parent
commit
708e9474e5
1 changed files with 69 additions and 24 deletions
  1. 69 24
      agent/router/knowledge_base_router.py

+ 69 - 24
agent/router/knowledge_base_router.py

@@ -286,12 +286,15 @@ async def upload_files(
     # 导入所需模块
     # import tempfile
     import zipfile
+    import py7zr
+    import rarfile
+    import tarfile
     import shutil
 
     uploaded_files = []
     for file in files:
         # 处理压缩文件
-        if file.filename.lower().endswith(('.zip', '.rar', '.7z')):
+        if file.filename.lower().endswith(('.zip', '.rar', '.tar', '.7z')):
             # 创建临时目录用于解压
             with tempfile.TemporaryDirectory() as temp_dir:
                 file_content = await file.read()
@@ -306,30 +309,39 @@ async def upload_files(
                 if file.filename.lower().endswith('.zip'):
                     with zipfile.ZipFile(file_path, 'r') as zip_ref:
                         zip_ref.extractall(extract_path)
+                elif file.filename.lower().endswith('.7z'):
+                    with py7zr.SevenZipFile(file_path, mode='r') as zip_ref:
+                        zip_ref.extractall(extract_path)
+                elif file.filename.lower().endswith('.rar'):
+                    with rarfile.RarFile(file_path) as zip_ref:
+                        zip_ref.extractall(extract_path)
+                elif file.filename.lower().endswith('.tar'):
+                    with tarfile.open(file_path, 'r') as tar:
+                        tar.extractall(path=extract_path)
                     
-                    # 处理解压后的文件
-                    for extracted_file in os.listdir(extract_path):
-                        extracted_file_path = os.path.join(extract_path, extracted_file)
-                        if os.path.isfile(extracted_file_path):
-                            # 为每个解压文件创建新的UploadFile对象
-                            with open(extracted_file_path, "rb") as f:
-                                content = f.read()
-                            
-                            # 处理中文文件名编码问题
-                            try:
-                                decoded_filename = extracted_file.encode('cp437').decode('gbk')
-                            except:
-                                decoded_filename = extracted_file
-                            
-                            extracted_file_obj = UploadFile(
-                                filename=decoded_filename,
-                                file=io.BytesIO(content),
-                                size=len(content)
-                            )
-                            
-                            # 递归处理解压后的文件
-                            result = await process_single_file(extracted_file_obj, kb_id, db, user_id, user_name)
-                            uploaded_files.extend(result)
+                # 处理解压后的文件
+                for extracted_file in os.listdir(extract_path):
+                    extracted_file_path = os.path.join(extract_path, extracted_file)
+                    if os.path.isfile(extracted_file_path):
+                        # 为每个解压文件创建新的UploadFile对象
+                        with open(extracted_file_path, "rb") as f:
+                            content = f.read()
+                        
+                        # 处理中文文件名编码问题
+                        try:
+                            decoded_filename = extracted_file.encode('cp437').decode('gbk')
+                        except:
+                            decoded_filename = extracted_file
+                        
+                        extracted_file_obj = UploadFile(
+                            filename=decoded_filename,
+                            file=io.BytesIO(content),
+                            size=len(content)
+                        )
+                        
+                        # 递归处理解压后的文件
+                        result = await process_single_file(extracted_file_obj, kb_id, db, user_id, user_name)
+                        uploaded_files.extend(result)
                 
                 continue
         
@@ -473,6 +485,39 @@ def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optio
         }
     )
 
+@router.get("/knowledge-base/{kb_id}/files/enable", response_model=ResponseModel)
+def list_files_enable(kb_id: int, status: int = 0, db: Session = Depends(get_db)):
+
+    query = db.query(KnowledgeFile,UserDataRelation.user_name).\
+        outerjoin(UserDataRelation,
+             and_(
+                 UserDataRelation.data_id == KnowledgeFile.id,
+                 UserDataRelation.data_category == 'KnowledgeFile'
+             )).\
+        filter(
+        KnowledgeFile.knowledge_base_id == kb_id,
+        KnowledgeFile.status == status,
+        KnowledgeFile.is_deleted == 0
+    ).order_by(KnowledgeFile.status.desc())
+
+
+    total = query.count()
+    files = query.all()
+
+    return ResponseModel(
+        code=200,
+        message="查询成功",
+        data={
+            "list": [KnowledgeFileResponse.model_validate(
+                {
+                    **file[0].__dict__,
+                    "user_name": file[1]
+                }
+            ).model_dump() for file in files],
+            "total": total
+        }
+    )
+
 
 @router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel)
 def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)):