|
@@ -286,12 +286,15 @@ async def upload_files(
|
|
|
# 导入所需模块
|
|
|
# import tempfile
|
|
|
import zipfile
|
|
|
+ import py7zr
|
|
|
+ import rarfile
|
|
|
+ import tarfile
|
|
|
import shutil
|
|
|
|
|
|
uploaded_files = []
|
|
|
for file in files:
|
|
|
# 处理压缩文件
|
|
|
- if file.filename.lower().endswith(('.zip', '.rar', '.7z')):
|
|
|
+ if file.filename.lower().endswith(('.zip', '.rar', '.tar', '.7z')):
|
|
|
# 创建临时目录用于解压
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
|
file_content = await file.read()
|
|
@@ -306,30 +309,39 @@ async def upload_files(
|
|
|
if file.filename.lower().endswith('.zip'):
|
|
|
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
|
|
zip_ref.extractall(extract_path)
|
|
|
+ elif file.filename.lower().endswith('.7z'):
|
|
|
+ with py7zr.SevenZipFile(file_path, mode='r') as zip_ref:
|
|
|
+ zip_ref.extractall(extract_path)
|
|
|
+ elif file.filename.lower().endswith('.rar'):
|
|
|
+ with rarfile.RarFile(file_path) as zip_ref:
|
|
|
+ zip_ref.extractall(extract_path)
|
|
|
+ elif file.filename.lower().endswith('.tar'):
|
|
|
+ with tarfile.open(file_path, 'r') as tar:
|
|
|
+ tar.extractall(path=extract_path)
|
|
|
|
|
|
- # 处理解压后的文件
|
|
|
- for extracted_file in os.listdir(extract_path):
|
|
|
- extracted_file_path = os.path.join(extract_path, extracted_file)
|
|
|
- if os.path.isfile(extracted_file_path):
|
|
|
- # 为每个解压文件创建新的UploadFile对象
|
|
|
- with open(extracted_file_path, "rb") as f:
|
|
|
- content = f.read()
|
|
|
-
|
|
|
- # 处理中文文件名编码问题
|
|
|
- try:
|
|
|
- decoded_filename = extracted_file.encode('cp437').decode('gbk')
|
|
|
- except:
|
|
|
- decoded_filename = extracted_file
|
|
|
-
|
|
|
- extracted_file_obj = UploadFile(
|
|
|
- filename=decoded_filename,
|
|
|
- file=io.BytesIO(content),
|
|
|
- size=len(content)
|
|
|
- )
|
|
|
-
|
|
|
- # 递归处理解压后的文件
|
|
|
- result = await process_single_file(extracted_file_obj, kb_id, db, user_id, user_name)
|
|
|
- uploaded_files.extend(result)
|
|
|
+ # 处理解压后的文件
|
|
|
+ for extracted_file in os.listdir(extract_path):
|
|
|
+ extracted_file_path = os.path.join(extract_path, extracted_file)
|
|
|
+ if os.path.isfile(extracted_file_path):
|
|
|
+ # 为每个解压文件创建新的UploadFile对象
|
|
|
+ with open(extracted_file_path, "rb") as f:
|
|
|
+ content = f.read()
|
|
|
+
|
|
|
+ # 处理中文文件名编码问题
|
|
|
+ try:
|
|
|
+ decoded_filename = extracted_file.encode('cp437').decode('gbk')
|
|
|
+ except:
|
|
|
+ decoded_filename = extracted_file
|
|
|
+
|
|
|
+ extracted_file_obj = UploadFile(
|
|
|
+ filename=decoded_filename,
|
|
|
+ file=io.BytesIO(content),
|
|
|
+ size=len(content)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 递归处理解压后的文件
|
|
|
+ result = await process_single_file(extracted_file_obj, kb_id, db, user_id, user_name)
|
|
|
+ uploaded_files.extend(result)
|
|
|
|
|
|
continue
|
|
|
|
|
@@ -473,6 +485,39 @@ def list_files(kb_id: int, pageNo: int = 1, pageSize: int = 10, file_name: Optio
|
|
|
}
|
|
|
)
|
|
|
|
|
|
+@router.get("/knowledge-base/{kb_id}/files/enable", response_model=ResponseModel)
|
|
|
+def list_files_enable(kb_id: int, status: int = 0, db: Session = Depends(get_db)):
|
|
|
+
|
|
|
+ query = db.query(KnowledgeFile,UserDataRelation.user_name).\
|
|
|
+ outerjoin(UserDataRelation,
|
|
|
+ and_(
|
|
|
+ UserDataRelation.data_id == KnowledgeFile.id,
|
|
|
+ UserDataRelation.data_category == 'KnowledgeFile'
|
|
|
+ )).\
|
|
|
+ filter(
|
|
|
+ KnowledgeFile.knowledge_base_id == kb_id,
|
|
|
+ KnowledgeFile.status == status,
|
|
|
+ KnowledgeFile.is_deleted == 0
|
|
|
+ ).order_by(KnowledgeFile.status.desc())
|
|
|
+
|
|
|
+
|
|
|
+ total = query.count()
|
|
|
+ files = query.all()
|
|
|
+
|
|
|
+ return ResponseModel(
|
|
|
+ code=200,
|
|
|
+ message="查询成功",
|
|
|
+ data={
|
|
|
+ "list": [KnowledgeFileResponse.model_validate(
|
|
|
+ {
|
|
|
+ **file[0].__dict__,
|
|
|
+ "user_name": file[1]
|
|
|
+ }
|
|
|
+ ).model_dump() for file in files],
|
|
|
+ "total": total
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
@router.get("/knowledge-base/{kb_id}/files/search/", response_model=ResponseModel)
|
|
|
def search_files(kb_id: int, file_name: str, db: Session = Depends(get_db)):
|