from datetime import timedelta, datetime from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlmodel import Session, select from pydantic import BaseModel, EmailStr from database import get_session from models import User, Profile from schemas import UserCreate, Token, UserRead from auth import verify_password, get_password_hash, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, get_current_user from email_service import ( send_verification_email, send_password_reset_email, generate_token, get_verification_expiry, get_reset_expiry ) router = APIRouter(prefix="/auth", tags=["auth"]) # Request/Response schemas for new endpoints class VerifyEmailRequest(BaseModel): token: str class ForgotPasswordRequest(BaseModel): email: EmailStr class ResetPasswordRequest(BaseModel): token: str new_password: str class ResendVerificationRequest(BaseModel): email: EmailStr @router.post("/register", response_model=UserRead) async def register( user_in: UserCreate, background_tasks: BackgroundTasks, session: Session = Depends(get_session) ): user = session.exec(select(User).where(User.email == user_in.email)).first() if user: raise HTTPException(status_code=400, detail="Email already registered") # Create User with verification token hashed_password = get_password_hash(user_in.password) verification_token = generate_token() db_user = User( email=user_in.email, hashed_password=hashed_password, email_verified=False, verification_token=verification_token, verification_token_expires=get_verification_expiry() ) session.add(db_user) session.commit() session.refresh(db_user) # Create Default Profile profile = Profile(user_id=db_user.id, username=user_in.username, display_name=user_in.username) session.add(profile) session.commit() # Send verification email in background background_tasks.add_task(send_verification_email, db_user.email, verification_token) return db_user @router.post("/verify-email") def verify_email(request: VerifyEmailRequest, session: Session = Depends(get_session)): """Verify user's email with token""" user = session.exec( select(User).where(User.verification_token == request.token) ).first() if not user: raise HTTPException(status_code=400, detail="Invalid verification token") if user.verification_token_expires and user.verification_token_expires < datetime.utcnow(): raise HTTPException(status_code=400, detail="Verification token expired") if user.email_verified: return {"message": "Email already verified"} # Mark as verified user.email_verified = True user.verification_token = None user.verification_token_expires = None session.add(user) session.commit() return {"message": "Email verified successfully"} @router.post("/resend-verification") async def resend_verification( request: ResendVerificationRequest, background_tasks: BackgroundTasks, session: Session = Depends(get_session) ): """Resend verification email""" user = session.exec(select(User).where(User.email == request.email)).first() if not user: # Don't reveal if email exists return {"message": "If the email exists, a verification link has been sent"} if user.email_verified: return {"message": "Email already verified"} # Generate new token user.verification_token = generate_token() user.verification_token_expires = get_verification_expiry() session.add(user) session.commit() background_tasks.add_task(send_verification_email, user.email, user.verification_token) return {"message": "If the email exists, a verification link has been sent"} @router.post("/forgot-password") async def forgot_password( request: ForgotPasswordRequest, background_tasks: BackgroundTasks, session: Session = Depends(get_session) ): """Request password reset email""" user = session.exec(select(User).where(User.email == request.email)).first() if not user: # Don't reveal if email exists return {"message": "If the email exists, a reset link has been sent"} # Generate reset token user.reset_token = generate_token() user.reset_token_expires = get_reset_expiry() session.add(user) session.commit() background_tasks.add_task(send_password_reset_email, user.email, user.reset_token) return {"message": "If the email exists, a reset link has been sent"} @router.post("/reset-password") def reset_password(request: ResetPasswordRequest, session: Session = Depends(get_session)): """Reset password with token""" user = session.exec( select(User).where(User.reset_token == request.token) ).first() if not user: raise HTTPException(status_code=400, detail="Invalid reset token") if user.reset_token_expires and user.reset_token_expires < datetime.utcnow(): raise HTTPException(status_code=400, detail="Reset token expired") # Update password user.hashed_password = get_password_hash(request.new_password) user.reset_token = None user.reset_token_expires = None session.add(user) session.commit() return {"message": "Password reset successfully"} @router.post("/token", response_model=Token) def login_for_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], session: Session = Depends(get_session) ): user = session.exec(select(User).where(User.email == form_data.username)).first() if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user.email}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} @router.get("/users/me", response_model=UserRead) def read_users_me(current_user: Annotated[User, Depends(get_current_user)]): return current_user