99 lines
3.2 KiB
Python
99 lines
3.2 KiB
Python
from typing import List, Optional
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlmodel import Session, select, func
|
|
from database import get_session
|
|
from models import Comment, Rating, User
|
|
from schemas import CommentCreate, CommentRead, RatingCreate, RatingRead
|
|
from auth import get_current_user
|
|
|
|
router = APIRouter(prefix="/social", tags=["social"])
|
|
|
|
# --- Comments ---
|
|
|
|
@router.post("/comments", response_model=CommentRead)
|
|
def create_comment(
|
|
comment: CommentCreate,
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
db_comment = Comment.model_validate(comment)
|
|
db_comment.user_id = current_user.id
|
|
session.add(db_comment)
|
|
session.commit()
|
|
session.refresh(db_comment)
|
|
|
|
# Notify parent author if reply (TODO: Add parent_id to Comment model)
|
|
# For now, let's just log it or skip.
|
|
|
|
return db_comment
|
|
|
|
@router.get("/comments", response_model=List[CommentRead])
|
|
def read_comments(
|
|
show_id: Optional[int] = None,
|
|
venue_id: Optional[int] = None,
|
|
song_id: Optional[int] = None,
|
|
offset: int = 0,
|
|
limit: int = Query(default=50, le=100),
|
|
session: Session = Depends(get_session)
|
|
):
|
|
query = select(Comment)
|
|
if show_id:
|
|
query = query.where(Comment.show_id == show_id)
|
|
if venue_id:
|
|
query = query.where(Comment.venue_id == venue_id)
|
|
if song_id:
|
|
query = query.where(Comment.song_id == song_id)
|
|
|
|
query = query.order_by(Comment.created_at.desc()).offset(offset).limit(limit)
|
|
comments = session.exec(query).all()
|
|
return comments
|
|
|
|
# --- Ratings ---
|
|
|
|
@router.post("/ratings", response_model=RatingRead)
|
|
def create_rating(
|
|
rating: RatingCreate,
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
# Check if user already rated this entity
|
|
query = select(Rating).where(Rating.user_id == current_user.id)
|
|
if rating.show_id:
|
|
query = query.where(Rating.show_id == rating.show_id)
|
|
elif rating.song_id:
|
|
query = query.where(Rating.song_id == rating.song_id)
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Must rate a show or song")
|
|
|
|
existing_rating = session.exec(query).first()
|
|
if existing_rating:
|
|
# Update existing
|
|
existing_rating.score = rating.score
|
|
session.add(existing_rating)
|
|
session.commit()
|
|
session.refresh(existing_rating)
|
|
return existing_rating
|
|
|
|
db_rating = Rating.model_validate(rating)
|
|
db_rating.user_id = current_user.id
|
|
session.add(db_rating)
|
|
session.commit()
|
|
session.refresh(db_rating)
|
|
return db_rating
|
|
|
|
@router.get("/ratings/average", response_model=float)
|
|
def get_average_rating(
|
|
show_id: Optional[int] = None,
|
|
song_id: Optional[int] = None,
|
|
session: Session = Depends(get_session)
|
|
):
|
|
query = select(func.avg(Rating.score))
|
|
if show_id:
|
|
query = query.where(Rating.show_id == show_id)
|
|
elif song_id:
|
|
query = query.where(Rating.song_id == song_id)
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Must specify show_id or song_id")
|
|
|
|
avg = session.exec(query).first()
|
|
return float(avg) if avg else 0.0
|