Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0e8d1453e2 |
44
alembic/versions/008_duties_indexes.py
Normal file
44
alembic/versions/008_duties_indexes.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Add indexes on duties table for performance.
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2025-02-25
|
||||
|
||||
Indexes for get_current_duty, get_next_shift_end, get_duties, get_duties_for_user.
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "008"
|
||||
down_revision: Union[str, None] = "007"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_duties_event_type_start_at",
|
||||
"duties",
|
||||
["event_type", "start_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_duties_event_type_end_at",
|
||||
"duties",
|
||||
["event_type", "end_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_duties_user_id_start_at",
|
||||
"duties",
|
||||
["user_id", "start_at"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_duties_user_id_start_at", table_name="duties")
|
||||
op.drop_index("ix_duties_event_type_end_at", table_name="duties")
|
||||
op.drop_index("ix_duties_event_type_start_at", table_name="duties")
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import date, timedelta
|
||||
|
||||
import duty_teller.config as config
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -18,6 +19,7 @@ from duty_teller.api.dependencies import (
|
||||
require_miniapp_username,
|
||||
)
|
||||
from duty_teller.api.personal_calendar_ics import build_personal_ics, build_team_ics
|
||||
from duty_teller.cache import ics_calendar_cache
|
||||
from duty_teller.db.repository import (
|
||||
get_duties,
|
||||
get_duties_for_user,
|
||||
@@ -116,14 +118,18 @@ def get_team_calendar_ical(
|
||||
user = get_user_by_calendar_token(session, token)
|
||||
if user is None:
|
||||
return Response(status_code=404, content="Not found")
|
||||
today = date.today()
|
||||
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
||||
all_duties = get_duties(session, from_date=from_date, to_date=to_date)
|
||||
duties_duty_only = [
|
||||
(d, name) for d, name in all_duties if (d.event_type or "duty") == "duty"
|
||||
]
|
||||
ics_bytes = build_team_ics(duties_duty_only)
|
||||
cache_key = ("team_ics",)
|
||||
ics_bytes, found = ics_calendar_cache.get(cache_key)
|
||||
if not found:
|
||||
today = date.today()
|
||||
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
||||
all_duties = get_duties(session, from_date=from_date, to_date=to_date)
|
||||
duties_duty_only = [
|
||||
(d, name) for d, name in all_duties if (d.event_type or "duty") == "duty"
|
||||
]
|
||||
ics_bytes = build_team_ics(duties_duty_only)
|
||||
ics_calendar_cache.set(cache_key, ics_bytes)
|
||||
return Response(
|
||||
content=ics_bytes,
|
||||
media_type="text/calendar; charset=utf-8",
|
||||
@@ -151,13 +157,17 @@ def get_personal_calendar_ical(
|
||||
user = get_user_by_calendar_token(session, token)
|
||||
if user is None:
|
||||
return Response(status_code=404, content="Not found")
|
||||
today = date.today()
|
||||
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
||||
duties_with_name = get_duties_for_user(
|
||||
session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"]
|
||||
)
|
||||
ics_bytes = build_personal_ics(duties_with_name)
|
||||
cache_key = ("personal_ics", user.id)
|
||||
ics_bytes, found = ics_calendar_cache.get(cache_key)
|
||||
if not found:
|
||||
today = date.today()
|
||||
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
||||
duties_with_name = get_duties_for_user(
|
||||
session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"]
|
||||
)
|
||||
ics_bytes = build_personal_ics(duties_with_name)
|
||||
ics_calendar_cache.set(cache_key, ics_bytes)
|
||||
return Response(
|
||||
content=ics_bytes,
|
||||
media_type="text/calendar; charset=utf-8",
|
||||
|
||||
@@ -7,12 +7,15 @@ from urllib.error import URLError
|
||||
|
||||
from icalendar import Calendar
|
||||
|
||||
from duty_teller.cache import TTLCache
|
||||
from duty_teller.utils.http_client import safe_urlopen
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# In-memory cache: url -> (cached_at_timestamp, raw_ics_bytes)
|
||||
# Raw ICS bytes cache: url -> (cached_at_timestamp, raw_ics_bytes)
|
||||
_ics_cache: dict[str, tuple[float, bytes]] = {}
|
||||
# Parsed events cache: url -> list of {date, summary}. TTL 7 days.
|
||||
_parsed_events_cache = TTLCache(ttl_seconds=7 * 24 * 3600, max_size=100)
|
||||
CACHE_TTL_SECONDS = 7 * 24 * 3600 # 1 week
|
||||
FETCH_TIMEOUT_SECONDS = 15
|
||||
|
||||
@@ -68,8 +71,8 @@ def _event_date_range(component) -> tuple[date | None, date | None]:
|
||||
return (start_d, last_d)
|
||||
|
||||
|
||||
def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]:
|
||||
"""Parse ICS bytes and return list of {date, summary} in [from_date, to_date]. One-time events only."""
|
||||
def _parse_ics_to_events(raw: bytes) -> list[dict]:
|
||||
"""Parse ICS bytes and return all events as list of {date, summary}. One-time events only."""
|
||||
result: list[dict] = []
|
||||
try:
|
||||
cal = Calendar.from_ical(raw)
|
||||
@@ -79,9 +82,6 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]
|
||||
log.warning("Failed to parse ICS: %s", e)
|
||||
return result
|
||||
|
||||
from_d = date.fromisoformat(from_date)
|
||||
to_d = date.fromisoformat(to_date)
|
||||
|
||||
for component in cal.walk():
|
||||
if component.name != "VEVENT":
|
||||
continue
|
||||
@@ -95,13 +95,27 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]
|
||||
|
||||
d = start_d
|
||||
while d <= end_d:
|
||||
if from_d <= d <= to_d:
|
||||
result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str})
|
||||
result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str})
|
||||
d += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _filter_events_by_range(
|
||||
events: list[dict], from_date: str, to_date: str
|
||||
) -> list[dict]:
|
||||
"""Filter events list to [from_date, to_date] range."""
|
||||
from_d = date.fromisoformat(from_date)
|
||||
to_d = date.fromisoformat(to_date)
|
||||
return [e for e in events if from_d <= date.fromisoformat(e["date"]) <= to_d]
|
||||
|
||||
|
||||
def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]:
|
||||
"""Parse ICS bytes and return events in [from_date, to_date]. Wrapper for tests."""
|
||||
events = _parse_ics_to_events(raw)
|
||||
return _filter_events_by_range(events, from_date, to_date)
|
||||
|
||||
|
||||
def get_calendar_events(
|
||||
url: str,
|
||||
from_date: str,
|
||||
@@ -135,4 +149,10 @@ def get_calendar_events(
|
||||
return []
|
||||
_ics_cache[url] = (now, raw)
|
||||
|
||||
return _get_events_from_ics(raw, from_date, to_date)
|
||||
# Use parsed events cache to avoid repeated Calendar.from_ical() + walk()
|
||||
cache_key = (url,)
|
||||
events, found = _parsed_events_cache.get(cache_key)
|
||||
if not found:
|
||||
events = _parse_ics_to_events(raw)
|
||||
_parsed_events_cache.set(cache_key, events)
|
||||
return _filter_events_by_range(events, from_date, to_date)
|
||||
|
||||
125
duty_teller/cache.py
Normal file
125
duty_teller/cache.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Simple in-memory TTL cache. Thread-safe for get/set."""
|
||||
|
||||
import logging
|
||||
from threading import Lock
|
||||
from time import time
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTLCache:
|
||||
"""Thread-safe TTL cache with optional max size and pattern invalidation."""
|
||||
|
||||
def __init__(self, ttl_seconds: float, max_size: int = 1000) -> None:
|
||||
"""Initialize cache with TTL and optional max size.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live in seconds for each entry.
|
||||
max_size: Maximum number of entries (0 = unlimited). LRU eviction when exceeded.
|
||||
"""
|
||||
self._ttl = ttl_seconds
|
||||
self._max_size = max_size
|
||||
self._data: dict[tuple, tuple[float, object]] = {} # key -> (cached_at, value)
|
||||
self._lock = Lock()
|
||||
self._access_order: list[tuple] = [] # For LRU when max_size > 0
|
||||
|
||||
def get(self, key: tuple) -> tuple[object, bool]:
|
||||
"""Get value by key if present and not expired.
|
||||
|
||||
Args:
|
||||
key: Cache key (must be hashable, typically tuple).
|
||||
|
||||
Returns:
|
||||
(value, found) — found is True if valid cached value exists.
|
||||
"""
|
||||
with self._lock:
|
||||
entry = self._data.get(key)
|
||||
if entry is None:
|
||||
return (None, False)
|
||||
cached_at, value = entry
|
||||
if time() - cached_at >= self._ttl:
|
||||
del self._data[key]
|
||||
if self._max_size > 0 and key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
return (None, False)
|
||||
if self._max_size > 0 and key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
self._access_order.append(key)
|
||||
return (value, True)
|
||||
|
||||
def set(self, key: tuple, value: object) -> None:
|
||||
"""Store value with current timestamp.
|
||||
|
||||
Args:
|
||||
key: Cache key (must be hashable).
|
||||
value: Value to cache.
|
||||
"""
|
||||
with self._lock:
|
||||
now = time()
|
||||
if (
|
||||
self._max_size > 0
|
||||
and len(self._data) >= self._max_size
|
||||
and key not in self._data
|
||||
):
|
||||
# Evict oldest
|
||||
while self._access_order and len(self._data) >= self._max_size:
|
||||
old_key = self._access_order.pop(0)
|
||||
if old_key in self._data:
|
||||
del self._data[old_key]
|
||||
self._data[key] = (now, value)
|
||||
if self._max_size > 0:
|
||||
if key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
self._access_order.append(key)
|
||||
|
||||
def invalidate(self, key: tuple) -> None:
|
||||
"""Remove a single key from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to remove.
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
if self._max_size > 0 and key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries. Useful for tests."""
|
||||
with self._lock:
|
||||
self._data.clear()
|
||||
self._access_order.clear()
|
||||
|
||||
def invalidate_pattern(self, key_prefix: tuple) -> None:
|
||||
"""Remove all keys that start with the given prefix.
|
||||
|
||||
Args:
|
||||
key_prefix: Prefix tuple (e.g. ("personal",) matches ("personal", 1), ("personal", 2)).
|
||||
"""
|
||||
with self._lock:
|
||||
to_remove = [k for k in self._data if self._key_starts_with(k, key_prefix)]
|
||||
for k in to_remove:
|
||||
del self._data[k]
|
||||
if self._max_size > 0 and k in self._access_order:
|
||||
self._access_order.remove(k)
|
||||
|
||||
@staticmethod
|
||||
def _key_starts_with(key: tuple, prefix: tuple) -> bool:
|
||||
"""Check if key starts with prefix (both tuples)."""
|
||||
if len(key) < len(prefix):
|
||||
return False
|
||||
return key[: len(prefix)] == prefix
|
||||
|
||||
|
||||
# Shared caches for duty-related data. Invalidate on import.
|
||||
ics_calendar_cache = TTLCache(ttl_seconds=600, max_size=500)
|
||||
duty_pin_cache = TTLCache(ttl_seconds=90, max_size=100) # current_duty, next_shift_end
|
||||
is_admin_cache = TTLCache(ttl_seconds=60, max_size=200)
|
||||
|
||||
|
||||
def invalidate_duty_related_caches() -> None:
|
||||
"""Invalidate caches that depend on duties data. Call after import."""
|
||||
ics_calendar_cache.invalidate_pattern(("personal_ics",))
|
||||
ics_calendar_cache.invalidate_pattern(("team_ics",))
|
||||
duty_pin_cache.invalidate_pattern(("duty_message_text",))
|
||||
duty_pin_cache.invalidate(("next_shift_end",))
|
||||
@@ -229,6 +229,22 @@ def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
|
||||
return user
|
||||
|
||||
|
||||
def get_users_by_full_names(session: Session, full_names: list[str]) -> dict[str, User]:
|
||||
"""Get users by full_name. Returns dict full_name -> User. Does not create missing.
|
||||
|
||||
Args:
|
||||
session: DB session.
|
||||
full_names: List of full names to look up.
|
||||
|
||||
Returns:
|
||||
Dict mapping full_name to User for found users.
|
||||
"""
|
||||
if not full_names:
|
||||
return {}
|
||||
users = session.query(User).filter(User.full_name.in_(full_names)).all()
|
||||
return {u.full_name: u for u in users}
|
||||
|
||||
|
||||
def update_user_display_name(
|
||||
session: Session,
|
||||
telegram_user_id: int,
|
||||
@@ -268,6 +284,8 @@ def delete_duties_in_range(
|
||||
user_id: int,
|
||||
from_date: str,
|
||||
to_date: str,
|
||||
*,
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Delete all duties of the user that overlap the given date range.
|
||||
|
||||
@@ -276,6 +294,7 @@ def delete_duties_in_range(
|
||||
user_id: User id.
|
||||
from_date: Start date YYYY-MM-DD.
|
||||
to_date: End date YYYY-MM-DD.
|
||||
commit: If True, commit immediately. If False, caller commits (for batch import).
|
||||
|
||||
Returns:
|
||||
Number of duties deleted.
|
||||
@@ -288,7 +307,8 @@ def delete_duties_in_range(
|
||||
)
|
||||
count = q.count()
|
||||
q.delete(synchronize_session=False)
|
||||
session.commit()
|
||||
if commit:
|
||||
session.commit()
|
||||
return count
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from duty_teller.db.repository import (
|
||||
ROLE_USER,
|
||||
ROLE_ADMIN,
|
||||
)
|
||||
from duty_teller.handlers.common import is_admin_async
|
||||
from duty_teller.handlers.common import invalidate_is_admin_cache, is_admin_async
|
||||
from duty_teller.i18n import get_lang, t
|
||||
from duty_teller.utils.user import build_full_name
|
||||
|
||||
@@ -230,6 +230,7 @@ async def set_role(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
|
||||
ok = await asyncio.get_running_loop().run_in_executor(None, do_set_role)
|
||||
if ok:
|
||||
invalidate_is_admin_cache(target_user.telegram_user_id)
|
||||
await update.message.reply_text(
|
||||
t(lang, "set_role.done", name=target_user.full_name, role=role_name)
|
||||
)
|
||||
|
||||
@@ -3,22 +3,35 @@
|
||||
import asyncio
|
||||
|
||||
import duty_teller.config as config
|
||||
from duty_teller.cache import is_admin_cache
|
||||
from duty_teller.db.repository import is_admin_for_telegram_user
|
||||
from duty_teller.db.session import session_scope
|
||||
|
||||
|
||||
async def is_admin_async(telegram_user_id: int) -> bool:
|
||||
"""Check if Telegram user is admin (username or phone). Runs DB check in executor.
|
||||
"""Check if Telegram user is admin. Cached 60s. Invalidated on set_user_role.
|
||||
|
||||
Args:
|
||||
telegram_user_id: Telegram user id.
|
||||
|
||||
Returns:
|
||||
True if user is in ADMIN_USERNAMES or their stored phone is in ADMIN_PHONES.
|
||||
True if user is admin (DB role or env fallback).
|
||||
"""
|
||||
cache_key = ("is_admin", telegram_user_id)
|
||||
value, found = is_admin_cache.get(cache_key)
|
||||
if found:
|
||||
return value
|
||||
|
||||
def _check() -> bool:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return is_admin_for_telegram_user(session, telegram_user_id)
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(None, _check)
|
||||
result = await asyncio.get_running_loop().run_in_executor(None, _check)
|
||||
is_admin_cache.set(cache_key, result)
|
||||
return result
|
||||
|
||||
|
||||
def invalidate_is_admin_cache(telegram_user_id: int | None) -> None:
|
||||
"""Invalidate is_admin cache for user. Call after set_user_role."""
|
||||
if telegram_user_id is not None:
|
||||
is_admin_cache.invalidate(("is_admin", telegram_user_id))
|
||||
|
||||
@@ -15,10 +15,11 @@ from duty_teller.db.session import session_scope
|
||||
from duty_teller.i18n import get_lang, t
|
||||
from duty_teller.services.group_duty_pin_service import (
|
||||
get_duty_message_text,
|
||||
get_message_id,
|
||||
get_next_shift_end_utc,
|
||||
get_pin_refresh_data,
|
||||
save_pin,
|
||||
delete_pin,
|
||||
get_message_id,
|
||||
get_all_pin_chat_ids,
|
||||
)
|
||||
|
||||
@@ -28,6 +29,14 @@ JOB_NAME_PREFIX = "duty_pin_"
|
||||
RETRY_WHEN_NO_DUTY_MINUTES = 15
|
||||
|
||||
|
||||
def _sync_get_pin_refresh_data(
|
||||
chat_id: int, lang: str = "en"
|
||||
) -> tuple[int | None, str, datetime | None]:
|
||||
"""Get message_id, duty text, next_shift_end in one DB session."""
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return get_pin_refresh_data(session, chat_id, config.DUTY_DISPLAY_TZ, lang)
|
||||
|
||||
|
||||
def _get_duty_message_text_sync(lang: str = "en") -> str:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return get_duty_message_text(session, config.DUTY_DISPLAY_TZ, lang)
|
||||
@@ -96,26 +105,27 @@ async def _refresh_pin_for_chat(
|
||||
) -> Literal["updated", "no_message", "failed"]:
|
||||
"""Refresh pinned duty message: send new message, unpin old, pin new, save new message_id.
|
||||
|
||||
Uses single DB session for message_id, text, next_shift_end (consolidated).
|
||||
|
||||
Returns:
|
||||
"updated" if the message was sent, pinned and saved successfully;
|
||||
"no_message" if there is no pin record for this chat;
|
||||
"failed" if send_message or permissions failed.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
message_id = await loop.run_in_executor(None, _sync_get_message_id, chat_id)
|
||||
message_id, text, next_end = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: _sync_get_pin_refresh_data(chat_id, config.DEFAULT_LANGUAGE),
|
||||
)
|
||||
if message_id is None:
|
||||
logger.info("No pin record for chat_id=%s, skipping update", chat_id)
|
||||
return "no_message"
|
||||
text = await loop.run_in_executor(
|
||||
None, lambda: _get_duty_message_text_sync(config.DEFAULT_LANGUAGE)
|
||||
)
|
||||
try:
|
||||
msg = await context.bot.send_message(chat_id=chat_id, text=text)
|
||||
except (BadRequest, Forbidden) as e:
|
||||
logger.warning(
|
||||
"Failed to send duty message for pin refresh chat_id=%s: %s", chat_id, e
|
||||
)
|
||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
||||
await _schedule_next_update(context.application, chat_id, next_end)
|
||||
return "failed"
|
||||
try:
|
||||
@@ -127,11 +137,9 @@ async def _refresh_pin_for_chat(
|
||||
)
|
||||
except (BadRequest, Forbidden) as e:
|
||||
logger.warning("Unpin or pin after refresh failed chat_id=%s: %s", chat_id, e)
|
||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
||||
await _schedule_next_update(context.application, chat_id, next_end)
|
||||
return "failed"
|
||||
await loop.run_in_executor(None, _sync_save_pin, chat_id, msg.message_id)
|
||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
||||
await _schedule_next_update(context.application, chat_id, next_end)
|
||||
return "updated"
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from duty_teller.cache import duty_pin_cache
|
||||
from duty_teller.db.repository import (
|
||||
get_current_duty,
|
||||
get_next_shift_end,
|
||||
@@ -17,6 +18,31 @@ from duty_teller.i18n import t
|
||||
from duty_teller.utils.dates import parse_utc_iso
|
||||
|
||||
|
||||
def get_pin_refresh_data(
|
||||
session: Session, chat_id: int, tz_name: str, lang: str = "en"
|
||||
) -> tuple[int | None, str, datetime | None]:
|
||||
"""Get all data needed for pin refresh in a single DB session.
|
||||
|
||||
Args:
|
||||
session: DB session.
|
||||
chat_id: Telegram chat id.
|
||||
tz_name: Timezone name for display.
|
||||
lang: Language code for i18n.
|
||||
|
||||
Returns:
|
||||
(message_id, duty_message_text, next_shift_end_utc).
|
||||
message_id is None if no pin record. next_shift_end_utc is naive UTC or None.
|
||||
"""
|
||||
pin = get_group_duty_pin(session, chat_id)
|
||||
message_id = pin.message_id if pin else None
|
||||
if message_id is None:
|
||||
return (None, t(lang, "duty.no_duty"), None)
|
||||
now = datetime.now(timezone.utc)
|
||||
text = get_duty_message_text(session, tz_name, lang)
|
||||
next_end = get_next_shift_end(session, now)
|
||||
return (message_id, text, next_end)
|
||||
|
||||
|
||||
def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str:
|
||||
"""Build the text for the pinned duty message.
|
||||
|
||||
@@ -64,34 +90,31 @@ def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str:
|
||||
|
||||
|
||||
def get_duty_message_text(session: Session, tz_name: str, lang: str = "en") -> str:
|
||||
"""Get current duty from DB and return formatted message text.
|
||||
|
||||
Args:
|
||||
session: DB session.
|
||||
tz_name: Timezone name for display.
|
||||
lang: Language code for i18n.
|
||||
|
||||
Returns:
|
||||
Formatted duty message or "No duty" if none.
|
||||
"""
|
||||
"""Get current duty from DB and return formatted message text. Cached 90s."""
|
||||
cache_key = ("duty_message_text", tz_name, lang)
|
||||
text, found = duty_pin_cache.get(cache_key)
|
||||
if found:
|
||||
return text
|
||||
now = datetime.now(timezone.utc)
|
||||
result = get_current_duty(session, now)
|
||||
if result is None:
|
||||
return t(lang, "duty.no_duty")
|
||||
duty, user = result
|
||||
return format_duty_message(duty, user, tz_name, lang)
|
||||
text = t(lang, "duty.no_duty")
|
||||
else:
|
||||
duty, user = result
|
||||
text = format_duty_message(duty, user, tz_name, lang)
|
||||
duty_pin_cache.set(cache_key, text)
|
||||
return text
|
||||
|
||||
|
||||
def get_next_shift_end_utc(session: Session) -> datetime | None:
|
||||
"""Return next shift end as naive UTC datetime for job scheduling.
|
||||
|
||||
Args:
|
||||
session: DB session.
|
||||
|
||||
Returns:
|
||||
Next shift end (naive UTC) or None.
|
||||
"""
|
||||
return get_next_shift_end(session, datetime.now(timezone.utc))
|
||||
"""Return next shift end as naive UTC datetime for job scheduling. Cached 90s."""
|
||||
cache_key = ("next_shift_end",)
|
||||
value, found = duty_pin_cache.get(cache_key)
|
||||
if found:
|
||||
return value
|
||||
result = get_next_shift_end(session, datetime.now(timezone.utc))
|
||||
duty_pin_cache.set(cache_key, result)
|
||||
return result
|
||||
|
||||
|
||||
def save_pin(session: Session, chat_id: int, message_id: int) -> None:
|
||||
|
||||
@@ -4,10 +4,12 @@ from datetime import date, timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from duty_teller.cache import invalidate_duty_related_caches
|
||||
from duty_teller.db.models import Duty
|
||||
from duty_teller.db.repository import (
|
||||
get_or_create_user_by_full_name,
|
||||
delete_duties_in_range,
|
||||
insert_duty,
|
||||
get_or_create_user_by_full_name,
|
||||
get_users_by_full_names,
|
||||
)
|
||||
from duty_teller.importers.duty_schedule import DutyScheduleResult
|
||||
from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso
|
||||
@@ -37,11 +39,10 @@ def run_import(
|
||||
hour_utc: int,
|
||||
minute_utc: int,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""Run duty-schedule import: delete range per user, insert duty/unavailable/vacation.
|
||||
"""Run duty-schedule import: delete range per user, bulk insert duties.
|
||||
|
||||
For each entry: get_or_create_user_by_full_name, delete_duties_in_range for
|
||||
the result date range, then insert duties (handover time in UTC), unavailable
|
||||
(all-day), and vacation (consecutive ranges).
|
||||
Batched: users fetched in one query, missing created; bulk_insert_mappings.
|
||||
One commit at end.
|
||||
|
||||
Args:
|
||||
session: DB session.
|
||||
@@ -55,31 +56,61 @@ def run_import(
|
||||
from_date_str = result.start_date.isoformat()
|
||||
to_date_str = result.end_date.isoformat()
|
||||
num_duty = num_unavailable = num_vacation = 0
|
||||
|
||||
# Batch: get all users by full_name, create missing
|
||||
names = [e.full_name for e in result.entries]
|
||||
users_map = get_users_by_full_names(session, names)
|
||||
for name in names:
|
||||
if name not in users_map:
|
||||
users_map[name] = get_or_create_user_by_full_name(session, name)
|
||||
|
||||
# Delete range per user (no commit)
|
||||
for entry in result.entries:
|
||||
user = get_or_create_user_by_full_name(session, entry.full_name)
|
||||
delete_duties_in_range(session, user.id, from_date_str, to_date_str)
|
||||
user = users_map[entry.full_name]
|
||||
delete_duties_in_range(
|
||||
session, user.id, from_date_str, to_date_str, commit=False
|
||||
)
|
||||
|
||||
# Build rows for bulk insert
|
||||
duty_rows: list[dict] = []
|
||||
for entry in result.entries:
|
||||
user = users_map[entry.full_name]
|
||||
for d in entry.duty_dates:
|
||||
start_at = duty_to_iso(d, hour_utc, minute_utc)
|
||||
d_next = d + timedelta(days=1)
|
||||
end_at = duty_to_iso(d_next, hour_utc, minute_utc)
|
||||
insert_duty(session, user.id, start_at, end_at, event_type="duty")
|
||||
duty_rows.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"start_at": start_at,
|
||||
"end_at": end_at,
|
||||
"event_type": "duty",
|
||||
}
|
||||
)
|
||||
num_duty += 1
|
||||
for d in entry.unavailable_dates:
|
||||
insert_duty(
|
||||
session,
|
||||
user.id,
|
||||
day_start_iso(d),
|
||||
day_end_iso(d),
|
||||
event_type="unavailable",
|
||||
duty_rows.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"start_at": day_start_iso(d),
|
||||
"end_at": day_end_iso(d),
|
||||
"event_type": "unavailable",
|
||||
}
|
||||
)
|
||||
num_unavailable += 1
|
||||
for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates):
|
||||
insert_duty(
|
||||
session,
|
||||
user.id,
|
||||
day_start_iso(start_d),
|
||||
day_end_iso(end_d),
|
||||
event_type="vacation",
|
||||
duty_rows.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"start_at": day_start_iso(start_d),
|
||||
"end_at": day_end_iso(end_d),
|
||||
"event_type": "vacation",
|
||||
}
|
||||
)
|
||||
num_vacation += 1
|
||||
|
||||
if duty_rows:
|
||||
session.bulk_insert_mappings(Duty, duty_rows)
|
||||
session.commit()
|
||||
invalidate_duty_related_caches()
|
||||
return (len(result.entries), num_duty, num_unavailable, num_vacation)
|
||||
|
||||
@@ -403,6 +403,9 @@ def test_calendar_ical_ignores_unknown_query_params(
|
||||
"""Unknown query params (e.g. events=all) are ignored; response is duty-only."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from duty_teller.cache import ics_calendar_cache
|
||||
|
||||
ics_calendar_cache.invalidate(("personal_ics", 1))
|
||||
mock_user = SimpleNamespace(id=1, full_name="User A")
|
||||
mock_get_user.return_value = mock_user
|
||||
duty = SimpleNamespace(
|
||||
|
||||
@@ -103,6 +103,7 @@ class TestGetDutyMessageText:
|
||||
"""Tests for get_duty_message_text."""
|
||||
|
||||
def test_no_current_duty_returns_no_duty(self, session):
|
||||
svc.duty_pin_cache.invalidate_pattern(("duty_message_text",))
|
||||
with patch(
|
||||
"duty_teller.services.group_duty_pin_service.get_current_duty",
|
||||
return_value=None,
|
||||
@@ -113,6 +114,7 @@ class TestGetDutyMessageText:
|
||||
assert result == "No duty"
|
||||
|
||||
def test_with_current_duty_returns_formatted(self, session, duty, user):
|
||||
svc.duty_pin_cache.invalidate_pattern(("duty_message_text",))
|
||||
with patch(
|
||||
"duty_teller.services.group_duty_pin_service.get_current_duty",
|
||||
return_value=(duty, user),
|
||||
@@ -130,6 +132,7 @@ class TestGetNextShiftEndUtc:
|
||||
"""Tests for get_next_shift_end_utc."""
|
||||
|
||||
def test_no_next_shift_returns_none(self, session):
|
||||
svc.duty_pin_cache.invalidate(("next_shift_end",))
|
||||
with patch(
|
||||
"duty_teller.services.group_duty_pin_service.get_next_shift_end",
|
||||
return_value=None,
|
||||
@@ -138,6 +141,7 @@ class TestGetNextShiftEndUtc:
|
||||
assert result is None
|
||||
|
||||
def test_has_next_shift_returns_naive_utc(self, session):
|
||||
svc.duty_pin_cache.invalidate(("next_shift_end",))
|
||||
naive = datetime(2025, 2, 21, 6, 0, 0)
|
||||
with patch(
|
||||
"duty_teller.services.group_duty_pin_service.get_next_shift_end",
|
||||
|
||||
@@ -143,14 +143,12 @@ async def test_update_group_pin_sends_new_unpins_pins_saves_schedules_next():
|
||||
context.application.job_queue.run_once = MagicMock()
|
||||
|
||||
with patch.object(config, "DUTY_PIN_NOTIFY", True):
|
||||
with patch.object(mod, "_sync_get_message_id", return_value=1):
|
||||
with patch.object(
|
||||
mod, "_get_duty_message_text_sync", return_value="Current duty"
|
||||
):
|
||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
||||
with patch.object(mod, "_schedule_next_update", AsyncMock()):
|
||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||
await mod.update_group_pin(context)
|
||||
with patch.object(
|
||||
mod, "_sync_get_pin_refresh_data", return_value=(1, "Current duty", None)
|
||||
):
|
||||
with patch.object(mod, "_schedule_next_update", AsyncMock()):
|
||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||
await mod.update_group_pin(context)
|
||||
context.bot.send_message.assert_called_once_with(chat_id=123, text="Current duty")
|
||||
context.bot.unpin_chat_message.assert_called_once_with(chat_id=123)
|
||||
context.bot.pin_chat_message.assert_called_once_with(
|
||||
@@ -168,7 +166,9 @@ async def test_update_group_pin_no_message_id_skips():
|
||||
context.bot = MagicMock()
|
||||
context.bot.send_message = AsyncMock()
|
||||
|
||||
with patch.object(mod, "_sync_get_message_id", return_value=None):
|
||||
with patch.object(
|
||||
mod, "_sync_get_pin_refresh_data", return_value=(None, "No duty", None)
|
||||
):
|
||||
await mod.update_group_pin(context)
|
||||
context.bot.send_message.assert_not_called()
|
||||
|
||||
@@ -185,13 +185,11 @@ async def test_update_group_pin_send_raises_no_unpin_pin_schedule_still_called()
|
||||
context.bot.pin_chat_message = AsyncMock()
|
||||
context.application = MagicMock()
|
||||
|
||||
with patch.object(mod, "_sync_get_message_id", return_value=2):
|
||||
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"):
|
||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
||||
with patch.object(
|
||||
mod, "_schedule_next_update", AsyncMock()
|
||||
) as mock_schedule:
|
||||
await mod.update_group_pin(context)
|
||||
with patch.object(
|
||||
mod, "_sync_get_pin_refresh_data", return_value=(2, "Text", None)
|
||||
):
|
||||
with patch.object(mod, "_schedule_next_update", AsyncMock()) as mock_schedule:
|
||||
await mod.update_group_pin(context)
|
||||
context.bot.unpin_chat_message.assert_not_called()
|
||||
context.bot.pin_chat_message.assert_not_called()
|
||||
mock_schedule.assert_called_once_with(context.application, 111, None)
|
||||
@@ -214,15 +212,15 @@ async def test_update_group_pin_repin_raises_still_schedules_next():
|
||||
context.application = MagicMock()
|
||||
|
||||
with patch.object(config, "DUTY_PIN_NOTIFY", True):
|
||||
with patch.object(mod, "_sync_get_message_id", return_value=3):
|
||||
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"):
|
||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
||||
with patch.object(
|
||||
mod, "_schedule_next_update", AsyncMock()
|
||||
) as mock_schedule:
|
||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||
with patch.object(mod, "logger") as mock_logger:
|
||||
await mod.update_group_pin(context)
|
||||
with patch.object(
|
||||
mod, "_sync_get_pin_refresh_data", return_value=(3, "Text", None)
|
||||
):
|
||||
with patch.object(
|
||||
mod, "_schedule_next_update", AsyncMock()
|
||||
) as mock_schedule:
|
||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||
with patch.object(mod, "logger") as mock_logger:
|
||||
await mod.update_group_pin(context)
|
||||
context.bot.send_message.assert_called_once_with(chat_id=222, text="Text")
|
||||
mock_save.assert_not_called()
|
||||
mock_logger.warning.assert_called_once()
|
||||
@@ -245,14 +243,14 @@ async def test_update_group_pin_duty_pin_notify_false_pins_silent():
|
||||
context.application = MagicMock()
|
||||
|
||||
with patch.object(config, "DUTY_PIN_NOTIFY", False):
|
||||
with patch.object(mod, "_sync_get_message_id", return_value=4):
|
||||
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"):
|
||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
||||
with patch.object(
|
||||
mod, "_schedule_next_update", AsyncMock()
|
||||
) as mock_schedule:
|
||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||
await mod.update_group_pin(context)
|
||||
with patch.object(
|
||||
mod, "_sync_get_pin_refresh_data", return_value=(4, "Text", None)
|
||||
):
|
||||
with patch.object(
|
||||
mod, "_schedule_next_update", AsyncMock()
|
||||
) as mock_schedule:
|
||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||
await mod.update_group_pin(context)
|
||||
context.bot.send_message.assert_called_once_with(chat_id=333, text="Text")
|
||||
context.bot.unpin_chat_message.assert_called_once_with(chat_id=333)
|
||||
context.bot.pin_chat_message.assert_called_once_with(
|
||||
|
||||
Reference in New Issue
Block a user