All checks were successful
CI / lint-and-test (push) Successful in 22s
- Added support for filtering calendar events by type in the ICS generation API endpoint, allowing users to specify whether to include only duty shifts or all event types (duty, unavailable, vacation). - Updated the `get_duties_for_user` function to accept an optional `event_types` parameter, enabling more flexible data retrieval based on user preferences. - Enhanced unit tests to cover the new event type filtering functionality, ensuring correct behavior and reliability of the ICS generation process.
497 lines
14 KiB
Python
497 lines
14 KiB
Python
"""Repository: get_or_create_user, get_duties, insert_duty, get_current_duty, group_duty_pins."""
|
|
|
|
import hashlib
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
import duty_teller.config as config
|
|
from duty_teller.db.models import User, Duty, GroupDutyPin, CalendarSubscriptionToken
|
|
|
|
|
|
def get_user_by_telegram_id(session: Session, telegram_user_id: int) -> User | None:
|
|
"""Find user by Telegram user ID.
|
|
|
|
Args:
|
|
session: DB session.
|
|
telegram_user_id: Telegram user id.
|
|
|
|
Returns:
|
|
User or None if not found. Does not create a user.
|
|
"""
|
|
return session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
|
|
|
|
|
|
def is_admin_for_telegram_user(session: Session, telegram_user_id: int) -> bool:
|
|
"""Check if the Telegram user is admin (by username or by stored phone).
|
|
|
|
Args:
|
|
session: DB session.
|
|
telegram_user_id: Telegram user id.
|
|
|
|
Returns:
|
|
True if user is in ADMIN_USERNAMES or their stored phone is in ADMIN_PHONES.
|
|
"""
|
|
user = get_user_by_telegram_id(session, telegram_user_id)
|
|
if not user:
|
|
return False
|
|
return config.is_admin(user.username or "") or config.is_admin_by_phone(user.phone)
|
|
|
|
|
|
def get_or_create_user(
|
|
session: Session,
|
|
telegram_user_id: int,
|
|
full_name: str,
|
|
username: str | None = None,
|
|
first_name: str | None = None,
|
|
last_name: str | None = None,
|
|
) -> User:
|
|
"""Get or create user by Telegram user ID.
|
|
|
|
On create, name fields come from Telegram. On update: username is always
|
|
synced; full_name, first_name, last_name are updated only if
|
|
name_manually_edited is False (otherwise existing display name is kept).
|
|
|
|
Args:
|
|
session: DB session.
|
|
telegram_user_id: Telegram user id.
|
|
full_name: Display full name.
|
|
username: Telegram username (optional).
|
|
first_name: Telegram first name (optional).
|
|
last_name: Telegram last name (optional).
|
|
|
|
Returns:
|
|
User instance (created or updated).
|
|
"""
|
|
user = get_user_by_telegram_id(session, telegram_user_id)
|
|
if user:
|
|
user.username = username
|
|
if not user.name_manually_edited:
|
|
user.full_name = full_name
|
|
user.first_name = first_name
|
|
user.last_name = last_name
|
|
session.commit()
|
|
session.refresh(user)
|
|
return user
|
|
user = User(
|
|
telegram_user_id=telegram_user_id,
|
|
full_name=full_name,
|
|
username=username,
|
|
first_name=first_name,
|
|
last_name=last_name,
|
|
name_manually_edited=False,
|
|
)
|
|
session.add(user)
|
|
session.commit()
|
|
session.refresh(user)
|
|
return user
|
|
|
|
|
|
def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
|
|
"""Find user by exact full_name or create one (for duty-schedule import).
|
|
|
|
New users have telegram_user_id=None and name_manually_edited=True.
|
|
|
|
Args:
|
|
session: DB session.
|
|
full_name: Exact full name to match or set.
|
|
|
|
Returns:
|
|
User instance (existing or newly created).
|
|
"""
|
|
user = session.query(User).filter(User.full_name == full_name).first()
|
|
if user:
|
|
return user
|
|
user = User(
|
|
telegram_user_id=None,
|
|
full_name=full_name,
|
|
username=None,
|
|
first_name=None,
|
|
last_name=None,
|
|
name_manually_edited=True,
|
|
)
|
|
session.add(user)
|
|
session.commit()
|
|
session.refresh(user)
|
|
return user
|
|
|
|
|
|
def update_user_display_name(
|
|
session: Session,
|
|
telegram_user_id: int,
|
|
full_name: str,
|
|
first_name: str | None = None,
|
|
last_name: str | None = None,
|
|
) -> User | None:
|
|
"""Update display name and set name_manually_edited=True.
|
|
|
|
Use from API or admin when name is changed manually; subsequent
|
|
get_or_create_user will not overwrite these fields.
|
|
|
|
Args:
|
|
session: DB session.
|
|
telegram_user_id: Telegram user id.
|
|
full_name: New full name.
|
|
first_name: New first name (optional).
|
|
last_name: New last name (optional).
|
|
|
|
Returns:
|
|
Updated User or None if not found.
|
|
"""
|
|
user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
|
|
if not user:
|
|
return None
|
|
user.full_name = full_name
|
|
user.first_name = first_name
|
|
user.last_name = last_name
|
|
user.name_manually_edited = True
|
|
session.commit()
|
|
session.refresh(user)
|
|
return user
|
|
|
|
|
|
def delete_duties_in_range(
|
|
session: Session,
|
|
user_id: int,
|
|
from_date: str,
|
|
to_date: str,
|
|
) -> int:
|
|
"""Delete all duties of the user that overlap the given date range.
|
|
|
|
Args:
|
|
session: DB session.
|
|
user_id: User id.
|
|
from_date: Start date YYYY-MM-DD.
|
|
to_date: End date YYYY-MM-DD.
|
|
|
|
Returns:
|
|
Number of duties deleted.
|
|
"""
|
|
to_next = (
|
|
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
|
).strftime("%Y-%m-%d")
|
|
q = session.query(Duty).filter(
|
|
Duty.user_id == user_id,
|
|
Duty.start_at < to_next,
|
|
Duty.end_at >= from_date,
|
|
)
|
|
count = q.count()
|
|
q.delete(synchronize_session=False)
|
|
session.commit()
|
|
return count
|
|
|
|
|
|
def get_duties(
|
|
session: Session,
|
|
from_date: str,
|
|
to_date: str,
|
|
) -> list[tuple[Duty, str]]:
|
|
"""Return duties overlapping the given date range with user full_name.
|
|
|
|
Args:
|
|
session: DB session.
|
|
from_date: Start date YYYY-MM-DD.
|
|
to_date: End date YYYY-MM-DD.
|
|
|
|
Returns:
|
|
List of (Duty, full_name) tuples.
|
|
"""
|
|
to_date_next = (
|
|
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
|
).strftime("%Y-%m-%d")
|
|
q = (
|
|
session.query(Duty, User.full_name)
|
|
.join(User, Duty.user_id == User.id)
|
|
.filter(Duty.start_at < to_date_next, Duty.end_at >= from_date)
|
|
)
|
|
return list(q.all())
|
|
|
|
|
|
def get_duties_for_user(
|
|
session: Session,
|
|
user_id: int,
|
|
from_date: str,
|
|
to_date: str,
|
|
event_types: list[str] | None = None,
|
|
) -> list[tuple[Duty, str]]:
|
|
"""Return duties for one user overlapping the date range.
|
|
|
|
Optionally filter by event_type (e.g. "duty", "unavailable", "vacation").
|
|
When event_types is None, all event types are returned.
|
|
|
|
Args:
|
|
session: DB session.
|
|
user_id: User id.
|
|
from_date: Start date YYYY-MM-DD.
|
|
to_date: End date YYYY-MM-DD.
|
|
event_types: If not None, only return duties whose event_type is in this list.
|
|
|
|
Returns:
|
|
List of (Duty, full_name) tuples.
|
|
"""
|
|
to_date_next = (
|
|
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
|
).strftime("%Y-%m-%d")
|
|
filters = [
|
|
Duty.user_id == user_id,
|
|
Duty.start_at < to_date_next,
|
|
Duty.end_at >= from_date,
|
|
]
|
|
if event_types is not None:
|
|
filters.append(Duty.event_type.in_(event_types))
|
|
q = (
|
|
session.query(Duty, User.full_name)
|
|
.join(User, Duty.user_id == User.id)
|
|
.filter(*filters)
|
|
)
|
|
return list(q.all())
|
|
|
|
|
|
def _token_hash(token: str) -> str:
|
|
"""Return SHA256 hex digest of the token (constant-time comparison via hmac)."""
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
|
|
def create_calendar_token(session: Session, user_id: int) -> str:
|
|
"""Create a new calendar subscription token for the user.
|
|
|
|
Any existing tokens for this user are removed. The raw token is returned
|
|
only once (not stored in plain text).
|
|
|
|
Args:
|
|
session: DB session.
|
|
user_id: User id.
|
|
|
|
Returns:
|
|
Raw token string (e.g. for URL /api/calendar/ical/{token}.ics).
|
|
"""
|
|
session.query(CalendarSubscriptionToken).filter(
|
|
CalendarSubscriptionToken.user_id == user_id
|
|
).delete(synchronize_session=False)
|
|
raw_token = secrets.token_urlsafe(32)
|
|
token_hash_val = _token_hash(raw_token)
|
|
now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
record = CalendarSubscriptionToken(
|
|
user_id=user_id,
|
|
token_hash=token_hash_val,
|
|
created_at=now_iso,
|
|
)
|
|
session.add(record)
|
|
session.commit()
|
|
return raw_token
|
|
|
|
|
|
def get_user_by_calendar_token(session: Session, token: str) -> User | None:
|
|
"""Find user by calendar subscription token.
|
|
|
|
Uses constant-time comparison to avoid timing leaks.
|
|
|
|
Args:
|
|
session: DB session.
|
|
token: Raw token from URL.
|
|
|
|
Returns:
|
|
User or None if token is invalid or not found.
|
|
"""
|
|
token_hash_val = _token_hash(token)
|
|
row = (
|
|
session.query(CalendarSubscriptionToken, User)
|
|
.join(User, CalendarSubscriptionToken.user_id == User.id)
|
|
.filter(CalendarSubscriptionToken.token_hash == token_hash_val)
|
|
.first()
|
|
)
|
|
if row is None:
|
|
return None
|
|
return row[1]
|
|
|
|
|
|
def insert_duty(
|
|
session: Session,
|
|
user_id: int,
|
|
start_at: str,
|
|
end_at: str,
|
|
event_type: str = "duty",
|
|
) -> Duty:
|
|
"""Create a duty record.
|
|
|
|
Args:
|
|
session: DB session.
|
|
user_id: User id.
|
|
start_at: Start time UTC, ISO 8601 with Z (e.g. 2025-01-15T09:00:00Z).
|
|
end_at: End time UTC, ISO 8601 with Z.
|
|
event_type: One of "duty", "unavailable", "vacation". Default "duty".
|
|
|
|
Returns:
|
|
Created Duty instance.
|
|
"""
|
|
duty = Duty(
|
|
user_id=user_id,
|
|
start_at=start_at,
|
|
end_at=end_at,
|
|
event_type=event_type,
|
|
)
|
|
session.add(duty)
|
|
session.commit()
|
|
session.refresh(duty)
|
|
return duty
|
|
|
|
|
|
def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None:
|
|
"""Return the duty and user active at the given UTC time (event_type='duty').
|
|
|
|
Args:
|
|
session: DB session.
|
|
at_utc: Point in time (timezone-aware or naive UTC).
|
|
|
|
Returns:
|
|
(Duty, User) or None if no duty at that time.
|
|
"""
|
|
from datetime import timezone
|
|
|
|
if at_utc.tzinfo is not None:
|
|
at_utc = at_utc.astimezone(timezone.utc)
|
|
now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
|
|
row = (
|
|
session.query(Duty, User)
|
|
.join(User, Duty.user_id == User.id)
|
|
.filter(
|
|
Duty.event_type == "duty",
|
|
Duty.start_at <= now_iso,
|
|
Duty.end_at > now_iso,
|
|
)
|
|
.first()
|
|
)
|
|
if row is None:
|
|
return None
|
|
return (row[0], row[1])
|
|
|
|
|
|
def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None:
|
|
"""Return the end_at of the current or next duty (event_type='duty').
|
|
|
|
Args:
|
|
session: DB session.
|
|
after_utc: Point in time (timezone-aware or naive UTC).
|
|
|
|
Returns:
|
|
End datetime (naive UTC) or None if no current or future duty.
|
|
"""
|
|
from datetime import timezone
|
|
|
|
if after_utc.tzinfo is not None:
|
|
after_utc = after_utc.astimezone(timezone.utc)
|
|
after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
|
|
current = (
|
|
session.query(Duty)
|
|
.filter(
|
|
Duty.event_type == "duty",
|
|
Duty.start_at <= after_iso,
|
|
Duty.end_at > after_iso,
|
|
)
|
|
.first()
|
|
)
|
|
if current:
|
|
return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace(
|
|
tzinfo=None
|
|
)
|
|
next_duty = (
|
|
session.query(Duty)
|
|
.filter(Duty.event_type == "duty", Duty.start_at > after_iso)
|
|
.order_by(Duty.start_at)
|
|
.first()
|
|
)
|
|
if next_duty:
|
|
return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace(
|
|
tzinfo=None
|
|
)
|
|
return None
|
|
|
|
|
|
def get_group_duty_pin(session: Session, chat_id: int) -> GroupDutyPin | None:
|
|
"""Get the pinned duty message record for a chat.
|
|
|
|
Args:
|
|
session: DB session.
|
|
chat_id: Telegram chat id.
|
|
|
|
Returns:
|
|
GroupDutyPin or None.
|
|
"""
|
|
return session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first()
|
|
|
|
|
|
def save_group_duty_pin(
|
|
session: Session, chat_id: int, message_id: int
|
|
) -> GroupDutyPin:
|
|
"""Save or update the pinned duty message for a chat.
|
|
|
|
Args:
|
|
session: DB session.
|
|
chat_id: Telegram chat id.
|
|
message_id: Message id to pin/update.
|
|
|
|
Returns:
|
|
GroupDutyPin instance (created or updated).
|
|
"""
|
|
pin = session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first()
|
|
if pin:
|
|
pin.message_id = message_id
|
|
else:
|
|
pin = GroupDutyPin(chat_id=chat_id, message_id=message_id)
|
|
session.add(pin)
|
|
session.commit()
|
|
session.refresh(pin)
|
|
return pin
|
|
|
|
|
|
def delete_group_duty_pin(session: Session, chat_id: int) -> None:
|
|
"""Remove the pinned duty message record for the chat (e.g. when bot leaves group).
|
|
|
|
Args:
|
|
session: DB session.
|
|
chat_id: Telegram chat id.
|
|
"""
|
|
session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).delete()
|
|
session.commit()
|
|
|
|
|
|
def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]:
|
|
"""Return all chat_ids that have a pinned duty message.
|
|
|
|
Used to restore update jobs on bot startup.
|
|
|
|
Args:
|
|
session: DB session.
|
|
|
|
Returns:
|
|
List of chat ids.
|
|
"""
|
|
rows = session.query(GroupDutyPin.chat_id).all()
|
|
return [r[0] for r in rows]
|
|
|
|
|
|
def set_user_phone(
|
|
session: Session, telegram_user_id: int, phone: str | None
|
|
) -> User | None:
|
|
"""Set or clear phone for user by Telegram user id.
|
|
|
|
Args:
|
|
session: DB session.
|
|
telegram_user_id: Telegram user id.
|
|
phone: Phone string or None to clear.
|
|
|
|
Returns:
|
|
Updated User or None if not found.
|
|
"""
|
|
user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
|
|
if not user:
|
|
return None
|
|
if phone is None or (isinstance(phone, str) and not phone.strip()):
|
|
user.phone = None
|
|
else:
|
|
user.phone = config.normalize_phone(phone)
|
|
session.commit()
|
|
session.refresh(user)
|
|
return user
|