from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, select, func from pydantic import BaseModel from database import get_session from models import User, Review, Attendance, Group, GroupMember, Show, UserPreferences, Profile from schemas import UserRead, ReviewRead, ShowRead, GroupRead, UserPreferencesUpdate from auth import get_current_user router = APIRouter(prefix="/users", tags=["users"]) class UserProfileUpdate(BaseModel): bio: Optional[str] = None avatar: Optional[str] = None username: Optional[str] = None display_name: Optional[str] = None @router.get("/{user_id}", response_model=UserRead) def get_user_public(user_id: int, session: Session = Depends(get_session)): """Get public user profile""" user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="User not found") return user @router.patch("/me", response_model=UserRead) def update_my_profile( update: UserProfileUpdate, current_user: User = Depends(get_current_user), session: Session = Depends(get_session) ): """Update current user's bio, avatar, and primary profile""" if update.bio is not None: current_user.bio = update.bio if update.avatar is not None: current_user.avatar = update.avatar if update.username or update.display_name: # Find or create primary profile query = select(Profile).where(Profile.user_id == current_user.id) profile = session.exec(query).first() if not profile: if not update.username: raise HTTPException(status_code=400, detail="Username required for new profile") # Check uniqueness (naive check) existing = session.exec(select(Profile).where(Profile.username == update.username)).first() if existing: raise HTTPException(status_code=400, detail="Username taken") profile = Profile( user_id=current_user.id, username=update.username, display_name=update.display_name or update.username ) session.add(profile) else: if update.username: # Check uniqueness if changing if update.username != profile.username: existing = session.exec(select(Profile).where(Profile.username == update.username)).first() if existing: raise HTTPException(status_code=400, detail="Username taken") profile.username = update.username if update.display_name: profile.display_name = update.display_name session.add(profile) session.add(current_user) session.commit() session.refresh(current_user) return current_user @router.patch("/me/preferences", response_model=UserPreferencesUpdate) def update_preferences( prefs: UserPreferencesUpdate, current_user: User = Depends(get_current_user), session: Session = Depends(get_session) ): # Find or create if not current_user.preferences: # Need to create db_prefs = UserPreferences(user_id=current_user.id) current_user.preferences = db_prefs # Link it? # Actually, if relation is set up, adding to session should work. # But safest to create explicitly. db_prefs = UserPreferences( user_id=current_user.id, wiki_mode=prefs.wiki_mode if prefs.wiki_mode is not None else False, show_ratings=prefs.show_ratings if prefs.show_ratings is not None else True, show_comments=prefs.show_comments if prefs.show_comments is not None else True ) session.add(db_prefs) else: db_prefs = current_user.preferences if prefs.wiki_mode is not None: db_prefs.wiki_mode = prefs.wiki_mode if prefs.show_ratings is not None: db_prefs.show_ratings = prefs.show_ratings if prefs.show_comments is not None: db_prefs.show_comments = prefs.show_comments session.add(db_prefs) session.commit() session.refresh(db_prefs) return db_prefs # --- User Stats --- @router.get("/{user_id}/stats") def get_user_stats(user_id: int, session: Session = Depends(get_session)): # Check if user exists user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="User not found") attendance_count = session.exec(select(func.count(Attendance.id)).where(Attendance.user_id == user_id)).one() review_count = session.exec(select(func.count(Review.id)).where(Review.user_id == user_id)).one() group_count = session.exec(select(func.count(GroupMember.id)).where(GroupMember.user_id == user_id)).one() return { "attendance_count": attendance_count, "review_count": review_count, "group_count": group_count } # --- User Data Lists --- @router.get("/{user_id}/attendance", response_model=List[ShowRead]) def get_user_attendance( user_id: int, offset: int = 0, limit: int = Query(default=50, le=100), session: Session = Depends(get_session) ): # Return shows the user attended shows = session.exec( select(Show) .join(Attendance, Show.id == Attendance.show_id) .where(Attendance.user_id == user_id) .order_by(Show.date.desc()) .offset(offset) .limit(limit) ).all() return shows @router.get("/{user_id}/reviews", response_model=List[ReviewRead]) def get_user_reviews( user_id: int, offset: int = 0, limit: int = Query(default=50, le=100), session: Session = Depends(get_session) ): reviews = session.exec( select(Review) .where(Review.user_id == user_id) .order_by(Review.created_at.desc()) .offset(offset) .limit(limit) ).all() return reviews @router.get("/{user_id}/groups", response_model=List[GroupRead]) def get_user_groups( user_id: int, offset: int = 0, limit: int = Query(default=50, le=100), session: Session = Depends(get_session) ): groups = session.exec( select(Group) .join(GroupMember, Group.id == GroupMember.group_id) .where(GroupMember.user_id == user_id) .offset(offset) .limit(limit) ).all() return groups