deps.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import time
  2. from typing import Generator
  3. import jwt
  4. from fastapi import Depends, HTTPException, status
  5. from fastapi.security import OAuth2PasswordBearer
  6. from sqlalchemy import select
  7. from sqlalchemy.orm import Session
  8. from app.core import config, security
  9. from app.core.session import session
  10. from app.models import User
  11. reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="auth/access-token")
  12. def get_session() -> Generator[Session, None, None]:
  13. with session() as db:
  14. yield db
  15. async def get_current_user(
  16. session: Session = Depends(get_session), token: str = Depends(reusable_oauth2)
  17. ) -> User:
  18. try:
  19. payload = jwt.decode(
  20. token, config.settings.SECRET_KEY, algorithms=[security.JWT_ALGORITHM]
  21. )
  22. except jwt.DecodeError:
  23. raise HTTPException(
  24. status_code=status.HTTP_403_FORBIDDEN,
  25. detail="Could not validate credentials.",
  26. )
  27. # JWT guarantees payload will be unchanged (and thus valid), no errors here
  28. token_data = security.JWTTokenPayload(**payload)
  29. if token_data.refresh:
  30. raise HTTPException(
  31. status_code=status.HTTP_403_FORBIDDEN,
  32. detail="Could not validate credentials, cannot use refresh token",
  33. )
  34. now = int(time.time())
  35. if now < token_data.issued_at or now > token_data.expires_at:
  36. raise HTTPException(
  37. status_code=status.HTTP_403_FORBIDDEN,
  38. detail="Could not validate credentials, token expired or not yet valid",
  39. )
  40. result = session.execute(select(User).where(User.id == token_data.sub))
  41. user = result.scalars().first()
  42. if not user:
  43. raise HTTPException(status_code=404, detail="User not found.")
  44. return user