feat: enhance calendar ICS generation with event type filtering
All checks were successful
CI / lint-and-test (push) Successful in 22s
All checks were successful
CI / lint-and-test (push) Successful in 22s
- Added support for filtering calendar events by type in the ICS generation API endpoint, allowing users to specify whether to include only duty shifts or all event types (duty, unavailable, vacation). - Updated the `get_duties_for_user` function to accept an optional `event_types` parameter, enabling more flexible data retrieval based on user preferences. - Enhanced unit tests to cover the new event type filtering functionality, ensuring correct behavior and reliability of the ICS generation process.
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import date, timedelta
|
||||
from typing import Literal
|
||||
|
||||
import duty_teller.config as config
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
@@ -91,16 +92,20 @@ def list_calendar_events(
|
||||
"/api/calendar/ical/{token}.ics",
|
||||
summary="Personal calendar ICS",
|
||||
description=(
|
||||
"Returns an ICS calendar with only the subscribing user's duties. "
|
||||
"Returns an ICS calendar with the subscribing user's events. "
|
||||
"By default only duty shifts are included; use query parameter events=all "
|
||||
"for all event types (duty, unavailable, vacation). "
|
||||
"No Telegram auth; access is by secret token in the URL."
|
||||
),
|
||||
)
|
||||
def get_personal_calendar_ical(
|
||||
token: str,
|
||||
events: Literal["duty", "all"] = "duty",
|
||||
session: Session = Depends(get_db_session),
|
||||
) -> Response:
|
||||
"""
|
||||
Return ICS calendar with only the subscribing user's duties.
|
||||
Return ICS calendar with the subscribing user's events.
|
||||
Default: only duty shifts. Use ?events=all for duty, unavailable, vacation.
|
||||
No Telegram auth; access is by secret token in the URL.
|
||||
"""
|
||||
if not _is_valid_calendar_token(token):
|
||||
@@ -111,8 +116,9 @@ def get_personal_calendar_ical(
|
||||
today = date.today()
|
||||
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
||||
event_types = ["duty"] if events == "duty" else None
|
||||
duties_with_name = get_duties_for_user(
|
||||
session, user.id, from_date=from_date, to_date=to_date
|
||||
session, user.id, from_date=from_date, to_date=to_date, event_types=event_types
|
||||
)
|
||||
ics_bytes = build_personal_ics(duties_with_name)
|
||||
return Response(
|
||||
|
||||
@@ -213,14 +213,19 @@ def get_duties_for_user(
|
||||
user_id: int,
|
||||
from_date: str,
|
||||
to_date: str,
|
||||
event_types: list[str] | None = None,
|
||||
) -> list[tuple[Duty, str]]:
|
||||
"""Return duties for one user overlapping the date range.
|
||||
|
||||
Optionally filter by event_type (e.g. "duty", "unavailable", "vacation").
|
||||
When event_types is None, all event types are returned.
|
||||
|
||||
Args:
|
||||
session: DB session.
|
||||
user_id: User id.
|
||||
from_date: Start date YYYY-MM-DD.
|
||||
to_date: End date YYYY-MM-DD.
|
||||
event_types: If not None, only return duties whose event_type is in this list.
|
||||
|
||||
Returns:
|
||||
List of (Duty, full_name) tuples.
|
||||
@@ -228,14 +233,17 @@ def get_duties_for_user(
|
||||
to_date_next = (
|
||||
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
||||
).strftime("%Y-%m-%d")
|
||||
filters = [
|
||||
Duty.user_id == user_id,
|
||||
Duty.start_at < to_date_next,
|
||||
Duty.end_at >= from_date,
|
||||
]
|
||||
if event_types is not None:
|
||||
filters.append(Duty.event_type.in_(event_types))
|
||||
q = (
|
||||
session.query(Duty, User.full_name)
|
||||
.join(User, Duty.user_id == User.id)
|
||||
.filter(
|
||||
Duty.user_id == user_id,
|
||||
Duty.start_at < to_date_next,
|
||||
Duty.end_at >= from_date,
|
||||
)
|
||||
.filter(*filters)
|
||||
)
|
||||
return list(q.all())
|
||||
|
||||
|
||||
@@ -304,7 +304,9 @@ def test_calendar_ical_200_returns_only_that_users_duties(
|
||||
assert r.headers.get("content-type", "").startswith("text/calendar")
|
||||
assert b"BEGIN:VCALENDAR" in r.content
|
||||
mock_get_user.assert_called_once()
|
||||
mock_get_duties.assert_called_once_with(ANY, 1, from_date=ANY, to_date=ANY)
|
||||
mock_get_duties.assert_called_once_with(
|
||||
ANY, 1, from_date=ANY, to_date=ANY, event_types=["duty"]
|
||||
)
|
||||
mock_build_ics.assert_called_once()
|
||||
# Only User A's duty was passed to build_personal_ics
|
||||
duties_arg = mock_build_ics.call_args[0][0]
|
||||
@@ -313,6 +315,59 @@ def test_calendar_ical_200_returns_only_that_users_duties(
|
||||
assert duties_arg[0][1] == "User A"
|
||||
|
||||
|
||||
@patch("duty_teller.api.app.build_personal_ics")
|
||||
@patch("duty_teller.api.app.get_duties_for_user")
|
||||
@patch("duty_teller.api.app.get_user_by_calendar_token")
|
||||
def test_calendar_ical_events_all_returns_all_event_types(
|
||||
mock_get_user, mock_get_duties, mock_build_ics, client
|
||||
):
|
||||
"""GET /api/calendar/ical/{token}.ics?events=all returns ICS with duty, unavailable, vacation."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
mock_user = SimpleNamespace(id=1, full_name="User A")
|
||||
mock_get_user.return_value = mock_user
|
||||
duty = SimpleNamespace(
|
||||
id=10,
|
||||
user_id=1,
|
||||
start_at="2026-06-15T09:00:00Z",
|
||||
end_at="2026-06-15T18:00:00Z",
|
||||
event_type="duty",
|
||||
)
|
||||
unavailable = SimpleNamespace(
|
||||
id=11,
|
||||
user_id=1,
|
||||
start_at="2026-06-16T09:00:00Z",
|
||||
end_at="2026-06-16T18:00:00Z",
|
||||
event_type="unavailable",
|
||||
)
|
||||
vacation = SimpleNamespace(
|
||||
id=12,
|
||||
user_id=1,
|
||||
start_at="2026-06-17T09:00:00Z",
|
||||
end_at="2026-06-17T18:00:00Z",
|
||||
event_type="vacation",
|
||||
)
|
||||
mock_get_duties.return_value = [
|
||||
(duty, "User A"),
|
||||
(unavailable, "User A"),
|
||||
(vacation, "User A"),
|
||||
]
|
||||
mock_build_ics.return_value = b"BEGIN:VCALENDAR\r\nVEVENT\r\nEND:VCALENDAR"
|
||||
token = "y" * 43
|
||||
|
||||
r = client.get(f"/api/calendar/ical/{token}.ics", params={"events": "all"})
|
||||
assert r.status_code == 200
|
||||
assert r.headers.get("content-type", "").startswith("text/calendar")
|
||||
mock_get_duties.assert_called_once_with(
|
||||
ANY, 1, from_date=ANY, to_date=ANY, event_types=None
|
||||
)
|
||||
duties_arg = mock_build_ics.call_args[0][0]
|
||||
assert len(duties_arg) == 3
|
||||
assert duties_arg[0][0].event_type == "duty"
|
||||
assert duties_arg[1][0].event_type == "unavailable"
|
||||
assert duties_arg[2][0].event_type == "vacation"
|
||||
|
||||
|
||||
# --- /api/calendar-events ---
|
||||
|
||||
|
||||
@@ -330,7 +385,10 @@ def test_calendar_events_empty_url_returns_empty_list(client):
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
@patch("duty_teller.api.app.config.EXTERNAL_CALENDAR_ICS_URL", "https://example.com/cal.ics")
|
||||
@patch(
|
||||
"duty_teller.api.app.config.EXTERNAL_CALENDAR_ICS_URL",
|
||||
"https://example.com/cal.ics",
|
||||
)
|
||||
@patch("duty_teller.api.app.config.MINI_APP_SKIP_AUTH", True)
|
||||
def test_calendar_events_200_returns_list_with_date_summary(client):
|
||||
"""GET /api/calendar-events with auth and URL set returns list of {date, summary}."""
|
||||
|
||||
@@ -106,7 +106,9 @@ class TestGetEventsFromIcs:
|
||||
assert result == []
|
||||
|
||||
def test_broken_ics_returns_empty_list(self):
|
||||
result = mod._get_events_from_ics(b"not ical at all", "2025-01-01", "2025-01-31")
|
||||
result = mod._get_events_from_ics(
|
||||
b"not ical at all", "2025-01-01", "2025-01-31"
|
||||
)
|
||||
assert result == []
|
||||
|
||||
def test_recurring_events_skipped(self):
|
||||
@@ -128,11 +130,18 @@ class TestGetCalendarEvents:
|
||||
assert mod.get_calendar_events("", "2025-01-01", "2025-01-31") == []
|
||||
|
||||
def test_from_after_to_returns_empty(self):
|
||||
assert mod.get_calendar_events("https://example.com/a.ics", "2025-02-01", "2025-01-01") == []
|
||||
assert (
|
||||
mod.get_calendar_events(
|
||||
"https://example.com/a.ics", "2025-02-01", "2025-01-01"
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
@patch.object(mod, "_fetch_ics", return_value=None)
|
||||
def test_fetch_returns_none_returns_empty(self, mock_fetch):
|
||||
result = mod.get_calendar_events("https://example.com/a.ics", "2025-01-01", "2025-01-31")
|
||||
result = mod.get_calendar_events(
|
||||
"https://example.com/a.ics", "2025-01-01", "2025-01-31"
|
||||
)
|
||||
assert result == []
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for duty_teller.services.group_duty_pin_service."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from duty_teller.db.models import Base, Duty, GroupDutyPin, User
|
||||
from duty_teller.db.models import Base, Duty, User
|
||||
from duty_teller.services import group_duty_pin_service as svc
|
||||
|
||||
|
||||
@@ -72,7 +72,9 @@ class TestFormatDutyMessage:
|
||||
mock_t.assert_called_with("en", "duty.no_duty")
|
||||
|
||||
def test_none_user_returns_no_duty(self):
|
||||
duty = SimpleNamespace(start_at="2025-01-15T09:00:00Z", end_at="2025-01-15T18:00:00Z")
|
||||
duty = SimpleNamespace(
|
||||
start_at="2025-01-15T09:00:00Z", end_at="2025-01-15T18:00:00Z"
|
||||
)
|
||||
with patch("duty_teller.services.group_duty_pin_service.t") as mock_t:
|
||||
mock_t.return_value = "No duty"
|
||||
result = svc.format_duty_message(duty, None, "Europe/Moscow", "en")
|
||||
@@ -89,9 +91,7 @@ class TestFormatDutyMessage:
|
||||
username="ivan",
|
||||
)
|
||||
with patch("duty_teller.services.group_duty_pin_service.t") as mock_t:
|
||||
mock_t.side_effect = lambda lang, key: (
|
||||
"Duty" if key == "duty.label" else ""
|
||||
)
|
||||
mock_t.side_effect = lambda lang, key: "Duty" if key == "duty.label" else ""
|
||||
result = svc.format_duty_message(duty, user, "Europe/Moscow", "ru")
|
||||
assert "Иван Иванов" in result
|
||||
assert "+79001234567" in result or "79001234567" in result
|
||||
|
||||
@@ -183,10 +183,14 @@ async def test_calendar_link_with_user_and_token_replies_with_url():
|
||||
"duty_teller.handlers.commands.create_calendar_token",
|
||||
return_value="abc43token",
|
||||
):
|
||||
with patch("duty_teller.handlers.commands.get_lang", return_value="en"):
|
||||
with patch(
|
||||
"duty_teller.handlers.commands.get_lang", return_value="en"
|
||||
):
|
||||
with patch("duty_teller.handlers.commands.t") as mock_t:
|
||||
mock_t.side_effect = lambda lang, key, **kw: (
|
||||
f"URL: {kw.get('url', '')}" if "success" in key else "Hint"
|
||||
f"URL: {kw.get('url', '')}"
|
||||
if "success" in key
|
||||
else "Hint"
|
||||
)
|
||||
await calendar_link(update, MagicMock())
|
||||
message.reply_text.assert_called_once()
|
||||
|
||||
@@ -10,6 +10,7 @@ from duty_teller.handlers.errors import error_handler
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handler_replies_with_generic_message():
|
||||
"""error_handler: when update has effective_message, reply_text with errors.generic."""
|
||||
|
||||
# Handler checks isinstance(update, Update); patch Update so our mock passes.
|
||||
class FakeUpdate:
|
||||
pass
|
||||
|
||||
@@ -70,15 +70,11 @@ async def test_update_group_pin_edits_message_and_schedules_next():
|
||||
context.application.job_queue.get_jobs_by_name = MagicMock(return_value=[])
|
||||
context.application.job_queue.run_once = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
mod, "_sync_get_message_id", return_value=1
|
||||
):
|
||||
with patch.object(mod, "_sync_get_message_id", return_value=1):
|
||||
with patch.object(
|
||||
mod, "_get_duty_message_text_sync", return_value="Current duty"
|
||||
):
|
||||
with patch.object(
|
||||
mod, "_get_next_shift_end_sync", return_value=None
|
||||
):
|
||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
||||
with patch.object(mod, "_schedule_next_update", AsyncMock()):
|
||||
await mod.update_group_pin(context)
|
||||
context.bot.edit_message_text.assert_called_once_with(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from duty_teller.services.import_service import _consecutive_date_ranges
|
||||
|
||||
|
||||
@@ -7,9 +7,10 @@ from sqlalchemy.orm import sessionmaker
|
||||
from duty_teller.db.models import Base, User
|
||||
from duty_teller.db.repository import (
|
||||
delete_duties_in_range,
|
||||
get_duties,
|
||||
get_duties_for_user,
|
||||
get_or_create_user,
|
||||
get_or_create_user_by_full_name,
|
||||
get_duties,
|
||||
insert_duty,
|
||||
update_user_display_name,
|
||||
)
|
||||
@@ -113,6 +114,54 @@ def test_get_duties_includes_duty_starting_on_last_day_of_range(session, user_a)
|
||||
assert rows[0][1] == "User A"
|
||||
|
||||
|
||||
def test_get_duties_for_user_event_types_duty_returns_only_duty(session, user_a):
|
||||
"""get_duties_for_user(..., event_types=["duty"]) returns only duty records."""
|
||||
insert_duty(
|
||||
session,
|
||||
user_a.id,
|
||||
"2026-02-01T09:00:00Z",
|
||||
"2026-02-01T18:00:00Z",
|
||||
event_type="duty",
|
||||
)
|
||||
insert_duty(
|
||||
session,
|
||||
user_a.id,
|
||||
"2026-02-02T09:00:00Z",
|
||||
"2026-02-02T18:00:00Z",
|
||||
event_type="unavailable",
|
||||
)
|
||||
rows = get_duties_for_user(
|
||||
session, user_a.id, "2026-02-01", "2026-02-28", event_types=["duty"]
|
||||
)
|
||||
assert len(rows) == 1
|
||||
assert rows[0][0].event_type == "duty"
|
||||
assert rows[0][1] == "User A"
|
||||
|
||||
|
||||
def test_get_duties_for_user_event_types_none_returns_all(session, user_a):
|
||||
"""get_duties_for_user(..., event_types=None) returns duty and unavailable."""
|
||||
insert_duty(
|
||||
session,
|
||||
user_a.id,
|
||||
"2026-02-01T09:00:00Z",
|
||||
"2026-02-01T18:00:00Z",
|
||||
event_type="duty",
|
||||
)
|
||||
insert_duty(
|
||||
session,
|
||||
user_a.id,
|
||||
"2026-02-02T09:00:00Z",
|
||||
"2026-02-02T18:00:00Z",
|
||||
event_type="unavailable",
|
||||
)
|
||||
rows = get_duties_for_user(
|
||||
session, user_a.id, "2026-02-01", "2026-02-28", event_types=None
|
||||
)
|
||||
assert len(rows) == 2
|
||||
types = {rows[0][0].event_type, rows[1][0].event_type}
|
||||
assert types == {"duty", "unavailable"}
|
||||
|
||||
|
||||
def test_get_or_create_user_overwrites_name_when_flag_false(session):
|
||||
"""When name_manually_edited is False, second get_or_create_user overwrites name."""
|
||||
u1 = get_or_create_user(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for duty_teller.run (main entry point with mocks)."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
Reference in New Issue
Block a user