from typing import List from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, select, desc from database import get_session from models import Notification, User from schemas import NotificationRead, NotificationCreate from auth import get_current_user router = APIRouter(prefix="/notifications", tags=["notifications"]) @router.get("/", response_model=List[NotificationRead]) def read_notifications( session: Session = Depends(get_session), current_user: User = Depends(get_current_user), limit: int = 20, offset: int = 0 ): notifications = session.exec( select(Notification) .where(Notification.user_id == current_user.id) .order_by(desc(Notification.created_at)) .offset(offset) .limit(limit) ).all() return notifications @router.get("/unread-count") def get_unread_count( session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): from sqlmodel import func count = session.exec( select(func.count(Notification.id)) .where(Notification.user_id == current_user.id) .where(Notification.is_read == False) ).one() return {"count": count} @router.post("/{notification_id}/read") def mark_as_read( notification_id: int, session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): notification = session.get(Notification, notification_id) if not notification: raise HTTPException(status_code=404, detail="Notification not found") if notification.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") notification.is_read = True session.add(notification) session.commit() return {"ok": True} @router.post("/mark-all-read") def mark_all_read( session: Session = Depends(get_session), current_user: User = Depends(get_current_user) ): notifications = session.exec( select(Notification) .where(Notification.user_id == current_user.id) .where(Notification.is_read == False) ).all() for n in notifications: n.is_read = True session.add(n) session.commit() return {"ok": True} # Helper function to create notifications (not an endpoint) def create_notification(session: Session, user_id: int, type: str, title: str, message: str, link: str = None): notification = Notification( user_id=user_id, type=type, title=title, message=message, link=link ) session.add(notification) session.commit() session.refresh(notification) return notification