From 0e8d1453e26043ea0b6d5555ab438bbe6626a163 Mon Sep 17 00:00:00 2001 From: Nikolay Tatarinov Date: Wed, 25 Feb 2026 13:25:34 +0300 Subject: [PATCH] feat: implement caching for duty-related data and enhance performance - Added a TTLCache class for in-memory caching of duty-related data, improving performance by reducing database queries. - Integrated caching into the group duty pin functionality, allowing for efficient retrieval of message text and next shift end times. - Introduced new methods to invalidate caches when relevant data changes, ensuring data consistency. - Created a new Alembic migration to add indexes on the duties table for improved query performance. - Updated tests to cover the new caching behavior and ensure proper functionality. --- .coverage | Bin 53248 -> 53248 bytes alembic/versions/008_duties_indexes.py | 44 ++++++ duty_teller/api/app.py | 40 +++--- duty_teller/api/calendar_ics.py | 38 ++++-- duty_teller/cache.py | 125 ++++++++++++++++++ duty_teller/db/repository.py | 22 ++- duty_teller/handlers/commands.py | 3 +- duty_teller/handlers/common.py | 19 ++- duty_teller/handlers/group_duty_pin.py | 24 ++-- .../services/group_duty_pin_service.py | 67 +++++++--- duty_teller/services/import_service.py | 73 +++++++--- tests/test_app.py | 3 + tests/test_group_duty_pin_service.py | 4 + tests/test_handlers_group_duty_pin.py | 64 +++++---- 14 files changed, 413 insertions(+), 113 deletions(-) create mode 100644 alembic/versions/008_duties_indexes.py create mode 100644 duty_teller/cache.py diff --git a/.coverage b/.coverage index 6586674399b1d9a3d04fd35db2800c73b93dafe6..215be4f7dea2bcecefd94c2f97a2da4c52593acd 100644 GIT binary patch delta 1480 zcmZ9KZERCj9LCS(-rM`qp8N0G+jf21x^|FkY_QDg*gz};vbGZ#VnD$#5hftW2y-~b zupJUUfI9UWn2T&nj3Jr`=z`!fUC}B^bw~O*US(>$xY&2S42Wp6CB^ z&drmP>}HeQ>@hCFhJ<;V5GnLy`GTB~XG?dbqta^0Azl&N#Kj^O&I^sgyq+v;+QgTJ zShPv|qN+gL*0^QH)*Xw2+UHe^`8*D_+DW-oD~;Ot0C6mJJEO9;AX=vNMP(*OyRoQP z%S0WFpQ^H$gYjvLVo}~p=0n=jSgCd?=3zXvtHvF?``=RwwC=d;^CXESW73*dJjRC>bFH}=^FpEaYgK>`Mmds*;R>Q2J}<(NL=3krYSZRFH;W`9 zF=Tqx43daSvN2RyOy&KL_L^fJ-dDqsWQ@IYu^{86Q)*o0J&&eHGLHWzw2{-_6!mF?x!Y~!=O22Yc z=~6yWwkR(tE0iL|u9&foGx!=FzQ6(Ew~OZ#S3vcE+LuxVMz!fj*+ZcR~R`+ zgwxppXw)v}8QRuyd2EQW1&E$zt@tE7DD`vgDDu=(D$KrT{qh$!@Zb}n-*&7$G- zxc)G`B-i67)hE0w*;@!j(403CcLq)!p3O-4Nhu=uj8%_ZljS8MKM}g$d{9rH>N$A# z=jkSght!JP6~mpgt5Ndx#EFRm6MB_wa+9)>J;jwxSR&rn^+A1o%Vf)tuCLd-#?A{1KQ zNUcGo*Hp^nCp%Fm8f`%BXjJOgAF^m%zau5}Cz)^p^-s32{bTn~%RLsIu8pJdl$1=} zBP9+Y#E>tL^d%o~)25;ZtC?I9GiinW4Bp_)uZ2^}< zB-J8EZ_y~0dEd8KERgo)h9WQ8c4 zK?++Qg*89{xfH$}3a_8STvrH`1+(bnDKtB|4dJCVWY@3yJ5%m&6sqAHq?gQZN|4F&s3M@pJt9d~>hGoWpaE z$U3x}tP5t5a5>MVv#2ditNH^1`LNu{d8w1Q?97*2NhnZYtZQltH|<&3xZm#~Re=&R z5-7mWh8r4oy!c$)UFRl80{KL&FmbNs`N;}_JW*j~Qb=t@6_ZRpsVL=~|8`fCoS=nq zkTtf z4X+v2^D+Jd>59}P)k?D1BxdpZgd4&!VT0)vQ<1UX$QT!C=vAsed;#pS4JBvf5OE>{jMXTdHW> zvs+5JFIItnf0+ z2xiM=Mogz7$=(HMh7-MUZSA^rf0D7LQT>p)xaZL^VSqh>O8kR5-V+}JG1r_*h3tCI zxA6Boil}O=z~QEjFJ&u}h&rRu7<)CdB*7X$i4oovpGg+0EGblE$a5{n)w_L}y~qFh zDMxQlre3bIcoz#{b2@VIDc-^Ta(aOfZ%0_t6PGypQS}weX z^|tsW?t&wc(w9E%*`toD%5C*m<^K6o>WK+vg_Sm*^0q{ou01~%seK>!4>3wqQFnb_ zvoJZkFf=iBZA?x3-=nX7&)F}6iPm0I_E_%=V_TkXuhWH$ z93!-O$j}>}kb(vl=h5tG7&z`nS7SYFL_LLOPx(=62q9#q2PF(MpM`1@e7sqt56Q-^ zy1YLZpa1!X;+)-BppIKiF(er1WuMnT@;kkxq0>#*rlT`2-@($p&(?BmB(*y1d-caW3pTdNkKx)4FQa?WyYv(ci9LQ#y5IrqdRX!G<6P_aF-2!CnZ% zPS^pL;A{92K7x1P9QYs$Y#_rexB)I0g{`m&9*4D10Nr#k--JVO5Dvfwx@g1FPCfmO zvB( None: + op.create_index( + "ix_duties_event_type_start_at", + "duties", + ["event_type", "start_at"], + unique=False, + ) + op.create_index( + "ix_duties_event_type_end_at", + "duties", + ["event_type", "end_at"], + unique=False, + ) + op.create_index( + "ix_duties_user_id_start_at", + "duties", + ["user_id", "start_at"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index("ix_duties_user_id_start_at", table_name="duties") + op.drop_index("ix_duties_event_type_end_at", table_name="duties") + op.drop_index("ix_duties_event_type_start_at", table_name="duties") diff --git a/duty_teller/api/app.py b/duty_teller/api/app.py index fcf667a..9770b16 100644 --- a/duty_teller/api/app.py +++ b/duty_teller/api/app.py @@ -3,6 +3,7 @@ import logging import re from datetime import date, timedelta + import duty_teller.config as config from fastapi import Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -18,6 +19,7 @@ from duty_teller.api.dependencies import ( require_miniapp_username, ) from duty_teller.api.personal_calendar_ics import build_personal_ics, build_team_ics +from duty_teller.cache import ics_calendar_cache from duty_teller.db.repository import ( get_duties, get_duties_for_user, @@ -116,14 +118,18 @@ def get_team_calendar_ical( user = get_user_by_calendar_token(session, token) if user is None: return Response(status_code=404, content="Not found") - today = date.today() - from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d") - to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d") - all_duties = get_duties(session, from_date=from_date, to_date=to_date) - duties_duty_only = [ - (d, name) for d, name in all_duties if (d.event_type or "duty") == "duty" - ] - ics_bytes = build_team_ics(duties_duty_only) + cache_key = ("team_ics",) + ics_bytes, found = ics_calendar_cache.get(cache_key) + if not found: + today = date.today() + from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d") + to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d") + all_duties = get_duties(session, from_date=from_date, to_date=to_date) + duties_duty_only = [ + (d, name) for d, name in all_duties if (d.event_type or "duty") == "duty" + ] + ics_bytes = build_team_ics(duties_duty_only) + ics_calendar_cache.set(cache_key, ics_bytes) return Response( content=ics_bytes, media_type="text/calendar; charset=utf-8", @@ -151,13 +157,17 @@ def get_personal_calendar_ical( user = get_user_by_calendar_token(session, token) if user is None: return Response(status_code=404, content="Not found") - today = date.today() - from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d") - to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d") - duties_with_name = get_duties_for_user( - session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"] - ) - ics_bytes = build_personal_ics(duties_with_name) + cache_key = ("personal_ics", user.id) + ics_bytes, found = ics_calendar_cache.get(cache_key) + if not found: + today = date.today() + from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d") + to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d") + duties_with_name = get_duties_for_user( + session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"] + ) + ics_bytes = build_personal_ics(duties_with_name) + ics_calendar_cache.set(cache_key, ics_bytes) return Response( content=ics_bytes, media_type="text/calendar; charset=utf-8", diff --git a/duty_teller/api/calendar_ics.py b/duty_teller/api/calendar_ics.py index 31f8320..074e532 100644 --- a/duty_teller/api/calendar_ics.py +++ b/duty_teller/api/calendar_ics.py @@ -7,12 +7,15 @@ from urllib.error import URLError from icalendar import Calendar +from duty_teller.cache import TTLCache from duty_teller.utils.http_client import safe_urlopen log = logging.getLogger(__name__) -# In-memory cache: url -> (cached_at_timestamp, raw_ics_bytes) +# Raw ICS bytes cache: url -> (cached_at_timestamp, raw_ics_bytes) _ics_cache: dict[str, tuple[float, bytes]] = {} +# Parsed events cache: url -> list of {date, summary}. TTL 7 days. +_parsed_events_cache = TTLCache(ttl_seconds=7 * 24 * 3600, max_size=100) CACHE_TTL_SECONDS = 7 * 24 * 3600 # 1 week FETCH_TIMEOUT_SECONDS = 15 @@ -68,8 +71,8 @@ def _event_date_range(component) -> tuple[date | None, date | None]: return (start_d, last_d) -def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]: - """Parse ICS bytes and return list of {date, summary} in [from_date, to_date]. One-time events only.""" +def _parse_ics_to_events(raw: bytes) -> list[dict]: + """Parse ICS bytes and return all events as list of {date, summary}. One-time events only.""" result: list[dict] = [] try: cal = Calendar.from_ical(raw) @@ -79,9 +82,6 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict] log.warning("Failed to parse ICS: %s", e) return result - from_d = date.fromisoformat(from_date) - to_d = date.fromisoformat(to_date) - for component in cal.walk(): if component.name != "VEVENT": continue @@ -95,13 +95,27 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict] d = start_d while d <= end_d: - if from_d <= d <= to_d: - result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str}) + result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str}) d += timedelta(days=1) return result +def _filter_events_by_range( + events: list[dict], from_date: str, to_date: str +) -> list[dict]: + """Filter events list to [from_date, to_date] range.""" + from_d = date.fromisoformat(from_date) + to_d = date.fromisoformat(to_date) + return [e for e in events if from_d <= date.fromisoformat(e["date"]) <= to_d] + + +def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]: + """Parse ICS bytes and return events in [from_date, to_date]. Wrapper for tests.""" + events = _parse_ics_to_events(raw) + return _filter_events_by_range(events, from_date, to_date) + + def get_calendar_events( url: str, from_date: str, @@ -135,4 +149,10 @@ def get_calendar_events( return [] _ics_cache[url] = (now, raw) - return _get_events_from_ics(raw, from_date, to_date) + # Use parsed events cache to avoid repeated Calendar.from_ical() + walk() + cache_key = (url,) + events, found = _parsed_events_cache.get(cache_key) + if not found: + events = _parse_ics_to_events(raw) + _parsed_events_cache.set(cache_key, events) + return _filter_events_by_range(events, from_date, to_date) diff --git a/duty_teller/cache.py b/duty_teller/cache.py new file mode 100644 index 0000000..89a5320 --- /dev/null +++ b/duty_teller/cache.py @@ -0,0 +1,125 @@ +"""Simple in-memory TTL cache. Thread-safe for get/set.""" + +import logging +from threading import Lock +from time import time + +log = logging.getLogger(__name__) + + +class TTLCache: + """Thread-safe TTL cache with optional max size and pattern invalidation.""" + + def __init__(self, ttl_seconds: float, max_size: int = 1000) -> None: + """Initialize cache with TTL and optional max size. + + Args: + ttl_seconds: Time-to-live in seconds for each entry. + max_size: Maximum number of entries (0 = unlimited). LRU eviction when exceeded. + """ + self._ttl = ttl_seconds + self._max_size = max_size + self._data: dict[tuple, tuple[float, object]] = {} # key -> (cached_at, value) + self._lock = Lock() + self._access_order: list[tuple] = [] # For LRU when max_size > 0 + + def get(self, key: tuple) -> tuple[object, bool]: + """Get value by key if present and not expired. + + Args: + key: Cache key (must be hashable, typically tuple). + + Returns: + (value, found) — found is True if valid cached value exists. + """ + with self._lock: + entry = self._data.get(key) + if entry is None: + return (None, False) + cached_at, value = entry + if time() - cached_at >= self._ttl: + del self._data[key] + if self._max_size > 0 and key in self._access_order: + self._access_order.remove(key) + return (None, False) + if self._max_size > 0 and key in self._access_order: + self._access_order.remove(key) + self._access_order.append(key) + return (value, True) + + def set(self, key: tuple, value: object) -> None: + """Store value with current timestamp. + + Args: + key: Cache key (must be hashable). + value: Value to cache. + """ + with self._lock: + now = time() + if ( + self._max_size > 0 + and len(self._data) >= self._max_size + and key not in self._data + ): + # Evict oldest + while self._access_order and len(self._data) >= self._max_size: + old_key = self._access_order.pop(0) + if old_key in self._data: + del self._data[old_key] + self._data[key] = (now, value) + if self._max_size > 0: + if key in self._access_order: + self._access_order.remove(key) + self._access_order.append(key) + + def invalidate(self, key: tuple) -> None: + """Remove a single key from cache. + + Args: + key: Cache key to remove. + """ + with self._lock: + if key in self._data: + del self._data[key] + if self._max_size > 0 and key in self._access_order: + self._access_order.remove(key) + + def clear(self) -> None: + """Remove all entries. Useful for tests.""" + with self._lock: + self._data.clear() + self._access_order.clear() + + def invalidate_pattern(self, key_prefix: tuple) -> None: + """Remove all keys that start with the given prefix. + + Args: + key_prefix: Prefix tuple (e.g. ("personal",) matches ("personal", 1), ("personal", 2)). + """ + with self._lock: + to_remove = [k for k in self._data if self._key_starts_with(k, key_prefix)] + for k in to_remove: + del self._data[k] + if self._max_size > 0 and k in self._access_order: + self._access_order.remove(k) + + @staticmethod + def _key_starts_with(key: tuple, prefix: tuple) -> bool: + """Check if key starts with prefix (both tuples).""" + if len(key) < len(prefix): + return False + return key[: len(prefix)] == prefix + + +# Shared caches for duty-related data. Invalidate on import. +ics_calendar_cache = TTLCache(ttl_seconds=600, max_size=500) +duty_pin_cache = TTLCache(ttl_seconds=90, max_size=100) # current_duty, next_shift_end +is_admin_cache = TTLCache(ttl_seconds=60, max_size=200) + + +def invalidate_duty_related_caches() -> None: + """Invalidate caches that depend on duties data. Call after import.""" + ics_calendar_cache.invalidate_pattern(("personal_ics",)) + ics_calendar_cache.invalidate_pattern(("team_ics",)) + duty_pin_cache.invalidate_pattern(("duty_message_text",)) + duty_pin_cache.invalidate(("next_shift_end",)) diff --git a/duty_teller/db/repository.py b/duty_teller/db/repository.py index 6c9b39c..83884e8 100644 --- a/duty_teller/db/repository.py +++ b/duty_teller/db/repository.py @@ -229,6 +229,22 @@ def get_or_create_user_by_full_name(session: Session, full_name: str) -> User: return user +def get_users_by_full_names(session: Session, full_names: list[str]) -> dict[str, User]: + """Get users by full_name. Returns dict full_name -> User. Does not create missing. + + Args: + session: DB session. + full_names: List of full names to look up. + + Returns: + Dict mapping full_name to User for found users. + """ + if not full_names: + return {} + users = session.query(User).filter(User.full_name.in_(full_names)).all() + return {u.full_name: u for u in users} + + def update_user_display_name( session: Session, telegram_user_id: int, @@ -268,6 +284,8 @@ def delete_duties_in_range( user_id: int, from_date: str, to_date: str, + *, + commit: bool = True, ) -> int: """Delete all duties of the user that overlap the given date range. @@ -276,6 +294,7 @@ def delete_duties_in_range( user_id: User id. from_date: Start date YYYY-MM-DD. to_date: End date YYYY-MM-DD. + commit: If True, commit immediately. If False, caller commits (for batch import). Returns: Number of duties deleted. @@ -288,7 +307,8 @@ def delete_duties_in_range( ) count = q.count() q.delete(synchronize_session=False) - session.commit() + if commit: + session.commit() return count diff --git a/duty_teller/handlers/commands.py b/duty_teller/handlers/commands.py index 335d8fa..6533058 100644 --- a/duty_teller/handlers/commands.py +++ b/duty_teller/handlers/commands.py @@ -19,7 +19,7 @@ from duty_teller.db.repository import ( ROLE_USER, ROLE_ADMIN, ) -from duty_teller.handlers.common import is_admin_async +from duty_teller.handlers.common import invalidate_is_admin_cache, is_admin_async from duty_teller.i18n import get_lang, t from duty_teller.utils.user import build_full_name @@ -230,6 +230,7 @@ async def set_role(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: ok = await asyncio.get_running_loop().run_in_executor(None, do_set_role) if ok: + invalidate_is_admin_cache(target_user.telegram_user_id) await update.message.reply_text( t(lang, "set_role.done", name=target_user.full_name, role=role_name) ) diff --git a/duty_teller/handlers/common.py b/duty_teller/handlers/common.py index 076a5a9..e239ff4 100644 --- a/duty_teller/handlers/common.py +++ b/duty_teller/handlers/common.py @@ -3,22 +3,35 @@ import asyncio import duty_teller.config as config +from duty_teller.cache import is_admin_cache from duty_teller.db.repository import is_admin_for_telegram_user from duty_teller.db.session import session_scope async def is_admin_async(telegram_user_id: int) -> bool: - """Check if Telegram user is admin (username or phone). Runs DB check in executor. + """Check if Telegram user is admin. Cached 60s. Invalidated on set_user_role. Args: telegram_user_id: Telegram user id. Returns: - True if user is in ADMIN_USERNAMES or their stored phone is in ADMIN_PHONES. + True if user is admin (DB role or env fallback). """ + cache_key = ("is_admin", telegram_user_id) + value, found = is_admin_cache.get(cache_key) + if found: + return value def _check() -> bool: with session_scope(config.DATABASE_URL) as session: return is_admin_for_telegram_user(session, telegram_user_id) - return await asyncio.get_running_loop().run_in_executor(None, _check) + result = await asyncio.get_running_loop().run_in_executor(None, _check) + is_admin_cache.set(cache_key, result) + return result + + +def invalidate_is_admin_cache(telegram_user_id: int | None) -> None: + """Invalidate is_admin cache for user. Call after set_user_role.""" + if telegram_user_id is not None: + is_admin_cache.invalidate(("is_admin", telegram_user_id)) diff --git a/duty_teller/handlers/group_duty_pin.py b/duty_teller/handlers/group_duty_pin.py index 4df910f..74805e5 100644 --- a/duty_teller/handlers/group_duty_pin.py +++ b/duty_teller/handlers/group_duty_pin.py @@ -15,10 +15,11 @@ from duty_teller.db.session import session_scope from duty_teller.i18n import get_lang, t from duty_teller.services.group_duty_pin_service import ( get_duty_message_text, + get_message_id, get_next_shift_end_utc, + get_pin_refresh_data, save_pin, delete_pin, - get_message_id, get_all_pin_chat_ids, ) @@ -28,6 +29,14 @@ JOB_NAME_PREFIX = "duty_pin_" RETRY_WHEN_NO_DUTY_MINUTES = 15 +def _sync_get_pin_refresh_data( + chat_id: int, lang: str = "en" +) -> tuple[int | None, str, datetime | None]: + """Get message_id, duty text, next_shift_end in one DB session.""" + with session_scope(config.DATABASE_URL) as session: + return get_pin_refresh_data(session, chat_id, config.DUTY_DISPLAY_TZ, lang) + + def _get_duty_message_text_sync(lang: str = "en") -> str: with session_scope(config.DATABASE_URL) as session: return get_duty_message_text(session, config.DUTY_DISPLAY_TZ, lang) @@ -96,26 +105,27 @@ async def _refresh_pin_for_chat( ) -> Literal["updated", "no_message", "failed"]: """Refresh pinned duty message: send new message, unpin old, pin new, save new message_id. + Uses single DB session for message_id, text, next_shift_end (consolidated). + Returns: "updated" if the message was sent, pinned and saved successfully; "no_message" if there is no pin record for this chat; "failed" if send_message or permissions failed. """ loop = asyncio.get_running_loop() - message_id = await loop.run_in_executor(None, _sync_get_message_id, chat_id) + message_id, text, next_end = await loop.run_in_executor( + None, + lambda: _sync_get_pin_refresh_data(chat_id, config.DEFAULT_LANGUAGE), + ) if message_id is None: logger.info("No pin record for chat_id=%s, skipping update", chat_id) return "no_message" - text = await loop.run_in_executor( - None, lambda: _get_duty_message_text_sync(config.DEFAULT_LANGUAGE) - ) try: msg = await context.bot.send_message(chat_id=chat_id, text=text) except (BadRequest, Forbidden) as e: logger.warning( "Failed to send duty message for pin refresh chat_id=%s: %s", chat_id, e ) - next_end = await loop.run_in_executor(None, _get_next_shift_end_sync) await _schedule_next_update(context.application, chat_id, next_end) return "failed" try: @@ -127,11 +137,9 @@ async def _refresh_pin_for_chat( ) except (BadRequest, Forbidden) as e: logger.warning("Unpin or pin after refresh failed chat_id=%s: %s", chat_id, e) - next_end = await loop.run_in_executor(None, _get_next_shift_end_sync) await _schedule_next_update(context.application, chat_id, next_end) return "failed" await loop.run_in_executor(None, _sync_save_pin, chat_id, msg.message_id) - next_end = await loop.run_in_executor(None, _get_next_shift_end_sync) await _schedule_next_update(context.application, chat_id, next_end) return "updated" diff --git a/duty_teller/services/group_duty_pin_service.py b/duty_teller/services/group_duty_pin_service.py index 07e2106..42f914f 100644 --- a/duty_teller/services/group_duty_pin_service.py +++ b/duty_teller/services/group_duty_pin_service.py @@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from sqlalchemy.orm import Session +from duty_teller.cache import duty_pin_cache from duty_teller.db.repository import ( get_current_duty, get_next_shift_end, @@ -17,6 +18,31 @@ from duty_teller.i18n import t from duty_teller.utils.dates import parse_utc_iso +def get_pin_refresh_data( + session: Session, chat_id: int, tz_name: str, lang: str = "en" +) -> tuple[int | None, str, datetime | None]: + """Get all data needed for pin refresh in a single DB session. + + Args: + session: DB session. + chat_id: Telegram chat id. + tz_name: Timezone name for display. + lang: Language code for i18n. + + Returns: + (message_id, duty_message_text, next_shift_end_utc). + message_id is None if no pin record. next_shift_end_utc is naive UTC or None. + """ + pin = get_group_duty_pin(session, chat_id) + message_id = pin.message_id if pin else None + if message_id is None: + return (None, t(lang, "duty.no_duty"), None) + now = datetime.now(timezone.utc) + text = get_duty_message_text(session, tz_name, lang) + next_end = get_next_shift_end(session, now) + return (message_id, text, next_end) + + def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str: """Build the text for the pinned duty message. @@ -64,34 +90,31 @@ def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str: def get_duty_message_text(session: Session, tz_name: str, lang: str = "en") -> str: - """Get current duty from DB and return formatted message text. - - Args: - session: DB session. - tz_name: Timezone name for display. - lang: Language code for i18n. - - Returns: - Formatted duty message or "No duty" if none. - """ + """Get current duty from DB and return formatted message text. Cached 90s.""" + cache_key = ("duty_message_text", tz_name, lang) + text, found = duty_pin_cache.get(cache_key) + if found: + return text now = datetime.now(timezone.utc) result = get_current_duty(session, now) if result is None: - return t(lang, "duty.no_duty") - duty, user = result - return format_duty_message(duty, user, tz_name, lang) + text = t(lang, "duty.no_duty") + else: + duty, user = result + text = format_duty_message(duty, user, tz_name, lang) + duty_pin_cache.set(cache_key, text) + return text def get_next_shift_end_utc(session: Session) -> datetime | None: - """Return next shift end as naive UTC datetime for job scheduling. - - Args: - session: DB session. - - Returns: - Next shift end (naive UTC) or None. - """ - return get_next_shift_end(session, datetime.now(timezone.utc)) + """Return next shift end as naive UTC datetime for job scheduling. Cached 90s.""" + cache_key = ("next_shift_end",) + value, found = duty_pin_cache.get(cache_key) + if found: + return value + result = get_next_shift_end(session, datetime.now(timezone.utc)) + duty_pin_cache.set(cache_key, result) + return result def save_pin(session: Session, chat_id: int, message_id: int) -> None: diff --git a/duty_teller/services/import_service.py b/duty_teller/services/import_service.py index f29f22a..1f5fb3d 100644 --- a/duty_teller/services/import_service.py +++ b/duty_teller/services/import_service.py @@ -4,10 +4,12 @@ from datetime import date, timedelta from sqlalchemy.orm import Session +from duty_teller.cache import invalidate_duty_related_caches +from duty_teller.db.models import Duty from duty_teller.db.repository import ( - get_or_create_user_by_full_name, delete_duties_in_range, - insert_duty, + get_or_create_user_by_full_name, + get_users_by_full_names, ) from duty_teller.importers.duty_schedule import DutyScheduleResult from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso @@ -37,11 +39,10 @@ def run_import( hour_utc: int, minute_utc: int, ) -> tuple[int, int, int, int]: - """Run duty-schedule import: delete range per user, insert duty/unavailable/vacation. + """Run duty-schedule import: delete range per user, bulk insert duties. - For each entry: get_or_create_user_by_full_name, delete_duties_in_range for - the result date range, then insert duties (handover time in UTC), unavailable - (all-day), and vacation (consecutive ranges). + Batched: users fetched in one query, missing created; bulk_insert_mappings. + One commit at end. Args: session: DB session. @@ -55,31 +56,61 @@ def run_import( from_date_str = result.start_date.isoformat() to_date_str = result.end_date.isoformat() num_duty = num_unavailable = num_vacation = 0 + + # Batch: get all users by full_name, create missing + names = [e.full_name for e in result.entries] + users_map = get_users_by_full_names(session, names) + for name in names: + if name not in users_map: + users_map[name] = get_or_create_user_by_full_name(session, name) + + # Delete range per user (no commit) for entry in result.entries: - user = get_or_create_user_by_full_name(session, entry.full_name) - delete_duties_in_range(session, user.id, from_date_str, to_date_str) + user = users_map[entry.full_name] + delete_duties_in_range( + session, user.id, from_date_str, to_date_str, commit=False + ) + + # Build rows for bulk insert + duty_rows: list[dict] = [] + for entry in result.entries: + user = users_map[entry.full_name] for d in entry.duty_dates: start_at = duty_to_iso(d, hour_utc, minute_utc) d_next = d + timedelta(days=1) end_at = duty_to_iso(d_next, hour_utc, minute_utc) - insert_duty(session, user.id, start_at, end_at, event_type="duty") + duty_rows.append( + { + "user_id": user.id, + "start_at": start_at, + "end_at": end_at, + "event_type": "duty", + } + ) num_duty += 1 for d in entry.unavailable_dates: - insert_duty( - session, - user.id, - day_start_iso(d), - day_end_iso(d), - event_type="unavailable", + duty_rows.append( + { + "user_id": user.id, + "start_at": day_start_iso(d), + "end_at": day_end_iso(d), + "event_type": "unavailable", + } ) num_unavailable += 1 for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates): - insert_duty( - session, - user.id, - day_start_iso(start_d), - day_end_iso(end_d), - event_type="vacation", + duty_rows.append( + { + "user_id": user.id, + "start_at": day_start_iso(start_d), + "end_at": day_end_iso(end_d), + "event_type": "vacation", + } ) num_vacation += 1 + + if duty_rows: + session.bulk_insert_mappings(Duty, duty_rows) + session.commit() + invalidate_duty_related_caches() return (len(result.entries), num_duty, num_unavailable, num_vacation) diff --git a/tests/test_app.py b/tests/test_app.py index eb0a11b..c479418 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -403,6 +403,9 @@ def test_calendar_ical_ignores_unknown_query_params( """Unknown query params (e.g. events=all) are ignored; response is duty-only.""" from types import SimpleNamespace + from duty_teller.cache import ics_calendar_cache + + ics_calendar_cache.invalidate(("personal_ics", 1)) mock_user = SimpleNamespace(id=1, full_name="User A") mock_get_user.return_value = mock_user duty = SimpleNamespace( diff --git a/tests/test_group_duty_pin_service.py b/tests/test_group_duty_pin_service.py index 2fe77db..7dc8b53 100644 --- a/tests/test_group_duty_pin_service.py +++ b/tests/test_group_duty_pin_service.py @@ -103,6 +103,7 @@ class TestGetDutyMessageText: """Tests for get_duty_message_text.""" def test_no_current_duty_returns_no_duty(self, session): + svc.duty_pin_cache.invalidate_pattern(("duty_message_text",)) with patch( "duty_teller.services.group_duty_pin_service.get_current_duty", return_value=None, @@ -113,6 +114,7 @@ class TestGetDutyMessageText: assert result == "No duty" def test_with_current_duty_returns_formatted(self, session, duty, user): + svc.duty_pin_cache.invalidate_pattern(("duty_message_text",)) with patch( "duty_teller.services.group_duty_pin_service.get_current_duty", return_value=(duty, user), @@ -130,6 +132,7 @@ class TestGetNextShiftEndUtc: """Tests for get_next_shift_end_utc.""" def test_no_next_shift_returns_none(self, session): + svc.duty_pin_cache.invalidate(("next_shift_end",)) with patch( "duty_teller.services.group_duty_pin_service.get_next_shift_end", return_value=None, @@ -138,6 +141,7 @@ class TestGetNextShiftEndUtc: assert result is None def test_has_next_shift_returns_naive_utc(self, session): + svc.duty_pin_cache.invalidate(("next_shift_end",)) naive = datetime(2025, 2, 21, 6, 0, 0) with patch( "duty_teller.services.group_duty_pin_service.get_next_shift_end", diff --git a/tests/test_handlers_group_duty_pin.py b/tests/test_handlers_group_duty_pin.py index 743e669..d93d14f 100644 --- a/tests/test_handlers_group_duty_pin.py +++ b/tests/test_handlers_group_duty_pin.py @@ -143,14 +143,12 @@ async def test_update_group_pin_sends_new_unpins_pins_saves_schedules_next(): context.application.job_queue.run_once = MagicMock() with patch.object(config, "DUTY_PIN_NOTIFY", True): - with patch.object(mod, "_sync_get_message_id", return_value=1): - with patch.object( - mod, "_get_duty_message_text_sync", return_value="Current duty" - ): - with patch.object(mod, "_get_next_shift_end_sync", return_value=None): - with patch.object(mod, "_schedule_next_update", AsyncMock()): - with patch.object(mod, "_sync_save_pin") as mock_save: - await mod.update_group_pin(context) + with patch.object( + mod, "_sync_get_pin_refresh_data", return_value=(1, "Current duty", None) + ): + with patch.object(mod, "_schedule_next_update", AsyncMock()): + with patch.object(mod, "_sync_save_pin") as mock_save: + await mod.update_group_pin(context) context.bot.send_message.assert_called_once_with(chat_id=123, text="Current duty") context.bot.unpin_chat_message.assert_called_once_with(chat_id=123) context.bot.pin_chat_message.assert_called_once_with( @@ -168,7 +166,9 @@ async def test_update_group_pin_no_message_id_skips(): context.bot = MagicMock() context.bot.send_message = AsyncMock() - with patch.object(mod, "_sync_get_message_id", return_value=None): + with patch.object( + mod, "_sync_get_pin_refresh_data", return_value=(None, "No duty", None) + ): await mod.update_group_pin(context) context.bot.send_message.assert_not_called() @@ -185,13 +185,11 @@ async def test_update_group_pin_send_raises_no_unpin_pin_schedule_still_called() context.bot.pin_chat_message = AsyncMock() context.application = MagicMock() - with patch.object(mod, "_sync_get_message_id", return_value=2): - with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"): - with patch.object(mod, "_get_next_shift_end_sync", return_value=None): - with patch.object( - mod, "_schedule_next_update", AsyncMock() - ) as mock_schedule: - await mod.update_group_pin(context) + with patch.object( + mod, "_sync_get_pin_refresh_data", return_value=(2, "Text", None) + ): + with patch.object(mod, "_schedule_next_update", AsyncMock()) as mock_schedule: + await mod.update_group_pin(context) context.bot.unpin_chat_message.assert_not_called() context.bot.pin_chat_message.assert_not_called() mock_schedule.assert_called_once_with(context.application, 111, None) @@ -214,15 +212,15 @@ async def test_update_group_pin_repin_raises_still_schedules_next(): context.application = MagicMock() with patch.object(config, "DUTY_PIN_NOTIFY", True): - with patch.object(mod, "_sync_get_message_id", return_value=3): - with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"): - with patch.object(mod, "_get_next_shift_end_sync", return_value=None): - with patch.object( - mod, "_schedule_next_update", AsyncMock() - ) as mock_schedule: - with patch.object(mod, "_sync_save_pin") as mock_save: - with patch.object(mod, "logger") as mock_logger: - await mod.update_group_pin(context) + with patch.object( + mod, "_sync_get_pin_refresh_data", return_value=(3, "Text", None) + ): + with patch.object( + mod, "_schedule_next_update", AsyncMock() + ) as mock_schedule: + with patch.object(mod, "_sync_save_pin") as mock_save: + with patch.object(mod, "logger") as mock_logger: + await mod.update_group_pin(context) context.bot.send_message.assert_called_once_with(chat_id=222, text="Text") mock_save.assert_not_called() mock_logger.warning.assert_called_once() @@ -245,14 +243,14 @@ async def test_update_group_pin_duty_pin_notify_false_pins_silent(): context.application = MagicMock() with patch.object(config, "DUTY_PIN_NOTIFY", False): - with patch.object(mod, "_sync_get_message_id", return_value=4): - with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"): - with patch.object(mod, "_get_next_shift_end_sync", return_value=None): - with patch.object( - mod, "_schedule_next_update", AsyncMock() - ) as mock_schedule: - with patch.object(mod, "_sync_save_pin") as mock_save: - await mod.update_group_pin(context) + with patch.object( + mod, "_sync_get_pin_refresh_data", return_value=(4, "Text", None) + ): + with patch.object( + mod, "_schedule_next_update", AsyncMock() + ) as mock_schedule: + with patch.object(mod, "_sync_save_pin") as mock_save: + await mod.update_group_pin(context) context.bot.send_message.assert_called_once_with(chat_id=333, text="Text") context.bot.unpin_chat_message.assert_called_once_with(chat_id=333) context.bot.pin_chat_message.assert_called_once_with(