"""Simple in-memory TTL cache. Thread-safe for get/set.""" import logging from threading import Lock from time import time log = logging.getLogger(__name__) class TTLCache: """Thread-safe TTL cache with optional max size and pattern invalidation.""" def __init__(self, ttl_seconds: float, max_size: int = 1000) -> None: """Initialize cache with TTL and optional max size. Args: ttl_seconds: Time-to-live in seconds for each entry. max_size: Maximum number of entries (0 = unlimited). LRU eviction when exceeded. """ self._ttl = ttl_seconds self._max_size = max_size self._data: dict[tuple, tuple[float, object]] = {} # key -> (cached_at, value) self._lock = Lock() self._access_order: list[tuple] = [] # For LRU when max_size > 0 def get(self, key: tuple) -> tuple[object, bool]: """Get value by key if present and not expired. Args: key: Cache key (must be hashable, typically tuple). Returns: (value, found) — found is True if valid cached value exists. """ with self._lock: entry = self._data.get(key) if entry is None: return (None, False) cached_at, value = entry if time() - cached_at >= self._ttl: del self._data[key] if self._max_size > 0 and key in self._access_order: self._access_order.remove(key) return (None, False) if self._max_size > 0 and key in self._access_order: self._access_order.remove(key) self._access_order.append(key) return (value, True) def set(self, key: tuple, value: object) -> None: """Store value with current timestamp. Args: key: Cache key (must be hashable). value: Value to cache. """ with self._lock: now = time() if ( self._max_size > 0 and len(self._data) >= self._max_size and key not in self._data ): # Evict oldest while self._access_order and len(self._data) >= self._max_size: old_key = self._access_order.pop(0) if old_key in self._data: del self._data[old_key] self._data[key] = (now, value) if self._max_size > 0: if key in self._access_order: self._access_order.remove(key) self._access_order.append(key) def invalidate(self, key: tuple) -> None: """Remove a single key from cache. Args: key: Cache key to remove. """ with self._lock: if key in self._data: del self._data[key] if self._max_size > 0 and key in self._access_order: self._access_order.remove(key) def clear(self) -> None: """Remove all entries. Useful for tests.""" with self._lock: self._data.clear() self._access_order.clear() def invalidate_pattern(self, key_prefix: tuple) -> None: """Remove all keys that start with the given prefix. Args: key_prefix: Prefix tuple (e.g. ("personal",) matches ("personal", 1), ("personal", 2)). """ with self._lock: to_remove = [k for k in self._data if self._key_starts_with(k, key_prefix)] for k in to_remove: del self._data[k] if self._max_size > 0 and k in self._access_order: self._access_order.remove(k) @staticmethod def _key_starts_with(key: tuple, prefix: tuple) -> bool: """Check if key starts with prefix (both tuples).""" if len(key) < len(prefix): return False return key[: len(prefix)] == prefix # Shared caches for duty-related data. Invalidate on import. ics_calendar_cache = TTLCache(ttl_seconds=600, max_size=500) duty_pin_cache = TTLCache(ttl_seconds=90, max_size=100) # current_duty, next_shift_end is_admin_cache = TTLCache(ttl_seconds=60, max_size=200) def invalidate_duty_related_caches() -> None: """Invalidate caches that depend on duties data. Call after import.""" ics_calendar_cache.invalidate_pattern(("personal_ics",)) ics_calendar_cache.invalidate_pattern(("team_ics",)) duty_pin_cache.invalidate_pattern(("duty_message_text",)) duty_pin_cache.invalidate(("next_shift_end",))