deps.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import time
  2. from collections.abc import AsyncGenerator
  3. import jwt
  4. from fastapi import Depends, HTTPException, status
  5. from fastapi.security import OAuth2PasswordBearer
  6. from sqlalchemy import select
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from app.core import config, security
  9. from app.core.session import async_session
  10. from app.models import User
  11. reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="auth/access-token")
  12. async def get_session() -> AsyncGenerator[AsyncSession, None]:
  13. async with async_session() as session:
  14. yield session
  15. async def get_current_user(
  16. session: AsyncSession = Depends(get_session), token: str = Depends(reusable_oauth2)
  17. ) -> User:
  18. try:
  19. payload = jwt.decode(
  20. token, config.settings.SECRET_KEY, algorithms=[security.JWT_ALGORITHM]
  21. )
  22. except jwt.DecodeError:
  23. raise HTTPException(
  24. status_code=status.HTTP_403_FORBIDDEN,
  25. detail="Could not validate credentials.",
  26. )
  27. # JWT guarantees payload will be unchanged (and thus valid), no errors here
  28. token_data = security.JWTTokenPayload(**payload)
  29. if token_data.refresh:
  30. raise HTTPException(
  31. status_code=status.HTTP_403_FORBIDDEN,
  32. detail="Could not validate credentials, cannot use refresh token",
  33. )
  34. now = int(time.time())
  35. if now < token_data.issued_at or now > token_data.expires_at:
  36. raise HTTPException(
  37. status_code=status.HTTP_403_FORBIDDEN,
  38. detail="Could not validate credentials, token expired or not yet valid",
  39. )
  40. result = await session.execute(select(User).where(User.id == token_data.sub))
  41. user = result.scalars().first()
  42. if not user:
  43. raise HTTPException(status_code=404, detail="User not found.")
  44. return user