Files
duty-teller/duty_teller/db/repository.py
Nikolay Tatarinov aa89494bd5
All checks were successful
CI / lint-and-test (push) Successful in 22s
feat: enhance calendar ICS generation with event type filtering
- Added support for filtering calendar events by type in the ICS generation API endpoint, allowing users to specify whether to include only duty shifts or all event types (duty, unavailable, vacation).
- Updated the `get_duties_for_user` function to accept an optional `event_types` parameter, enabling more flexible data retrieval based on user preferences.
- Enhanced unit tests to cover the new event type filtering functionality, ensuring correct behavior and reliability of the ICS generation process.
2026-02-20 17:47:52 +03:00

497 lines
14 KiB
Python

"""Repository: get_or_create_user, get_duties, insert_duty, get_current_duty, group_duty_pins."""
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
import duty_teller.config as config
from duty_teller.db.models import User, Duty, GroupDutyPin, CalendarSubscriptionToken
def get_user_by_telegram_id(session: Session, telegram_user_id: int) -> User | None:
"""Find user by Telegram user ID.
Args:
session: DB session.
telegram_user_id: Telegram user id.
Returns:
User or None if not found. Does not create a user.
"""
return session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
def is_admin_for_telegram_user(session: Session, telegram_user_id: int) -> bool:
"""Check if the Telegram user is admin (by username or by stored phone).
Args:
session: DB session.
telegram_user_id: Telegram user id.
Returns:
True if user is in ADMIN_USERNAMES or their stored phone is in ADMIN_PHONES.
"""
user = get_user_by_telegram_id(session, telegram_user_id)
if not user:
return False
return config.is_admin(user.username or "") or config.is_admin_by_phone(user.phone)
def get_or_create_user(
session: Session,
telegram_user_id: int,
full_name: str,
username: str | None = None,
first_name: str | None = None,
last_name: str | None = None,
) -> User:
"""Get or create user by Telegram user ID.
On create, name fields come from Telegram. On update: username is always
synced; full_name, first_name, last_name are updated only if
name_manually_edited is False (otherwise existing display name is kept).
Args:
session: DB session.
telegram_user_id: Telegram user id.
full_name: Display full name.
username: Telegram username (optional).
first_name: Telegram first name (optional).
last_name: Telegram last name (optional).
Returns:
User instance (created or updated).
"""
user = get_user_by_telegram_id(session, telegram_user_id)
if user:
user.username = username
if not user.name_manually_edited:
user.full_name = full_name
user.first_name = first_name
user.last_name = last_name
session.commit()
session.refresh(user)
return user
user = User(
telegram_user_id=telegram_user_id,
full_name=full_name,
username=username,
first_name=first_name,
last_name=last_name,
name_manually_edited=False,
)
session.add(user)
session.commit()
session.refresh(user)
return user
def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
"""Find user by exact full_name or create one (for duty-schedule import).
New users have telegram_user_id=None and name_manually_edited=True.
Args:
session: DB session.
full_name: Exact full name to match or set.
Returns:
User instance (existing or newly created).
"""
user = session.query(User).filter(User.full_name == full_name).first()
if user:
return user
user = User(
telegram_user_id=None,
full_name=full_name,
username=None,
first_name=None,
last_name=None,
name_manually_edited=True,
)
session.add(user)
session.commit()
session.refresh(user)
return user
def update_user_display_name(
session: Session,
telegram_user_id: int,
full_name: str,
first_name: str | None = None,
last_name: str | None = None,
) -> User | None:
"""Update display name and set name_manually_edited=True.
Use from API or admin when name is changed manually; subsequent
get_or_create_user will not overwrite these fields.
Args:
session: DB session.
telegram_user_id: Telegram user id.
full_name: New full name.
first_name: New first name (optional).
last_name: New last name (optional).
Returns:
Updated User or None if not found.
"""
user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
if not user:
return None
user.full_name = full_name
user.first_name = first_name
user.last_name = last_name
user.name_manually_edited = True
session.commit()
session.refresh(user)
return user
def delete_duties_in_range(
session: Session,
user_id: int,
from_date: str,
to_date: str,
) -> int:
"""Delete all duties of the user that overlap the given date range.
Args:
session: DB session.
user_id: User id.
from_date: Start date YYYY-MM-DD.
to_date: End date YYYY-MM-DD.
Returns:
Number of duties deleted.
"""
to_next = (
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
).strftime("%Y-%m-%d")
q = session.query(Duty).filter(
Duty.user_id == user_id,
Duty.start_at < to_next,
Duty.end_at >= from_date,
)
count = q.count()
q.delete(synchronize_session=False)
session.commit()
return count
def get_duties(
session: Session,
from_date: str,
to_date: str,
) -> list[tuple[Duty, str]]:
"""Return duties overlapping the given date range with user full_name.
Args:
session: DB session.
from_date: Start date YYYY-MM-DD.
to_date: End date YYYY-MM-DD.
Returns:
List of (Duty, full_name) tuples.
"""
to_date_next = (
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
).strftime("%Y-%m-%d")
q = (
session.query(Duty, User.full_name)
.join(User, Duty.user_id == User.id)
.filter(Duty.start_at < to_date_next, Duty.end_at >= from_date)
)
return list(q.all())
def get_duties_for_user(
session: Session,
user_id: int,
from_date: str,
to_date: str,
event_types: list[str] | None = None,
) -> list[tuple[Duty, str]]:
"""Return duties for one user overlapping the date range.
Optionally filter by event_type (e.g. "duty", "unavailable", "vacation").
When event_types is None, all event types are returned.
Args:
session: DB session.
user_id: User id.
from_date: Start date YYYY-MM-DD.
to_date: End date YYYY-MM-DD.
event_types: If not None, only return duties whose event_type is in this list.
Returns:
List of (Duty, full_name) tuples.
"""
to_date_next = (
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
).strftime("%Y-%m-%d")
filters = [
Duty.user_id == user_id,
Duty.start_at < to_date_next,
Duty.end_at >= from_date,
]
if event_types is not None:
filters.append(Duty.event_type.in_(event_types))
q = (
session.query(Duty, User.full_name)
.join(User, Duty.user_id == User.id)
.filter(*filters)
)
return list(q.all())
def _token_hash(token: str) -> str:
"""Return SHA256 hex digest of the token (constant-time comparison via hmac)."""
return hashlib.sha256(token.encode()).hexdigest()
def create_calendar_token(session: Session, user_id: int) -> str:
"""Create a new calendar subscription token for the user.
Any existing tokens for this user are removed. The raw token is returned
only once (not stored in plain text).
Args:
session: DB session.
user_id: User id.
Returns:
Raw token string (e.g. for URL /api/calendar/ical/{token}.ics).
"""
session.query(CalendarSubscriptionToken).filter(
CalendarSubscriptionToken.user_id == user_id
).delete(synchronize_session=False)
raw_token = secrets.token_urlsafe(32)
token_hash_val = _token_hash(raw_token)
now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
record = CalendarSubscriptionToken(
user_id=user_id,
token_hash=token_hash_val,
created_at=now_iso,
)
session.add(record)
session.commit()
return raw_token
def get_user_by_calendar_token(session: Session, token: str) -> User | None:
"""Find user by calendar subscription token.
Uses constant-time comparison to avoid timing leaks.
Args:
session: DB session.
token: Raw token from URL.
Returns:
User or None if token is invalid or not found.
"""
token_hash_val = _token_hash(token)
row = (
session.query(CalendarSubscriptionToken, User)
.join(User, CalendarSubscriptionToken.user_id == User.id)
.filter(CalendarSubscriptionToken.token_hash == token_hash_val)
.first()
)
if row is None:
return None
return row[1]
def insert_duty(
session: Session,
user_id: int,
start_at: str,
end_at: str,
event_type: str = "duty",
) -> Duty:
"""Create a duty record.
Args:
session: DB session.
user_id: User id.
start_at: Start time UTC, ISO 8601 with Z (e.g. 2025-01-15T09:00:00Z).
end_at: End time UTC, ISO 8601 with Z.
event_type: One of "duty", "unavailable", "vacation". Default "duty".
Returns:
Created Duty instance.
"""
duty = Duty(
user_id=user_id,
start_at=start_at,
end_at=end_at,
event_type=event_type,
)
session.add(duty)
session.commit()
session.refresh(duty)
return duty
def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None:
"""Return the duty and user active at the given UTC time (event_type='duty').
Args:
session: DB session.
at_utc: Point in time (timezone-aware or naive UTC).
Returns:
(Duty, User) or None if no duty at that time.
"""
from datetime import timezone
if at_utc.tzinfo is not None:
at_utc = at_utc.astimezone(timezone.utc)
now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
row = (
session.query(Duty, User)
.join(User, Duty.user_id == User.id)
.filter(
Duty.event_type == "duty",
Duty.start_at <= now_iso,
Duty.end_at > now_iso,
)
.first()
)
if row is None:
return None
return (row[0], row[1])
def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None:
"""Return the end_at of the current or next duty (event_type='duty').
Args:
session: DB session.
after_utc: Point in time (timezone-aware or naive UTC).
Returns:
End datetime (naive UTC) or None if no current or future duty.
"""
from datetime import timezone
if after_utc.tzinfo is not None:
after_utc = after_utc.astimezone(timezone.utc)
after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
current = (
session.query(Duty)
.filter(
Duty.event_type == "duty",
Duty.start_at <= after_iso,
Duty.end_at > after_iso,
)
.first()
)
if current:
return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace(
tzinfo=None
)
next_duty = (
session.query(Duty)
.filter(Duty.event_type == "duty", Duty.start_at > after_iso)
.order_by(Duty.start_at)
.first()
)
if next_duty:
return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace(
tzinfo=None
)
return None
def get_group_duty_pin(session: Session, chat_id: int) -> GroupDutyPin | None:
"""Get the pinned duty message record for a chat.
Args:
session: DB session.
chat_id: Telegram chat id.
Returns:
GroupDutyPin or None.
"""
return session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first()
def save_group_duty_pin(
session: Session, chat_id: int, message_id: int
) -> GroupDutyPin:
"""Save or update the pinned duty message for a chat.
Args:
session: DB session.
chat_id: Telegram chat id.
message_id: Message id to pin/update.
Returns:
GroupDutyPin instance (created or updated).
"""
pin = session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first()
if pin:
pin.message_id = message_id
else:
pin = GroupDutyPin(chat_id=chat_id, message_id=message_id)
session.add(pin)
session.commit()
session.refresh(pin)
return pin
def delete_group_duty_pin(session: Session, chat_id: int) -> None:
"""Remove the pinned duty message record for the chat (e.g. when bot leaves group).
Args:
session: DB session.
chat_id: Telegram chat id.
"""
session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).delete()
session.commit()
def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]:
"""Return all chat_ids that have a pinned duty message.
Used to restore update jobs on bot startup.
Args:
session: DB session.
Returns:
List of chat ids.
"""
rows = session.query(GroupDutyPin.chat_id).all()
return [r[0] for r in rows]
def set_user_phone(
session: Session, telegram_user_id: int, phone: str | None
) -> User | None:
"""Set or clear phone for user by Telegram user id.
Args:
session: DB session.
telegram_user_id: Telegram user id.
phone: Phone string or None to clear.
Returns:
Updated User or None if not found.
"""
user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
if not user:
return None
if phone is None or (isinstance(phone, str) and not phone.strip()):
user.phone = None
else:
user.phone = config.normalize_phone(phone)
session.commit()
session.refresh(user)
return user