"""Repository: get_or_create_user, get_duties, insert_duty, get_current_duty, group_duty_pins.""" from datetime import datetime, timedelta from sqlalchemy.orm import Session from db.models import User, Duty, GroupDutyPin def get_or_create_user( session: Session, telegram_user_id: int, full_name: str, username: str | None = None, first_name: str | None = None, last_name: str | None = None, ) -> User: user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if user: user.full_name = full_name user.username = username user.first_name = first_name user.last_name = last_name session.commit() session.refresh(user) return user user = User( telegram_user_id=telegram_user_id, full_name=full_name, username=username, first_name=first_name, last_name=last_name, ) session.add(user) session.commit() session.refresh(user) return user def get_or_create_user_by_full_name(session: Session, full_name: str) -> User: """Find user by exact full_name or create one with telegram_user_id=None (for duty-schedule import).""" user = session.query(User).filter(User.full_name == full_name).first() if user: return user user = User( telegram_user_id=None, full_name=full_name, username=None, first_name=None, last_name=None, ) session.add(user) session.commit() session.refresh(user) return user def delete_duties_in_range( session: Session, user_id: int, from_date: str, to_date: str, ) -> int: """Delete all duties of the user that overlap [from_date, to_date] (YYYY-MM-DD). Returns count deleted.""" # start_at < to_date + 1 day so duties starting on to_date are included (start_at is ISO with T) to_next = (datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)).strftime("%Y-%m-%d") q = ( session.query(Duty) .filter( Duty.user_id == user_id, Duty.start_at < to_next, Duty.end_at >= from_date, ) ) count = q.count() q.delete(synchronize_session=False) session.commit() return count def get_duties( session: Session, from_date: str, to_date: str, ) -> list[tuple[Duty, str]]: """Return list of (Duty, full_name) overlapping the given date range. from_date/to_date are YYYY-MM-DD (inclusive). Duty.start_at and end_at are stored in UTC (ISO 8601 with Z). Use to_date_next so duties starting on to_date are included (start_at like 2025-01-31T09:00:00Z is > "2025-01-31" lexicographically). """ to_date_next = (datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)).strftime("%Y-%m-%d") q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) .filter(Duty.start_at < to_date_next, Duty.end_at >= from_date) ) return list(q.all()) def insert_duty( session: Session, user_id: int, start_at: str, end_at: str, event_type: str = "duty", ) -> Duty: """Create a duty. start_at and end_at must be UTC, ISO 8601 with Z. event_type: 'duty' | 'unavailable' | 'vacation'.""" duty = Duty( user_id=user_id, start_at=start_at, end_at=end_at, event_type=event_type, ) session.add(duty) session.commit() session.refresh(duty) return duty def get_current_duty( session: Session, at_utc: datetime ) -> tuple[Duty, User] | None: """Return the duty (and user) for which start_at <= at_utc < end_at, event_type='duty'. at_utc is in UTC (naive or aware); comparison uses ISO strings.""" from datetime import timezone if at_utc.tzinfo is not None: at_utc = at_utc.astimezone(timezone.utc) now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" row = ( session.query(Duty, User) .join(User, Duty.user_id == User.id) .filter( Duty.event_type == "duty", Duty.start_at <= now_iso, Duty.end_at > now_iso, ) .first() ) if row is None: return None return (row[0], row[1]) def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None: """Return the end_at of the current duty (if after_utc is inside one) or of the next duty. For scheduling the next pin update. Returns naive UTC datetime.""" from datetime import timezone if after_utc.tzinfo is not None: after_utc = after_utc.astimezone(timezone.utc) after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" # Current duty: start_at <= after_iso < end_at → use this end_at current = ( session.query(Duty) .filter( Duty.event_type == "duty", Duty.start_at <= after_iso, Duty.end_at > after_iso, ) .first() ) if current: return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace(tzinfo=None) # Next future duty: start_at > after_iso, order by start_at next_duty = ( session.query(Duty) .filter(Duty.event_type == "duty", Duty.start_at > after_iso) .order_by(Duty.start_at) .first() ) if next_duty: return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace(tzinfo=None) return None def get_group_duty_pin(session: Session, chat_id: int) -> GroupDutyPin | None: """Get the pinned message record for a chat, if any.""" return session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() def save_group_duty_pin( session: Session, chat_id: int, message_id: int ) -> GroupDutyPin: """Save or update the pinned message for a chat.""" pin = session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first() if pin: pin.message_id = message_id else: pin = GroupDutyPin(chat_id=chat_id, message_id=message_id) session.add(pin) session.commit() session.refresh(pin) return pin def delete_group_duty_pin(session: Session, chat_id: int) -> None: """Remove the pinned message record when the bot leaves the group.""" session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).delete() session.commit() def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]: """Return all chat_ids that have a pinned duty message (for restoring jobs on startup).""" rows = session.query(GroupDutyPin.chat_id).all() return [r[0] for r in rows] def set_user_phone(session: Session, telegram_user_id: int, phone: str | None) -> User | None: """Set phone for user by telegram_user_id. Returns User or None if not found.""" user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if not user: return None user.phone = phone session.commit() session.refresh(user) return user