deps.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import time
  2. from collections.abc import AsyncGenerator
  3. from typing import Generator
  4. import jwt
  5. from fastapi import Depends, HTTPException, status
  6. from fastapi.security import OAuth2PasswordBearer
  7. from sqlalchemy import select
  8. from sqlalchemy.orm import Session
  9. from app.core import config, security
  10. from app.core.session import session
  11. from app.models import User
  12. reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="auth/access-token")
  13. def get_session() -> Generator[Session, None, None]:
  14. with session() as db:
  15. yield db
  16. async def get_current_user(
  17. session: Session = Depends(get_session), token: str = Depends(reusable_oauth2)
  18. ) -> User:
  19. try:
  20. payload = jwt.decode(
  21. token, config.settings.SECRET_KEY, algorithms=[security.JWT_ALGORITHM]
  22. )
  23. except jwt.DecodeError:
  24. raise HTTPException(
  25. status_code=status.HTTP_403_FORBIDDEN,
  26. detail="Could not validate credentials.",
  27. )
  28. # JWT guarantees payload will be unchanged (and thus valid), no errors here
  29. token_data = security.JWTTokenPayload(**payload)
  30. if token_data.refresh:
  31. raise HTTPException(
  32. status_code=status.HTTP_403_FORBIDDEN,
  33. detail="Could not validate credentials, cannot use refresh token",
  34. )
  35. now = int(time.time())
  36. if now < token_data.issued_at or now > token_data.expires_at:
  37. raise HTTPException(
  38. status_code=status.HTTP_403_FORBIDDEN,
  39. detail="Could not validate credentials, token expired or not yet valid",
  40. )
  41. result = session.execute(select(User).where(User.id == token_data.sub))
  42. user = result.scalars().first()
  43. if not user:
  44. raise HTTPException(status_code=404, detail="User not found.")
  45. return user