Browse Source

增加修改密码接口功能

chenbin 1 month ago
parent
commit
f69e258bab
3 changed files with 21 additions and 3 deletions
  1. 1 1
      agent/libs/user.py
  2. 18 0
      agent/router/user_router.py
  3. 2 2
      agent/utils.py

+ 1 - 1
agent/libs/user.py

@@ -46,7 +46,7 @@ class UserBusiness:
                 user.username = username
             if password:
                 password = hash_pwd(password)
-                user.password = password
+                user.hashed_password = password
             self.db.commit()
             self.db.refresh(user)
         return user

+ 18 - 0
agent/router/user_router.py

@@ -31,6 +31,24 @@ def register(request: BasicRequest, db: Session = Depends(get_db)):
         if user is None:
             return StandardResponse(code=FAILED, message="create user failed")
         return StandardResponse(code=SUCCESS, message="create user success")
+    elif request.action == 'modifyPWD':
+        biz = UserBusiness(db)
+        request_password = request.get_param("password", "")
+        request_new_password = request.get_param("new_password", "")
+        session_id = request.get_param("session_id", "")
+        
+        session = SessionBusiness(db)
+        user_id = session.get_session(session_id).user_id
+        user = biz.get_user(user_id)
+        
+        if user is None:
+            return StandardResponse(code=FAILED, message="modify passward failed")
+        if not biz.verify_password(request_password, user.hashed_password):
+            return StandardResponse(code=FAILED, message="password error")
+
+        biz.update_user(user_id, password=request_new_password)
+        session.delete_session(session_id)
+        return StandardResponse(code=SUCCESS, message="modify passward success")
     elif request.action =='login':
         request_username = request.get_param("username", "")
         request_password = request.get_param("password", "")

+ 2 - 2
agent/utils.py

@@ -31,14 +31,14 @@ class DatabaseUtils:
         return bool(re.match(pattern, name))
 
     @staticmethod
-    def create_knowledge_base(db: Session, name: str, description: Optional[str] = None, tags: Optional[str] = None) -> KnowledgeBase:
+    def create_knowledge_base(db: Session, name: str, creator: Optional[str] = None, description: Optional[str] = None, tags: Optional[str] = None) -> KnowledgeBase:
         if not DatabaseUtils.validate_knowledge_base_name(name):
             raise HTTPException(status_code=400, detail="知识库名称格式不正确")
         
         if description and len(description) > 400:
             raise HTTPException(status_code=400, detail="知识库备注不能超过400字")
         
-        db_kb = KnowledgeBase(name=name, description=description, tags=tags, file_count=0)
+        db_kb = KnowledgeBase(name=name, description=description, creator = creator, tags=tags, file_count=0)
         db.add(db_kb)
         db.commit()
         db.refresh(db_kb)