from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, select, func from database import get_session from models import Comment, Rating, User from schemas import CommentCreate, CommentRead, RatingCreate, RatingRead from auth import get_current_user router = APIRouter(prefix="/social", tags=["social"]) # --- Comments --- @router.post("/comments", response_model=CommentRead) def create_comment( comment: CommentCreate, session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): db_comment = Comment.model_validate(comment) db_comment.user_id = current_user.id session.add(db_comment) session.commit() session.refresh(db_comment) # Notify parent author if reply (TODO: Add parent_id to Comment model) # For now, let's just log it or skip. return db_comment @router.get("/comments", response_model=List[CommentRead]) def read_comments( show_id: Optional[int] = None, venue_id: Optional[int] = None, song_id: Optional[int] = None, offset: int = 0, limit: int = Query(default=50, le=100), session: Session = Depends(get_session) ): query = select(Comment) if show_id: query = query.where(Comment.show_id == show_id) if venue_id: query = query.where(Comment.venue_id == venue_id) if song_id: query = query.where(Comment.song_id == song_id) query = query.order_by(Comment.created_at.desc()).offset(offset).limit(limit) comments = session.exec(query).all() return comments # --- Ratings --- @router.post("/ratings", response_model=RatingRead) def create_rating( rating: RatingCreate, session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): # Check if user already rated this entity query = select(Rating).where(Rating.user_id == current_user.id) if rating.show_id: query = query.where(Rating.show_id == rating.show_id) elif rating.song_id: query = query.where(Rating.song_id == rating.song_id) elif rating.performance_id: query = query.where(Rating.performance_id == rating.performance_id) else: raise HTTPException(status_code=400, detail="Must rate a show, song, or performance") existing_rating = session.exec(query).first() if existing_rating: # Update existing existing_rating.score = rating.score session.add(existing_rating) session.commit() session.refresh(existing_rating) return existing_rating db_rating = Rating.model_validate(rating) db_rating.user_id = current_user.id session.add(db_rating) session.commit() session.refresh(db_rating) return db_rating @router.get("/ratings/average", response_model=float) def get_average_rating( show_id: Optional[int] = None, song_id: Optional[int] = None, performance_id: Optional[int] = None, session: Session = Depends(get_session) ): query = select(func.avg(Rating.score)) if show_id: query = query.where(Rating.show_id == show_id) elif song_id: query = query.where(Rating.song_id == song_id) elif performance_id: query = query.where(Rating.performance_id == performance_id) else: raise HTTPException(status_code=400, detail="Must specify show_id, song_id, or performance_id") avg = session.exec(query).first() return float(avg) if avg else 0.0