auth.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from fastapi import APIRouter, Depends, HTTPException, status
  2. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  3. from pydantic import BaseModel
  4. from db.database import Base, engine, get_db
  5. from db.models import DbUsers
  6. from sqlalchemy.orm import Session
  7. from datetime import datetime, timedelta
  8. from jose import JWTError, jwt
  9. from typing import Optional
  10. from passlib.context import CryptContext
  11. from utils.response import resp_200, resp_400
  12. # JWT 相关配置
  13. SECRET_KEY = "kg-server"
  14. ALGORITHM = "HS256"
  15. ACCESS_TOKEN_EXPIRE_MINUTES = 30
  16. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/token")
  17. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  18. password = "secret"
  19. hashed_password = pwd_context.hash(password)
  20. print("auth.py Hashed password:", hashed_password)
  21. # 假的用户数据库
  22. # fake_users_db = {
  23. # "johndoe": {
  24. # "username": "johndoe",
  25. # "full_name": "John Doe",
  26. # "email": "johndoe@example.com",
  27. # "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36Y6QJwqmn4yZHRx70jN9nF", # 密码为 'secret'
  28. # "disabled": False,
  29. # }
  30. # }
  31. # 用户模型
  32. class Token(BaseModel):
  33. access_token: str
  34. token_type: str
  35. class User(BaseModel):
  36. username: str
  37. email: Optional[str] = None
  38. full_name: Optional[str] = None
  39. status: Optional[int] = None
  40. class UserInDB(User):
  41. hashed_password: str
  42. class Config:
  43. from_attributes = True
  44. # 密码验证
  45. def verify_password(plain_password, hashed_password):
  46. return pwd_context.verify(plain_password, hashed_password)
  47. def get_password_hash(password):
  48. return pwd_context.hash(password)
  49. def get_user(db, username: str):
  50. user = db.query(DbUsers).filter(DbUsers.username == username).first()
  51. if user:
  52. return UserInDB.model_validate(user)
  53. return None
  54. def authenticate_user(db, username: str, password: str):
  55. user = get_user(db, username)
  56. if not user:
  57. return False
  58. if not verify_password(password, user.hashed_password):
  59. return False
  60. return user
  61. # 创建访问令牌
  62. def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
  63. to_encode = data.copy()
  64. if expires_delta:
  65. expire = datetime.utcnow() + expires_delta
  66. else:
  67. expire = datetime.utcnow() + timedelta(minutes=15)
  68. to_encode.update({"exp": expire})
  69. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  70. return encoded_jwt
  71. # 路由
  72. router = APIRouter()
  73. @router.post("/api/token")
  74. async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
  75. user = authenticate_user(db, form_data.username, form_data.password)
  76. if not user:
  77. return resp_400(message="Incorrect username or password", data=[])
  78. access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  79. access_token = create_access_token(
  80. data={"sub": user.username}, expires_delta=access_token_expires
  81. )
  82. return resp_200(data = {"access_token": access_token, "token_type": "bearer"})
  83. # 获取当前用户
  84. @router.get("/api/get-user")
  85. async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
  86. credentials_exception = HTTPException(
  87. status_code=status.HTTP_401_UNAUTHORIZED,
  88. detail="Could not validate credentials",
  89. headers={"WWW-Authenticate": "Bearer"},
  90. )
  91. try:
  92. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  93. username: str = payload.get("sub")
  94. if username is None:
  95. raise credentials_exception
  96. except JWTError:
  97. raise credentials_exception
  98. user = get_user(db, username)
  99. if user is None:
  100. raise credentials_exception
  101. return resp_200(data = user.model_dump())
  102. def verify_token(token: str):
  103. try:
  104. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  105. username: str = payload.get("sub")
  106. if username is None:
  107. return False
  108. return True
  109. except JWTError:
  110. return False
  111. auth_router = router