security.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """Black-box security shortcuts to generate JWT tokens and password hashing and verifcation."""
  2. import time
  3. import jwt
  4. from passlib.context import CryptContext
  5. from pydantic import BaseModel
  6. from app.core import config
  7. from app.schemas.responses import AccessTokenResponse
  8. JWT_ALGORITHM = "HS256"
  9. ACCESS_TOKEN_EXPIRE_SECS = config.settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
  10. REFRESH_TOKEN_EXPIRE_SECS = config.settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60
  11. PWD_CONTEXT = CryptContext(
  12. schemes=["bcrypt"],
  13. deprecated="auto",
  14. bcrypt__rounds=config.settings.SECURITY_BCRYPT_ROUNDS,
  15. )
  16. class JWTTokenPayload(BaseModel):
  17. sub: str | int
  18. refresh: bool
  19. issued_at: int
  20. expires_at: int
  21. def create_jwt_token(subject: str | int, exp_secs: int, refresh: bool):
  22. """Creates jwt access or refresh token for user.
  23. Args:
  24. subject: anything unique to user, id or email etc.
  25. exp_secs: expire time in seconds
  26. refresh: if True, this is refresh token
  27. """
  28. issued_at = int(time.time())
  29. expires_at = issued_at + exp_secs
  30. to_encode: dict[str, int | str | bool] = {
  31. "issued_at": issued_at,
  32. "expires_at": expires_at,
  33. "sub": subject,
  34. "refresh": refresh,
  35. }
  36. encoded_jwt = jwt.encode(
  37. to_encode,
  38. key=config.settings.SECRET_KEY,
  39. algorithm=JWT_ALGORITHM,
  40. )
  41. return encoded_jwt, expires_at, issued_at
  42. def generate_access_token_response(subject: str | int):
  43. """Generate tokens and return AccessTokenResponse"""
  44. access_token, expires_at, issued_at = create_jwt_token(
  45. subject, ACCESS_TOKEN_EXPIRE_SECS, refresh=False
  46. )
  47. refresh_token, refresh_expires_at, refresh_issued_at = create_jwt_token(
  48. subject, REFRESH_TOKEN_EXPIRE_SECS, refresh=True
  49. )
  50. return AccessTokenResponse(
  51. token_type="Bearer",
  52. access_token=access_token,
  53. expires_at=expires_at,
  54. issued_at=issued_at,
  55. refresh_token=refresh_token,
  56. refresh_token_expires_at=refresh_expires_at,
  57. refresh_token_issued_at=refresh_issued_at,
  58. )
  59. def verify_password(plain_password: str, hashed_password: str) -> bool:
  60. """Verifies plain and hashed password matches
  61. Applies passlib context based on bcrypt algorithm on plain passoword.
  62. It takes about 0.3s for default 12 rounds of SECURITY_BCRYPT_DEFAULT_ROUNDS.
  63. """
  64. return PWD_CONTEXT.verify(plain_password, hashed_password)
  65. def get_password_hash(password: str) -> str:
  66. """Creates hash from password
  67. Applies passlib context based on bcrypt algorithm on plain passoword.
  68. It takes about 0.3s for default 12 rounds of SECURITY_BCRYPT_DEFAULT_ROUNDS.
  69. """
  70. return PWD_CONTEXT.hash(password)