123 lines
4.2 KiB
Python
123 lines
4.2 KiB
Python
import re
|
|
from typing import List, Optional
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlmodel import Session, select, func
|
|
from database import get_session
|
|
from models import Comment, Rating, User, Profile
|
|
from schemas import CommentCreate, CommentRead, RatingCreate, RatingRead
|
|
from auth import get_current_user
|
|
from helpers import create_notification
|
|
|
|
router = APIRouter(prefix="/social", tags=["social"])
|
|
|
|
# --- Comments ---
|
|
|
|
@router.post("/comments", response_model=CommentRead)
|
|
def create_comment(
|
|
comment: CommentCreate,
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
db_comment = Comment.model_validate(comment)
|
|
db_comment.user_id = current_user.id
|
|
session.add(db_comment)
|
|
session.commit()
|
|
session.refresh(db_comment)
|
|
|
|
# Notify parent author if reply (TODO: Add parent_id to Comment model)
|
|
# For now, let's just log it or skip.
|
|
|
|
# Handle Mentions
|
|
mention_pattern = r"@(\w+)"
|
|
mentions = re.findall(mention_pattern, db_comment.content)
|
|
if mentions:
|
|
# Find users with these profile usernames
|
|
mentioned_profiles = session.exec(select(Profile).where(Profile.username.in_(mentions))).all()
|
|
for profile in mentioned_profiles:
|
|
if profile.user_id != current_user.id:
|
|
create_notification(
|
|
session,
|
|
user_id=profile.user_id,
|
|
title="You were mentioned!",
|
|
message=f"Someone mentioned you in a comment.",
|
|
type="mention",
|
|
link=f"/activity" # Generic link for now
|
|
)
|
|
|
|
return db_comment
|
|
|
|
@router.get("/comments", response_model=List[CommentRead])
|
|
def read_comments(
|
|
show_id: Optional[int] = None,
|
|
venue_id: Optional[int] = None,
|
|
song_id: Optional[int] = None,
|
|
offset: int = 0,
|
|
limit: int = Query(default=50, le=100),
|
|
session: Session = Depends(get_session)
|
|
):
|
|
query = select(Comment)
|
|
if show_id:
|
|
query = query.where(Comment.show_id == show_id)
|
|
if venue_id:
|
|
query = query.where(Comment.venue_id == venue_id)
|
|
if song_id:
|
|
query = query.where(Comment.song_id == song_id)
|
|
|
|
query = query.order_by(Comment.created_at.desc()).offset(offset).limit(limit)
|
|
comments = session.exec(query).all()
|
|
return comments
|
|
|
|
# --- Ratings ---
|
|
|
|
@router.post("/ratings", response_model=RatingRead)
|
|
def create_rating(
|
|
rating: RatingCreate,
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
# Check if user already rated this entity
|
|
query = select(Rating).where(Rating.user_id == current_user.id)
|
|
if rating.show_id:
|
|
query = query.where(Rating.show_id == rating.show_id)
|
|
elif rating.song_id:
|
|
query = query.where(Rating.song_id == rating.song_id)
|
|
elif rating.performance_id:
|
|
query = query.where(Rating.performance_id == rating.performance_id)
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Must rate a show, song, or performance")
|
|
|
|
existing_rating = session.exec(query).first()
|
|
if existing_rating:
|
|
# Update existing
|
|
existing_rating.score = rating.score
|
|
session.add(existing_rating)
|
|
session.commit()
|
|
session.refresh(existing_rating)
|
|
return existing_rating
|
|
|
|
db_rating = Rating.model_validate(rating)
|
|
db_rating.user_id = current_user.id
|
|
session.add(db_rating)
|
|
session.commit()
|
|
session.refresh(db_rating)
|
|
return db_rating
|
|
|
|
@router.get("/ratings/average", response_model=float)
|
|
def get_average_rating(
|
|
show_id: Optional[int] = None,
|
|
song_id: Optional[int] = None,
|
|
performance_id: Optional[int] = None,
|
|
session: Session = Depends(get_session)
|
|
):
|
|
query = select(func.avg(Rating.score))
|
|
if show_id:
|
|
query = query.where(Rating.show_id == show_id)
|
|
elif song_id:
|
|
query = query.where(Rating.song_id == song_id)
|
|
elif performance_id:
|
|
query = query.where(Rating.performance_id == performance_id)
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Must specify show_id, song_id, or performance_id")
|
|
|
|
avg = session.exec(query).first()
|
|
return float(avg) if avg else 0.0
|