Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0e8d1453e2 |
44
alembic/versions/008_duties_indexes.py
Normal file
44
alembic/versions/008_duties_indexes.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Add indexes on duties table for performance.
|
||||||
|
|
||||||
|
Revision ID: 008
|
||||||
|
Revises: 007
|
||||||
|
Create Date: 2025-02-25
|
||||||
|
|
||||||
|
Indexes for get_current_duty, get_next_shift_end, get_duties, get_duties_for_user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "008"
|
||||||
|
down_revision: Union[str, None] = "007"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_index(
|
||||||
|
"ix_duties_event_type_start_at",
|
||||||
|
"duties",
|
||||||
|
["event_type", "start_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_duties_event_type_end_at",
|
||||||
|
"duties",
|
||||||
|
["event_type", "end_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_duties_user_id_start_at",
|
||||||
|
"duties",
|
||||||
|
["user_id", "start_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_duties_user_id_start_at", table_name="duties")
|
||||||
|
op.drop_index("ix_duties_event_type_end_at", table_name="duties")
|
||||||
|
op.drop_index("ix_duties_event_type_start_at", table_name="duties")
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
|
|
||||||
import duty_teller.config as config
|
import duty_teller.config as config
|
||||||
from fastapi import Depends, FastAPI, Request
|
from fastapi import Depends, FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -18,6 +19,7 @@ from duty_teller.api.dependencies import (
|
|||||||
require_miniapp_username,
|
require_miniapp_username,
|
||||||
)
|
)
|
||||||
from duty_teller.api.personal_calendar_ics import build_personal_ics, build_team_ics
|
from duty_teller.api.personal_calendar_ics import build_personal_ics, build_team_ics
|
||||||
|
from duty_teller.cache import ics_calendar_cache
|
||||||
from duty_teller.db.repository import (
|
from duty_teller.db.repository import (
|
||||||
get_duties,
|
get_duties,
|
||||||
get_duties_for_user,
|
get_duties_for_user,
|
||||||
@@ -116,14 +118,18 @@ def get_team_calendar_ical(
|
|||||||
user = get_user_by_calendar_token(session, token)
|
user = get_user_by_calendar_token(session, token)
|
||||||
if user is None:
|
if user is None:
|
||||||
return Response(status_code=404, content="Not found")
|
return Response(status_code=404, content="Not found")
|
||||||
today = date.today()
|
cache_key = ("team_ics",)
|
||||||
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
ics_bytes, found = ics_calendar_cache.get(cache_key)
|
||||||
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
if not found:
|
||||||
all_duties = get_duties(session, from_date=from_date, to_date=to_date)
|
today = date.today()
|
||||||
duties_duty_only = [
|
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||||
(d, name) for d, name in all_duties if (d.event_type or "duty") == "duty"
|
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
||||||
]
|
all_duties = get_duties(session, from_date=from_date, to_date=to_date)
|
||||||
ics_bytes = build_team_ics(duties_duty_only)
|
duties_duty_only = [
|
||||||
|
(d, name) for d, name in all_duties if (d.event_type or "duty") == "duty"
|
||||||
|
]
|
||||||
|
ics_bytes = build_team_ics(duties_duty_only)
|
||||||
|
ics_calendar_cache.set(cache_key, ics_bytes)
|
||||||
return Response(
|
return Response(
|
||||||
content=ics_bytes,
|
content=ics_bytes,
|
||||||
media_type="text/calendar; charset=utf-8",
|
media_type="text/calendar; charset=utf-8",
|
||||||
@@ -151,13 +157,17 @@ def get_personal_calendar_ical(
|
|||||||
user = get_user_by_calendar_token(session, token)
|
user = get_user_by_calendar_token(session, token)
|
||||||
if user is None:
|
if user is None:
|
||||||
return Response(status_code=404, content="Not found")
|
return Response(status_code=404, content="Not found")
|
||||||
today = date.today()
|
cache_key = ("personal_ics", user.id)
|
||||||
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
ics_bytes, found = ics_calendar_cache.get(cache_key)
|
||||||
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
if not found:
|
||||||
duties_with_name = get_duties_for_user(
|
today = date.today()
|
||||||
session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"]
|
from_date = (today - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||||
)
|
to_date = (today + timedelta(days=365 * 2)).strftime("%Y-%m-%d")
|
||||||
ics_bytes = build_personal_ics(duties_with_name)
|
duties_with_name = get_duties_for_user(
|
||||||
|
session, user.id, from_date=from_date, to_date=to_date, event_types=["duty"]
|
||||||
|
)
|
||||||
|
ics_bytes = build_personal_ics(duties_with_name)
|
||||||
|
ics_calendar_cache.set(cache_key, ics_bytes)
|
||||||
return Response(
|
return Response(
|
||||||
content=ics_bytes,
|
content=ics_bytes,
|
||||||
media_type="text/calendar; charset=utf-8",
|
media_type="text/calendar; charset=utf-8",
|
||||||
|
|||||||
@@ -7,12 +7,15 @@ from urllib.error import URLError
|
|||||||
|
|
||||||
from icalendar import Calendar
|
from icalendar import Calendar
|
||||||
|
|
||||||
|
from duty_teller.cache import TTLCache
|
||||||
from duty_teller.utils.http_client import safe_urlopen
|
from duty_teller.utils.http_client import safe_urlopen
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# In-memory cache: url -> (cached_at_timestamp, raw_ics_bytes)
|
# Raw ICS bytes cache: url -> (cached_at_timestamp, raw_ics_bytes)
|
||||||
_ics_cache: dict[str, tuple[float, bytes]] = {}
|
_ics_cache: dict[str, tuple[float, bytes]] = {}
|
||||||
|
# Parsed events cache: url -> list of {date, summary}. TTL 7 days.
|
||||||
|
_parsed_events_cache = TTLCache(ttl_seconds=7 * 24 * 3600, max_size=100)
|
||||||
CACHE_TTL_SECONDS = 7 * 24 * 3600 # 1 week
|
CACHE_TTL_SECONDS = 7 * 24 * 3600 # 1 week
|
||||||
FETCH_TIMEOUT_SECONDS = 15
|
FETCH_TIMEOUT_SECONDS = 15
|
||||||
|
|
||||||
@@ -68,8 +71,8 @@ def _event_date_range(component) -> tuple[date | None, date | None]:
|
|||||||
return (start_d, last_d)
|
return (start_d, last_d)
|
||||||
|
|
||||||
|
|
||||||
def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]:
|
def _parse_ics_to_events(raw: bytes) -> list[dict]:
|
||||||
"""Parse ICS bytes and return list of {date, summary} in [from_date, to_date]. One-time events only."""
|
"""Parse ICS bytes and return all events as list of {date, summary}. One-time events only."""
|
||||||
result: list[dict] = []
|
result: list[dict] = []
|
||||||
try:
|
try:
|
||||||
cal = Calendar.from_ical(raw)
|
cal = Calendar.from_ical(raw)
|
||||||
@@ -79,9 +82,6 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]
|
|||||||
log.warning("Failed to parse ICS: %s", e)
|
log.warning("Failed to parse ICS: %s", e)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
from_d = date.fromisoformat(from_date)
|
|
||||||
to_d = date.fromisoformat(to_date)
|
|
||||||
|
|
||||||
for component in cal.walk():
|
for component in cal.walk():
|
||||||
if component.name != "VEVENT":
|
if component.name != "VEVENT":
|
||||||
continue
|
continue
|
||||||
@@ -95,13 +95,27 @@ def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]
|
|||||||
|
|
||||||
d = start_d
|
d = start_d
|
||||||
while d <= end_d:
|
while d <= end_d:
|
||||||
if from_d <= d <= to_d:
|
result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str})
|
||||||
result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str})
|
|
||||||
d += timedelta(days=1)
|
d += timedelta(days=1)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_events_by_range(
|
||||||
|
events: list[dict], from_date: str, to_date: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Filter events list to [from_date, to_date] range."""
|
||||||
|
from_d = date.fromisoformat(from_date)
|
||||||
|
to_d = date.fromisoformat(to_date)
|
||||||
|
return [e for e in events if from_d <= date.fromisoformat(e["date"]) <= to_d]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]:
|
||||||
|
"""Parse ICS bytes and return events in [from_date, to_date]. Wrapper for tests."""
|
||||||
|
events = _parse_ics_to_events(raw)
|
||||||
|
return _filter_events_by_range(events, from_date, to_date)
|
||||||
|
|
||||||
|
|
||||||
def get_calendar_events(
|
def get_calendar_events(
|
||||||
url: str,
|
url: str,
|
||||||
from_date: str,
|
from_date: str,
|
||||||
@@ -135,4 +149,10 @@ def get_calendar_events(
|
|||||||
return []
|
return []
|
||||||
_ics_cache[url] = (now, raw)
|
_ics_cache[url] = (now, raw)
|
||||||
|
|
||||||
return _get_events_from_ics(raw, from_date, to_date)
|
# Use parsed events cache to avoid repeated Calendar.from_ical() + walk()
|
||||||
|
cache_key = (url,)
|
||||||
|
events, found = _parsed_events_cache.get(cache_key)
|
||||||
|
if not found:
|
||||||
|
events = _parse_ics_to_events(raw)
|
||||||
|
_parsed_events_cache.set(cache_key, events)
|
||||||
|
return _filter_events_by_range(events, from_date, to_date)
|
||||||
|
|||||||
125
duty_teller/cache.py
Normal file
125
duty_teller/cache.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Simple in-memory TTL cache. Thread-safe for get/set."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from threading import Lock
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TTLCache:
|
||||||
|
"""Thread-safe TTL cache with optional max size and pattern invalidation."""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: float, max_size: int = 1000) -> None:
|
||||||
|
"""Initialize cache with TTL and optional max size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl_seconds: Time-to-live in seconds for each entry.
|
||||||
|
max_size: Maximum number of entries (0 = unlimited). LRU eviction when exceeded.
|
||||||
|
"""
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
self._max_size = max_size
|
||||||
|
self._data: dict[tuple, tuple[float, object]] = {} # key -> (cached_at, value)
|
||||||
|
self._lock = Lock()
|
||||||
|
self._access_order: list[tuple] = [] # For LRU when max_size > 0
|
||||||
|
|
||||||
|
def get(self, key: tuple) -> tuple[object, bool]:
|
||||||
|
"""Get value by key if present and not expired.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key (must be hashable, typically tuple).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(value, found) — found is True if valid cached value exists.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
entry = self._data.get(key)
|
||||||
|
if entry is None:
|
||||||
|
return (None, False)
|
||||||
|
cached_at, value = entry
|
||||||
|
if time() - cached_at >= self._ttl:
|
||||||
|
del self._data[key]
|
||||||
|
if self._max_size > 0 and key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
return (None, False)
|
||||||
|
if self._max_size > 0 and key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
self._access_order.append(key)
|
||||||
|
return (value, True)
|
||||||
|
|
||||||
|
def set(self, key: tuple, value: object) -> None:
|
||||||
|
"""Store value with current timestamp.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key (must be hashable).
|
||||||
|
value: Value to cache.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
now = time()
|
||||||
|
if (
|
||||||
|
self._max_size > 0
|
||||||
|
and len(self._data) >= self._max_size
|
||||||
|
and key not in self._data
|
||||||
|
):
|
||||||
|
# Evict oldest
|
||||||
|
while self._access_order and len(self._data) >= self._max_size:
|
||||||
|
old_key = self._access_order.pop(0)
|
||||||
|
if old_key in self._data:
|
||||||
|
del self._data[old_key]
|
||||||
|
self._data[key] = (now, value)
|
||||||
|
if self._max_size > 0:
|
||||||
|
if key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
self._access_order.append(key)
|
||||||
|
|
||||||
|
def invalidate(self, key: tuple) -> None:
|
||||||
|
"""Remove a single key from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to remove.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if key in self._data:
|
||||||
|
del self._data[key]
|
||||||
|
if self._max_size > 0 and key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all entries. Useful for tests."""
|
||||||
|
with self._lock:
|
||||||
|
self._data.clear()
|
||||||
|
self._access_order.clear()
|
||||||
|
|
||||||
|
def invalidate_pattern(self, key_prefix: tuple) -> None:
|
||||||
|
"""Remove all keys that start with the given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_prefix: Prefix tuple (e.g. ("personal",) matches ("personal", 1), ("personal", 2)).
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
to_remove = [k for k in self._data if self._key_starts_with(k, key_prefix)]
|
||||||
|
for k in to_remove:
|
||||||
|
del self._data[k]
|
||||||
|
if self._max_size > 0 and k in self._access_order:
|
||||||
|
self._access_order.remove(k)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key_starts_with(key: tuple, prefix: tuple) -> bool:
|
||||||
|
"""Check if key starts with prefix (both tuples)."""
|
||||||
|
if len(key) < len(prefix):
|
||||||
|
return False
|
||||||
|
return key[: len(prefix)] == prefix
|
||||||
|
|
||||||
|
|
||||||
|
# Shared caches for duty-related data. Invalidate on import.
|
||||||
|
ics_calendar_cache = TTLCache(ttl_seconds=600, max_size=500)
|
||||||
|
duty_pin_cache = TTLCache(ttl_seconds=90, max_size=100) # current_duty, next_shift_end
|
||||||
|
is_admin_cache = TTLCache(ttl_seconds=60, max_size=200)
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_duty_related_caches() -> None:
|
||||||
|
"""Invalidate caches that depend on duties data. Call after import."""
|
||||||
|
ics_calendar_cache.invalidate_pattern(("personal_ics",))
|
||||||
|
ics_calendar_cache.invalidate_pattern(("team_ics",))
|
||||||
|
duty_pin_cache.invalidate_pattern(("duty_message_text",))
|
||||||
|
duty_pin_cache.invalidate(("next_shift_end",))
|
||||||
@@ -229,6 +229,22 @@ def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_users_by_full_names(session: Session, full_names: list[str]) -> dict[str, User]:
|
||||||
|
"""Get users by full_name. Returns dict full_name -> User. Does not create missing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB session.
|
||||||
|
full_names: List of full names to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping full_name to User for found users.
|
||||||
|
"""
|
||||||
|
if not full_names:
|
||||||
|
return {}
|
||||||
|
users = session.query(User).filter(User.full_name.in_(full_names)).all()
|
||||||
|
return {u.full_name: u for u in users}
|
||||||
|
|
||||||
|
|
||||||
def update_user_display_name(
|
def update_user_display_name(
|
||||||
session: Session,
|
session: Session,
|
||||||
telegram_user_id: int,
|
telegram_user_id: int,
|
||||||
@@ -268,6 +284,8 @@ def delete_duties_in_range(
|
|||||||
user_id: int,
|
user_id: int,
|
||||||
from_date: str,
|
from_date: str,
|
||||||
to_date: str,
|
to_date: str,
|
||||||
|
*,
|
||||||
|
commit: bool = True,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Delete all duties of the user that overlap the given date range.
|
"""Delete all duties of the user that overlap the given date range.
|
||||||
|
|
||||||
@@ -276,6 +294,7 @@ def delete_duties_in_range(
|
|||||||
user_id: User id.
|
user_id: User id.
|
||||||
from_date: Start date YYYY-MM-DD.
|
from_date: Start date YYYY-MM-DD.
|
||||||
to_date: End date YYYY-MM-DD.
|
to_date: End date YYYY-MM-DD.
|
||||||
|
commit: If True, commit immediately. If False, caller commits (for batch import).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of duties deleted.
|
Number of duties deleted.
|
||||||
@@ -288,7 +307,8 @@ def delete_duties_in_range(
|
|||||||
)
|
)
|
||||||
count = q.count()
|
count = q.count()
|
||||||
q.delete(synchronize_session=False)
|
q.delete(synchronize_session=False)
|
||||||
session.commit()
|
if commit:
|
||||||
|
session.commit()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from duty_teller.db.repository import (
|
|||||||
ROLE_USER,
|
ROLE_USER,
|
||||||
ROLE_ADMIN,
|
ROLE_ADMIN,
|
||||||
)
|
)
|
||||||
from duty_teller.handlers.common import is_admin_async
|
from duty_teller.handlers.common import invalidate_is_admin_cache, is_admin_async
|
||||||
from duty_teller.i18n import get_lang, t
|
from duty_teller.i18n import get_lang, t
|
||||||
from duty_teller.utils.user import build_full_name
|
from duty_teller.utils.user import build_full_name
|
||||||
|
|
||||||
@@ -230,6 +230,7 @@ async def set_role(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
|||||||
|
|
||||||
ok = await asyncio.get_running_loop().run_in_executor(None, do_set_role)
|
ok = await asyncio.get_running_loop().run_in_executor(None, do_set_role)
|
||||||
if ok:
|
if ok:
|
||||||
|
invalidate_is_admin_cache(target_user.telegram_user_id)
|
||||||
await update.message.reply_text(
|
await update.message.reply_text(
|
||||||
t(lang, "set_role.done", name=target_user.full_name, role=role_name)
|
t(lang, "set_role.done", name=target_user.full_name, role=role_name)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,22 +3,35 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import duty_teller.config as config
|
import duty_teller.config as config
|
||||||
|
from duty_teller.cache import is_admin_cache
|
||||||
from duty_teller.db.repository import is_admin_for_telegram_user
|
from duty_teller.db.repository import is_admin_for_telegram_user
|
||||||
from duty_teller.db.session import session_scope
|
from duty_teller.db.session import session_scope
|
||||||
|
|
||||||
|
|
||||||
async def is_admin_async(telegram_user_id: int) -> bool:
|
async def is_admin_async(telegram_user_id: int) -> bool:
|
||||||
"""Check if Telegram user is admin (username or phone). Runs DB check in executor.
|
"""Check if Telegram user is admin. Cached 60s. Invalidated on set_user_role.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
telegram_user_id: Telegram user id.
|
telegram_user_id: Telegram user id.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if user is in ADMIN_USERNAMES or their stored phone is in ADMIN_PHONES.
|
True if user is admin (DB role or env fallback).
|
||||||
"""
|
"""
|
||||||
|
cache_key = ("is_admin", telegram_user_id)
|
||||||
|
value, found = is_admin_cache.get(cache_key)
|
||||||
|
if found:
|
||||||
|
return value
|
||||||
|
|
||||||
def _check() -> bool:
|
def _check() -> bool:
|
||||||
with session_scope(config.DATABASE_URL) as session:
|
with session_scope(config.DATABASE_URL) as session:
|
||||||
return is_admin_for_telegram_user(session, telegram_user_id)
|
return is_admin_for_telegram_user(session, telegram_user_id)
|
||||||
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(None, _check)
|
result = await asyncio.get_running_loop().run_in_executor(None, _check)
|
||||||
|
is_admin_cache.set(cache_key, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_is_admin_cache(telegram_user_id: int | None) -> None:
|
||||||
|
"""Invalidate is_admin cache for user. Call after set_user_role."""
|
||||||
|
if telegram_user_id is not None:
|
||||||
|
is_admin_cache.invalidate(("is_admin", telegram_user_id))
|
||||||
|
|||||||
@@ -15,10 +15,11 @@ from duty_teller.db.session import session_scope
|
|||||||
from duty_teller.i18n import get_lang, t
|
from duty_teller.i18n import get_lang, t
|
||||||
from duty_teller.services.group_duty_pin_service import (
|
from duty_teller.services.group_duty_pin_service import (
|
||||||
get_duty_message_text,
|
get_duty_message_text,
|
||||||
|
get_message_id,
|
||||||
get_next_shift_end_utc,
|
get_next_shift_end_utc,
|
||||||
|
get_pin_refresh_data,
|
||||||
save_pin,
|
save_pin,
|
||||||
delete_pin,
|
delete_pin,
|
||||||
get_message_id,
|
|
||||||
get_all_pin_chat_ids,
|
get_all_pin_chat_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,6 +29,14 @@ JOB_NAME_PREFIX = "duty_pin_"
|
|||||||
RETRY_WHEN_NO_DUTY_MINUTES = 15
|
RETRY_WHEN_NO_DUTY_MINUTES = 15
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_get_pin_refresh_data(
|
||||||
|
chat_id: int, lang: str = "en"
|
||||||
|
) -> tuple[int | None, str, datetime | None]:
|
||||||
|
"""Get message_id, duty text, next_shift_end in one DB session."""
|
||||||
|
with session_scope(config.DATABASE_URL) as session:
|
||||||
|
return get_pin_refresh_data(session, chat_id, config.DUTY_DISPLAY_TZ, lang)
|
||||||
|
|
||||||
|
|
||||||
def _get_duty_message_text_sync(lang: str = "en") -> str:
|
def _get_duty_message_text_sync(lang: str = "en") -> str:
|
||||||
with session_scope(config.DATABASE_URL) as session:
|
with session_scope(config.DATABASE_URL) as session:
|
||||||
return get_duty_message_text(session, config.DUTY_DISPLAY_TZ, lang)
|
return get_duty_message_text(session, config.DUTY_DISPLAY_TZ, lang)
|
||||||
@@ -96,26 +105,27 @@ async def _refresh_pin_for_chat(
|
|||||||
) -> Literal["updated", "no_message", "failed"]:
|
) -> Literal["updated", "no_message", "failed"]:
|
||||||
"""Refresh pinned duty message: send new message, unpin old, pin new, save new message_id.
|
"""Refresh pinned duty message: send new message, unpin old, pin new, save new message_id.
|
||||||
|
|
||||||
|
Uses single DB session for message_id, text, next_shift_end (consolidated).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
"updated" if the message was sent, pinned and saved successfully;
|
"updated" if the message was sent, pinned and saved successfully;
|
||||||
"no_message" if there is no pin record for this chat;
|
"no_message" if there is no pin record for this chat;
|
||||||
"failed" if send_message or permissions failed.
|
"failed" if send_message or permissions failed.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
message_id = await loop.run_in_executor(None, _sync_get_message_id, chat_id)
|
message_id, text, next_end = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: _sync_get_pin_refresh_data(chat_id, config.DEFAULT_LANGUAGE),
|
||||||
|
)
|
||||||
if message_id is None:
|
if message_id is None:
|
||||||
logger.info("No pin record for chat_id=%s, skipping update", chat_id)
|
logger.info("No pin record for chat_id=%s, skipping update", chat_id)
|
||||||
return "no_message"
|
return "no_message"
|
||||||
text = await loop.run_in_executor(
|
|
||||||
None, lambda: _get_duty_message_text_sync(config.DEFAULT_LANGUAGE)
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
msg = await context.bot.send_message(chat_id=chat_id, text=text)
|
msg = await context.bot.send_message(chat_id=chat_id, text=text)
|
||||||
except (BadRequest, Forbidden) as e:
|
except (BadRequest, Forbidden) as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to send duty message for pin refresh chat_id=%s: %s", chat_id, e
|
"Failed to send duty message for pin refresh chat_id=%s: %s", chat_id, e
|
||||||
)
|
)
|
||||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
|
||||||
await _schedule_next_update(context.application, chat_id, next_end)
|
await _schedule_next_update(context.application, chat_id, next_end)
|
||||||
return "failed"
|
return "failed"
|
||||||
try:
|
try:
|
||||||
@@ -127,11 +137,9 @@ async def _refresh_pin_for_chat(
|
|||||||
)
|
)
|
||||||
except (BadRequest, Forbidden) as e:
|
except (BadRequest, Forbidden) as e:
|
||||||
logger.warning("Unpin or pin after refresh failed chat_id=%s: %s", chat_id, e)
|
logger.warning("Unpin or pin after refresh failed chat_id=%s: %s", chat_id, e)
|
||||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
|
||||||
await _schedule_next_update(context.application, chat_id, next_end)
|
await _schedule_next_update(context.application, chat_id, next_end)
|
||||||
return "failed"
|
return "failed"
|
||||||
await loop.run_in_executor(None, _sync_save_pin, chat_id, msg.message_id)
|
await loop.run_in_executor(None, _sync_save_pin, chat_id, msg.message_id)
|
||||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
|
||||||
await _schedule_next_update(context.application, chat_id, next_end)
|
await _schedule_next_update(context.application, chat_id, next_end)
|
||||||
return "updated"
|
return "updated"
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from duty_teller.cache import duty_pin_cache
|
||||||
from duty_teller.db.repository import (
|
from duty_teller.db.repository import (
|
||||||
get_current_duty,
|
get_current_duty,
|
||||||
get_next_shift_end,
|
get_next_shift_end,
|
||||||
@@ -17,6 +18,31 @@ from duty_teller.i18n import t
|
|||||||
from duty_teller.utils.dates import parse_utc_iso
|
from duty_teller.utils.dates import parse_utc_iso
|
||||||
|
|
||||||
|
|
||||||
|
def get_pin_refresh_data(
|
||||||
|
session: Session, chat_id: int, tz_name: str, lang: str = "en"
|
||||||
|
) -> tuple[int | None, str, datetime | None]:
|
||||||
|
"""Get all data needed for pin refresh in a single DB session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB session.
|
||||||
|
chat_id: Telegram chat id.
|
||||||
|
tz_name: Timezone name for display.
|
||||||
|
lang: Language code for i18n.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(message_id, duty_message_text, next_shift_end_utc).
|
||||||
|
message_id is None if no pin record. next_shift_end_utc is naive UTC or None.
|
||||||
|
"""
|
||||||
|
pin = get_group_duty_pin(session, chat_id)
|
||||||
|
message_id = pin.message_id if pin else None
|
||||||
|
if message_id is None:
|
||||||
|
return (None, t(lang, "duty.no_duty"), None)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
text = get_duty_message_text(session, tz_name, lang)
|
||||||
|
next_end = get_next_shift_end(session, now)
|
||||||
|
return (message_id, text, next_end)
|
||||||
|
|
||||||
|
|
||||||
def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str:
|
def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str:
|
||||||
"""Build the text for the pinned duty message.
|
"""Build the text for the pinned duty message.
|
||||||
|
|
||||||
@@ -64,34 +90,31 @@ def format_duty_message(duty, user, tz_name: str, lang: str = "en") -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_duty_message_text(session: Session, tz_name: str, lang: str = "en") -> str:
|
def get_duty_message_text(session: Session, tz_name: str, lang: str = "en") -> str:
|
||||||
"""Get current duty from DB and return formatted message text.
|
"""Get current duty from DB and return formatted message text. Cached 90s."""
|
||||||
|
cache_key = ("duty_message_text", tz_name, lang)
|
||||||
Args:
|
text, found = duty_pin_cache.get(cache_key)
|
||||||
session: DB session.
|
if found:
|
||||||
tz_name: Timezone name for display.
|
return text
|
||||||
lang: Language code for i18n.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted duty message or "No duty" if none.
|
|
||||||
"""
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
result = get_current_duty(session, now)
|
result = get_current_duty(session, now)
|
||||||
if result is None:
|
if result is None:
|
||||||
return t(lang, "duty.no_duty")
|
text = t(lang, "duty.no_duty")
|
||||||
duty, user = result
|
else:
|
||||||
return format_duty_message(duty, user, tz_name, lang)
|
duty, user = result
|
||||||
|
text = format_duty_message(duty, user, tz_name, lang)
|
||||||
|
duty_pin_cache.set(cache_key, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def get_next_shift_end_utc(session: Session) -> datetime | None:
|
def get_next_shift_end_utc(session: Session) -> datetime | None:
|
||||||
"""Return next shift end as naive UTC datetime for job scheduling.
|
"""Return next shift end as naive UTC datetime for job scheduling. Cached 90s."""
|
||||||
|
cache_key = ("next_shift_end",)
|
||||||
Args:
|
value, found = duty_pin_cache.get(cache_key)
|
||||||
session: DB session.
|
if found:
|
||||||
|
return value
|
||||||
Returns:
|
result = get_next_shift_end(session, datetime.now(timezone.utc))
|
||||||
Next shift end (naive UTC) or None.
|
duty_pin_cache.set(cache_key, result)
|
||||||
"""
|
return result
|
||||||
return get_next_shift_end(session, datetime.now(timezone.utc))
|
|
||||||
|
|
||||||
|
|
||||||
def save_pin(session: Session, chat_id: int, message_id: int) -> None:
|
def save_pin(session: Session, chat_id: int, message_id: int) -> None:
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ from datetime import date, timedelta
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from duty_teller.cache import invalidate_duty_related_caches
|
||||||
|
from duty_teller.db.models import Duty
|
||||||
from duty_teller.db.repository import (
|
from duty_teller.db.repository import (
|
||||||
get_or_create_user_by_full_name,
|
|
||||||
delete_duties_in_range,
|
delete_duties_in_range,
|
||||||
insert_duty,
|
get_or_create_user_by_full_name,
|
||||||
|
get_users_by_full_names,
|
||||||
)
|
)
|
||||||
from duty_teller.importers.duty_schedule import DutyScheduleResult
|
from duty_teller.importers.duty_schedule import DutyScheduleResult
|
||||||
from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso
|
from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso
|
||||||
@@ -37,11 +39,10 @@ def run_import(
|
|||||||
hour_utc: int,
|
hour_utc: int,
|
||||||
minute_utc: int,
|
minute_utc: int,
|
||||||
) -> tuple[int, int, int, int]:
|
) -> tuple[int, int, int, int]:
|
||||||
"""Run duty-schedule import: delete range per user, insert duty/unavailable/vacation.
|
"""Run duty-schedule import: delete range per user, bulk insert duties.
|
||||||
|
|
||||||
For each entry: get_or_create_user_by_full_name, delete_duties_in_range for
|
Batched: users fetched in one query, missing created; bulk_insert_mappings.
|
||||||
the result date range, then insert duties (handover time in UTC), unavailable
|
One commit at end.
|
||||||
(all-day), and vacation (consecutive ranges).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB session.
|
session: DB session.
|
||||||
@@ -55,31 +56,61 @@ def run_import(
|
|||||||
from_date_str = result.start_date.isoformat()
|
from_date_str = result.start_date.isoformat()
|
||||||
to_date_str = result.end_date.isoformat()
|
to_date_str = result.end_date.isoformat()
|
||||||
num_duty = num_unavailable = num_vacation = 0
|
num_duty = num_unavailable = num_vacation = 0
|
||||||
|
|
||||||
|
# Batch: get all users by full_name, create missing
|
||||||
|
names = [e.full_name for e in result.entries]
|
||||||
|
users_map = get_users_by_full_names(session, names)
|
||||||
|
for name in names:
|
||||||
|
if name not in users_map:
|
||||||
|
users_map[name] = get_or_create_user_by_full_name(session, name)
|
||||||
|
|
||||||
|
# Delete range per user (no commit)
|
||||||
for entry in result.entries:
|
for entry in result.entries:
|
||||||
user = get_or_create_user_by_full_name(session, entry.full_name)
|
user = users_map[entry.full_name]
|
||||||
delete_duties_in_range(session, user.id, from_date_str, to_date_str)
|
delete_duties_in_range(
|
||||||
|
session, user.id, from_date_str, to_date_str, commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build rows for bulk insert
|
||||||
|
duty_rows: list[dict] = []
|
||||||
|
for entry in result.entries:
|
||||||
|
user = users_map[entry.full_name]
|
||||||
for d in entry.duty_dates:
|
for d in entry.duty_dates:
|
||||||
start_at = duty_to_iso(d, hour_utc, minute_utc)
|
start_at = duty_to_iso(d, hour_utc, minute_utc)
|
||||||
d_next = d + timedelta(days=1)
|
d_next = d + timedelta(days=1)
|
||||||
end_at = duty_to_iso(d_next, hour_utc, minute_utc)
|
end_at = duty_to_iso(d_next, hour_utc, minute_utc)
|
||||||
insert_duty(session, user.id, start_at, end_at, event_type="duty")
|
duty_rows.append(
|
||||||
|
{
|
||||||
|
"user_id": user.id,
|
||||||
|
"start_at": start_at,
|
||||||
|
"end_at": end_at,
|
||||||
|
"event_type": "duty",
|
||||||
|
}
|
||||||
|
)
|
||||||
num_duty += 1
|
num_duty += 1
|
||||||
for d in entry.unavailable_dates:
|
for d in entry.unavailable_dates:
|
||||||
insert_duty(
|
duty_rows.append(
|
||||||
session,
|
{
|
||||||
user.id,
|
"user_id": user.id,
|
||||||
day_start_iso(d),
|
"start_at": day_start_iso(d),
|
||||||
day_end_iso(d),
|
"end_at": day_end_iso(d),
|
||||||
event_type="unavailable",
|
"event_type": "unavailable",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
num_unavailable += 1
|
num_unavailable += 1
|
||||||
for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates):
|
for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates):
|
||||||
insert_duty(
|
duty_rows.append(
|
||||||
session,
|
{
|
||||||
user.id,
|
"user_id": user.id,
|
||||||
day_start_iso(start_d),
|
"start_at": day_start_iso(start_d),
|
||||||
day_end_iso(end_d),
|
"end_at": day_end_iso(end_d),
|
||||||
event_type="vacation",
|
"event_type": "vacation",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
num_vacation += 1
|
num_vacation += 1
|
||||||
|
|
||||||
|
if duty_rows:
|
||||||
|
session.bulk_insert_mappings(Duty, duty_rows)
|
||||||
|
session.commit()
|
||||||
|
invalidate_duty_related_caches()
|
||||||
return (len(result.entries), num_duty, num_unavailable, num_vacation)
|
return (len(result.entries), num_duty, num_unavailable, num_vacation)
|
||||||
|
|||||||
@@ -403,6 +403,9 @@ def test_calendar_ical_ignores_unknown_query_params(
|
|||||||
"""Unknown query params (e.g. events=all) are ignored; response is duty-only."""
|
"""Unknown query params (e.g. events=all) are ignored; response is duty-only."""
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from duty_teller.cache import ics_calendar_cache
|
||||||
|
|
||||||
|
ics_calendar_cache.invalidate(("personal_ics", 1))
|
||||||
mock_user = SimpleNamespace(id=1, full_name="User A")
|
mock_user = SimpleNamespace(id=1, full_name="User A")
|
||||||
mock_get_user.return_value = mock_user
|
mock_get_user.return_value = mock_user
|
||||||
duty = SimpleNamespace(
|
duty = SimpleNamespace(
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class TestGetDutyMessageText:
|
|||||||
"""Tests for get_duty_message_text."""
|
"""Tests for get_duty_message_text."""
|
||||||
|
|
||||||
def test_no_current_duty_returns_no_duty(self, session):
|
def test_no_current_duty_returns_no_duty(self, session):
|
||||||
|
svc.duty_pin_cache.invalidate_pattern(("duty_message_text",))
|
||||||
with patch(
|
with patch(
|
||||||
"duty_teller.services.group_duty_pin_service.get_current_duty",
|
"duty_teller.services.group_duty_pin_service.get_current_duty",
|
||||||
return_value=None,
|
return_value=None,
|
||||||
@@ -113,6 +114,7 @@ class TestGetDutyMessageText:
|
|||||||
assert result == "No duty"
|
assert result == "No duty"
|
||||||
|
|
||||||
def test_with_current_duty_returns_formatted(self, session, duty, user):
|
def test_with_current_duty_returns_formatted(self, session, duty, user):
|
||||||
|
svc.duty_pin_cache.invalidate_pattern(("duty_message_text",))
|
||||||
with patch(
|
with patch(
|
||||||
"duty_teller.services.group_duty_pin_service.get_current_duty",
|
"duty_teller.services.group_duty_pin_service.get_current_duty",
|
||||||
return_value=(duty, user),
|
return_value=(duty, user),
|
||||||
@@ -130,6 +132,7 @@ class TestGetNextShiftEndUtc:
|
|||||||
"""Tests for get_next_shift_end_utc."""
|
"""Tests for get_next_shift_end_utc."""
|
||||||
|
|
||||||
def test_no_next_shift_returns_none(self, session):
|
def test_no_next_shift_returns_none(self, session):
|
||||||
|
svc.duty_pin_cache.invalidate(("next_shift_end",))
|
||||||
with patch(
|
with patch(
|
||||||
"duty_teller.services.group_duty_pin_service.get_next_shift_end",
|
"duty_teller.services.group_duty_pin_service.get_next_shift_end",
|
||||||
return_value=None,
|
return_value=None,
|
||||||
@@ -138,6 +141,7 @@ class TestGetNextShiftEndUtc:
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_has_next_shift_returns_naive_utc(self, session):
|
def test_has_next_shift_returns_naive_utc(self, session):
|
||||||
|
svc.duty_pin_cache.invalidate(("next_shift_end",))
|
||||||
naive = datetime(2025, 2, 21, 6, 0, 0)
|
naive = datetime(2025, 2, 21, 6, 0, 0)
|
||||||
with patch(
|
with patch(
|
||||||
"duty_teller.services.group_duty_pin_service.get_next_shift_end",
|
"duty_teller.services.group_duty_pin_service.get_next_shift_end",
|
||||||
|
|||||||
@@ -143,14 +143,12 @@ async def test_update_group_pin_sends_new_unpins_pins_saves_schedules_next():
|
|||||||
context.application.job_queue.run_once = MagicMock()
|
context.application.job_queue.run_once = MagicMock()
|
||||||
|
|
||||||
with patch.object(config, "DUTY_PIN_NOTIFY", True):
|
with patch.object(config, "DUTY_PIN_NOTIFY", True):
|
||||||
with patch.object(mod, "_sync_get_message_id", return_value=1):
|
with patch.object(
|
||||||
with patch.object(
|
mod, "_sync_get_pin_refresh_data", return_value=(1, "Current duty", None)
|
||||||
mod, "_get_duty_message_text_sync", return_value="Current duty"
|
):
|
||||||
):
|
with patch.object(mod, "_schedule_next_update", AsyncMock()):
|
||||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||||
with patch.object(mod, "_schedule_next_update", AsyncMock()):
|
await mod.update_group_pin(context)
|
||||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
|
||||||
await mod.update_group_pin(context)
|
|
||||||
context.bot.send_message.assert_called_once_with(chat_id=123, text="Current duty")
|
context.bot.send_message.assert_called_once_with(chat_id=123, text="Current duty")
|
||||||
context.bot.unpin_chat_message.assert_called_once_with(chat_id=123)
|
context.bot.unpin_chat_message.assert_called_once_with(chat_id=123)
|
||||||
context.bot.pin_chat_message.assert_called_once_with(
|
context.bot.pin_chat_message.assert_called_once_with(
|
||||||
@@ -168,7 +166,9 @@ async def test_update_group_pin_no_message_id_skips():
|
|||||||
context.bot = MagicMock()
|
context.bot = MagicMock()
|
||||||
context.bot.send_message = AsyncMock()
|
context.bot.send_message = AsyncMock()
|
||||||
|
|
||||||
with patch.object(mod, "_sync_get_message_id", return_value=None):
|
with patch.object(
|
||||||
|
mod, "_sync_get_pin_refresh_data", return_value=(None, "No duty", None)
|
||||||
|
):
|
||||||
await mod.update_group_pin(context)
|
await mod.update_group_pin(context)
|
||||||
context.bot.send_message.assert_not_called()
|
context.bot.send_message.assert_not_called()
|
||||||
|
|
||||||
@@ -185,13 +185,11 @@ async def test_update_group_pin_send_raises_no_unpin_pin_schedule_still_called()
|
|||||||
context.bot.pin_chat_message = AsyncMock()
|
context.bot.pin_chat_message = AsyncMock()
|
||||||
context.application = MagicMock()
|
context.application = MagicMock()
|
||||||
|
|
||||||
with patch.object(mod, "_sync_get_message_id", return_value=2):
|
with patch.object(
|
||||||
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"):
|
mod, "_sync_get_pin_refresh_data", return_value=(2, "Text", None)
|
||||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
):
|
||||||
with patch.object(
|
with patch.object(mod, "_schedule_next_update", AsyncMock()) as mock_schedule:
|
||||||
mod, "_schedule_next_update", AsyncMock()
|
await mod.update_group_pin(context)
|
||||||
) as mock_schedule:
|
|
||||||
await mod.update_group_pin(context)
|
|
||||||
context.bot.unpin_chat_message.assert_not_called()
|
context.bot.unpin_chat_message.assert_not_called()
|
||||||
context.bot.pin_chat_message.assert_not_called()
|
context.bot.pin_chat_message.assert_not_called()
|
||||||
mock_schedule.assert_called_once_with(context.application, 111, None)
|
mock_schedule.assert_called_once_with(context.application, 111, None)
|
||||||
@@ -214,15 +212,15 @@ async def test_update_group_pin_repin_raises_still_schedules_next():
|
|||||||
context.application = MagicMock()
|
context.application = MagicMock()
|
||||||
|
|
||||||
with patch.object(config, "DUTY_PIN_NOTIFY", True):
|
with patch.object(config, "DUTY_PIN_NOTIFY", True):
|
||||||
with patch.object(mod, "_sync_get_message_id", return_value=3):
|
with patch.object(
|
||||||
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"):
|
mod, "_sync_get_pin_refresh_data", return_value=(3, "Text", None)
|
||||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
mod, "_schedule_next_update", AsyncMock()
|
mod, "_schedule_next_update", AsyncMock()
|
||||||
) as mock_schedule:
|
) as mock_schedule:
|
||||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||||
with patch.object(mod, "logger") as mock_logger:
|
with patch.object(mod, "logger") as mock_logger:
|
||||||
await mod.update_group_pin(context)
|
await mod.update_group_pin(context)
|
||||||
context.bot.send_message.assert_called_once_with(chat_id=222, text="Text")
|
context.bot.send_message.assert_called_once_with(chat_id=222, text="Text")
|
||||||
mock_save.assert_not_called()
|
mock_save.assert_not_called()
|
||||||
mock_logger.warning.assert_called_once()
|
mock_logger.warning.assert_called_once()
|
||||||
@@ -245,14 +243,14 @@ async def test_update_group_pin_duty_pin_notify_false_pins_silent():
|
|||||||
context.application = MagicMock()
|
context.application = MagicMock()
|
||||||
|
|
||||||
with patch.object(config, "DUTY_PIN_NOTIFY", False):
|
with patch.object(config, "DUTY_PIN_NOTIFY", False):
|
||||||
with patch.object(mod, "_sync_get_message_id", return_value=4):
|
with patch.object(
|
||||||
with patch.object(mod, "_get_duty_message_text_sync", return_value="Text"):
|
mod, "_sync_get_pin_refresh_data", return_value=(4, "Text", None)
|
||||||
with patch.object(mod, "_get_next_shift_end_sync", return_value=None):
|
):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
mod, "_schedule_next_update", AsyncMock()
|
mod, "_schedule_next_update", AsyncMock()
|
||||||
) as mock_schedule:
|
) as mock_schedule:
|
||||||
with patch.object(mod, "_sync_save_pin") as mock_save:
|
with patch.object(mod, "_sync_save_pin") as mock_save:
|
||||||
await mod.update_group_pin(context)
|
await mod.update_group_pin(context)
|
||||||
context.bot.send_message.assert_called_once_with(chat_id=333, text="Text")
|
context.bot.send_message.assert_called_once_with(chat_id=333, text="Text")
|
||||||
context.bot.unpin_chat_message.assert_called_once_with(chat_id=333)
|
context.bot.unpin_chat_message.assert_called_once_with(chat_id=333)
|
||||||
context.bot.pin_chat_message.assert_called_once_with(
|
context.bot.pin_chat_message.assert_called_once_with(
|
||||||
|
|||||||
Reference in New Issue
Block a user