authentication working

parent 03de49a5
Showing with 61 additions and 55 deletions
...@@ -3,85 +3,85 @@ ...@@ -3,85 +3,85 @@
# #
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from requests import Session
from fastapi import HTTPException, Depends from fastapi import HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
import jwt import jwt
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel from pydantic import BaseModel
from user.schemas import UserBase
from user.crud import get_user_by_email from user.crud import get_user_by_email
# #
# This is the secret key used to hash the password # This is the secret key used to hash the token
# to get a string like this run: # to get a string like this run:
# openssl rand -hex 32 # openssl rand -hex 32
# #
SECRET_KEY = "" SECRET_KEY = "67c82b6b6b49e47fff1a8b51915ad0daf262c4cb4a69795af9ac90f03ecae10b"
# This is the password hashing algorithm
ALGORITHM = "HS256" ALGORITHM = "HS256"
# This is the expiration time of the token in minutes
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
# # DTO for token
# DTOs for token and credentials
#
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class TokenData(BaseModel):
username: str
role: str
class UserCredentials(BaseModel):
email: str
password: str
# Password hashing # Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
# Token creation # Token creation
def create_access_token(data: dict, expires_delta: timedelta | None = None): def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy() to_encode = data.copy()
if expires_delta: expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
#
# Authentication # Authentication
# def login_user(db: Session, email: str, password: str):
def authenticate_user(db: Session, email: str, password: str):
user = get_user_by_email(db, email) user = get_user_by_email(db, email)
if not user: if user and pwd_context.verify(password, user.password):
return False access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
if not verify_password(password, user.password): access_token = create_access_token(
return False data={"email": user.email, "role": user.role, "id": user.id},
return user expires_delta=access_token_expires
)
def login_user(db, email: str, password: str): return Token(access_token=access_token, token_type="bearer")
user = authenticate_user(db, email, password) raise HTTPException(status_code=401, detail="Incorrect username or password")
if not user:
raise HTTPException(status_code=401, detail="Incorrect username or password") # Authorization
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) from typing import Annotated
access_token = create_access_token( from fastapi import HTTPException, status
data={"sub": user.email, "role": user.role, "id": user.id}, from fastapi.security import OAuth2PasswordBearer
expires_delta=access_token_expires from user.schemas import User
from database import get_db
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
async def is_valid_user(
db: Annotated[Session, Depends(get_db)],
token: Annotated[str, Depends(oauth2_scheme)]
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
) )
return Token(access_token=access_token, token_type="bearer") try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_info = payload
if user_info.get('email') is None or user_info.get('role') is None:
raise credentials_exception
except jwt.InvalidTokenError:
raise credentials_exception
user = get_user_by_email(db, user_info.get('email'))
if user is None:
raise credentials_exception
return User(**user.__dict__)
async def is_admin_user(
current_user: Annotated[User, Depends(is_valid_user)],
):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Not enough permissions")
return current_user
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from datetime import date from datetime import date
from typing import List from typing import List
...@@ -8,6 +8,7 @@ from user import crud as user_crud, schemas as user_schemas ...@@ -8,6 +8,7 @@ from user import crud as user_crud, schemas as user_schemas
from book import crud as book_crud, schemas as book_schemas from book import crud as book_crud, schemas as book_schemas
from loan import crud as loan_crud, schemas as loan_schemas from loan import crud as loan_crud, schemas as loan_schemas
from book.schemas import Book # Import Book schema from book.schemas import Book # Import Book schema
from auth import is_admin_user, is_valid_user
app = FastAPI() app = FastAPI()
...@@ -31,11 +32,16 @@ def create_user(user: user_schemas.UserCreate, db: Session = Depends(get_db)): ...@@ -31,11 +32,16 @@ def create_user(user: user_schemas.UserCreate, db: Session = Depends(get_db)):
return user_crud.create_user(db, user) return user_crud.create_user(db, user)
@app.get("/user/", response_model=List[user_schemas.User]) @app.get("/user/", response_model=List[user_schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): async def read_users(skip: int = 0, limit: int = 100,
db: Session = Depends(get_db),
_: user_schemas.User = Depends(is_admin_user)):
return user_crud.get_users(db, skip, limit) return user_crud.get_users(db, skip, limit)
@app.get("/user/{user_id}", response_model=user_schemas.User) @app.get("/user/{user_id}", response_model=user_schemas.User)
def read_user(user_id: int, db: Session = Depends(get_db)): def read_user(user_id: int, db: Session = Depends(get_db),
current_user: user_schemas.User = Depends(is_valid_user)):
if current_user.id != user_id:
raise HTTPException(status_code=403, detail="Forbidden")
return user_crud.get_user(db, user_id) return user_crud.get_user(db, user_id)
@app.put("/user/{user_id}", response_model=user_schemas.User) @app.put("/user/{user_id}", response_model=user_schemas.User)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment