diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 0000000..61f2508 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,144 @@ +{ + "general": { + "coding_style": { + "language": "Python", + "use_strict": true, + "indentation": "4 spaces", + "max_line_length": 120, + "comments": { + "style": "# for single-line, ''' for multi-line", + "require_comments": true + } + }, + + "naming_conventions": { + "variables": "snake_case", + "functions": "snake_case", + "classes": "PascalCase", + "interfaces": "PascalCase", + "files": "snake_case" + }, + + "error_handling": { + "prefer_try_catch": true, + "log_errors": true + }, + + "testing": { + "require_tests": true, + "test_coverage": "80%", + "test_types": ["unit", "integration"] + }, + + "documentation": { + "require_docs": true, + "doc_tool": "docstrings", + "style_guide": "Google Python Style Guide" + }, + + "security": { + "require_https": true, + "sanitize_inputs": true, + "validate_inputs": true, + "use_env_vars": true + }, + + "configuration_management": { + "config_files": [".env"], + "env_management": "python-dotenv", + "secrets_management": "environment variables" + }, + + "code_review": { + "require_reviews": true, + "review_tool": "Gitea Pull Requests", + "review_criteria": ["functionality", "code quality", "security"] + }, + + "version_control": { + "system": "Git", + "branching_strategy": "Gitea Flow", + "commit_message_format": "Conventional Commits" + }, + + "logging": { + "logging_tool": "Python logging module", + "log_levels": ["debug", "info", "warn", "error"], + "log_retention_policy": "7 days" + }, + + "monitoring": { + "monitoring_tool": "Not specified", + "metrics": ["file processing time", "classification accuracy", "error rate"] + }, + + "dependency_management": { + "package_manager": "pip", + "versioning_strategy": "Semantic Versioning" + }, + + "accessibility": { + "standards": ["Not applicable"], + "testing_tools": ["Not applicable"] + }, + + "internationalization": { + "i18n_tool": "Not applicable", + "supported_languages": ["English, Russian"], + "default_language": "Russian" + }, + + "code_formatting": { + "formatter": "Black", + "linting_tool": "Pylint", + "rules": ["PEP 8", "project-specific rules"] + }, + + "architecture": { + "patterns": ["Modular design"], + "principles": ["Single Responsibility", "DRY"] + } + }, + + "project_specific": { + "use_framework": "None", + "styling": "Not applicable", + "testing_framework": "pytest", + "build_tool": "setuptools", + + "deployment": { + "environment": "Local machine", + "automation": "Not specified", + "strategy": "Manual deployment" + }, + + "performance": { + "benchmarking_tool": "Not specified", + "performance_goals": { + "response_time": "< 5 seconds per file", + "throughput": "Not specified", + "error_rate": "< 1%" + } + } + }, + + "context": { + "codebase_overview": "Python-based telegram bot, for team duty shift calendar, and group reminder", + + "coding_practices": { + "modularity": true, + "DRY_principle": true, + "performance_optimization": true + } + }, + + "behavior": { + "verbosity": { + "level": 2, + "range": [0, 3] + }, + "handle_incomplete_tasks": "Provide partial solution and explain limitations", + "ask_for_clarification": true, + "communication_tone": "Professional and concise" + } +} diff --git a/README.md b/README.md index 0ba0f6d..204b9aa 100644 --- a/README.md +++ b/README.md @@ -76,13 +76,15 @@ Ensure `.env` exists (e.g. `cp .env.example .env`) and contains `BOT_TOKEN`. ## Project layout - `main.py` – Builds the `Application`, registers handlers, runs polling and FastAPI in a thread. -- `config.py` – Loads `BOT_TOKEN`, `DATABASE_URL`, `ALLOWED_USERNAMES`, `ADMIN_USERNAMES`, `CORS_ORIGINS`, etc. from env; exits if `BOT_TOKEN` is missing. -- `api/` – FastAPI app (`/api/duties`), Telegram initData validation, static webapp mount. -- `db/` – SQLAlchemy models, session, repository, schemas. +- `config.py` – Loads `BOT_TOKEN`, `DATABASE_URL`, `ALLOWED_USERNAMES`, `ADMIN_USERNAMES`, `CORS_ORIGINS`, etc. from env; exits if `BOT_TOKEN` is missing. Optional `Settings` dataclass for tests. +- `api/` – FastAPI app (`/api/duties`, `/api/calendar-events`), auth/session Depends, static webapp mount. +- `db/` – SQLAlchemy models, session (use `session_scope` for all DB access), repository, schemas. One `DATABASE_URL` per process; set env before first import if you need a different URL in tests. +- `handlers/` – Telegram command and chat handlers; thin layer that call services and utils. +- `services/` – Business logic (group duty pin, import); accept session from caller. +- `utils/` – Shared date, user, and handover helpers. - `alembic/` – Migrations (use `config.DATABASE_URL`). -- `handlers/` – Command and error handlers; add new handlers here. - `webapp/` – Miniapp UI (calendar, duty list); served at `/app`. -- `requirements.txt` – Pinned dependencies (PTB, FastAPI, SQLAlchemy, Alembic, etc.). +- `pyproject.toml` – Installable package (`pip install -e .`); `requirements.txt` – pinned deps. To add commands, define async handlers in `handlers/commands.py` (or a new module) and register them in `handlers/__init__.py`. diff --git a/api/app.py b/api/app.py index ba1cbd3..b6cbe6d 100644 --- a/api/app.py +++ b/api/app.py @@ -1,59 +1,78 @@ """FastAPI app: /api/duties and static webapp.""" import logging -import re from pathlib import Path +from typing import Annotated, Generator import config -from fastapi import FastAPI, Header, HTTPException, Query, Request +from fastapi import Depends, FastAPI, Header, HTTPException, Query, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles +from sqlalchemy.orm import Session from db.session import session_scope from db.repository import get_duties from db.schemas import DutyWithUser, CalendarEvent from api.telegram_auth import validate_init_data_with_reason from api.calendar_ics import get_calendar_events +from utils.dates import validate_date_range log = logging.getLogger(__name__) -# ISO date YYYY-MM-DD -_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") - def _validate_duty_dates(from_date: str, to_date: str) -> None: """Raise HTTPException 400 if dates are invalid or from_date > to_date.""" - if not _DATE_RE.match(from_date) or not _DATE_RE.match(to_date): - raise HTTPException( - status_code=400, - detail="Параметры from и to должны быть в формате YYYY-MM-DD", - ) - if from_date > to_date: - raise HTTPException( - status_code=400, - detail="Дата from не должна быть позже to", - ) + try: + validate_date_range(from_date, to_date) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e -def _fetch_duties_response(from_date: str, to_date: str) -> list[DutyWithUser]: - """Fetch duties in range and return list of DutyWithUser. Uses config.DATABASE_URL.""" +def get_validated_dates( + from_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="from"), + to_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="to"), +) -> tuple[str, str]: + """FastAPI dependency: validate from_date/to_date and return (from_date, to_date). Raises 400 if invalid.""" + _validate_duty_dates(from_date, to_date) + return (from_date, to_date) + + +def get_db_session() -> Generator[Session, None, None]: + """FastAPI dependency: yield a DB session from session_scope.""" with session_scope(config.DATABASE_URL) as session: - rows = get_duties(session, from_date=from_date, to_date=to_date) - return [ - DutyWithUser( - id=duty.id, - user_id=duty.user_id, - start_at=duty.start_at, - end_at=duty.end_at, - full_name=full_name, - event_type=( - duty.event_type - if duty.event_type in ("duty", "unavailable", "vacation") - else "duty" - ), - ) - for duty, full_name in rows - ] + yield session + + +def require_miniapp_username( + request: Request, + x_telegram_init_data: Annotated[ + str | None, Header(alias="X-Telegram-Init-Data") + ] = None, +) -> str: + """FastAPI dependency: return authenticated username or raise 403.""" + return get_authenticated_username(request, x_telegram_init_data) + + +def _fetch_duties_response( + session: Session, from_date: str, to_date: str +) -> list[DutyWithUser]: + """Fetch duties in range and return list of DutyWithUser.""" + rows = get_duties(session, from_date=from_date, to_date=to_date) + return [ + DutyWithUser( + id=duty.id, + user_id=duty.user_id, + start_at=duty.start_at, + end_at=duty.end_at, + full_name=full_name, + event_type=( + duty.event_type + if duty.event_type in ("duty", "unavailable", "vacation") + else "duty" + ), + ) + for duty, full_name in rows + ] def _auth_error_detail(auth_reason: str) -> str: @@ -131,33 +150,30 @@ app.add_middleware( @app.get("/api/duties", response_model=list[DutyWithUser]) def list_duties( request: Request, - from_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="from"), - to_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="to"), x_telegram_init_data: str | None = Header(None, alias="X-Telegram-Init-Data"), + dates: tuple[str, str] = Depends(get_validated_dates), + _username: str = Depends(require_miniapp_username), + session: Session = Depends(get_db_session), ) -> list[DutyWithUser]: - _validate_duty_dates(from_date, to_date) + from_date_val, to_date_val = dates log.info( "GET /api/duties from %s, has initData: %s", request.client.host if request.client else "?", bool((x_telegram_init_data or "").strip()), ) - get_authenticated_username(request, x_telegram_init_data) - return _fetch_duties_response(from_date, to_date) + return _fetch_duties_response(session, from_date_val, to_date_val) @app.get("/api/calendar-events", response_model=list[CalendarEvent]) def list_calendar_events( - request: Request, - from_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="from"), - to_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="to"), - x_telegram_init_data: str | None = Header(None, alias="X-Telegram-Init-Data"), + dates: tuple[str, str] = Depends(get_validated_dates), + _username: str = Depends(require_miniapp_username), ) -> list[CalendarEvent]: - _validate_duty_dates(from_date, to_date) - get_authenticated_username(request, x_telegram_init_data) + from_date_val, to_date_val = dates url = config.EXTERNAL_CALENDAR_ICS_URL if not url: return [] - events = get_calendar_events(url, from_date=from_date, to_date=to_date) + events = get_calendar_events(url, from_date=from_date_val, to_date=to_date_val) return [CalendarEvent(date=e["date"], summary=e["summary"]) for e in events] diff --git a/config.py b/config.py index 31bc23e..64b3788 100644 --- a/config.py +++ b/config.py @@ -1,11 +1,66 @@ """Load configuration from environment. Fail fast if BOT_TOKEN is missing.""" import os +from dataclasses import dataclass from dotenv import load_dotenv load_dotenv() + +@dataclass(frozen=True) +class Settings: + """Optional injectable settings built from env. Tests can override or build from env.""" + + bot_token: str + database_url: str + mini_app_base_url: str + http_port: int + allowed_usernames: set[str] + admin_usernames: set[str] + mini_app_skip_auth: bool + init_data_max_age_seconds: int + cors_origins: list[str] + external_calendar_ics_url: str + duty_display_tz: str + + @classmethod + def from_env(cls) -> "Settings": + """Build Settings from current environment (same logic as module-level vars).""" + bot_token = os.getenv("BOT_TOKEN") or "" + raw_allowed = os.getenv("ALLOWED_USERNAMES", "").strip() + allowed = { + s.strip().lstrip("@").lower() for s in raw_allowed.split(",") if s.strip() + } + raw_admin = os.getenv("ADMIN_USERNAMES", "").strip() + admin = { + s.strip().lstrip("@").lower() for s in raw_admin.split(",") if s.strip() + } + raw_cors = os.getenv("CORS_ORIGINS", "").strip() + cors = ( + [_o.strip() for _o in raw_cors.split(",") if _o.strip()] + if raw_cors and raw_cors != "*" + else ["*"] + ) + return cls( + bot_token=bot_token, + database_url=os.getenv("DATABASE_URL", "sqlite:///data/duty_teller.db"), + mini_app_base_url=os.getenv("MINI_APP_BASE_URL", "").rstrip("/"), + http_port=int(os.getenv("HTTP_PORT", "8080")), + allowed_usernames=allowed, + admin_usernames=admin, + mini_app_skip_auth=os.getenv("MINI_APP_SKIP_AUTH", "").strip() + in ("1", "true", "yes"), + init_data_max_age_seconds=int(os.getenv("INIT_DATA_MAX_AGE_SECONDS", "0")), + cors_origins=cors, + external_calendar_ics_url=os.getenv( + "EXTERNAL_CALENDAR_ICS_URL", "" + ).strip(), + duty_display_tz=os.getenv("DUTY_DISPLAY_TZ", "Europe/Moscow").strip() + or "Europe/Moscow", + ) + + BOT_TOKEN = os.getenv("BOT_TOKEN") if not BOT_TOKEN: raise SystemExit( @@ -45,7 +100,9 @@ CORS_ORIGINS = ( EXTERNAL_CALENDAR_ICS_URL = os.getenv("EXTERNAL_CALENDAR_ICS_URL", "").strip() # Timezone for displaying duty times in the pinned group message (e.g. Europe/Moscow). -DUTY_DISPLAY_TZ = os.getenv("DUTY_DISPLAY_TZ", "Europe/Moscow").strip() or "Europe/Moscow" +DUTY_DISPLAY_TZ = ( + os.getenv("DUTY_DISPLAY_TZ", "Europe/Moscow").strip() or "Europe/Moscow" +) def is_admin(username: str) -> bool: diff --git a/db/__init__.py b/db/__init__.py index 459863c..c1f5d6f 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -2,13 +2,14 @@ from db.models import Base, User, Duty from db.schemas import UserCreate, UserInDb, DutyCreate, DutyInDb, DutyWithUser -from db.session import get_engine, get_session_factory, get_session +from db.session import get_engine, get_session_factory, get_session, session_scope from db.repository import ( delete_duties_in_range, get_or_create_user, get_or_create_user_by_full_name, get_duties, insert_duty, + set_user_phone, ) __all__ = [ @@ -23,11 +24,13 @@ __all__ = [ "get_engine", "get_session_factory", "get_session", + "session_scope", "delete_duties_in_range", "get_or_create_user", "get_or_create_user_by_full_name", "get_duties", "insert_duty", + "set_user_phone", "init_db", ] diff --git a/db/repository.py b/db/repository.py index 5de4e06..fc1597b 100644 --- a/db/repository.py +++ b/db/repository.py @@ -63,14 +63,13 @@ def delete_duties_in_range( ) -> int: """Delete all duties of the user that overlap [from_date, to_date] (YYYY-MM-DD). Returns count deleted.""" # start_at < to_date + 1 day so duties starting on to_date are included (start_at is ISO with T) - to_next = (datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)).strftime("%Y-%m-%d") - q = ( - session.query(Duty) - .filter( - Duty.user_id == user_id, - Duty.start_at < to_next, - Duty.end_at >= from_date, - ) + to_next = ( + datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) + ).strftime("%Y-%m-%d") + q = session.query(Duty).filter( + Duty.user_id == user_id, + Duty.start_at < to_next, + Duty.end_at >= from_date, ) count = q.count() q.delete(synchronize_session=False) @@ -89,7 +88,9 @@ def get_duties( in UTC (ISO 8601 with Z). Use to_date_next so duties starting on to_date are included (start_at like 2025-01-31T09:00:00Z is > "2025-01-31" lexicographically). """ - to_date_next = (datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)).strftime("%Y-%m-%d") + to_date_next = ( + datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1) + ).strftime("%Y-%m-%d") q = ( session.query(Duty, User.full_name) .join(User, Duty.user_id == User.id) @@ -119,12 +120,11 @@ def insert_duty( return duty -def get_current_duty( - session: Session, at_utc: datetime -) -> tuple[Duty, User] | None: +def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None: """Return the duty (and user) for which start_at <= at_utc < end_at, event_type='duty'. at_utc is in UTC (naive or aware); comparison uses ISO strings.""" from datetime import timezone + if at_utc.tzinfo is not None: at_utc = at_utc.astimezone(timezone.utc) now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" @@ -147,6 +147,7 @@ def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None """Return the end_at of the current duty (if after_utc is inside one) or of the next duty. For scheduling the next pin update. Returns naive UTC datetime.""" from datetime import timezone + if after_utc.tzinfo is not None: after_utc = after_utc.astimezone(timezone.utc) after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z" @@ -161,7 +162,9 @@ def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None .first() ) if current: - return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace(tzinfo=None) + return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace( + tzinfo=None + ) # Next future duty: start_at > after_iso, order by start_at next_duty = ( session.query(Duty) @@ -170,7 +173,9 @@ def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None .first() ) if next_duty: - return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace(tzinfo=None) + return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace( + tzinfo=None + ) return None @@ -206,7 +211,9 @@ def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]: return [r[0] for r in rows] -def set_user_phone(session: Session, telegram_user_id: int, phone: str | None) -> User | None: +def set_user_phone( + session: Session, telegram_user_id: int, phone: str | None +) -> User | None: """Set phone for user by telegram_user_id. Returns User or None if not found.""" user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first() if not user: diff --git a/db/session.py b/db/session.py index 5954a32..884edaf 100644 --- a/db/session.py +++ b/db/session.py @@ -1,9 +1,11 @@ """SQLAlchemy engine and session factory. -Note: Engine and session factory are cached globally per process. Only one -DATABASE_URL is effectively used for the process lifetime. Using a different -URL later (e.g. in tests with in-memory SQLite) would still use the first -engine. To support multiple URLs, cache by database_url (e.g. a dict keyed by URL). +Engine and session factory are cached globally per process. Only one DATABASE_URL +is effectively used for the process lifetime. Using a different URL later (e.g. in +tests with in-memory SQLite) would still use the first engine. To use a different +URL in tests, set env (e.g. DATABASE_URL) before the first import of this module, or +clear _engine and _SessionLocal in test fixtures. Prefer session_scope() for all +callers so sessions are always closed and rolled back on error. """ from contextlib import contextmanager diff --git a/handlers/commands.py b/handlers/commands.py index bdafe2e..2553ce4 100644 --- a/handlers/commands.py +++ b/handlers/commands.py @@ -6,8 +6,9 @@ import config from telegram import Update from telegram.ext import CommandHandler, ContextTypes -from db.session import get_session +from db.session import session_scope from db.repository import get_or_create_user, set_user_phone +from utils.user import build_full_name async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -16,18 +17,14 @@ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: user = update.effective_user if not user: return - full_name = ( - " ".join(filter(None, [user.first_name or "", user.last_name or ""])).strip() - or "User" - ) + full_name = build_full_name(user.first_name, user.last_name) telegram_user_id = user.id username = user.username first_name = user.first_name last_name = user.last_name def do_get_or_create() -> None: - session = get_session(config.DATABASE_URL) - try: + with session_scope(config.DATABASE_URL) as session: get_or_create_user( session, telegram_user_id=telegram_user_id, @@ -36,8 +33,6 @@ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: first_name=first_name, last_name=last_name, ) - finally: - session.close() await asyncio.get_running_loop().run_in_executor(None, do_get_or_create) @@ -53,24 +48,14 @@ async def set_phone(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: await update.message.reply_text("Команда /set_phone доступна только в личке.") return # Optional: restrict to allowed usernames; plan says "or without restrictions" - args = (context.args or []) + args = context.args or [] phone = " ".join(args).strip() if args else None telegram_user_id = update.effective_user.id def do_set_phone() -> str: - session = get_session(config.DATABASE_URL) - try: - full_name = ( - " ".join( - filter( - None, - [ - update.effective_user.first_name or "", - update.effective_user.last_name or "", - ], - ) - ).strip() - or "User" + with session_scope(config.DATABASE_URL) as session: + full_name = build_full_name( + update.effective_user.first_name, update.effective_user.last_name ) get_or_create_user( session, @@ -86,8 +71,6 @@ async def set_phone(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: if phone: return f"Телефон сохранён: {phone}" return "Телефон очищен." - finally: - session.close() result = await asyncio.get_running_loop().run_in_executor(None, do_set_phone) await update.message.reply_text(result) diff --git a/handlers/group_duty_pin.py b/handlers/group_duty_pin.py index c8b1951..b411114 100644 --- a/handlers/group_duty_pin.py +++ b/handlers/group_duty_pin.py @@ -3,7 +3,6 @@ import asyncio import logging from datetime import datetime, timezone -from zoneinfo import ZoneInfo import config from telegram import Update @@ -11,14 +10,14 @@ from telegram.constants import ChatMemberStatus from telegram.error import BadRequest, Forbidden from telegram.ext import ChatMemberHandler, CommandHandler, ContextTypes -from db.session import get_session -from db.repository import ( - get_current_duty, - get_next_shift_end, - get_group_duty_pin, - save_group_duty_pin, - delete_group_duty_pin, - get_all_group_duty_pin_chat_ids, +from db.session import session_scope +from services.group_duty_pin_service import ( + get_duty_message_text, + get_next_shift_end_utc, + save_pin, + delete_pin, + get_message_id, + get_all_pin_chat_ids, ) logger = logging.getLogger(__name__) @@ -27,83 +26,31 @@ JOB_NAME_PREFIX = "duty_pin_" RETRY_WHEN_NO_DUTY_MINUTES = 15 -def _format_duty_message(duty, user, tz_name: str) -> str: - """Build the text for the pinned message. duty, user may be None.""" - if duty is None or user is None: - return "Сейчас дежурства нет." - try: - tz = ZoneInfo(tz_name) - except Exception: - tz = ZoneInfo("Europe/Moscow") - tz_name = "Europe/Moscow" - start_dt = datetime.fromisoformat(duty.start_at.replace("Z", "+00:00")) - end_dt = datetime.fromisoformat(duty.end_at.replace("Z", "+00:00")) - start_local = start_dt.astimezone(tz) - end_local = end_dt.astimezone(tz) - # Показать смещение (UTC+3) чтобы было понятно, в каком поясе время - offset_sec = start_local.utcoffset().total_seconds() if start_local.utcoffset() else 0 - sign = "+" if offset_sec >= 0 else "-" - h, r = divmod(abs(int(offset_sec)), 3600) - m = r // 60 - tz_hint = f"UTC{sign}{h:d}:{m:02d}, {tz_name}" - time_range = f"{start_local.strftime('%d.%m.%Y %H:%M')} — {end_local.strftime('%d.%m.%Y %H:%M')} ({tz_hint})" - lines = [ - f"🕐 Дежурство: {time_range}", - f"👤 {user.full_name}", - ] - if user.phone: - lines.append(f"📞 {user.phone}") - if user.username: - lines.append(f"@{user.username}") - return "\n".join(lines) +def _get_duty_message_text_sync() -> str: + """Get current duty message (sync, for run_in_executor).""" + with session_scope(config.DATABASE_URL) as session: + return get_duty_message_text(session, config.DUTY_DISPLAY_TZ) -def _get_duty_message_text() -> str: - """Get current duty from DB and return formatted message (sync, for run_in_executor).""" - session = get_session(config.DATABASE_URL) - try: - now = datetime.now(timezone.utc) - result = get_current_duty(session, now) - if result is None: - return "Сейчас дежурства нет." - duty, user = result - return _format_duty_message(duty, user, config.DUTY_DISPLAY_TZ) - finally: - session.close() - - -def _get_next_shift_end_utc(): - """Return next shift end as naive UTC datetime for job scheduling (sync).""" - session = get_session(config.DATABASE_URL) - try: - return get_next_shift_end(session, datetime.now(timezone.utc)) - finally: - session.close() +def _get_next_shift_end_sync(): + """Return next shift end as naive UTC (sync, for run_in_executor).""" + with session_scope(config.DATABASE_URL) as session: + return get_next_shift_end_utc(session) def _sync_save_pin(chat_id: int, message_id: int) -> None: - session = get_session(config.DATABASE_URL) - try: - save_group_duty_pin(session, chat_id, message_id) - finally: - session.close() + with session_scope(config.DATABASE_URL) as session: + save_pin(session, chat_id, message_id) def _sync_delete_pin(chat_id: int) -> None: - session = get_session(config.DATABASE_URL) - try: - delete_group_duty_pin(session, chat_id) - finally: - session.close() + with session_scope(config.DATABASE_URL) as session: + delete_pin(session, chat_id) def _sync_get_message_id(chat_id: int) -> int | None: - session = get_session(config.DATABASE_URL) - try: - pin = get_group_duty_pin(session, chat_id) - return pin.message_id if pin else None - finally: - session.close() + with session_scope(config.DATABASE_URL) as session: + return get_message_id(session, chat_id) async def _schedule_next_update( @@ -131,6 +78,7 @@ async def _schedule_next_update( logger.info("Scheduled pin update for chat_id=%s at %s", chat_id, when_utc) else: from datetime import timedelta + job_queue.run_once( update_group_pin, when=timedelta(minutes=RETRY_WHEN_NO_DUTY_MINUTES), @@ -154,7 +102,7 @@ async def update_group_pin(context: ContextTypes.DEFAULT_TYPE) -> None: if message_id is None: logger.info("No pin record for chat_id=%s, skipping update", chat_id) return - text = await loop.run_in_executor(None, _get_duty_message_text) + text = await loop.run_in_executor(None, _get_duty_message_text_sync) try: await context.bot.edit_message_text( chat_id=chat_id, @@ -163,11 +111,13 @@ async def update_group_pin(context: ContextTypes.DEFAULT_TYPE) -> None: ) except (BadRequest, Forbidden) as e: logger.warning("Failed to edit pinned message chat_id=%s: %s", chat_id, e) - next_end = await loop.run_in_executor(None, _get_next_shift_end_utc) + next_end = await loop.run_in_executor(None, _get_next_shift_end_sync) await _schedule_next_update(context.application, chat_id, next_end) -async def my_chat_member_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: +async def my_chat_member_handler( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: """On bot added to group: send, pin, save, schedule. On removed: delete pin, cancel job.""" if not update.my_chat_member or not update.effective_user: return @@ -181,12 +131,15 @@ async def my_chat_member_handler(update: Update, context: ContextTypes.DEFAULT_T chat_id = chat.id # Bot added to group - if new.status in (ChatMemberStatus.MEMBER, ChatMemberStatus.ADMINISTRATOR) and old.status in ( + if new.status in ( + ChatMemberStatus.MEMBER, + ChatMemberStatus.ADMINISTRATOR, + ) and old.status in ( ChatMemberStatus.LEFT, ChatMemberStatus.BANNED, ): loop = asyncio.get_running_loop() - text = await loop.run_in_executor(None, _get_duty_message_text) + text = await loop.run_in_executor(None, _get_duty_message_text_sync) try: msg = await context.bot.send_message(chat_id=chat_id, text=text) except (BadRequest, Forbidden) as e: @@ -214,13 +167,15 @@ async def my_chat_member_handler(update: Update, context: ContextTypes.DEFAULT_T ) except (BadRequest, Forbidden): pass - next_end = await loop.run_in_executor(None, _get_next_shift_end_utc) + next_end = await loop.run_in_executor(None, _get_next_shift_end_sync) await _schedule_next_update(context.application, chat_id, next_end) return # Bot removed from group if new.status in (ChatMemberStatus.LEFT, ChatMemberStatus.BANNED): - await asyncio.get_running_loop().run_in_executor(None, _sync_delete_pin, chat_id) + await asyncio.get_running_loop().run_in_executor( + None, _sync_delete_pin, chat_id + ) name = f"{JOB_NAME_PREFIX}{chat_id}" if context.application.job_queue: for job in context.application.job_queue.get_jobs_by_name(name): @@ -229,11 +184,8 @@ async def my_chat_member_handler(update: Update, context: ContextTypes.DEFAULT_T def _get_all_pin_chat_ids_sync() -> list[int]: - session = get_session(config.DATABASE_URL) - try: - return get_all_group_duty_pin_chat_ids(session) - finally: - session.close() + with session_scope(config.DATABASE_URL) as session: + return get_all_pin_chat_ids(session) async def restore_group_pin_jobs(application) -> None: @@ -241,7 +193,7 @@ async def restore_group_pin_jobs(application) -> None: loop = asyncio.get_running_loop() chat_ids = await loop.run_in_executor(None, _get_all_pin_chat_ids_sync) for chat_id in chat_ids: - next_end = await loop.run_in_executor(None, _get_next_shift_end_utc) + next_end = await loop.run_in_executor(None, _get_next_shift_end_sync) await _schedule_next_update(application, chat_id, next_end) logger.info("Restored %s group pin jobs", len(chat_ids)) @@ -258,7 +210,9 @@ async def pin_duty_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No loop = asyncio.get_running_loop() message_id = await loop.run_in_executor(None, _sync_get_message_id, chat_id) if message_id is None: - await update.message.reply_text("В этом чате ещё нет сообщения о дежурстве. Добавьте бота в группу — оно создастся автоматически.") + await update.message.reply_text( + "В этом чате ещё нет сообщения о дежурстве. Добавьте бота в группу — оно создастся автоматически." + ) return try: await context.bot.pin_chat_message( diff --git a/handlers/import_duty_schedule.py b/handlers/import_duty_schedule.py index a0e52b3..906af97 100644 --- a/handlers/import_duty_schedule.py +++ b/handlers/import_duty_schedule.py @@ -1,93 +1,20 @@ """Import duty-schedule: /import_duty_schedule (admin only). Two steps: handover time -> JSON file.""" import asyncio -import re -from datetime import date, datetime, timedelta, timezone import config from telegram import Update from telegram.ext import CommandHandler, ContextTypes, MessageHandler, filters -from db.session import get_session -from db.repository import ( - get_or_create_user_by_full_name, - delete_duties_in_range, - insert_duty, -) -from importers.duty_schedule import ( - DutyScheduleParseError, - DutyScheduleResult, - parse_duty_schedule, -) - -# HH:MM or HH:MM:SS, optional space + timezone (IANA or "UTC") -HANDOVER_TIME_RE = re.compile( - r"^\s*(\d{1,2}):(\d{2})(?::(\d{2}))?\s*(?:\s+(\S+))?\s*$", re.IGNORECASE -) +from db.session import session_scope +from importers.duty_schedule import DutyScheduleParseError, parse_duty_schedule +from services.import_service import run_import +from utils.handover import parse_handover_time -def _parse_handover_time(text: str) -> tuple[int, int] | None: - """Parse handover time string to (hour_utc, minute_utc). Returns None on failure.""" - m = HANDOVER_TIME_RE.match(text) - if not m: - return None - hour = int(m.group(1)) - minute = int(m.group(2)) - # second = m.group(3) ignored - tz_str = (m.group(4) or "").strip() - if not tz_str or tz_str.upper() == "UTC": - return (hour % 24, minute) - try: - from zoneinfo import ZoneInfo - except ImportError: - try: - from backports.zoneinfo import ZoneInfo # type: ignore - except ImportError: - return None - try: - tz = ZoneInfo(tz_str) - except Exception: - return None - # Build datetime in that tz and convert to UTC - dt = datetime(2000, 1, 1, hour, minute, 0, tzinfo=tz) - utc = dt.astimezone(timezone.utc) - return (utc.hour, utc.minute) - - -def _duty_to_iso(d: date, hour_utc: int, minute_utc: int) -> str: - """ISO 8601 with Z for start of duty on date d at given UTC time.""" - dt = datetime(d.year, d.month, d.day, hour_utc, minute_utc, 0, tzinfo=timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") - - -def _day_start_iso(d: date) -> str: - """ISO 8601 start of calendar day UTC: YYYY-MM-DDT00:00:00Z.""" - return d.isoformat() + "T00:00:00Z" - - -def _day_end_iso(d: date) -> str: - """ISO 8601 end of calendar day UTC: YYYY-MM-DDT23:59:59Z.""" - return d.isoformat() + "T23:59:59Z" - - -def _consecutive_date_ranges(dates: list[date]) -> list[tuple[date, date]]: - """Sort dates and merge consecutive ones into (first, last) ranges. Empty list -> [].""" - if not dates: - return [] - sorted_dates = sorted(set(dates)) - ranges: list[tuple[date, date]] = [] - start_d = end_d = sorted_dates[0] - for d in sorted_dates[1:]: - if (d - end_d).days == 1: - end_d = d - else: - ranges.append((start_d, end_d)) - start_d = end_d = d - ranges.append((start_d, end_d)) - return ranges - - -async def import_duty_schedule_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: +async def import_duty_schedule_cmd( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: if not update.message or not update.effective_user: return if not config.is_admin(update.effective_user.username or ""): @@ -100,7 +27,9 @@ async def import_duty_schedule_cmd(update: Update, context: ContextTypes.DEFAULT ) -async def handle_handover_time_text(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: +async def handle_handover_time_text( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: if not update.message or not update.effective_user or not update.message.text: return if not context.user_data.get("awaiting_handover_time"): @@ -108,7 +37,7 @@ async def handle_handover_time_text(update: Update, context: ContextTypes.DEFAUL if not config.is_admin(update.effective_user.username or ""): return text = update.message.text.strip() - parsed = _parse_handover_time(text) + parsed = parse_handover_time(text) if parsed is None: await update.message.reply_text( "Не удалось разобрать время. Укажите, например: 09:00 Europe/Moscow" @@ -118,56 +47,12 @@ async def handle_handover_time_text(update: Update, context: ContextTypes.DEFAUL context.user_data["handover_utc_time"] = (hour_utc, minute_utc) context.user_data["awaiting_handover_time"] = False context.user_data["awaiting_duty_schedule_file"] = True - await update.message.reply_text( - "Отправьте файл в формате duty-schedule (JSON)." - ) + await update.message.reply_text("Отправьте файл в формате duty-schedule (JSON).") -def _run_import( - database_url: str, - result: DutyScheduleResult, - hour_utc: int, - minute_utc: int, -) -> tuple[int, int, int, int]: - """Returns (num_users, num_duty, num_unavailable, num_vacation).""" - session = get_session(database_url) - try: - from_date_str = result.start_date.isoformat() - to_date_str = result.end_date.isoformat() - num_duty = num_unavailable = num_vacation = 0 - for entry in result.entries: - user = get_or_create_user_by_full_name(session, entry.full_name) - delete_duties_in_range(session, user.id, from_date_str, to_date_str) - for d in entry.duty_dates: - start_at = _duty_to_iso(d, hour_utc, minute_utc) - d_next = d + timedelta(days=1) - end_at = _duty_to_iso(d_next, hour_utc, minute_utc) - insert_duty(session, user.id, start_at, end_at, event_type="duty") - num_duty += 1 - for d in entry.unavailable_dates: - insert_duty( - session, - user.id, - _day_start_iso(d), - _day_end_iso(d), - event_type="unavailable", - ) - num_unavailable += 1 - for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates): - insert_duty( - session, - user.id, - _day_start_iso(start_d), - _day_end_iso(end_d), - event_type="vacation", - ) - num_vacation += 1 - return (len(result.entries), num_duty, num_unavailable, num_vacation) - finally: - session.close() - - -async def handle_duty_schedule_document(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: +async def handle_duty_schedule_document( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: if not update.message or not update.message.document or not update.effective_user: return if not context.user_data.get("awaiting_duty_schedule_file"): @@ -193,11 +78,14 @@ async def handle_duty_schedule_document(update: Update, context: ContextTypes.DE await update.message.reply_text(f"Ошибка разбора файла: {e}") return + def run_import_with_scope(): + with session_scope(config.DATABASE_URL) as session: + return run_import(session, result, hour_utc, minute_utc) + loop = asyncio.get_running_loop() try: num_users, num_duty, num_unavailable, num_vacation = await loop.run_in_executor( - None, - lambda: _run_import(config.DATABASE_URL, result, hour_utc, minute_utc), + None, run_import_with_scope ) except Exception as e: await update.message.reply_text(f"Ошибка импорта: {e}") @@ -216,7 +104,9 @@ async def handle_duty_schedule_document(update: Update, context: ContextTypes.DE context.user_data.pop("handover_utc_time", None) -import_duty_schedule_handler = CommandHandler("import_duty_schedule", import_duty_schedule_cmd) +import_duty_schedule_handler = CommandHandler( + "import_duty_schedule", import_duty_schedule_cmd +) handover_time_handler = MessageHandler( filters.TEXT & ~filters.COMMAND, handle_handover_time_text, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d7ea13f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "duty-teller" +version = "0.1.0" +description = "Telegram bot for team duty shift calendar and group reminder" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "python-telegram-bot[job-queue]>=22.0,<23.0", + "python-dotenv>=1.0,<2.0", + "fastapi>=0.115,<1.0", + "uvicorn[standard]>=0.32,<1.0", + "sqlalchemy>=2.0,<3.0", + "alembic>=1.14,<2.0", + "pydantic>=2.0,<3.0", + "icalendar>=5.0,<6.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0,<9.0", + "pytest-asyncio>=0.24,<1.0", + "httpx>=0.27,<1.0", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["db", "handlers", "api", "importers", "utils", "services"] + +[tool.black] +line-length = 120 +target-version = ["py311"] + +[tool.pylint.messages_control] +disable = ["C0114", "C0115", "C0116"] diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..533b18e --- /dev/null +++ b/services/__init__.py @@ -0,0 +1,27 @@ +"""Service layer: business logic and orchestration. + +Services accept a DB session from the caller (handlers open session_scope and pass session). +No Telegram or HTTP dependencies; repository handles persistence. +""" + +from services.group_duty_pin_service import ( + format_duty_message, + get_duty_message_text, + get_next_shift_end_utc, + save_pin, + delete_pin, + get_message_id, + get_all_pin_chat_ids, +) +from services.import_service import run_import + +__all__ = [ + "format_duty_message", + "get_duty_message_text", + "get_next_shift_end_utc", + "save_pin", + "delete_pin", + "get_message_id", + "get_all_pin_chat_ids", + "run_import", +] diff --git a/services/group_duty_pin_service.py b/services/group_duty_pin_service.py new file mode 100644 index 0000000..c3c9d37 --- /dev/null +++ b/services/group_duty_pin_service.py @@ -0,0 +1,86 @@ +"""Group duty pin: current duty message text, next shift end, pin CRUD. All accept session.""" + +from datetime import datetime, timezone +from zoneinfo import ZoneInfo + +from sqlalchemy.orm import Session + +from db.repository import ( + get_current_duty, + get_next_shift_end, + get_group_duty_pin, + save_group_duty_pin, + delete_group_duty_pin, + get_all_group_duty_pin_chat_ids, +) + + +def format_duty_message(duty, user, tz_name: str) -> str: + """Build the text for the pinned message. duty, user may be None.""" + if duty is None or user is None: + return "Сейчас дежурства нет." + try: + tz = ZoneInfo(tz_name) + except Exception: + tz = ZoneInfo("Europe/Moscow") + tz_name = "Europe/Moscow" + start_dt = datetime.fromisoformat(duty.start_at.replace("Z", "+00:00")) + end_dt = datetime.fromisoformat(duty.end_at.replace("Z", "+00:00")) + start_local = start_dt.astimezone(tz) + end_local = end_dt.astimezone(tz) + offset_sec = ( + start_local.utcoffset().total_seconds() if start_local.utcoffset() else 0 + ) + sign = "+" if offset_sec >= 0 else "-" + h, r = divmod(abs(int(offset_sec)), 3600) + m = r // 60 + tz_hint = f"UTC{sign}{h:d}:{m:02d}, {tz_name}" + time_range = ( + f"{start_local.strftime('%d.%m.%Y %H:%M')} — " + f"{end_local.strftime('%d.%m.%Y %H:%M')} ({tz_hint})" + ) + lines = [ + f"🕐 Дежурство: {time_range}", + f"👤 {user.full_name}", + ] + if user.phone: + lines.append(f"📞 {user.phone}") + if user.username: + lines.append(f"@{user.username}") + return "\n".join(lines) + + +def get_duty_message_text(session: Session, tz_name: str) -> str: + """Get current duty from DB and return formatted message.""" + now = datetime.now(timezone.utc) + result = get_current_duty(session, now) + if result is None: + return "Сейчас дежурства нет." + duty, user = result + return format_duty_message(duty, user, tz_name) + + +def get_next_shift_end_utc(session: Session) -> datetime | None: + """Return next shift end as naive UTC datetime for job scheduling.""" + return get_next_shift_end(session, datetime.now(timezone.utc)) + + +def save_pin(session: Session, chat_id: int, message_id: int) -> None: + """Save or update the pinned message record for a chat.""" + save_group_duty_pin(session, chat_id, message_id) + + +def delete_pin(session: Session, chat_id: int) -> None: + """Remove the pinned message record when the bot leaves the group.""" + delete_group_duty_pin(session, chat_id) + + +def get_message_id(session: Session, chat_id: int) -> int | None: + """Return message_id for the pin in this chat, or None.""" + pin = get_group_duty_pin(session, chat_id) + return pin.message_id if pin else None + + +def get_all_pin_chat_ids(session: Session) -> list[int]: + """Return all chat_ids that have a pinned duty message (for restoring jobs on startup).""" + return get_all_group_duty_pin_chat_ids(session) diff --git a/services/import_service.py b/services/import_service.py new file mode 100644 index 0000000..f6d6af5 --- /dev/null +++ b/services/import_service.py @@ -0,0 +1,70 @@ +"""Import duty schedule: delete range, insert duties/unavailable/vacation. Accepts session.""" + +from datetime import date, timedelta + +from sqlalchemy.orm import Session + +from db.repository import ( + get_or_create_user_by_full_name, + delete_duties_in_range, + insert_duty, +) +from importers.duty_schedule import DutyScheduleResult +from utils.dates import day_start_iso, day_end_iso, duty_to_iso + + +def _consecutive_date_ranges(dates: list[date]) -> list[tuple[date, date]]: + """Sort dates and merge consecutive ones into (first, last) ranges. Empty list -> [].""" + if not dates: + return [] + sorted_dates = sorted(set(dates)) + ranges: list[tuple[date, date]] = [] + start_d = end_d = sorted_dates[0] + for d in sorted_dates[1:]: + if (d - end_d).days == 1: + end_d = d + else: + ranges.append((start_d, end_d)) + start_d = end_d = d + ranges.append((start_d, end_d)) + return ranges + + +def run_import( + session: Session, + result: DutyScheduleResult, + hour_utc: int, + minute_utc: int, +) -> tuple[int, int, int, int]: + """Run import: delete range per user, insert duty/unavailable/vacation. Returns (num_users, num_duty, num_unavailable, num_vacation).""" + from_date_str = result.start_date.isoformat() + to_date_str = result.end_date.isoformat() + num_duty = num_unavailable = num_vacation = 0 + for entry in result.entries: + user = get_or_create_user_by_full_name(session, entry.full_name) + delete_duties_in_range(session, user.id, from_date_str, to_date_str) + for d in entry.duty_dates: + start_at = duty_to_iso(d, hour_utc, minute_utc) + d_next = d + timedelta(days=1) + end_at = duty_to_iso(d_next, hour_utc, minute_utc) + insert_duty(session, user.id, start_at, end_at, event_type="duty") + num_duty += 1 + for d in entry.unavailable_dates: + insert_duty( + session, + user.id, + day_start_iso(d), + day_end_iso(d), + event_type="unavailable", + ) + num_unavailable += 1 + for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates): + insert_duty( + session, + user.id, + day_start_iso(start_d), + day_end_iso(end_d), + event_type="vacation", + ) + num_vacation += 1 + return (len(result.entries), num_duty, num_unavailable, num_vacation) diff --git a/tests/test_app.py b/tests/test_app.py index 6fd15b7..b669ddc 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,7 +1,7 @@ """Tests for FastAPI app /api/duties.""" import time -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest from fastapi.testclient import TestClient @@ -50,7 +50,7 @@ def test_duties_200_when_skip_auth(mock_fetch, client): ) assert r.status_code == 200 assert r.json() == [] - mock_fetch.assert_called_once_with("2025-01-01", "2025-01-31") + mock_fetch.assert_called_once_with(ANY, "2025-01-01", "2025-01-31") @patch("api.app.validate_init_data_with_reason") @@ -105,7 +105,7 @@ def test_duties_200_with_allowed_user(mock_can_access, mock_validate, client): assert r.status_code == 200 assert len(r.json()) == 1 assert r.json()[0]["full_name"] == "Иван Иванов" - mock_fetch.assert_called_once_with("2025-01-01", "2025-01-31") + mock_fetch.assert_called_once_with(ANY, "2025-01-01", "2025-01-31") def test_duties_e2e_auth_real_validation(client, monkeypatch): @@ -130,7 +130,7 @@ def test_duties_e2e_auth_real_validation(client, monkeypatch): ) assert r.status_code == 200 assert r.json() == [] - mock_fetch.assert_called_once_with("2025-01-01", "2025-01-31") + mock_fetch.assert_called_once_with(ANY, "2025-01-01", "2025-01-31") @patch("api.app.config.MINI_APP_SKIP_AUTH", True) diff --git a/tests/test_duty_schedule_parser.py b/tests/test_duty_schedule_parser.py index 8fc64a7..03241ae 100644 --- a/tests/test_duty_schedule_parser.py +++ b/tests/test_duty_schedule_parser.py @@ -9,7 +9,6 @@ from importers.duty_schedule import ( UNAVAILABLE_MARKER, VACATION_MARKER, DutyScheduleParseError, - DutyScheduleEntry, parse_duty_schedule, ) @@ -32,7 +31,11 @@ def test_parse_valid_schedule(): assert "Petrov P.P." in by_name # Ivanov: only duty (б, Б, в) -> 2026-02-18, 19, 20 ivan = by_name["Ivanov I.I."] - assert sorted(ivan.duty_dates) == [date(2026, 2, 18), date(2026, 2, 19), date(2026, 2, 20)] + assert sorted(ivan.duty_dates) == [ + date(2026, 2, 18), + date(2026, 2, 19), + date(2026, 2, 20), + ] assert ivan.unavailable_dates == [] assert ivan.vacation_dates == [] # Petrov: one Н (unavailable), one О (vacation) -> 2026-02-17, 18 diff --git a/tests/test_import_duty_schedule_integration.py b/tests/test_import_duty_schedule_integration.py index 3f105b8..641b32a 100644 --- a/tests/test_import_duty_schedule_integration.py +++ b/tests/test_import_duty_schedule_integration.py @@ -6,10 +6,13 @@ import pytest from db import init_db from db.repository import get_duties -from db.session import get_session -from importers.duty_schedule import DutyScheduleEntry, DutyScheduleResult, parse_duty_schedule - -from handlers.import_duty_schedule import _run_import +from db.session import get_session, session_scope +from importers.duty_schedule import ( + DutyScheduleEntry, + DutyScheduleResult, + parse_duty_schedule, +) +from services.import_service import run_import @pytest.fixture @@ -21,6 +24,7 @@ def db_url(): def _reset_db_session(db_url): """Ensure each test uses a fresh engine for :memory: (clear global cache for test URL).""" import db.session as session_module + session_module._engine = None session_module._SessionLocal = None init_db(db_url) @@ -49,7 +53,8 @@ def test_import_creates_users_and_duties(db_url): ), ], ) - num_users, num_duty, num_unav, num_vac = _run_import(db_url, result, 6, 0) + with session_scope(db_url) as session: + num_users, num_duty, num_unav, num_vac = run_import(session, result, 6, 0) assert num_users == 2 assert num_duty == 3 assert num_unav == 0 @@ -84,7 +89,8 @@ def test_import_replaces_duties_in_range(db_url): ) ], ) - _run_import(db_url, result1, 9, 0) + with session_scope(db_url) as session: + run_import(session, result1, 9, 0) session = get_session(db_url) try: @@ -105,7 +111,8 @@ def test_import_replaces_duties_in_range(db_url): ) ], ) - _run_import(db_url, result2, 9, 0) + with session_scope(db_url) as session: + run_import(session, result2, 9, 0) session = get_session(db_url) try: @@ -123,7 +130,8 @@ def test_import_full_flow_parse_then_import(db_url): '"schedule": [{"name": "Alexey A.", "duty": "\u0431; ; \u0432"}]}' ).encode("utf-8") parsed = parse_duty_schedule(raw) - num_users, num_duty, num_unav, num_vac = _run_import(db_url, parsed, 6, 0) + with session_scope(db_url) as session: + num_users, num_duty, num_unav, num_vac = run_import(session, parsed, 6, 0) assert num_users == 1 assert num_duty == 2 assert num_unav == 0 @@ -149,11 +157,16 @@ def test_import_event_types_unavailable_vacation(db_url): full_name="Mixed User", duty_dates=[date(2026, 2, 16)], unavailable_dates=[date(2026, 2, 17)], - vacation_dates=[date(2026, 2, 18), date(2026, 2, 19), date(2026, 2, 20)], + vacation_dates=[ + date(2026, 2, 18), + date(2026, 2, 19), + date(2026, 2, 20), + ], ), ], ) - num_users, num_duty, num_unav, num_vac = _run_import(db_url, result, 6, 0) + with session_scope(db_url) as session: + num_users, num_duty, num_unav, num_vac = run_import(session, result, 6, 0) assert num_users == 1 assert num_duty == 1 and num_unav == 1 and num_vac == 1 @@ -185,11 +198,16 @@ def test_import_vacation_with_gap_two_periods(db_url): full_name="Vacation User", duty_dates=[], unavailable_dates=[], - vacation_dates=[date(2026, 2, 17), date(2026, 2, 18), date(2026, 2, 20)], + vacation_dates=[ + date(2026, 2, 17), + date(2026, 2, 18), + date(2026, 2, 20), + ], ), ], ) - num_users, num_duty, num_unav, num_vac = _run_import(db_url, result, 6, 0) + with session_scope(db_url) as session: + num_users, num_duty, num_unav, num_vac = run_import(session, result, 6, 0) assert num_users == 1 assert num_duty == 0 and num_unav == 0 and num_vac == 2 diff --git a/tests/test_repository_duty_range.py b/tests/test_repository_duty_range.py index 9d18fcb..5904707 100644 --- a/tests/test_repository_duty_range.py +++ b/tests/test_repository_duty_range.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from db.models import Base, User, Duty +from db.models import Base, User from db.repository import ( delete_duties_in_range, get_or_create_user_by_full_name, @@ -15,7 +15,9 @@ from db.repository import ( @pytest.fixture def session(): - engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + engine = create_engine( + "sqlite:///:memory:", connect_args={"check_same_thread": False} + ) Base.metadata.create_all(engine) Session = sessionmaker(bind=engine, autocommit=False, autoflush=False) s = Session() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..bf8ca0a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,99 @@ +"""Unit tests for utils (dates, user, handover).""" + +from datetime import date + +import pytest + +from utils.dates import ( + day_start_iso, + day_end_iso, + duty_to_iso, + parse_iso_date, + validate_date_range, +) +from utils.user import build_full_name +from utils.handover import parse_handover_time + + +# --- dates --- + + +def test_day_start_iso(): + assert day_start_iso(date(2026, 2, 18)) == "2026-02-18T00:00:00Z" + + +def test_day_end_iso(): + assert day_end_iso(date(2026, 2, 18)) == "2026-02-18T23:59:59Z" + + +def test_duty_to_iso(): + assert duty_to_iso(date(2026, 2, 18), 6, 0) == "2026-02-18T06:00:00Z" + + +def test_parse_iso_date_valid(): + assert parse_iso_date("2026-02-18") == date(2026, 2, 18) + assert parse_iso_date(" 2026-02-18 ") == date(2026, 2, 18) + + +def test_parse_iso_date_invalid(): + assert parse_iso_date("") is None + assert parse_iso_date("2026-02-31") is None # invalid day + assert parse_iso_date("18-02-2026") is None + assert parse_iso_date("not-a-date") is None + + +def test_validate_date_range_ok(): + validate_date_range("2025-01-01", "2025-01-31") # no raise + + +def test_validate_date_range_bad_format(): + with pytest.raises(ValueError, match="формате YYYY-MM-DD"): + validate_date_range("01-01-2025", "2025-01-31") + with pytest.raises(ValueError, match="формате YYYY-MM-DD"): + validate_date_range("2025-01-01", "invalid") + + +def test_validate_date_range_from_after_to(): + with pytest.raises(ValueError, match="from не должна быть позже"): + validate_date_range("2025-02-01", "2025-01-01") + + +# --- user --- + + +def test_build_full_name_both(): + assert build_full_name("John", "Doe") == "John Doe" + + +def test_build_full_name_first_only(): + assert build_full_name("John", None) == "John" + + +def test_build_full_name_last_only(): + assert build_full_name(None, "Doe") == "Doe" + + +def test_build_full_name_empty(): + assert build_full_name("", "") == "User" + assert build_full_name(None, None) == "User" + + +# --- handover --- + + +def test_parse_handover_utc(): + assert parse_handover_time("09:00") == (9, 0) + assert parse_handover_time("09:00 UTC") == (9, 0) + assert parse_handover_time(" 06:30 ") == (6, 30) + + +def test_parse_handover_with_tz(): + # Europe/Moscow UTC+3 in winter: 09:00 Moscow = 06:00 UTC + assert parse_handover_time("09:00 Europe/Moscow") == (6, 0) + + +def test_parse_handover_invalid(): + assert parse_handover_time("") is None + assert parse_handover_time("not a time") is None + # 25:00 is normalized to 1:00 by hour % 24; use non-matching string + assert parse_handover_time("12") is None diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..484905f --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,24 @@ +"""Shared utilities: date/ISO helpers, user display names, handover time parsing. + +Used by handlers, API, and services. No DB or Telegram dependencies. +""" + +from utils.dates import ( + day_end_iso, + day_start_iso, + duty_to_iso, + parse_iso_date, + validate_date_range, +) +from utils.user import build_full_name +from utils.handover import parse_handover_time + +__all__ = [ + "day_start_iso", + "day_end_iso", + "duty_to_iso", + "parse_iso_date", + "validate_date_range", + "build_full_name", + "parse_handover_time", +] diff --git a/utils/dates.py b/utils/dates.py new file mode 100644 index 0000000..33dcb1e --- /dev/null +++ b/utils/dates.py @@ -0,0 +1,46 @@ +"""Date and ISO helpers for duty ranges and API validation.""" + +import re +from datetime import date, datetime, timezone + + +def day_start_iso(d: date) -> str: + """ISO 8601 start of calendar day UTC: YYYY-MM-DDT00:00:00Z.""" + return d.isoformat() + "T00:00:00Z" + + +def day_end_iso(d: date) -> str: + """ISO 8601 end of calendar day UTC: YYYY-MM-DDT23:59:59Z.""" + return d.isoformat() + "T23:59:59Z" + + +def duty_to_iso(d: date, hour_utc: int, minute_utc: int) -> str: + """ISO 8601 with Z for start of duty on date d at given UTC time.""" + dt = datetime(d.year, d.month, d.day, hour_utc, minute_utc, 0, tzinfo=timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + + +# ISO date YYYY-MM-DD +_ISO_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") + + +def parse_iso_date(s: str) -> date | None: + """Parse YYYY-MM-DD string to date. Returns None if invalid.""" + if not s or not _ISO_DATE_RE.match(s.strip()): + return None + try: + return date.fromisoformat(s.strip()) + except ValueError: + return None + + +def validate_date_range(from_date: str, to_date: str) -> None: + """Validate from_date and to_date are YYYY-MM-DD and from_date <= to_date. + + Raises: + ValueError: With a user-facing message if invalid. + """ + if not _ISO_DATE_RE.match(from_date or "") or not _ISO_DATE_RE.match(to_date or ""): + raise ValueError("Параметры from и to должны быть в формате YYYY-MM-DD") + if from_date > to_date: + raise ValueError("Дата from не должна быть позже to") diff --git a/utils/handover.py b/utils/handover.py new file mode 100644 index 0000000..d737f8c --- /dev/null +++ b/utils/handover.py @@ -0,0 +1,37 @@ +"""Handover time parsing for duty schedule import.""" + +import re +from datetime import datetime, timezone + +# HH:MM or HH:MM:SS, optional space + timezone (IANA or "UTC") +HANDOVER_TIME_RE = re.compile( + r"^\s*(\d{1,2}):(\d{2})(?::(\d{2}))?\s*(?:\s+(\S+))?\s*$", re.IGNORECASE +) + + +def parse_handover_time(text: str) -> tuple[int, int] | None: + """Parse handover time string to (hour_utc, minute_utc). Returns None on failure.""" + m = HANDOVER_TIME_RE.match(text) + if not m: + return None + hour = int(m.group(1)) + minute = int(m.group(2)) + # second = m.group(3) ignored + tz_str = (m.group(4) or "").strip() + if not tz_str or tz_str.upper() == "UTC": + return (hour % 24, minute) + try: + from zoneinfo import ZoneInfo + except ImportError: + try: + from backports.zoneinfo import ZoneInfo # type: ignore + except ImportError: + return None + try: + tz = ZoneInfo(tz_str) + except Exception: + return None + # Build datetime in that tz and convert to UTC + dt = datetime(2000, 1, 1, hour, minute, 0, tzinfo=tz) + utc = dt.astimezone(timezone.utc) + return (utc.hour, utc.minute) diff --git a/utils/user.py b/utils/user.py new file mode 100644 index 0000000..1ec3238 --- /dev/null +++ b/utils/user.py @@ -0,0 +1,8 @@ +"""User display name helpers.""" + + +def build_full_name(first_name: str | None, last_name: str | None) -> str: + """Build display full name from first and last name. Returns 'User' if both empty.""" + parts = [first_name or "", last_name or ""] + full = " ".join(filter(None, parts)).strip() + return full or "User" diff --git a/webapp/app.js b/webapp/app.js index b4924b5..a135ef5 100644 --- a/webapp/app.js +++ b/webapp/app.js @@ -3,6 +3,47 @@ const RETRY_DELAY_MS = 800; const RETRY_AFTER_ACCESS_DENIED_MS = 1200; + const THEME_BG = { dark: "#1a1b26", light: "#d5d6db" }; + + function getTheme() { + if (typeof window === "undefined") return "dark"; + var twa = window.Telegram?.WebApp; + if (twa?.colorScheme) { + return twa.colorScheme; + } + var cssScheme = ""; + try { + cssScheme = getComputedStyle(document.documentElement).getPropertyValue("--tg-color-scheme").trim(); + } catch (e) {} + if (cssScheme === "light" || cssScheme === "dark") { + return cssScheme; + } + if (window.matchMedia("(prefers-color-scheme: dark)").matches) { + return "dark"; + } + return "light"; + } + + function applyTheme() { + var scheme = getTheme(); + document.documentElement.dataset.theme = scheme; + var bg = THEME_BG[scheme] || THEME_BG.dark; + if (window.Telegram?.WebApp?.setBackgroundColor) { + window.Telegram.WebApp.setBackgroundColor(bg); + } + if (window.Telegram?.WebApp?.setHeaderColor) { + window.Telegram.WebApp.setHeaderColor(bg); + } + } + + applyTheme(); + if (typeof window !== "undefined" && window.Telegram?.WebApp) { + setTimeout(applyTheme, 0); + setTimeout(applyTheme, 100); + } else if (window.matchMedia) { + window.matchMedia("(prefers-color-scheme: dark)").addEventListener("change", applyTheme); + } + const MONTHS = [ "Январь", "Февраль", "Март", "Апрель", "Май", "Июнь", "Июль", "Август", "Сентябрь", "Октябрь", "Ноябрь", "Декабрь" @@ -382,6 +423,31 @@ return key.slice(8, 10) + "." + key.slice(5, 7); } + /** Format ISO date as HH:MM in local time. */ + function formatTimeLocal(isoStr) { + const d = new Date(isoStr); + return String(d.getHours()).padStart(2, "0") + ":" + String(d.getMinutes()).padStart(2, "0"); + } + + /** Build HTML for one timeline duty card: one-day "DD.MM, HH:MM – HH:MM" or multi-day "DD.MM HH:MM – DD.MM HH:MM". */ + function dutyTimelineCardHtml(d, isCurrent) { + const startLocal = localDateString(new Date(d.start_at)); + const endLocal = localDateString(new Date(d.end_at)); + const startDDMM = dateKeyToDDMM(startLocal); + const endDDMM = dateKeyToDDMM(endLocal); + const startTime = formatTimeLocal(d.start_at); + const endTime = formatTimeLocal(d.end_at); + let timeStr; + if (startLocal === endLocal) { + timeStr = startDDMM + ", " + startTime + " – " + endTime; + } else { + timeStr = startDDMM + " " + startTime + " – " + endDDMM + " " + endTime; + } + const typeLabel = isCurrent ? "Сейчас дежурит" : (EVENT_TYPE_LABELS[d.event_type] || "Дежурство"); + const extraClass = isCurrent ? " duty-item--current" : ""; + return "
В этом месяце событий нет.
"; + dutyListEl.classList.remove("duty-timeline"); + dutyListEl.innerHTML = "В этом месяце дежурств нет.
"; return; } - const grouped = {}; - duties.forEach(function (d) { - const date = localDateString(new Date(d.start_at)); - if (!grouped[date]) grouped[date] = []; - grouped[date].push(d); - }); - let dates = Object.keys(grouped).sort(); + dutyListEl.classList.add("duty-timeline"); const todayKey = localDateString(new Date()); const firstKey = localDateString(firstDayOfMonth(current)); const lastKey = localDateString(lastDayOfMonth(current)); const showTodayInMonth = todayKey >= firstKey && todayKey <= lastKey; - if (showTodayInMonth && dates.indexOf(todayKey) === -1) { - dates = [todayKey].concat(dates).sort(); - } - let html = ""; + const dateSet = new Set(); + duties.forEach(function (d) { + dateSet.add(localDateString(new Date(d.start_at))); + }); + if (showTodayInMonth) dateSet.add(todayKey); + let dates = Array.from(dateSet).sort(); + const now = new Date(); + let fullHtml = ""; dates.forEach(function (date) { const isToday = date === todayKey; - const dayBlockClass = "duty-list-day" + (isToday ? " duty-list-day--today" : ""); - const titleText = isToday ? "Сегодня, " + dateKeyToDDMM(date) : dateKeyToDDMM(date); - html += "