"""Repository: get_or_create_user, get_duties, insert_duty, get_current_duty, group_duty_pins.""" import hashlib import hmac import secrets from datetime import datetime, timedelta, timezone from sqlalchemy.orm import Session from duty_teller.db.models import User, Duty, GroupDutyPin, CalendarSubscriptionToken def get_user_by_telegram_id(session: Session, telegram_user_id: int) -> User | None: """Find user by telegram_user_id. Returns None if not found (no creation).""" return session.query(User).filter(User.telegram_user_id == telegram_user_id).first() def get_or_create_user( session: Session, telegram_user_id: int, full_name: str, username: str | None = None, first_name: str | None = None, last_name: str | None = None, ) -> User: user = get_user_by_telegram_id(session, telegram_user_id) if user: user.full_name = full_name user.username = username user.first_name = first_name user.last_name = last_name session.commit() session.refresh(user) return user user = User( telegram_user_id=telegram_user_id, full_name=full_name, username=username, first_name=first_name, last_name=last_name, ) session.add(user) session.commit() session.refresh(user) return user def get_or_create_user_by_full_name(session: Session, full_name: str) -> User: """Find user by exact full_name or create one with telegram_user_id=None (for duty-schedule import).""" user = session.query(User).filter(User.full_name == full_name).first() if user: return user user = User( telegram_user_id=None, full_name=full_name, username=None, first_name=None, last_name=None, ) session.add(user) session.commit() session.refresh(user) return user def delete_duties_in_range( session: Session, user_id: int, from_date: str, to_date: str, ) -> int: """Delete all duties of the user that overlap [from_date, to_date] (YYYY-MM-DD). Returns count deleted.""" to_next = ( datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) ).strftime("%Y-%m-%d") q = session.query(Duty).filter( Duty.user_id == user_id, Duty.start_at < to_next, Duty.end_at >= from_date, ) count = q.count() q.delete(synchronize_session=False) session.commit() return count def get_duties( session: Session, from_date: str, to_date: str, ) -> list[tuple[Duty, str]]: """Return list of (Duty, full_name) overlapping the given date range.""" to_date_next = ( datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) ).strftime("%Y-%m-%d") q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) .filter(Duty.start_at < to_date_next, Duty.end_at >= from_date) ) return list(q.all()) def get_duties_for_user( session: Session, user_id: int, from_date: str, to_date: str, ) -> list[tuple[Duty, str]]: """Return list of (Duty, full_name) for the given user overlapping the date range.""" to_date_next = ( datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) ).strftime("%Y-%m-%d") q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) .filter( Duty.user_id == user_id, Duty.start_at < to_date_next, Duty.end_at >= from_date, ) ) return list(q.all()) def _token_hash(token: str) -> str: """Return SHA256 hex digest of the token (constant-time comparison via hmac).""" return hashlib.sha256(token.encode()).hexdigest() def create_calendar_token(session: Session, user_id: int) -> str: """ Create a new calendar subscription token for the user. Removes any existing tokens for this user. Returns the raw token string. """ session.query(CalendarSubscriptionToken).filter( CalendarSubscriptionToken.user_id == user_id ).delete(synchronize_session=False) raw_token = secrets.token_urlsafe(32) token_hash_val = _token_hash(raw_token) now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") record = CalendarSubscriptionToken( user_id=user_id, token_hash=token_hash_val, created_at=now_iso, ) session.add(record) session.commit() return raw_token def get_user_by_calendar_token(session: Session, token: str) -> User | None: """ Find user by calendar subscription token. Uses constant-time comparison. Returns None if token is invalid or not found. """ token_hash_val = _token_hash(token) row = ( session.query(CalendarSubscriptionToken, User) .join(User, CalendarSubscriptionToken.user_id == User.id) .filter(CalendarSubscriptionToken.token_hash == token_hash_val) .first() ) if row is None: return None # Constant-time compare to avoid timing leaks (token_hash is already hashed). if not hmac.compare_digest(row[0].token_hash, token_hash_val): return None return row[1] def insert_duty( session: Session, user_id: int, start_at: str, end_at: str, event_type: str = "duty", ) -> Duty: """Create a duty. start_at and end_at must be UTC, ISO 8601 with Z.""" duty = Duty( user_id=user_id, start_at=start_at, end_at=end_at, event_type=event_type, ) session.add(duty) session.commit() session.refresh(duty) return duty def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None: """Return the duty (and user) for which start_at <= at_utc < end_at, event_type='duty'.""" from datetime import timezone if at_utc.tzinfo is not None: at_utc = at_utc.astimezone(timezone.utc) now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" row = ( session.query(Duty, User) .join(User, Duty.user_id == User.id) .filter( Duty.event_type == "duty", Duty.start_at <= now_iso, Duty.end_at > now_iso, ) .first() ) if row is None: return None return (row[0], row[1]) def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None: """Return the end_at of the current duty or of the next duty. Naive UTC.""" from datetime import timezone if after_utc.tzinfo is not None: after_utc = after_utc.astimezone(timezone.utc) after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" current = ( session.query(Duty) .filter( Duty.event_type == "duty", Duty.start_at <= after_iso, Duty.end_at > after_iso, ) .first() ) if current: return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace( tzinfo=None ) next_duty = ( session.query(Duty) .filter(Duty.event_type == "duty", Duty.start_at > after_iso) .order_by(Duty.start_at) .first() ) if next_duty: return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace( tzinfo=None ) return None def get_group_duty_pin(session: Session, chat_id: int) -> GroupDutyPin | None: """Get the pinned message record for a chat, if any.""" return session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() def save_group_duty_pin( session: Session, chat_id: int, message_id: int ) -> GroupDutyPin: """Save or update the pinned message for a chat.""" pin = session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() if pin: pin.message_id = message_id else: pin = GroupDutyPin(chat_id=chat_id, message_id=message_id) session.add(pin) session.commit() session.refresh(pin) return pin def delete_group_duty_pin(session: Session, chat_id: int) -> None: """Remove the pinned message record when the bot leaves the group.""" session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).delete() session.commit() def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]: """Return all chat_ids that have a pinned duty message (for restoring jobs on startup).""" rows = session.query(GroupDutyPin.chat_id).all() return [r[0] for r in rows] def set_user_phone( session: Session, telegram_user_id: int, phone: str | None ) -> User | None: """Set phone for user by telegram_user_id. Returns User or None if not found.""" user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if not user: return None user.phone = phone session.commit() session.refresh(user) return user