feat: implement caching for duty-related data and enhance performance
All checks were successful
CI / lint-and-test (push) Successful in 24s
Docker Build and Release / build-and-push (push) Successful in 49s
Docker Build and Release / release (push) Successful in 8s

- Added a TTLCache class for in-memory caching of duty-related data, improving performance by reducing database queries.
- Integrated caching into the group duty pin functionality, allowing for efficient retrieval of message text and next shift end times.
- Introduced new methods to invalidate caches when relevant data changes, ensuring data consistency.
- Created a new Alembic migration to add indexes on the duties table for improved query performance.
- Updated tests to cover the new caching behavior and ensure proper functionality.
This commit is contained in:
2026-02-25 13:25:34 +03:00
parent 5334a4aeac
commit 0e8d1453e2
14 changed files with 413 additions and 113 deletions

BIN
.coverage

Binary file not shown.

View File

@@ -0,0 +1,44 @@
"""Add indexes on duties table for performance.
Revision ID: 008
Revises: 007
Create Date: 2025-02-25
Indexes for get_current_duty, get_next_shift_end, get_duties, get_duties_for_user.
"""
from typing import Sequence, Union
from alembic import op
revision: str = "008"
down_revision: Union[str, None] = "007"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_index(
"ix_duties_event_type_start_at",
"duties",
["event_type", "start_at"],
unique=False,
)
op.create_index(
"ix_duties_event_type_end_at",
"duties",
["event_type", "end_at"],
unique=False,
)
op.create_index(
"ix_duties_user_id_start_at",
"duties",
["user_id", "start_at"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_duties_user_id_start_at", table_name="duties")
op.drop_index("ix_duties_event_type_end_at", table_name="duties")
op.drop_index("ix_duties_event_type_start_at", table_name="duties")

View File

@@ -3,6 +3,7 @@
import logging import logging
import re import re
from datetime import date, timedelta from datetime import date, timedelta
import duty_teller.config as config import duty_teller.config as config
from fastapi import Depends, FastAPI, Request from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -18,6 +19,7 @@ from duty_teller.api.dependencies import (
require_miniapp_username, require_miniapp_username,
) )
from duty_teller.api.personal_calendar_ics import build_personal_ics, build_team_ics from duty_teller.api.personal_calendar_ics import build_personal_ics, build_team_ics
from duty_teller.cache import ics_calendar_cache
from duty_teller.db.repository import ( from duty_teller.db.repository import (
get_duties, get_duties,
get_duties_for_user, get_duties_for_user,
@@ -116,14 +118,18 @@ def get_team_calendar_ical(
user = get_user_by_calendar_token(session, token) user = get_user_by_calendar_token(session, token)
if user is None: if user is None:
return Response(status_code=404, content="Not found") return Response(status_code=404, content="Not found")
today = date.today() cache_key = ("team_ics",)
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d") ics_bytes, found = ics_calendar_cache.get(cache_key)
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d") if not found:
all_duties = get_duties(session, from_date=from_date, to_date=to_date) today = date.today()
duties_duty_only = [ from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
(d, name) for d, name in all_duties if (d.event_type or "duty") == "duty" to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
] all_duties = get_duties(session, from_date=from_date, to_date=to_date)
ics_bytes = build_team_ics(duties_duty_only) duties_duty_only = [
(d, name) for d, name in all_duties if (d.event_type or "duty") == "duty"
]
ics_bytes = build_team_ics(duties_duty_only)
ics_calendar_cache.set(cache_key, ics_bytes)
return Response( return Response(
content=ics_bytes, content=ics_bytes,
media_type="text/calendar; charset=utf-8", media_type="text/calendar; charset=utf-8",
@@ -151,13 +157,17 @@ def get_personal_calendar_ical(
user = get_user_by_calendar_token(session, token) user = get_user_by_calendar_token(session, token)
if user is None: if user is None:
return Response(status_code=404, content="Not found") return Response(status_code=404, content="Not found")
today = date.today() cache_key = ("personal_ics", user.id)
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d") ics_bytes, found = ics_calendar_cache.get(cache_key)
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d") if not found:
duties_with_name = get_duties_for_user( today = date.today()
session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"] from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
) to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
ics_bytes = build_personal_ics(duties_with_name) duties_with_name = get_duties_for_user(
session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"]
)
ics_bytes = build_personal_ics(duties_with_name)
ics_calendar_cache.set(cache_key, ics_bytes)
return Response( return Response(
content=ics_bytes, content=ics_bytes,
media_type="text/calendar; charset=utf-8", media_type="text/calendar; charset=utf-8",

View File

@@ -7,12 +7,15 @@ from urllib.error import URLError
from icalendar import Calendar from icalendar import Calendar
from duty_teller.cache import TTLCache
from duty_teller.utils.http_client import safe_urlopen from duty_teller.utils.http_client import safe_urlopen
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# In-memory cache: url -> (cached_at_timestamp, raw_ics_bytes) # Raw ICS bytes cache: url -> (cached_at_timestamp, raw_ics_bytes)
_ics_cache: dict[str, tuple[float, bytes]] = {} _ics_cache: dict[str, tuple[float, bytes]] = {}
# Parsed events cache: url -> list of {date, summary}. TTL 7 days.
_parsed_events_cache = TTLCache(ttl_seconds=7 * 24 * 3600, max_size=100)
CACHE_TTL_SECONDS = 7 * 24 * 3600 # 1 week CACHE_TTL_SECONDS = 7 * 24 * 3600 # 1 week
FETCH_TIMEOUT_SECONDS = 15 FETCH_TIMEOUT_SECONDS = 15
@@ -68,8 +71,8 @@ def _event_date_range(component) -> tuple[date | None, date | None]:
return (start_d, last_d) return (start_d, last_d)
def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]: def _parse_ics_to_events(raw: bytes) -> list[dict]:
"""Parse ICS bytes and return list of {date, summary} in [from_date, to_date]. One-time events only.""" """Parse ICS bytes and return all events as list of {date, summary}. One-time events only."""
result: list[dict] = [] result: list[dict] = []
try: try:
cal = Calendar.from_ical(raw) cal = Calendar.from_ical(raw)
@@ -79,9 +82,6 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]
log.warning("Failed to parse ICS: %s", e) log.warning("Failed to parse ICS: %s", e)
return result return result
from_d = date.fromisoformat(from_date)
to_d = date.fromisoformat(to_date)
for component in cal.walk(): for component in cal.walk():
if component.name != "VEVENT": if component.name != "VEVENT":
continue continue
@@ -95,13 +95,27 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]
d = start_d d = start_d
while d <= end_d: while d <= end_d:
if from_d <= d <= to_d: result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str})
result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str})
d += timedelta(days=1) d += timedelta(days=1)
return result return result
def _filter_events_by_range(
events: list[dict], from_date: str, to_date: str
) -> list[dict]:
"""Filter events list to [from_date, to_date] range."""
from_d = date.fromisoformat(from_date)
to_d = date.fromisoformat(to_date)
return [e for e in events if from_d <= date.fromisoformat(e["date"]) <= to_d]
def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]:
"""Parse ICS bytes and return events in [from_date, to_date]. Wrapper for tests."""
events = _parse_ics_to_events(raw)
return _filter_events_by_range(events, from_date, to_date)
def get_calendar_events( def get_calendar_events(
url: str, url: str,
from_date: str, from_date: str,
@@ -135,4 +149,10 @@ def get_calendar_events(
return [] return []
_ics_cache[url] = (now, raw) _ics_cache[url] = (now, raw)
return _get_events_from_ics(raw, from_date, to_date) # Use parsed events cache to avoid repeated Calendar.from_ical() + walk()
cache_key = (url,)
events, found = _parsed_events_cache.get(cache_key)
if not found:
events = _parse_ics_to_events(raw)
_parsed_events_cache.set(cache_key, events)
return _filter_events_by_range(events, from_date, to_date)

125
duty_teller/cache.py Normal file
View File

@@ -0,0 +1,125 @@
"""Simple in-memory TTL cache. Thread-safe for get/set."""
import logging
from threading import Lock
from time import time
log = logging.getLogger(__name__)
class TTLCache:
"""Thread-safe TTL cache with optional max size and pattern invalidation."""
def __init__(self, ttl_seconds: float, max_size: int = 1000) -> None:
"""Initialize cache with TTL and optional max size.
Args:
ttl_seconds: Time-to-live in seconds for each entry.
max_size: Maximum number of entries (0 = unlimited). LRU eviction when exceeded.
"""
self._ttl = ttl_seconds
self._max_size = max_size
self._data: dict[tuple, tuple[float, object]] = {} # key -> (cached_at, value)
self._lock = Lock()
self._access_order: list[tuple] = [] # For LRU when max_size > 0
def get(self, key: tuple) -> tuple[object, bool]:
"""Get value by key if present and not expired.
Args:
key: Cache key (must be hashable, typically tuple).
Returns:
(value, found) — found is True if valid cached value exists.
"""
with self._lock:
entry = self._data.get(key)
if entry is None:
return (None, False)
cached_at, value = entry
if time() - cached_at >= self._ttl:
del self._data[key]
if self._max_size > 0 and key in self._access_order:
self._access_order.remove(key)
return (None, False)
if self._max_size > 0 and key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
return (value, True)
def set(self, key: tuple, value: object) -> None:
"""Store value with current timestamp.
Args:
key: Cache key (must be hashable).
value: Value to cache.
"""
with self._lock:
now = time()
if (
self._max_size > 0
and len(self._data) >= self._max_size
and key not in self._data
):
# Evict oldest
while self._access_order and len(self._data) >= self._max_size:
old_key = self._access_order.pop(0)
if old_key in self._data:
del self._data[old_key]
self._data[key] = (now, value)
if self._max_size > 0:
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
def invalidate(self, key: tuple) -> None:
"""Remove a single key from cache.
Args:
key: Cache key to remove.
"""
with self._lock:
if key in self._data:
del self._data[key]
if self._max_size > 0 and key in self._access_order:
self._access_order.remove(key)
def clear(self) -> None:
"""Remove all entries. Useful for tests."""
with self._lock:
self._data.clear()
self._access_order.clear()
def invalidate_pattern(self, key_prefix: tuple) -> None:
"""Remove all keys that start with the given prefix.
Args:
key_prefix: Prefix tuple (e.g. ("personal",) matches ("personal", 1), ("personal", 2)).
"""
with self._lock:
to_remove = [k for k in self._data if self._key_starts_with(k, key_prefix)]
for k in to_remove:
del self._data[k]
if self._max_size > 0 and k in self._access_order:
self._access_order.remove(k)
@staticmethod
def _key_starts_with(key: tuple, prefix: tuple) -> bool:
"""Check if key starts with prefix (both tuples)."""
if len(key) < len(prefix):
return False
return key[: len(prefix)] == prefix
# Shared caches for duty-related data. Invalidate on import.
ics_calendar_cache = TTLCache(ttl_seconds=600, max_size=500)
duty_pin_cache = TTLCache(ttl_seconds=90, max_size=100) # current_duty, next_shift_end
is_admin_cache = TTLCache(ttl_seconds=60, max_size=200)
def invalidate_duty_related_caches() -> None:
"""Invalidate caches that depend on duties data. Call after import."""
ics_calendar_cache.invalidate_pattern(("personal_ics",))
ics_calendar_cache.invalidate_pattern(("team_ics",))
duty_pin_cache.invalidate_pattern(("duty_message_text",))
duty_pin_cache.invalidate(("next_shift_end",))

View File

@@ -229,6 +229,22 @@ def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
return user return user
def get_users_by_full_names(session: Session, full_names: list[str]) -> dict[str, User]:
"""Get users by full_name. Returns dict full_name -> User. Does not create missing.
Args:
session: DB session.
full_names: List of full names to look up.
Returns:
Dict mapping full_name to User for found users.
"""
if not full_names:
return {}
users = session.query(User).filter(User.full_name.in_(full_names)).all()
return {u.full_name: u for u in users}
def update_user_display_name( def update_user_display_name(
session: Session, session: Session,
telegram_user_id: int, telegram_user_id: int,
@@ -268,6 +284,8 @@ def delete_duties_in_range(
user_id: int, user_id: int,
from_date: str, from_date: str,
to_date: str, to_date: str,
*,
commit: bool = True,
) -> int: ) -> int:
"""Delete all duties of the user that overlap the given date range. """Delete all duties of the user that overlap the given date range.
@@ -276,6 +294,7 @@ def delete_duties_in_range(
user_id: User id. user_id: User id.
from_date: Start date YYYY-MM-DD. from_date: Start date YYYY-MM-DD.
to_date: End date YYYY-MM-DD. to_date: End date YYYY-MM-DD.
commit: If True, commit immediately. If False, caller commits (for batch import).
Returns: Returns:
Number of duties deleted. Number of duties deleted.
@@ -288,7 +307,8 @@ def delete_duties_in_range(
) )
count = q.count() count = q.count()
q.delete(synchronize_session=False) q.delete(synchronize_session=False)
session.commit() if commit:
session.commit()
return count return count

View File

@@ -19,7 +19,7 @@ from duty_teller.db.repository import (
ROLE_USER, ROLE_USER,
ROLE_ADMIN, ROLE_ADMIN,
) )
from duty_teller.handlers.common import is_admin_async from duty_teller.handlers.common import invalidate_is_admin_cache, is_admin_async
from duty_teller.i18n import get_lang, t from duty_teller.i18n import get_lang, t
from duty_teller.utils.user import build_full_name from duty_teller.utils.user import build_full_name
@@ -230,6 +230,7 @@ async def set_role(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
ok = await asyncio.get_running_loop().run_in_executor(None, do_set_role) ok = await asyncio.get_running_loop().run_in_executor(None, do_set_role)
if ok: if ok:
invalidate_is_admin_cache(target_user.telegram_user_id)
await update.message.reply_text( await update.message.reply_text(
t(lang, "set_role.done", name=target_user.full_name, role=role_name) t(lang, "set_role.done", name=target_user.full_name, role=role_name)
) )

View File

@@ -3,22 +3,35 @@
import asyncio import asyncio
import duty_teller.config as config import duty_teller.config as config
from duty_teller.cache import is_admin_cache
from duty_teller.db.repository import is_admin_for_telegram_user from duty_teller.db.repository import is_admin_for_telegram_user
from duty_teller.db.session import session_scope from duty_teller.db.session import session_scope
async def is_admin_async(telegram_user_id: int) -> bool: async def is_admin_async(telegram_user_id: int) -> bool:
"""Check if Telegram user is admin (username or phone). Runs DB check in executor. """Check if Telegram user is admin. Cached 60s. Invalidated on set_user_role.
Args: Args:
telegram_user_id: Telegram user id. telegram_user_id: Telegram user id.
Returns: Returns:
True if user is in ADMIN_USERNAMES or their stored phone is in ADMIN_PHONES. True if user is admin (DB role or env fallback).
""" """
cache_key = ("is_admin", telegram_user_id)
value, found = is_admin_cache.get(cache_key)
if found:
return value
def _check() -> bool: def _check() -> bool:
with session_scope(config.DATABASE_URL) as session: with session_scope(config.DATABASE_URL) as session:
return is_admin_for_telegram_user(session, telegram_user_id) return is_admin_for_telegram_user(session, telegram_user_id)
return await asyncio.get_running_loop().run_in_executor(None, _check) result = await asyncio.get_running_loop().run_in_executor(None, _check)
is_admin_cache.set(cache_key, result)
return result
def invalidate_is_admin_cache(telegram_user_id: int | None) -> None:
"""Invalidate is_admin cache for user. Call after set_user_role."""
if telegram_user_id is not None:
is_admin_cache.invalidate(("is_admin", telegram_user_id))

View File

@@ -15,10 +15,11 @@ from duty_teller.db.session import session_scope
from duty_teller.i18n import get_lang, t from duty_teller.i18n import get_lang, t
from duty_teller.services.group_duty_pin_service import ( from duty_teller.services.group_duty_pin_service import (
get_duty_message_text, get_duty_message_text,
get_message_id,
get_next_shift_end_utc, get_next_shift_end_utc,
get_pin_refresh_data,
save_pin, save_pin,
delete_pin, delete_pin,
get_message_id,
get_all_pin_chat_ids, get_all_pin_chat_ids,
) )
@@ -28,6 +29,14 @@ JOB_NAME_PREFIX = "duty_pin_"
RETRY_WHEN_NO_DUTY_MINUTES = 15 RETRY_WHEN_NO_DUTY_MINUTES = 15
def _sync_get_pin_refresh_data(
chat_id: int, lang: str = "en"
) -> tuple[int | None, str, datetime | None]:
"""Get message_id, duty text, next_shift_end in one DB session."""
with session_scope(config.DATABASE_URL) as session:
return get_pin_refresh_data(session, chat_id, config.DUTY_DISPLAY_TZ, lang)
def _get_duty_message_text_sync(lang: str = "en") -> str: def _get_duty_message_text_sync(lang: str = "en") -> str:
with session_scope(config.DATABASE_URL) as session: with session_scope(config.DATABASE_URL) as session:
return get_duty_message_text(session, config.DUTY_DISPLAY_TZ, lang) return get_duty_message_text(session, config.DUTY_DISPLAY_TZ, lang)
@@ -96,26 +105,27 @@ async def _refresh_pin_for_chat(
) -> Literal["updated", "no_message", "failed"]: ) -> Literal["updated", "no_message", "failed"]:
"""Refresh pinned duty message: send new message, unpin old, pin new, save new message_id. """Refresh pinned duty message: send new message, unpin old, pin new, save new message_id.
Uses single DB session for message_id, text, next_shift_end (consolidated).
Returns: Returns:
"updated" if the message was sent, pinned and saved successfully; "updated" if the message was sent, pinned and saved successfully;
"no_message" if there is no pin record for this chat; "no_message" if there is no pin record for this chat;
"failed" if send_message or permissions failed. "failed" if send_message or permissions failed.
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
message_id = await loop.run_in_executor(None, _sync_get_message_id, chat_id) message_id, text, next_end = await loop.run_in_executor(
None,
lambda: _sync_get_pin_refresh_data(chat_id, config.DEFAULT_LANGUAGE),
)
if message_id is None: if message_id is None:
logger.info("No pin record for chat_id=%s, skipping update", chat_id) logger.info("No pin record for chat_id=%s, skipping update", chat_id)
return "no_message" return "no_message"
text = await loop.run_in_executor(
None, lambda: _get_duty_message_text_sync(config.DEFAULT_LANGUAGE)
)
try: try:
msg = await context.bot.send_message(chat_id=chat_id, text=text) msg = await context.bot.send_message(chat_id=chat_id, text=text)
except (BadRequest, Forbidden) as e: except (BadRequest, Forbidden) as e:
logger.warning( logger.warning(
"Failed to send duty message for pin refresh chat_id=%s: %s", chat_id, e "Failed to send duty message for pin refresh chat_id=%s: %s", chat_id, e
) )
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
await _schedule_next_update(context.application, chat_id, next_end) await _schedule_next_update(context.application, chat_id, next_end)
return "failed" return "failed"
try: try:
@@ -127,11 +137,9 @@ async def _refresh_pin_for_chat(
) )
except (BadRequest, Forbidden) as e: except (BadRequest, Forbidden) as e:
logger.warning("Unpin or pin after refresh failed chat_id=%s: %s", chat_id, e) logger.warning("Unpin or pin after refresh failed chat_id=%s: %s", chat_id, e)
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
await _schedule_next_update(context.application, chat_id, next_end) await _schedule_next_update(context.application, chat_id, next_end)
return "failed" return "failed"
await loop.run_in_executor(None, _sync_save_pin, chat_id, msg.message_id) await loop.run_in_executor(None, _sync_save_pin, chat_id, msg.message_id)
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
await _schedule_next_update(context.application, chat_id, next_end) await _schedule_next_update(context.application, chat_id, next_end)
return "updated" return "updated"

View File

@@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from duty_teller.cache import duty_pin_cache
from duty_teller.db.repository import ( from duty_teller.db.repository import (
get_current_duty, get_current_duty,
get_next_shift_end, get_next_shift_end,
@@ -17,6 +18,31 @@ from duty_teller.i18n import t
from duty_teller.utils.dates import parse_utc_iso from duty_teller.utils.dates import parse_utc_iso
def get_pin_refresh_data(
session: Session, chat_id: int, tz_name: str, lang: str = "en"
) -> tuple[int | None, str, datetime | None]:
"""Get all data needed for pin refresh in a single DB session.
Args:
session: DB session.
chat_id: Telegram chat id.
tz_name: Timezone name for display.
lang: Language code for i18n.
Returns:
(message_id, duty_message_text, next_shift_end_utc).
message_id is None if no pin record. next_shift_end_utc is naive UTC or None.
"""
pin = get_group_duty_pin(session, chat_id)
message_id = pin.message_id if pin else None
if message_id is None:
return (None, t(lang, "duty.no_duty"), None)
now = datetime.now(timezone.utc)
text = get_duty_message_text(session, tz_name, lang)
next_end = get_next_shift_end(session, now)
return (message_id, text, next_end)
def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str: def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str:
"""Build the text for the pinned duty message. """Build the text for the pinned duty message.
@@ -64,34 +90,31 @@ def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str:
def get_duty_message_text(session: Session, tz_name: str, lang: str = "en") -> str: def get_duty_message_text(session: Session, tz_name: str, lang: str = "en") -> str:
"""Get current duty from DB and return formatted message text. """Get current duty from DB and return formatted message text. Cached 90s."""
cache_key = ("duty_message_text", tz_name, lang)
Args: text, found = duty_pin_cache.get(cache_key)
session: DB session. if found:
tz_name: Timezone name for display. return text
lang: Language code for i18n.
Returns:
Formatted duty message or "No duty" if none.
"""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
result = get_current_duty(session, now) result = get_current_duty(session, now)
if result is None: if result is None:
return t(lang, "duty.no_duty") text = t(lang, "duty.no_duty")
duty, user = result else:
return format_duty_message(duty, user, tz_name, lang) duty, user = result
text = format_duty_message(duty, user, tz_name, lang)
duty_pin_cache.set(cache_key, text)
return text
def get_next_shift_end_utc(session: Session) -> datetime | None: def get_next_shift_end_utc(session: Session) -> datetime | None:
"""Return next shift end as naive UTC datetime for job scheduling. """Return next shift end as naive UTC datetime for job scheduling. Cached 90s."""
cache_key = ("next_shift_end",)
Args: value, found = duty_pin_cache.get(cache_key)
session: DB session. if found:
return value
Returns: result = get_next_shift_end(session, datetime.now(timezone.utc))
Next shift end (naive UTC) or None. duty_pin_cache.set(cache_key, result)
""" return result
return get_next_shift_end(session, datetime.now(timezone.utc))
def save_pin(session: Session, chat_id: int, message_id: int) -> None: def save_pin(session: Session, chat_id: int, message_id: int) -> None:

View File

@@ -4,10 +4,12 @@ from datetime import date, timedelta
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from duty_teller.cache import invalidate_duty_related_caches
from duty_teller.db.models import Duty
from duty_teller.db.repository import ( from duty_teller.db.repository import (
get_or_create_user_by_full_name,
delete_duties_in_range, delete_duties_in_range,
insert_duty, get_or_create_user_by_full_name,
get_users_by_full_names,
) )
from duty_teller.importers.duty_schedule import DutyScheduleResult from duty_teller.importers.duty_schedule import DutyScheduleResult
from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso
@@ -37,11 +39,10 @@ def run_import(
hour_utc: int, hour_utc: int,
minute_utc: int, minute_utc: int,
) -> tuple[int, int, int, int]: ) -> tuple[int, int, int, int]:
"""Run duty-schedule import: delete range per user, insert duty/unavailable/vacation. """Run duty-schedule import: delete range per user, bulk insert duties.
For each entry: get_or_create_user_by_full_name, delete_duties_in_range for Batched: users fetched in one query, missing created; bulk_insert_mappings.
the result date range, then insert duties (handover time in UTC), unavailable One commit at end.
(all-day), and vacation (consecutive ranges).
Args: Args:
session: DB session. session: DB session.
@@ -55,31 +56,61 @@ def run_import(
from_date_str = result.start_date.isoformat() from_date_str = result.start_date.isoformat()
to_date_str = result.end_date.isoformat() to_date_str = result.end_date.isoformat()
num_duty = num_unavailable = num_vacation = 0 num_duty = num_unavailable = num_vacation = 0
# Batch: get all users by full_name, create missing
names = [e.full_name for e in result.entries]
users_map = get_users_by_full_names(session, names)
for name in names:
if name not in users_map:
users_map[name] = get_or_create_user_by_full_name(session, name)
# Delete range per user (no commit)
for entry in result.entries: for entry in result.entries:
user = get_or_create_user_by_full_name(session, entry.full_name) user = users_map[entry.full_name]
delete_duties_in_range(session, user.id, from_date_str, to_date_str) delete_duties_in_range(
session, user.id, from_date_str, to_date_str, commit=False
)
# Build rows for bulk insert
duty_rows: list[dict] = []
for entry in result.entries:
user = users_map[entry.full_name]
for d in entry.duty_dates: for d in entry.duty_dates:
start_at = duty_to_iso(d, hour_utc, minute_utc) start_at = duty_to_iso(d, hour_utc, minute_utc)
d_next = d + timedelta(days=1) d_next = d + timedelta(days=1)
end_at = duty_to_iso(d_next, hour_utc, minute_utc) end_at = duty_to_iso(d_next, hour_utc, minute_utc)
insert_duty(session, user.id, start_at, end_at, event_type="duty") duty_rows.append(
{
"user_id": user.id,
"start_at": start_at,
"end_at": end_at,
"event_type": "duty",
}
)
num_duty += 1 num_duty += 1
for d in entry.unavailable_dates: for d in entry.unavailable_dates:
insert_duty( duty_rows.append(
session, {
user.id, "user_id": user.id,
day_start_iso(d), "start_at": day_start_iso(d),
day_end_iso(d), "end_at": day_end_iso(d),
event_type="unavailable", "event_type": "unavailable",
}
) )
num_unavailable += 1 num_unavailable += 1
for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates): for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates):
insert_duty( duty_rows.append(
session, {
user.id, "user_id": user.id,
day_start_iso(start_d), "start_at": day_start_iso(start_d),
day_end_iso(end_d), "end_at": day_end_iso(end_d),
event_type="vacation", "event_type": "vacation",
}
) )
num_vacation += 1 num_vacation += 1
if duty_rows:
session.bulk_insert_mappings(Duty, duty_rows)
session.commit()
invalidate_duty_related_caches()
return (len(result.entries), num_duty, num_unavailable, num_vacation) return (len(result.entries), num_duty, num_unavailable, num_vacation)

View File

@@ -403,6 +403,9 @@ def test_calendar_ical_ignores_unknown_query_params(
"""Unknown query params (e.g. events=all) are ignored; response is duty-only.""" """Unknown query params (e.g. events=all) are ignored; response is duty-only."""
from types import SimpleNamespace from types import SimpleNamespace
from duty_teller.cache import ics_calendar_cache
ics_calendar_cache.invalidate(("personal_ics", 1))
mock_user = SimpleNamespace(id=1, full_name="User A") mock_user = SimpleNamespace(id=1, full_name="User A")
mock_get_user.return_value = mock_user mock_get_user.return_value = mock_user
duty = SimpleNamespace( duty = SimpleNamespace(

View File

@@ -103,6 +103,7 @@ class TestGetDutyMessageText:
"""Tests for get_duty_message_text.""" """Tests for get_duty_message_text."""
def test_no_current_duty_returns_no_duty(self, session): def test_no_current_duty_returns_no_duty(self, session):
svc.duty_pin_cache.invalidate_pattern(("duty_message_text",))
with patch( with patch(
"duty_teller.services.group_duty_pin_service.get_current_duty", "duty_teller.services.group_duty_pin_service.get_current_duty",
return_value=None, return_value=None,
@@ -113,6 +114,7 @@ class TestGetDutyMessageText:
assert result == "No duty" assert result == "No duty"
def test_with_current_duty_returns_formatted(self, session, duty, user): def test_with_current_duty_returns_formatted(self, session, duty, user):
svc.duty_pin_cache.invalidate_pattern(("duty_message_text",))
with patch( with patch(
"duty_teller.services.group_duty_pin_service.get_current_duty", "duty_teller.services.group_duty_pin_service.get_current_duty",
return_value=(duty, user), return_value=(duty, user),
@@ -130,6 +132,7 @@ class TestGetNextShiftEndUtc:
"""Tests for get_next_shift_end_utc.""" """Tests for get_next_shift_end_utc."""
def test_no_next_shift_returns_none(self, session): def test_no_next_shift_returns_none(self, session):
svc.duty_pin_cache.invalidate(("next_shift_end",))
with patch( with patch(
"duty_teller.services.group_duty_pin_service.get_next_shift_end", "duty_teller.services.group_duty_pin_service.get_next_shift_end",
return_value=None, return_value=None,
@@ -138,6 +141,7 @@ class TestGetNextShiftEndUtc:
assert result is None assert result is None
def test_has_next_shift_returns_naive_utc(self, session): def test_has_next_shift_returns_naive_utc(self, session):
svc.duty_pin_cache.invalidate(("next_shift_end",))
naive = datetime(2025, 2, 21, 6, 0, 0) naive = datetime(2025, 2, 21, 6, 0, 0)
with patch( with patch(
"duty_teller.services.group_duty_pin_service.get_next_shift_end", "duty_teller.services.group_duty_pin_service.get_next_shift_end",

View File

@@ -143,14 +143,12 @@ async def test_update_group_pin_sends_new_unpins_pins_saves_schedules_next():
context.application.job_queue.run_once = MagicMock() context.application.job_queue.run_once = MagicMock()
with patch.object(config, "DUTY_PIN_NOTIFY", True): with patch.object(config, "DUTY_PIN_NOTIFY", True):
with patch.object(mod, "_sync_get_message_id", return_value=1): with patch.object(
with patch.object( mod, "_sync_get_pin_refresh_data", return_value=(1, "Current duty", None)
mod, "_get_duty_message_text_sync", return_value="Current duty" ):
): with patch.object(mod, "_schedule_next_update", AsyncMock()):
with patch.object(mod, "_get_next_shift_end_sync", return_value=None): with patch.object(mod, "_sync_save_pin") as mock_save:
with patch.object(mod, "_schedule_next_update", AsyncMock()): await mod.update_group_pin(context)
with patch.object(mod, "_sync_save_pin") as mock_save:
await mod.update_group_pin(context)
context.bot.send_message.assert_called_once_with(chat_id=123, text="Current duty") context.bot.send_message.assert_called_once_with(chat_id=123, text="Current duty")
context.bot.unpin_chat_message.assert_called_once_with(chat_id=123) context.bot.unpin_chat_message.assert_called_once_with(chat_id=123)
context.bot.pin_chat_message.assert_called_once_with( context.bot.pin_chat_message.assert_called_once_with(
@@ -168,7 +166,9 @@ async def test_update_group_pin_no_message_id_skips():
context.bot = MagicMock() context.bot = MagicMock()
context.bot.send_message = AsyncMock() context.bot.send_message = AsyncMock()
with patch.object(mod, "_sync_get_message_id", return_value=None): with patch.object(
mod, "_sync_get_pin_refresh_data", return_value=(None, "No duty", None)
):
await mod.update_group_pin(context) await mod.update_group_pin(context)
context.bot.send_message.assert_not_called() context.bot.send_message.assert_not_called()
@@ -185,13 +185,11 @@ async def test_update_group_pin_send_raises_no_unpin_pin_schedule_still_called()
context.bot.pin_chat_message = AsyncMock() context.bot.pin_chat_message = AsyncMock()
context.application = MagicMock() context.application = MagicMock()
with patch.object(mod, "_sync_get_message_id", return_value=2): with patch.object(
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"): mod, "_sync_get_pin_refresh_data", return_value=(2, "Text", None)
with patch.object(mod, "_get_next_shift_end_sync", return_value=None): ):
with patch.object( with patch.object(mod, "_schedule_next_update", AsyncMock()) as mock_schedule:
mod, "_schedule_next_update", AsyncMock() await mod.update_group_pin(context)
) as mock_schedule:
await mod.update_group_pin(context)
context.bot.unpin_chat_message.assert_not_called() context.bot.unpin_chat_message.assert_not_called()
context.bot.pin_chat_message.assert_not_called() context.bot.pin_chat_message.assert_not_called()
mock_schedule.assert_called_once_with(context.application, 111, None) mock_schedule.assert_called_once_with(context.application, 111, None)
@@ -214,15 +212,15 @@ async def test_update_group_pin_repin_raises_still_schedules_next():
context.application = MagicMock() context.application = MagicMock()
with patch.object(config, "DUTY_PIN_NOTIFY", True): with patch.object(config, "DUTY_PIN_NOTIFY", True):
with patch.object(mod, "_sync_get_message_id", return_value=3): with patch.object(
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"): mod, "_sync_get_pin_refresh_data", return_value=(3, "Text", None)
with patch.object(mod, "_get_next_shift_end_sync", return_value=None): ):
with patch.object( with patch.object(
mod, "_schedule_next_update", AsyncMock() mod, "_schedule_next_update", AsyncMock()
) as mock_schedule: ) as mock_schedule:
with patch.object(mod, "_sync_save_pin") as mock_save: with patch.object(mod, "_sync_save_pin") as mock_save:
with patch.object(mod, "logger") as mock_logger: with patch.object(mod, "logger") as mock_logger:
await mod.update_group_pin(context) await mod.update_group_pin(context)
context.bot.send_message.assert_called_once_with(chat_id=222, text="Text") context.bot.send_message.assert_called_once_with(chat_id=222, text="Text")
mock_save.assert_not_called() mock_save.assert_not_called()
mock_logger.warning.assert_called_once() mock_logger.warning.assert_called_once()
@@ -245,14 +243,14 @@ async def test_update_group_pin_duty_pin_notify_false_pins_silent():
context.application = MagicMock() context.application = MagicMock()
with patch.object(config, "DUTY_PIN_NOTIFY", False): with patch.object(config, "DUTY_PIN_NOTIFY", False):
with patch.object(mod, "_sync_get_message_id", return_value=4): with patch.object(
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"): mod, "_sync_get_pin_refresh_data", return_value=(4, "Text", None)
with patch.object(mod, "_get_next_shift_end_sync", return_value=None): ):
with patch.object( with patch.object(
mod, "_schedule_next_update", AsyncMock() mod, "_schedule_next_update", AsyncMock()
) as mock_schedule: ) as mock_schedule:
with patch.object(mod, "_sync_save_pin") as mock_save: with patch.object(mod, "_sync_save_pin") as mock_save:
await mod.update_group_pin(context) await mod.update_group_pin(context)
context.bot.send_message.assert_called_once_with(chat_id=333, text="Text") context.bot.send_message.assert_called_once_with(chat_id=333, text="Text")
context.bot.unpin_chat_message.assert_called_once_with(chat_id=333) context.bot.unpin_chat_message.assert_called_once_with(chat_id=333)
context.bot.pin_chat_message.assert_called_once_with( context.bot.pin_chat_message.assert_called_once_with(