"""Repository: get_or_create_user, get_duties, insert_duty, get_current_duty, group_duty_pins.""" import hashlib import secrets from datetime import datetime, timedelta, timezone from sqlalchemy.orm import Session import duty_teller.config as config from duty_teller.db.models import User, Duty, GroupDutyPin, CalendarSubscriptionToken def get_user_by_telegram_id(session: Session, telegram_user_id: int) -> User | None: """Find user by Telegram user ID. Args: session: DB session. telegram_user_id: Telegram user id. Returns: User or None if not found. Does not create a user. """ return session.query(User).filter(User.telegram_user_id == telegram_user_id).first() def is_admin_for_telegram_user(session: Session, telegram_user_id: int) -> bool: """Check if the Telegram user is admin (by username or by stored phone). Args: session: DB session. telegram_user_id: Telegram user id. Returns: True if user is in ADMIN_USERNAMES or their stored phone is in ADMIN_PHONES. """ user = get_user_by_telegram_id(session, telegram_user_id) if not user: return False return config.is_admin(user.username or "") or config.is_admin_by_phone(user.phone) def get_or_create_user( session: Session, telegram_user_id: int, full_name: str, username: str | None = None, first_name: str | None = None, last_name: str | None = None, ) -> User: """Get or create user by Telegram user ID. On create, name fields come from Telegram. On update: username is always synced; full_name, first_name, last_name are updated only if name_manually_edited is False (otherwise existing display name is kept). Args: session: DB session. telegram_user_id: Telegram user id. full_name: Display full name. username: Telegram username (optional). first_name: Telegram first name (optional). last_name: Telegram last name (optional). Returns: User instance (created or updated). """ user = get_user_by_telegram_id(session, telegram_user_id) if user: user.username = username if not user.name_manually_edited: user.full_name = full_name user.first_name = first_name user.last_name = last_name session.commit() session.refresh(user) return user user = User( telegram_user_id=telegram_user_id, full_name=full_name, username=username, first_name=first_name, last_name=last_name, name_manually_edited=False, ) session.add(user) session.commit() session.refresh(user) return user def get_or_create_user_by_full_name(session: Session, full_name: str) -> User: """Find user by exact full_name or create one (for duty-schedule import). New users have telegram_user_id=None and name_manually_edited=True. Args: session: DB session. full_name: Exact full name to match or set. Returns: User instance (existing or newly created). """ user = session.query(User).filter(User.full_name == full_name).first() if user: return user user = User( telegram_user_id=None, full_name=full_name, username=None, first_name=None, last_name=None, name_manually_edited=True, ) session.add(user) session.commit() session.refresh(user) return user def update_user_display_name( session: Session, telegram_user_id: int, full_name: str, first_name: str | None = None, last_name: str | None = None, ) -> User | None: """Update display name and set name_manually_edited=True. Use from API or admin when name is changed manually; subsequent get_or_create_user will not overwrite these fields. Args: session: DB session. telegram_user_id: Telegram user id. full_name: New full name. first_name: New first name (optional). last_name: New last name (optional). Returns: Updated User or None if not found. """ user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if not user: return None user.full_name = full_name user.first_name = first_name user.last_name = last_name user.name_manually_edited = True session.commit() session.refresh(user) return user def delete_duties_in_range( session: Session, user_id: int, from_date: str, to_date: str, ) -> int: """Delete all duties of the user that overlap the given date range. Args: session: DB session. user_id: User id. from_date: Start date YYYY-MM-DD. to_date: End date YYYY-MM-DD. Returns: Number of duties deleted. """ to_next = ( datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) ).strftime("%Y-%m-%d") q = session.query(Duty).filter( Duty.user_id == user_id, Duty.start_at < to_next, Duty.end_at >= from_date, ) count = q.count() q.delete(synchronize_session=False) session.commit() return count def get_duties( session: Session, from_date: str, to_date: str, ) -> list[tuple[Duty, str]]: """Return duties overlapping the given date range with user full_name. Args: session: DB session. from_date: Start date YYYY-MM-DD. to_date: End date YYYY-MM-DD. Returns: List of (Duty, full_name) tuples. """ to_date_next = ( datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) ).strftime("%Y-%m-%d") q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) .filter(Duty.start_at < to_date_next, Duty.end_at >= from_date) ) return list(q.all()) def get_duties_for_user( session: Session, user_id: int, from_date: str, to_date: str, ) -> list[tuple[Duty, str]]: """Return duties for one user overlapping the date range. Args: session: DB session. user_id: User id. from_date: Start date YYYY-MM-DD. to_date: End date YYYY-MM-DD. Returns: List of (Duty, full_name) tuples. """ to_date_next = ( datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) ).strftime("%Y-%m-%d") q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) .filter( Duty.user_id == user_id, Duty.start_at < to_date_next, Duty.end_at >= from_date, ) ) return list(q.all()) def _token_hash(token: str) -> str: """Return SHA256 hex digest of the token (constant-time comparison via hmac).""" return hashlib.sha256(token.encode()).hexdigest() def create_calendar_token(session: Session, user_id: int) -> str: """Create a new calendar subscription token for the user. Any existing tokens for this user are removed. The raw token is returned only once (not stored in plain text). Args: session: DB session. user_id: User id. Returns: Raw token string (e.g. for URL /api/calendar/ical/{token}.ics). """ session.query(CalendarSubscriptionToken).filter( CalendarSubscriptionToken.user_id == user_id ).delete(synchronize_session=False) raw_token = secrets.token_urlsafe(32) token_hash_val = _token_hash(raw_token) now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") record = CalendarSubscriptionToken( user_id=user_id, token_hash=token_hash_val, created_at=now_iso, ) session.add(record) session.commit() return raw_token def get_user_by_calendar_token(session: Session, token: str) -> User | None: """Find user by calendar subscription token. Uses constant-time comparison to avoid timing leaks. Args: session: DB session. token: Raw token from URL. Returns: User or None if token is invalid or not found. """ token_hash_val = _token_hash(token) row = ( session.query(CalendarSubscriptionToken, User) .join(User, CalendarSubscriptionToken.user_id == User.id) .filter(CalendarSubscriptionToken.token_hash == token_hash_val) .first() ) if row is None: return None return row[1] def insert_duty( session: Session, user_id: int, start_at: str, end_at: str, event_type: str = "duty", ) -> Duty: """Create a duty record. Args: session: DB session. user_id: User id. start_at: Start time UTC, ISO 8601 with Z (e.g. 2025-01-15T09:00:00Z). end_at: End time UTC, ISO 8601 with Z. event_type: One of "duty", "unavailable", "vacation". Default "duty". Returns: Created Duty instance. """ duty = Duty( user_id=user_id, start_at=start_at, end_at=end_at, event_type=event_type, ) session.add(duty) session.commit() session.refresh(duty) return duty def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None: """Return the duty and user active at the given UTC time (event_type='duty'). Args: session: DB session. at_utc: Point in time (timezone-aware or naive UTC). Returns: (Duty, User) or None if no duty at that time. """ from datetime import timezone if at_utc.tzinfo is not None: at_utc = at_utc.astimezone(timezone.utc) now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" row = ( session.query(Duty, User) .join(User, Duty.user_id == User.id) .filter( Duty.event_type == "duty", Duty.start_at <= now_iso, Duty.end_at > now_iso, ) .first() ) if row is None: return None return (row[0], row[1]) def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None: """Return the end_at of the current or next duty (event_type='duty'). Args: session: DB session. after_utc: Point in time (timezone-aware or naive UTC). Returns: End datetime (naive UTC) or None if no current or future duty. """ from datetime import timezone if after_utc.tzinfo is not None: after_utc = after_utc.astimezone(timezone.utc) after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" current = ( session.query(Duty) .filter( Duty.event_type == "duty", Duty.start_at <= after_iso, Duty.end_at > after_iso, ) .first() ) if current: return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace( tzinfo=None ) next_duty = ( session.query(Duty) .filter(Duty.event_type == "duty", Duty.start_at > after_iso) .order_by(Duty.start_at) .first() ) if next_duty: return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace( tzinfo=None ) return None def get_group_duty_pin(session: Session, chat_id: int) -> GroupDutyPin | None: """Get the pinned duty message record for a chat. Args: session: DB session. chat_id: Telegram chat id. Returns: GroupDutyPin or None. """ return session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() def save_group_duty_pin( session: Session, chat_id: int, message_id: int ) -> GroupDutyPin: """Save or update the pinned duty message for a chat. Args: session: DB session. chat_id: Telegram chat id. message_id: Message id to pin/update. Returns: GroupDutyPin instance (created or updated). """ pin = session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() if pin: pin.message_id = message_id else: pin = GroupDutyPin(chat_id=chat_id, message_id=message_id) session.add(pin) session.commit() session.refresh(pin) return pin def delete_group_duty_pin(session: Session, chat_id: int) -> None: """Remove the pinned duty message record for the chat (e.g. when bot leaves group). Args: session: DB session. chat_id: Telegram chat id. """ session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).delete() session.commit() def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]: """Return all chat_ids that have a pinned duty message. Used to restore update jobs on bot startup. Args: session: DB session. Returns: List of chat ids. """ rows = session.query(GroupDutyPin.chat_id).all() return [r[0] for r in rows] def set_user_phone( session: Session, telegram_user_id: int, phone: str | None ) -> User | None: """Set or clear phone for user by Telegram user id. Args: session: DB session. telegram_user_id: Telegram user id. phone: Phone string or None to clear. Returns: Updated User or None if not found. """ user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if not user: return None if phone is None or (isinstance(phone, str) and not phone.strip()): user.phone = None else: user.phone = config.normalize_phone(phone) session.commit() session.refresh(user) return user