| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import time
- from typing import Generator
- import jwt
- from fastapi import Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from app.core import config, security
- from app.core.session import session
- from app.models import User
- reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="auth/access-token")
- def get_session() -> Generator[Session, None, None]:
- with session() as db:
- yield db
- async def get_current_user(
- session: Session = Depends(get_session), token: str = Depends(reusable_oauth2)
- ) -> User:
- try:
- payload = jwt.decode(token, config.settings.SECRET_KEY, algorithms=[security.JWT_ALGORITHM])
- except jwt.DecodeError:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Could not validate credentials.",
- )
- # JWT guarantees payload will be unchanged (and thus valid), no errors here
- token_data = security.JWTTokenPayload(**payload)
- if token_data.refresh:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Could not validate credentials, cannot use refresh token",
- )
- now = int(time.time())
- if now < token_data.issued_at or now > token_data.expires_at:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Could not validate credentials, token expired or not yet valid",
- )
- result = session.execute(select(User).where(User.id == token_data.sub))
- user = result.scalars().first()
- if not user:
- raise HTTPException(status_code=404, detail="User not found.")
- return user
|