user_router.py 16 KB


  1. import sys,os
  2. import uuid
  3. current_path = os.getcwd()
  4. sys.path.append(current_path)
  5. from config.site import SiteConfig
  6. from fastapi import APIRouter, Depends, Query
  7. from db.database import get_db
  8. from sqlalchemy.orm import Session
  9. from agent.models.web.response import StandardResponse,FAILED,SUCCESS
  10. from agent.models.web.request import BasicRequest
  11. from agent.libs.user import UserBusiness,SessionBusiness, UserRoleBusiness, RoleBusiness, PermissionBusiness
  12. import logging
  13. from pydantic import BaseModel
  14. from typing import Optional
  15. router = APIRouter(prefix="/user", tags=["agent job interface"])
  16. logger = logging.getLogger(__name__)
  17. config = SiteConfig()
  18. @router.post("/session", response_model=StandardResponse)
  19. def register(request: BasicRequest, db: Session = Depends(get_db)):
  20. if request.action == 'register':
  21. biz = UserBusiness(db)
  22. request_username = request.get_param("username", "")
  23. request_password = request.get_param("password", "")
  24. user = biz.get_user_by_username(request_username)
  25. if user is not None:
  26. return StandardResponse(code=FAILED, message="user already exists")
  27. user = biz.create_user(request_username, request_password)
  28. if user is None:
  29. return StandardResponse(code=FAILED, message="create user failed")
  30. return StandardResponse(code=SUCCESS, message="create user success")
  31. elif request.action =='login':
  32. request_username = request.get_param("username", "")
  33. request_password = request.get_param("password", "")
  34. logger.info(f"login: {request_username} {request_password}")
  35. biz = UserBusiness(db)
  36. user = biz.get_user_by_username(request_username)
  37. if user is None:
  38. return StandardResponse(code=FAILED, message="user not exists")
  39. if not biz.verify_password(request_password, user.hashed_password):
  40. return StandardResponse(code=FAILED, message="password error")
  41. session = SessionBusiness(db)
  42. old_session = session.get_session_by_user_id(user.id)
  43. if old_session is not None:
  44. logger.info("delete old session")
  45. session.delete_session(old_session.session_id)
  46. logger.info("create new session")
  47. new_session = session.create_session(user)
  48. # Get user roles and permissions
  49. user_role_biz = UserRoleBusiness(db)
  50. user_roles = user_role_biz.get_user_roles(user.id)
  51. user_menu_permissions = user_role_biz.get_user_menu_permissions(user.id)
  52. # Prepare roles and permissions for response
  53. roles_data = [{
  54. "id": role.id,
  55. "name": role.name,
  56. "description": role.description
  57. } for role in user_roles]
  58. # permissions_data = [{
  59. # "id": perm.id,
  60. # "name": perm.name,
  61. # "description": perm.description,
  62. # "menu_name": perm.menu_name,
  63. # "menu_route": perm.menu_route,
  64. # "menu_icon": perm.menu_icon,
  65. # "parent_id": perm.parent_id
  66. # } for perm in user_menu_permissions]
  67. # 构建权限字典,方便通过ID查找
  68. permission_map = {p.id: {
  69. "id": p.id, "name": p.name, "description": p.description,
  70. "menu_name": p.menu_name, "menu_route": p.menu_route,
  71. "menu_icon": p.menu_icon, "parent_id": p.parent_id,
  72. "children": []
  73. } for p in user_menu_permissions}
  74. # 构建树形结构
  75. tree = []
  76. for p_id, p_data in permission_map.items():
  77. parent_id = p_data["parent_id"]
  78. if parent_id and parent_id in permission_map:
  79. permission_map[parent_id]["children"].append(p_data)
  80. else:
  81. tree.append(p_data)
  82. session_data = {
  83. "session_id": new_session.session_id,
  84. "user_id": new_session.user_id,
  85. "username": new_session.username,
  86. "full_name": new_session.full_name
  87. }
  88. return StandardResponse(code=SUCCESS, message="login success", records=[{"session": session_data, "roles": roles_data, "menu_permissions": tree}])
  89. elif request.action == "login_session":
  90. session_id = request.get_param("session_id", "")
  91. session = SessionBusiness(db)
  92. old_session = session.get_session(session_id)
  93. if old_session is None:
  94. return StandardResponse(code=FAILED, message="session not exists")
  95. old_session_data = {
  96. "session_id": old_session.session_id,
  97. "user_id": old_session.user_id,
  98. "username": old_session.username,
  99. "full_name": old_session.full_name
  100. }
  101. return StandardResponse(code=SUCCESS, message="login success", records=[old_session_data])
  102. elif request.action == "logout":
  103. session_id = request.get_param("session_id", "")
  104. session = SessionBusiness(db)
  105. session.delete_session(session_id)
  106. return StandardResponse(code=SUCCESS, message="logout success")
  107. @router.get("/logout/{session_id}", response_model=StandardResponse)
  108. def logout(session_id: str, db: Session = Depends(get_db)):
  109. session = SessionBusiness(db)
  110. session.delete_session(session_id)
  111. return StandardResponse(code=SUCCESS, message="logout success")
  112. @router.post("/signin", response_model=StandardResponse)
  113. def signin(request: BasicRequest, db: Session = Depends(get_db)):
  114. if request.action == 'signin':
  115. biz = UserBusiness(db)
  116. request_username = request.get_param("username", "")
  117. request_password = request.get_param("password", "")
  118. request_fullname = request.get_param("full_name", request_username) # 如果 full_name 未提供,则使用 username
  119. request_email = request.get_param("email", "")
  120. # 确保提供了用户名和密码
  121. if not request_username or not request_password:
  122. return StandardResponse(code=FAILED, message="Username and password are required")
  123. user = biz.get_user_by_username(request_username)
  124. if user is not None:
  125. return StandardResponse(code=FAILED, message="用户名已存在")
  126. user = biz.create_user(username=request_username, password=request_password, fullname=request_fullname, email=request_email)
  127. if user is None:
  128. return StandardResponse(code=FAILED, message="创建用户失败")
  129. # 分配角色
  130. user_role_biz = UserRoleBusiness(db)
  131. assigned = user_role_biz.assign_role_to_user(user.id, 10) # 分配角色ID为10
  132. if not assigned:
  133. logger.warning(f"Failed to assign role 10 to user {user.id} during signin")
  134. # 即使角色分配失败,也认为注册成功
  135. return StandardResponse(code=SUCCESS, data=user.to_dict(), message="成功创建用户,请继续登录")
  136. return StandardResponse(code=FAILED, message="invalid action")
  137. # Pydantic models for request bodies
  138. class RoleCreateWithPermissionsRequest(BaseModel):
  139. role_id: Optional[int] = None
  140. name: str
  141. description: Optional[str] = None
  142. permission_ids: list[int] = []
  143. class PermissionCreateRequest(BaseModel):
  144. name: str
  145. description: Optional[str] = None
  146. menu_name: Optional[str] = None
  147. menu_route: Optional[str] = None
  148. menu_icon: Optional[str] = None
  149. parent_id: Optional[int] = None
  150. # Role Management Endpoints
  151. @router.post("/roles", response_model=StandardResponse)
  152. def create_role_with_permissions_endpoint(request: RoleCreateWithPermissionsRequest, db: Session = Depends(get_db)):
  153. role_id = request.role_id
  154. role_name = request.name
  155. role_description = request.description
  156. permission_ids = request.permission_ids
  157. role_biz = RoleBusiness(db)
  158. if role_id:
  159. # 修改现有角色
  160. role = role_biz.get_role(role_id)
  161. if not role:
  162. return StandardResponse(code=FAILED, message="角色不存在")
  163. # 更新角色名称和描述(如果提供)
  164. if role_name and role_name != role.name:
  165. existing_role_by_name = role_biz.get_role_by_name(role_name)
  166. if existing_role_by_name and existing_role_by_name.id != role_id:
  167. return StandardResponse(code=FAILED, message="新角色名称已存在")
  168. role_biz.update_role(role_id, name=role_name)
  169. if role_description is not None and role_description != role.description:
  170. role_biz.update_role(role_id, description=role_description)
  171. # 撤销所有现有权限
  172. role_biz.revoke_all_permissions_from_role(role_id)
  173. # 重新分配权限
  174. success_permissions = []
  175. failed_permissions = []
  176. for permission_id in permission_ids:
  177. if role_biz.assign_permission_to_role(role.id, permission_id):
  178. success_permissions.append(permission_id)
  179. else:
  180. failed_permissions.append(permission_id)
  181. response_message = f"角色 '{role.name}' 更新成功"
  182. if success_permissions:
  183. response_message += f", 成功分配 {len(success_permissions)} 个权限"
  184. if failed_permissions:
  185. response_message += f", {len(failed_permissions)} 个权限分配失败"
  186. return StandardResponse(
  187. code=SUCCESS,
  188. message=response_message,
  189. records=[{
  190. "id": role.id,
  191. "name": role.name,
  192. "success_permissions": success_permissions,
  193. "failed_permissions": failed_permissions
  194. }]
  195. )
  196. else:
  197. # 新增角色
  198. if not role_name:
  199. return StandardResponse(code=FAILED, message="角色名称不能为空")
  200. existing_role = role_biz.get_role_by_name(role_name)
  201. if existing_role:
  202. return StandardResponse(code=FAILED, message="角色已存在")
  203. # 创建角色
  204. role = role_biz.create_role(role_name, role_description)
  205. if not role:
  206. return StandardResponse(code=FAILED, message="创建角色失败")
  207. # 分配权限
  208. success_permissions = []
  209. failed_permissions = []
  210. for permission_id in permission_ids:
  211. if role_biz.assign_permission_to_role(role.id, permission_id):
  212. success_permissions.append(permission_id)
  213. else:
  214. failed_permissions.append(permission_id)
  215. response_message = f"角色创建成功"
  216. if success_permissions:
  217. response_message += f", 成功分配 {len(success_permissions)} 个权限"
  218. if failed_permissions:
  219. response_message += f", {len(failed_permissions)} 个权限分配失败"
  220. return StandardResponse(
  221. code=SUCCESS,
  222. message=response_message,
  223. records=[{
  224. "id": role.id,
  225. "name": role.name,
  226. "success_permissions": success_permissions,
  227. "failed_permissions": failed_permissions
  228. }]
  229. )
  230. @router.get("/roles", response_model=StandardResponse)
  231. def get_roles_endpoint(db: Session = Depends(get_db)):
  232. role_biz = RoleBusiness(db)
  233. roles = role_biz.get_all_roles()
  234. roles_data = [{
  235. "id": role.id,
  236. "name": role.name,
  237. "description": role.description,
  238. "permission_ids": [perm.id for perm in role_biz.get_role_permissions(role.id)]
  239. } for role in roles]
  240. return StandardResponse(code=SUCCESS, message="角色列表获取成功", records=roles_data)
  241. # Permission Management Endpoints
  242. @router.post("/permissions", response_model=StandardResponse)
  243. def create_permission_endpoint(request: PermissionCreateRequest, db: Session = Depends(get_db)):
  244. perm_name = request.name
  245. perm_desc = request.description
  246. menu_name = request.menu_name
  247. menu_route = request.menu_route
  248. menu_icon = request.menu_icon
  249. parent_id = request.parent_id
  250. if not perm_name:
  251. return StandardResponse(code=FAILED, message="Permission name is required")
  252. perm_biz = PermissionBusiness(db)
  253. existing_perm = perm_biz.get_permission_by_name(perm_name)
  254. if existing_perm:
  255. return StandardResponse(code=FAILED, message="Permission already exists")
  256. permission = perm_biz.create_permission(perm_name, perm_desc, menu_name, menu_route, menu_icon, parent_id)
  257. if permission:
  258. return StandardResponse(code=SUCCESS, message="Permission created successfully", records=[{"id": permission.id, "name": permission.name}])
  259. return StandardResponse(code=FAILED, message="Failed to create permission")
  260. @router.get("/permissions", response_model=StandardResponse)
  261. def get_permissions_endpoint(db: Session = Depends(get_db)):
  262. perm_biz = PermissionBusiness(db)
  263. permissions = perm_biz.get_all_permissions()
  264. # 构建权限字典,方便通过ID查找
  265. permission_map = {p.id: {
  266. "id": p.id, "name": p.name, "description": p.description,
  267. "menu_name": p.menu_name, "menu_route": p.menu_route,
  268. "menu_icon": p.menu_icon, "parent_id": p.parent_id,
  269. "children": []
  270. } for p in permissions}
  271. # 构建树形结构
  272. tree = []
  273. for p_id, p_data in permission_map.items():
  274. parent_id = p_data["parent_id"]
  275. if parent_id and parent_id in permission_map:
  276. permission_map[parent_id]["children"].append(p_data)
  277. else:
  278. tree.append(p_data)
  279. return StandardResponse(code=SUCCESS, message="权限列表获取成功", records=tree)
  280. class UserRoleAssignmentRequest(BaseModel):
  281. user_id: int
  282. role_ids: list[int]
  283. @router.post("/users/roles", response_model=StandardResponse)
  284. def assign_roles_to_user_endpoint(request: UserRoleAssignmentRequest, db: Session = Depends(get_db)):
  285. user_role_biz = UserRoleBusiness(db)
  286. user_id = request.user_id
  287. new_role_ids = set(request.role_ids)
  288. # 获取用户当前的角色ID
  289. current_roles = user_role_biz.get_user_roles(user_id)
  290. current_role_ids = {role.id for role in current_roles} if current_roles else set()
  291. # 需要添加的角色
  292. roles_to_add = list(new_role_ids - current_role_ids)
  293. # 需要移除的角色
  294. roles_to_remove = list(current_role_ids - new_role_ids)
  295. success_add_count = 0
  296. failed_add_roles = []
  297. for role_id in roles_to_add:
  298. if user_role_biz.assign_role_to_user(user_id, role_id):
  299. success_add_count += 1
  300. else:
  301. failed_add_roles.append(role_id)
  302. success_remove_count = 0
  303. failed_remove_roles = []
  304. for role_id in roles_to_remove:
  305. if user_role_biz.revoke_role_from_user(user_id, role_id):
  306. success_remove_count += 1
  307. else:
  308. failed_remove_roles.append(role_id)
  309. message = f"角色分配更新完成。成功添加 {success_add_count} 个角色,成功移除 {success_remove_count} 个角色。"
  310. if failed_add_roles:
  311. message += f" 添加失败的角色ID: {failed_add_roles}."
  312. if failed_remove_roles:
  313. message += f" 移除失败的角色ID: {failed_remove_roles}."
  314. if not failed_add_roles and not failed_remove_roles:
  315. return StandardResponse(code=SUCCESS, message="用户角色更新成功")
  316. else:
  317. return StandardResponse(code=FAILED, message=message)
  318. @router.get("/users", response_model=StandardResponse)
  319. def get_users_endpoint(
  320. db: Session = Depends(get_db),
  321. userName: Optional[str] = Query(None, description="用户名,用于模糊查询"),
  322. pageNo: int = Query(1, ge=1, description="页码,从1开始"),
  323. pageSize: int = Query(10, ge=1, le=100, description="每页数量,最大100")
  324. ):
  325. user_biz = UserBusiness(db)
  326. user_role_biz = UserRoleBusiness(db)
  327. paginated_users, total_count = user_biz.get_users_paginated(userName, pageNo, pageSize)
  328. users_data = []
  329. for user in paginated_users:
  330. roles = user_role_biz.get_user_roles(user.id)
  331. role_ids = [role.id for role in roles] if roles else []
  332. users_data.append({
  333. "id": user.id,
  334. "username": user.username,
  335. "full_name": user.full_name,
  336. "email": user.email,
  337. "role_ids": role_ids
  338. })
  339. return StandardResponse(code=SUCCESS, message="用户列表获取成功", records=users_data, total=total_count)
  340. user_router = router