from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, select from database import get_session from models import Attendance, User, Show from schemas import AttendanceCreate, AttendanceRead from auth import get_current_user router = APIRouter(prefix="/attendance", tags=["attendance"]) @router.post("/", response_model=AttendanceRead) def mark_attendance( attendance: AttendanceCreate, session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): # Check if already attended existing = session.exec( select(Attendance) .where(Attendance.user_id == current_user.id) .where(Attendance.show_id == attendance.show_id) ).first() if existing: # Update notes if provided, or just return existing if attendance.notes: existing.notes = attendance.notes session.add(existing) session.commit() session.refresh(existing) return existing db_attendance = Attendance(**attendance.model_dump(), user_id=current_user.id) session.add(db_attendance) session.commit() session.refresh(db_attendance) return db_attendance @router.delete("/{show_id}") def remove_attendance( show_id: int, session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): attendance = session.exec( select(Attendance) .where(Attendance.user_id == current_user.id) .where(Attendance.show_id == show_id) ).first() if not attendance: raise HTTPException(status_code=404, detail="Attendance not found") session.delete(attendance) session.commit() return {"ok": True} @router.get("/me", response_model=List[AttendanceRead]) def get_my_attendance( session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): return session.exec(select(Attendance).where(Attendance.user_id == current_user.id)).all() @router.get("/show/{show_id}", response_model=List[AttendanceRead]) def get_show_attendance( show_id: int, session: Session = Depends(get_session), offset: int = 0, limit: int = 100 ): return session.exec( select(Attendance) .where(Attendance.show_id == show_id) .offset(offset) .limit(limit) ).all()