|
@@ -1,3 +1,4 @@
|
|
|
+import json
|
|
|
import os
|
|
|
import io
|
|
|
import logging
|
|
@@ -126,13 +127,13 @@ def get_db():
|
|
|
class KnowledgeBaseCreate(BaseModel):
|
|
|
name: str
|
|
|
description: Optional[str] = None
|
|
|
- tags: Optional[str] = None
|
|
|
+ tags: Optional[List[str]] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
class KnowledgeBaseUpdate(BaseModel):
|
|
|
name: str
|
|
|
description: Optional[str] = None
|
|
|
- tags: Optional[str] = None
|
|
|
+ tags: Optional[List[str]] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
class FileUpdate(BaseModel):
|
|
@@ -157,10 +158,11 @@ def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db)
|
|
|
sess:SessionValues = Depends(verify_session_id)):
|
|
|
# 1. 从session获取user_id
|
|
|
user_id = sess.user_id
|
|
|
- user_name = sess.username
|
|
|
+ user_name = sess.username
|
|
|
+ tags = json.dumps(kb.tags)
|
|
|
|
|
|
# 2. 创建知识库
|
|
|
- kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, user_id, kb.description, kb.tags)
|
|
|
+ kb_data = DatabaseUtils.create_knowledge_base(db, kb.name, user_id, kb.description, tags)
|
|
|
|
|
|
# 3. 创建用户数据关联
|
|
|
relation_business = UserDataRelationBusiness(db)
|
|
@@ -182,7 +184,8 @@ def create_knowledge_base(kb: KnowledgeBaseCreate, db: Session = Depends(get_db)
|
|
|
|
|
|
@router.put("/knowledge-base/{kb_id}", response_model=ResponseModel)
|
|
|
def update_knowledge_base(kb_id: int, kb: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
|
|
|
- kb_data = DatabaseUtils.update_knowledge_base(db, kb_id, kb.name, kb.description, kb.tags)
|
|
|
+ tags = json.dumps(kb.tags)
|
|
|
+ kb_data = DatabaseUtils.update_knowledge_base(db, kb_id, kb.name, kb.description, tags)
|
|
|
return ResponseModel(
|
|
|
code=200,
|
|
|
message="更新成功",
|