"""Repository: get_or_create_user, get_duties, insert_duty, get_current_duty, group_duty_pins.""" import hashlib import secrets from datetime import datetime, timezone from sqlalchemy.orm import Session import duty_teller.config as config from duty_teller.db.models import ( User, Duty, GroupDutyPin, TrustedGroup, CalendarSubscriptionToken, Role, ) from duty_teller.utils.dates import parse_utc_iso_naive, to_date_exclusive_iso # Role names stored in DB (table roles). ROLE_USER = "user" ROLE_ADMIN = "admin" def get_user_by_telegram_id(session: Session, telegram_user_id: int) -> User | None: """Find user by Telegram user ID. Args: session: DB session. telegram_user_id: Telegram user id. Returns: User or None if not found. Does not create a user. """ return session.query(User).filter(User.telegram_user_id == telegram_user_id).first() def get_user_by_username(session: Session, username: str) -> User | None: """Find user by Telegram username (case-insensitive, optional @ prefix). Args: session: DB session. username: Telegram username with or without @. Returns: User or None if not found. """ from sqlalchemy import func name = (username or "").strip().lstrip("@").lower() if not name: return None return session.query(User).filter(func.lower(User.username) == name).first() def get_user_role(session: Session, user_id: int) -> str | None: """Return role name for user by internal user id, or None if no role. Args: session: DB session. user_id: Internal user id (users.id). Returns: Role name ('user' or 'admin') or None. """ user = session.get(User, user_id) if not user or not user.role: return None return user.role.name def is_admin_for_telegram_user(session: Session, telegram_user_id: int) -> bool: """Check if the Telegram user is admin. If user has a role in DB, returns True only for role 'admin'. If user has no role in DB, fallback: True if in ADMIN_USERNAMES or ADMIN_PHONES. Args: session: DB session. telegram_user_id: Telegram user id. Returns: True if admin (by DB role or env fallback). """ user = get_user_by_telegram_id(session, telegram_user_id) if not user: return False if user.role is not None: return user.role.name == ROLE_ADMIN return config.is_admin(user.username or "") or config.is_admin_by_phone(user.phone) def can_access_miniapp_for_telegram_user( session: Session, telegram_user_id: int ) -> bool: """Check if Telegram user can access the calendar miniapp. Access if: user has role 'user' or 'admin' in DB, or (no role in DB and env fallback: in ADMIN_USERNAMES or ADMIN_PHONES). No user in DB -> no access. Args: session: DB session. telegram_user_id: Telegram user id. Returns: True if user may open the miniapp. """ user = get_user_by_telegram_id(session, telegram_user_id) return can_access_miniapp_for_user(session, user) if user else False def can_access_miniapp_for_user(session: Session, user: User | None) -> bool: """Check if user (already loaded) can access the calendar miniapp. Access if: user has role 'user' or 'admin' in DB, or (no role in DB and env fallback: in ADMIN_USERNAMES or ADMIN_PHONES). Args: session: DB session (unused; kept for API consistency). user: User instance or None. Returns: True if user may open the miniapp. """ if not user: return False if user.role is not None: return user.role.name in (ROLE_USER, ROLE_ADMIN) return config.is_admin(user.username or "") or config.is_admin_by_phone(user.phone) def set_user_role(session: Session, user_id: int, role_name: str) -> User | None: """Set user role by internal user id and role name. Args: session: DB session. user_id: Internal user id (users.id). role_name: 'user' or 'admin'. Returns: Updated User or None if user or role not found. """ user = session.get(User, user_id) if not user: return None role = session.query(Role).filter(Role.name == role_name).first() if not role: return None user.role_id = role.id session.commit() session.refresh(user) return user def get_or_create_user( session: Session, telegram_user_id: int, full_name: str, username: str | None = None, first_name: str | None = None, last_name: str | None = None, ) -> User: """Get or create user by Telegram user ID. On create, name fields come from Telegram. On update: username is always synced; full_name, first_name, last_name are updated only if name_manually_edited is False (otherwise existing display name is kept). Args: session: DB session. telegram_user_id: Telegram user id. full_name: Display full name. username: Telegram username (optional). first_name: Telegram first name (optional). last_name: Telegram last name (optional). Returns: User instance (created or updated). """ user = get_user_by_telegram_id(session, telegram_user_id) if user: user.username = username if not user.name_manually_edited: user.full_name = full_name user.first_name = first_name user.last_name = last_name session.commit() session.refresh(user) return user user = User( telegram_user_id=telegram_user_id, full_name=full_name, username=username, first_name=first_name, last_name=last_name, name_manually_edited=False, ) session.add(user) session.commit() session.refresh(user) return user def get_or_create_user_by_full_name(session: Session, full_name: str) -> User: """Find user by exact full_name or create one (for duty-schedule import). New users have telegram_user_id=None and name_manually_edited=True. Args: session: DB session. full_name: Exact full name to match or set. Returns: User instance (existing or newly created). """ user = session.query(User).filter(User.full_name == full_name).first() if user: return user user = User( telegram_user_id=None, full_name=full_name, username=None, first_name=None, last_name=None, name_manually_edited=True, ) session.add(user) session.commit() session.refresh(user) return user def get_users_by_full_names(session: Session, full_names: list[str]) -> dict[str, User]: """Get users by full_name. Returns dict full_name -> User. Does not create missing. Args: session: DB session. full_names: List of full names to look up. Returns: Dict mapping full_name to User for found users. """ if not full_names: return {} users = session.query(User).filter(User.full_name.in_(full_names)).all() return {u.full_name: u for u in users} def update_user_display_name( session: Session, telegram_user_id: int, full_name: str, first_name: str | None = None, last_name: str | None = None, ) -> User | None: """Update display name and set name_manually_edited=True. Use from API or admin when name is changed manually; subsequent get_or_create_user will not overwrite these fields. Args: session: DB session. telegram_user_id: Telegram user id. full_name: New full name. first_name: New first name (optional). last_name: New last name (optional). Returns: Updated User or None if not found. """ user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if not user: return None user.full_name = full_name user.first_name = first_name user.last_name = last_name user.name_manually_edited = True session.commit() session.refresh(user) return user def delete_duties_in_range( session: Session, user_id: int, from_date: str, to_date: str, *, commit: bool = True, ) -> int: """Delete all duties of the user that overlap the given date range. Args: session: DB session. user_id: User id. from_date: Start date YYYY-MM-DD. to_date: End date YYYY-MM-DD. commit: If True, commit immediately. If False, caller commits (for batch import). Returns: Number of duties deleted. """ to_next = to_date_exclusive_iso(to_date) q = session.query(Duty).filter( Duty.user_id == user_id, Duty.start_at < to_next, Duty.end_at >= from_date, ) count = q.count() q.delete(synchronize_session=False) if commit: session.commit() return count def get_duties( session: Session, from_date: str, to_date: str, ) -> list[tuple[Duty, str]]: """Return duties overlapping the given date range with user full_name. Args: session: DB session. from_date: Start date YYYY-MM-DD. to_date: End date YYYY-MM-DD. Returns: List of (Duty, full_name) tuples. """ to_date_next = to_date_exclusive_iso(to_date) q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) .filter(Duty.start_at < to_date_next, Duty.end_at >= from_date) ) return list(q.all()) def get_duties_for_user( session: Session, user_id: int, from_date: str, to_date: str, event_types: list[str] | None = None, ) -> list[tuple[Duty, str]]: """Return duties for one user overlapping the date range. Optionally filter by event_type (e.g. "duty", "unavailable", "vacation"). When event_types is None, all event types are returned. Args: session: DB session. user_id: User id. from_date: Start date YYYY-MM-DD. to_date: End date YYYY-MM-DD. event_types: If not None, only return duties whose event_type is in this list. Returns: List of (Duty, full_name) tuples. """ to_date_next = to_date_exclusive_iso(to_date) filters = [ Duty.user_id == user_id, Duty.start_at < to_date_next, Duty.end_at >= from_date, ] if event_types is not None: filters.append(Duty.event_type.in_(event_types)) q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) .filter(*filters) ) return list(q.all()) def _token_hash(token: str) -> str: """Return SHA256 hex digest of the token (constant-time comparison via hmac).""" return hashlib.sha256(token.encode()).hexdigest() def create_calendar_token(session: Session, user_id: int) -> str: """Create a new calendar subscription token for the user. Any existing tokens for this user are removed. The raw token is returned only once (not stored in plain text). Args: session: DB session. user_id: User id. Returns: Raw token string (e.g. for URL /api/calendar/ical/{token}.ics). """ session.query(CalendarSubscriptionToken).filter( CalendarSubscriptionToken.user_id == user_id ).delete(synchronize_session=False) raw_token = secrets.token_urlsafe(32) token_hash_val = _token_hash(raw_token) now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") record = CalendarSubscriptionToken( user_id=user_id, token_hash=token_hash_val, created_at=now_iso, ) session.add(record) session.commit() return raw_token def get_user_by_calendar_token(session: Session, token: str) -> User | None: """Find user by calendar subscription token. Uses constant-time comparison to avoid timing leaks. Args: session: DB session. token: Raw token from URL. Returns: User or None if token is invalid or not found. """ token_hash_val = _token_hash(token) row = ( session.query(CalendarSubscriptionToken, User) .join(User, CalendarSubscriptionToken.user_id == User.id) .filter(CalendarSubscriptionToken.token_hash == token_hash_val) .first() ) if row is None: return None return row[1] def insert_duty( session: Session, user_id: int, start_at: str, end_at: str, event_type: str = "duty", ) -> Duty: """Create a duty record. Args: session: DB session. user_id: User id. start_at: Start time UTC, ISO 8601 with Z (e.g. 2025-01-15T09:00:00Z). end_at: End time UTC, ISO 8601 with Z. event_type: One of "duty", "unavailable", "vacation". Default "duty". Returns: Created Duty instance. """ duty = Duty( user_id=user_id, start_at=start_at, end_at=end_at, event_type=event_type, ) session.add(duty) session.commit() session.refresh(duty) return duty def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None: """Return the duty and user active at the given UTC time (event_type='duty'). Args: session: DB session. at_utc: Point in time (timezone-aware or naive UTC). Returns: (Duty, User) or None if no duty at that time. """ from datetime import timezone if at_utc.tzinfo is not None: at_utc = at_utc.astimezone(timezone.utc) now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" row = ( session.query(Duty, User) .join(User, Duty.user_id == User.id) .filter( Duty.event_type == "duty", Duty.start_at <= now_iso, Duty.end_at > now_iso, ) .first() ) if row is None: return None return (row[0], row[1]) def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None: """Return the end_at of the current or next duty (event_type='duty'). Args: session: DB session. after_utc: Point in time (timezone-aware or naive UTC). Returns: End datetime (naive UTC) or None if no current or future duty. """ from datetime import timezone if after_utc.tzinfo is not None: after_utc = after_utc.astimezone(timezone.utc) after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" current = ( session.query(Duty) .filter( Duty.event_type == "duty", Duty.start_at <= after_iso, Duty.end_at > after_iso, ) .first() ) if current: return parse_utc_iso_naive(current.end_at) next_duty = ( session.query(Duty) .filter(Duty.event_type == "duty", Duty.start_at > after_iso) .order_by(Duty.start_at) .first() ) if next_duty: return parse_utc_iso_naive(next_duty.end_at) return None def get_group_duty_pin(session: Session, chat_id: int) -> GroupDutyPin | None: """Get the pinned duty message record for a chat. Args: session: DB session. chat_id: Telegram chat id. Returns: GroupDutyPin or None. """ return session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() def save_group_duty_pin( session: Session, chat_id: int, message_id: int ) -> GroupDutyPin: """Save or update the pinned duty message for a chat. Args: session: DB session. chat_id: Telegram chat id. message_id: Message id to pin/update. Returns: GroupDutyPin instance (created or updated). """ pin = session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() if pin: pin.message_id = message_id else: pin = GroupDutyPin(chat_id=chat_id, message_id=message_id) session.add(pin) session.commit() session.refresh(pin) return pin def delete_group_duty_pin(session: Session, chat_id: int) -> None: """Remove the pinned duty message record for the chat (e.g. when bot leaves group). Args: session: DB session. chat_id: Telegram chat id. """ session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).delete() session.commit() def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]: """Return all chat_ids that have a pinned duty message. Used to restore update jobs on bot startup. Args: session: DB session. Returns: List of chat ids. """ rows = session.query(GroupDutyPin.chat_id).all() return [r[0] for r in rows] def is_trusted_group(session: Session, chat_id: int) -> bool: """Check if the chat is in the trusted groups list. Args: session: DB session. chat_id: Telegram chat id. Returns: True if the group is trusted. """ return ( session.query(TrustedGroup).filter(TrustedGroup.chat_id == chat_id).first() is not None ) def add_trusted_group( session: Session, chat_id: int, added_by_user_id: int | None = None ) -> TrustedGroup: """Add a group to the trusted list. Args: session: DB session. chat_id: Telegram chat id. added_by_user_id: Telegram user id of the admin who added the group (optional). Returns: Created TrustedGroup instance. """ now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") record = TrustedGroup( chat_id=chat_id, added_by_user_id=added_by_user_id, added_at=now_iso, ) session.add(record) session.commit() session.refresh(record) return record def remove_trusted_group(session: Session, chat_id: int) -> None: """Remove a group from the trusted list. Args: session: DB session. chat_id: Telegram chat id. """ session.query(TrustedGroup).filter(TrustedGroup.chat_id == chat_id).delete() session.commit() def get_all_trusted_group_ids(session: Session) -> list[int]: """Return all chat_ids that are trusted. Args: session: DB session. Returns: List of trusted chat ids. """ rows = session.query(TrustedGroup.chat_id).all() return [r[0] for r in rows] def set_user_phone( session: Session, telegram_user_id: int, phone: str | None ) -> User | None: """Set or clear phone for user by Telegram user id. Args: session: DB session. telegram_user_id: Telegram user id. phone: Phone string or None to clear. Returns: Updated User or None if not found. """ user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if not user: return None if phone is None or (isinstance(phone, str) and not phone.strip()): user.phone = None else: user.phone = config.normalize_phone(phone) session.commit() session.refresh(user) return user