auth.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import time
  2. import jwt
  3. from fastapi import APIRouter, Depends, HTTPException, status
  4. from fastapi.security import OAuth2PasswordRequestForm
  5. from pydantic import ValidationError
  6. from sqlalchemy import select
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from app.api import deps
  9. from app.core import config, security
  10. from app.models import User
  11. from app.schemas.requests import RefreshTokenRequest
  12. from app.schemas.responses import AccessTokenResponse
  13. router = APIRouter()
  14. @router.post("/access-token", response_model=AccessTokenResponse)
  15. async def login_access_token(
  16. session: AsyncSession = Depends(deps.get_session),
  17. form_data: OAuth2PasswordRequestForm = Depends(),
  18. ):
  19. """OAuth2 compatible token, get an access token for future requests using username and password"""
  20. result = await session.execute(select(User).where(User.email == form_data.username))
  21. user = result.scalars().first()
  22. if user is None:
  23. raise HTTPException(status_code=400, detail="Incorrect email or password")
  24. if not security.verify_password(form_data.password, user.hashed_password):
  25. raise HTTPException(status_code=400, detail="Incorrect email or password")
  26. return security.generate_access_token_response(str(user.id))
  27. @router.post("/refresh-token", response_model=AccessTokenResponse)
  28. async def refresh_token(
  29. input: RefreshTokenRequest,
  30. session: AsyncSession = Depends(deps.get_session),
  31. ):
  32. """OAuth2 compatible token, get an access token for future requests using refresh token"""
  33. try:
  34. payload = jwt.decode(
  35. input.refresh_token,
  36. config.settings.SECRET_KEY,
  37. algorithms=[security.JWT_ALGORITHM],
  38. )
  39. except (jwt.DecodeError, ValidationError):
  40. raise HTTPException(
  41. status_code=status.HTTP_403_FORBIDDEN,
  42. detail="Could not validate credentials, unknown error",
  43. )
  44. # JWT guarantees payload will be unchanged (and thus valid), no errors here
  45. token_data = security.JWTTokenPayload(**payload)
  46. if not token_data.refresh:
  47. raise HTTPException(
  48. status_code=status.HTTP_403_FORBIDDEN,
  49. detail="Could not validate credentials, cannot use access token",
  50. )
  51. now = int(time.time())
  52. if now < token_data.issued_at or now > token_data.expires_at:
  53. raise HTTPException(
  54. status_code=status.HTTP_403_FORBIDDEN,
  55. detail="Could not validate credentials, token expired or not yet valid",
  56. )
  57. result = await session.execute(select(User).where(User.id == token_data.sub))
  58. user = result.scalars().first()
  59. if user is None:
  60. raise HTTPException(status_code=404, detail="User not found")
  61. return security.generate_access_token_response(str(user.id))