88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
from typing import List
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlmodel import Session, select, desc
|
|
from database import get_session
|
|
from models import Notification, User
|
|
from schemas import NotificationRead, NotificationCreate
|
|
from auth import get_current_user
|
|
|
|
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
|
|
|
@router.get("/", response_model=List[NotificationRead])
|
|
def read_notifications(
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user),
|
|
limit: int = 20,
|
|
offset: int = 0
|
|
):
|
|
notifications = session.exec(
|
|
select(Notification)
|
|
.where(Notification.user_id == current_user.id)
|
|
.order_by(desc(Notification.created_at))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
).all()
|
|
return notifications
|
|
|
|
@router.get("/unread-count")
|
|
def get_unread_count(
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
from sqlmodel import func
|
|
count = session.exec(
|
|
select(func.count(Notification.id))
|
|
.where(Notification.user_id == current_user.id)
|
|
.where(Notification.is_read == False)
|
|
).one()
|
|
return {"count": count}
|
|
|
|
@router.post("/{notification_id}/read")
|
|
def mark_as_read(
|
|
notification_id: int,
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
notification = session.get(Notification, notification_id)
|
|
if not notification:
|
|
raise HTTPException(status_code=404, detail="Notification not found")
|
|
|
|
if notification.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
notification.is_read = True
|
|
session.add(notification)
|
|
session.commit()
|
|
return {"ok": True}
|
|
|
|
@router.post("/mark-all-read")
|
|
def mark_all_read(
|
|
session: Session = Depends(get_session),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
notifications = session.exec(
|
|
select(Notification)
|
|
.where(Notification.user_id == current_user.id)
|
|
.where(Notification.is_read == False)
|
|
).all()
|
|
|
|
for n in notifications:
|
|
n.is_read = True
|
|
session.add(n)
|
|
|
|
session.commit()
|
|
return {"ok": True}
|
|
|
|
# Helper function to create notifications (not an endpoint)
|
|
def create_notification(session: Session, user_id: int, type: str, title: str, message: str, link: str = None):
|
|
notification = Notification(
|
|
user_id=user_id,
|
|
type=type,
|
|
title=title,
|
|
message=message,
|
|
link=link
|
|
)
|
|
session.add(notification)
|
|
session.commit()
|
|
session.refresh(notification)
|
|
return notification
|