Refactor project structure and enhance Docker configuration
- Updated `.dockerignore` to exclude test and development artifacts, optimizing the Docker image size. - Refactored `main.py` to delegate execution to `duty_teller.run.main()`, simplifying the entry point. - Introduced a new `duty_teller` package to encapsulate core functionality, improving modularity and organization. - Enhanced `pyproject.toml` to define a script for running the application, streamlining the execution process. - Updated README documentation to reflect changes in project structure and usage instructions. - Improved Alembic environment configuration to utilize the new package structure for database migrations.
This commit is contained in:
8
duty_teller/__init__.py
Normal file
8
duty_teller/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Duty Teller: Telegram bot for team duty shift calendar and group reminder."""
|
||||
|
||||
from importlib.metadata import version, PackageNotFoundError
|
||||
|
||||
try:
|
||||
__version__ = version("duty-teller")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "0.1.0"
|
||||
1
duty_teller/api/__init__.py
Normal file
1
duty_teller/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# HTTP API for Mini App
|
||||
62
duty_teller/api/app.py
Normal file
62
duty_teller/api/app.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""FastAPI app: /api/duties and static webapp."""
|
||||
|
||||
import logging
|
||||
|
||||
import duty_teller.config as config
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from duty_teller.api.calendar_ics import get_calendar_events
|
||||
from duty_teller.api.dependencies import (
|
||||
fetch_duties_response,
|
||||
get_db_session,
|
||||
get_validated_dates,
|
||||
require_miniapp_username,
|
||||
)
|
||||
from duty_teller.db.schemas import CalendarEvent, DutyWithUser
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Duty Teller API")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/duties", response_model=list[DutyWithUser])
|
||||
def list_duties(
|
||||
request: Request,
|
||||
dates: tuple[str, str] = Depends(get_validated_dates),
|
||||
_username: str = Depends(require_miniapp_username),
|
||||
session: Session = Depends(get_db_session),
|
||||
):
|
||||
from_date_val, to_date_val = dates
|
||||
log.info(
|
||||
"GET /api/duties from %s",
|
||||
request.client.host if request.client else "?",
|
||||
)
|
||||
return fetch_duties_response(session, from_date_val, to_date_val)
|
||||
|
||||
|
||||
@app.get("/api/calendar-events", response_model=list[CalendarEvent])
|
||||
def list_calendar_events(
|
||||
dates: tuple[str, str] = Depends(get_validated_dates),
|
||||
_username: str = Depends(require_miniapp_username),
|
||||
):
|
||||
from_date_val, to_date_val = dates
|
||||
url = config.EXTERNAL_CALENDAR_ICS_URL
|
||||
if not url:
|
||||
return []
|
||||
events = get_calendar_events(url, from_date=from_date_val, to_date=to_date_val)
|
||||
return [CalendarEvent(date=e["date"], summary=e["summary"]) for e in events]
|
||||
|
||||
|
||||
webapp_path = config.PROJECT_ROOT / "webapp"
|
||||
if webapp_path.is_dir():
|
||||
app.mount("/app", StaticFiles(directory=str(webapp_path), html=True), name="webapp")
|
||||
124
duty_teller/api/calendar_ics.py
Normal file
124
duty_teller/api/calendar_ics.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Fetch and parse external ICS calendar; in-memory cache with 7-day TTL."""
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from urllib.request import Request, urlopen
|
||||
from urllib.error import URLError
|
||||
|
||||
from icalendar import Calendar
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# In-memory cache: url -> (cached_at_timestamp, raw_ics_bytes)
|
||||
_ics_cache: dict[str, tuple[float, bytes]] = {}
|
||||
CACHE_TTL_SECONDS = 7 * 24 * 3600 # 1 week
|
||||
FETCH_TIMEOUT_SECONDS = 15
|
||||
|
||||
|
||||
def _fetch_ics(url: str) -> bytes | None:
|
||||
"""GET url, return response body or None on error."""
|
||||
try:
|
||||
req = Request(url, headers={"User-Agent": "DutyTeller/1.0"})
|
||||
with urlopen(req, timeout=FETCH_TIMEOUT_SECONDS) as resp:
|
||||
return resp.read()
|
||||
except URLError as e:
|
||||
log.warning("Failed to fetch ICS from %s: %s", url, e)
|
||||
return None
|
||||
except OSError as e:
|
||||
log.warning("Error fetching ICS from %s: %s", url, e)
|
||||
return None
|
||||
|
||||
|
||||
def _to_date(dt) -> date | None:
|
||||
"""Convert icalendar DATE or DATE-TIME to date. Return None if invalid."""
|
||||
if isinstance(dt, datetime):
|
||||
return dt.date()
|
||||
if isinstance(dt, date):
|
||||
return dt
|
||||
return None
|
||||
|
||||
|
||||
def _event_date_range(component) -> tuple[date | None, date | None]:
|
||||
"""
|
||||
Get (start_date, end_date) for a VEVENT. DTEND is exclusive in iCalendar;
|
||||
last day of event = DTEND date - 1 day. Returns (None, None) if invalid.
|
||||
"""
|
||||
dtstart = component.get("dtstart")
|
||||
if not dtstart:
|
||||
return (None, None)
|
||||
start_d = _to_date(dtstart.dt)
|
||||
if not start_d:
|
||||
return (None, None)
|
||||
|
||||
dtend = component.get("dtend")
|
||||
if not dtend:
|
||||
return (start_d, start_d)
|
||||
|
||||
end_dt = dtend.dt
|
||||
end_d = _to_date(end_dt)
|
||||
if not end_d:
|
||||
return (start_d, start_d)
|
||||
# DTEND is exclusive: last day of event is end_d - 1 day
|
||||
last_d = end_d - timedelta(days=1)
|
||||
return (start_d, last_d)
|
||||
|
||||
|
||||
def _get_events_from_ics(raw: bytes, from_date: str, to_date: str) -> list[dict]:
|
||||
"""Parse ICS bytes and return list of {date, summary} in [from_date, to_date]. One-time events only."""
|
||||
result: list[dict] = []
|
||||
try:
|
||||
cal = Calendar.from_ical(raw)
|
||||
if not cal:
|
||||
return result
|
||||
except Exception as e:
|
||||
log.warning("Failed to parse ICS: %s", e)
|
||||
return result
|
||||
|
||||
from_d = date.fromisoformat(from_date)
|
||||
to_d = date.fromisoformat(to_date)
|
||||
|
||||
for component in cal.walk():
|
||||
if component.name != "VEVENT":
|
||||
continue
|
||||
if component.get("rrule"):
|
||||
continue # skip recurring in first iteration
|
||||
start_d, end_d = _event_date_range(component)
|
||||
if not start_d or not end_d:
|
||||
continue
|
||||
summary = component.get("summary")
|
||||
summary_str = str(summary) if summary else ""
|
||||
|
||||
d = start_d
|
||||
while d <= end_d:
|
||||
if from_d <= d <= to_d:
|
||||
result.append({"date": d.strftime("%Y-%m-%d"), "summary": summary_str})
|
||||
d += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_calendar_events(
|
||||
url: str,
|
||||
from_date: str,
|
||||
to_date: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return list of {date: "YYYY-MM-DD", summary: "..."} for events in [from_date, to_date].
|
||||
Uses in-memory cache with TTL 7 days. On fetch/parse error returns [].
|
||||
"""
|
||||
if not url or from_date > to_date:
|
||||
return []
|
||||
|
||||
now = datetime.now().timestamp()
|
||||
raw: bytes | None = None
|
||||
if url in _ics_cache:
|
||||
cached_at, cached_raw = _ics_cache[url]
|
||||
if now - cached_at < CACHE_TTL_SECONDS:
|
||||
raw = cached_raw
|
||||
if raw is None:
|
||||
raw = _fetch_ics(url)
|
||||
if raw is None:
|
||||
return []
|
||||
_ics_cache[url] = (now, raw)
|
||||
|
||||
return _get_events_from_ics(raw, from_date, to_date)
|
||||
116
duty_teller/api/dependencies.py
Normal file
116
duty_teller/api/dependencies.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""FastAPI dependencies: DB session, auth, date validation."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Generator
|
||||
|
||||
from fastapi import Header, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import duty_teller.config as config
|
||||
from duty_teller.api.telegram_auth import validate_init_data_with_reason
|
||||
from duty_teller.db.repository import get_duties
|
||||
from duty_teller.db.schemas import DutyWithUser
|
||||
from duty_teller.db.session import session_scope
|
||||
from duty_teller.utils.dates import validate_date_range
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_duty_dates(from_date: str, to_date: str) -> None:
|
||||
try:
|
||||
validate_date_range(from_date, to_date)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
def get_validated_dates(
|
||||
from_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="from"),
|
||||
to_date: str = Query(..., description="ISO date YYYY-MM-DD", alias="to"),
|
||||
) -> tuple[str, str]:
|
||||
_validate_duty_dates(from_date, to_date)
|
||||
return (from_date, to_date)
|
||||
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def require_miniapp_username(
|
||||
request: Request,
|
||||
x_telegram_init_data: Annotated[
|
||||
str | None, Header(alias="X-Telegram-Init-Data")
|
||||
] = None,
|
||||
) -> str:
|
||||
return get_authenticated_username(request, x_telegram_init_data)
|
||||
|
||||
|
||||
def _auth_error_detail(auth_reason: str) -> str:
|
||||
if auth_reason == "hash_mismatch":
|
||||
return (
|
||||
"Неверная подпись. Убедитесь, что BOT_TOKEN на сервере совпадает с токеном бота, "
|
||||
"из которого открыт календарь (тот же бот, что в меню)."
|
||||
)
|
||||
return "Неверные данные авторизации"
|
||||
|
||||
|
||||
def _is_private_client(client_host: str | None) -> bool:
|
||||
if not client_host:
|
||||
return False
|
||||
if client_host in ("127.0.0.1", "::1"):
|
||||
return True
|
||||
parts = client_host.split(".")
|
||||
if len(parts) == 4:
|
||||
try:
|
||||
a, b, c, d = (int(x) for x in parts)
|
||||
if (a == 10) or (a == 172 and 16 <= b <= 31) or (a == 192 and b == 168):
|
||||
return True
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def get_authenticated_username(
|
||||
request: Request, x_telegram_init_data: str | None
|
||||
) -> str:
|
||||
init_data = (x_telegram_init_data or "").strip()
|
||||
if not init_data:
|
||||
client_host = request.client.host if request.client else None
|
||||
if _is_private_client(client_host) or config.MINI_APP_SKIP_AUTH:
|
||||
if config.MINI_APP_SKIP_AUTH:
|
||||
log.warning("allowing without initData (MINI_APP_SKIP_AUTH is set)")
|
||||
return ""
|
||||
log.warning("no X-Telegram-Init-Data header (client=%s)", client_host)
|
||||
raise HTTPException(status_code=403, detail="Откройте календарь из Telegram")
|
||||
max_age = config.INIT_DATA_MAX_AGE_SECONDS or None
|
||||
username, auth_reason = validate_init_data_with_reason(
|
||||
init_data, config.BOT_TOKEN, max_age_seconds=max_age
|
||||
)
|
||||
if username is None:
|
||||
log.warning("initData validation failed: %s", auth_reason)
|
||||
raise HTTPException(status_code=403, detail=_auth_error_detail(auth_reason))
|
||||
if not config.can_access_miniapp(username):
|
||||
log.warning("username not in allowlist: %s", username)
|
||||
raise HTTPException(status_code=403, detail="Доступ запрещён")
|
||||
return username
|
||||
|
||||
|
||||
def fetch_duties_response(
|
||||
session: Session, from_date: str, to_date: str
|
||||
) -> list[DutyWithUser]:
|
||||
rows = get_duties(session, from_date=from_date, to_date=to_date)
|
||||
return [
|
||||
DutyWithUser(
|
||||
id=duty.id,
|
||||
user_id=duty.user_id,
|
||||
start_at=duty.start_at,
|
||||
end_at=duty.end_at,
|
||||
full_name=full_name,
|
||||
event_type=(
|
||||
duty.event_type
|
||||
if duty.event_type in ("duty", "unavailable", "vacation")
|
||||
else "duty"
|
||||
),
|
||||
)
|
||||
for duty, full_name in rows
|
||||
]
|
||||
84
duty_teller/api/telegram_auth.py
Normal file
84
duty_teller/api/telegram_auth.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Validate Telegram Web App initData and extract user username."""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import unquote
|
||||
|
||||
# Telegram algorithm: https://core.telegram.org/bots/webapps#validating-data-received-via-the-mini-app
|
||||
# Data-check string: sorted key=value with URL-decoded values, then HMAC-SHA256(WebAppData, token) as secret.
|
||||
|
||||
|
||||
def validate_init_data(
|
||||
init_data: str,
|
||||
bot_token: str,
|
||||
max_age_seconds: int | None = None,
|
||||
) -> str | None:
|
||||
"""Validate initData and return username; see validate_init_data_with_reason for failure reason."""
|
||||
username, _ = validate_init_data_with_reason(init_data, bot_token, max_age_seconds)
|
||||
return username
|
||||
|
||||
|
||||
def validate_init_data_with_reason(
|
||||
init_data: str,
|
||||
bot_token: str,
|
||||
max_age_seconds: int | None = None,
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Validate initData signature and return (username, None) or (None, reason).
|
||||
reason is one of: "ok", "empty", "no_hash", "hash_mismatch", "auth_date_expired", "no_user", "user_invalid", "no_username".
|
||||
"""
|
||||
if not init_data or not bot_token:
|
||||
return (None, "empty")
|
||||
init_data = init_data.strip()
|
||||
params = {}
|
||||
for part in init_data.split("&"):
|
||||
if "=" not in part:
|
||||
continue
|
||||
key, _, value = part.partition("=")
|
||||
if not key:
|
||||
continue
|
||||
params[key] = value
|
||||
hash_val = params.pop("hash", None)
|
||||
if not hash_val:
|
||||
return (None, "no_hash")
|
||||
data_pairs = sorted(params.items())
|
||||
# Data-check string: key=value with URL-decoded values (per Telegram example)
|
||||
data_string = "\n".join(f"{k}={unquote(v)}" for k, v in data_pairs)
|
||||
# HMAC-SHA256(key=WebAppData, message=bot_token) per reference implementations
|
||||
secret_key = hmac.new(
|
||||
b"WebAppData",
|
||||
msg=bot_token.encode(),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
computed = hmac.new(
|
||||
secret_key,
|
||||
msg=data_string.encode(),
|
||||
digestmod=hashlib.sha256,
|
||||
).hexdigest()
|
||||
if not hmac.compare_digest(computed.lower(), hash_val.lower()):
|
||||
return (None, "hash_mismatch")
|
||||
if max_age_seconds is not None and max_age_seconds > 0:
|
||||
auth_date_raw = params.get("auth_date")
|
||||
if not auth_date_raw:
|
||||
return (None, "auth_date_expired")
|
||||
try:
|
||||
auth_date = int(float(auth_date_raw))
|
||||
except (ValueError, TypeError):
|
||||
return (None, "auth_date_expired")
|
||||
if time.time() - auth_date > max_age_seconds:
|
||||
return (None, "auth_date_expired")
|
||||
user_raw = params.get("user")
|
||||
if not user_raw:
|
||||
return (None, "no_user")
|
||||
try:
|
||||
user = json.loads(unquote(user_raw))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return (None, "user_invalid")
|
||||
if not isinstance(user, dict):
|
||||
return (None, "user_invalid")
|
||||
username = user.get("username")
|
||||
if not username or not isinstance(username, str):
|
||||
return (None, "no_username")
|
||||
return (username.strip().lstrip("@").lower(), "ok")
|
||||
116
duty_teller/config.py
Normal file
116
duty_teller/config.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Load configuration from environment. BOT_TOKEN is not validated on import; check in main/entry point."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Project root (parent of duty_teller package). Used for webapp path, etc.
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
"""Optional injectable settings built from env. Tests can override or build from env."""
|
||||
|
||||
bot_token: str
|
||||
database_url: str
|
||||
mini_app_base_url: str
|
||||
http_port: int
|
||||
allowed_usernames: set[str]
|
||||
admin_usernames: set[str]
|
||||
mini_app_skip_auth: bool
|
||||
init_data_max_age_seconds: int
|
||||
cors_origins: list[str]
|
||||
external_calendar_ics_url: str
|
||||
duty_display_tz: str
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "Settings":
|
||||
"""Build Settings from current environment (same logic as module-level vars)."""
|
||||
bot_token = os.getenv("BOT_TOKEN") or ""
|
||||
raw_allowed = os.getenv("ALLOWED_USERNAMES", "").strip()
|
||||
allowed = {
|
||||
s.strip().lstrip("@").lower() for s in raw_allowed.split(",") if s.strip()
|
||||
}
|
||||
raw_admin = os.getenv("ADMIN_USERNAMES", "").strip()
|
||||
admin = {
|
||||
s.strip().lstrip("@").lower() for s in raw_admin.split(",") if s.strip()
|
||||
}
|
||||
raw_cors = os.getenv("CORS_ORIGINS", "").strip()
|
||||
cors = (
|
||||
[_o.strip() for _o in raw_cors.split(",") if _o.strip()]
|
||||
if raw_cors and raw_cors != "*"
|
||||
else ["*"]
|
||||
)
|
||||
return cls(
|
||||
bot_token=bot_token,
|
||||
database_url=os.getenv("DATABASE_URL", "sqlite:///data/duty_teller.db"),
|
||||
mini_app_base_url=os.getenv("MINI_APP_BASE_URL", "").rstrip("/"),
|
||||
http_port=int(os.getenv("HTTP_PORT", "8080")),
|
||||
allowed_usernames=allowed,
|
||||
admin_usernames=admin,
|
||||
mini_app_skip_auth=os.getenv("MINI_APP_SKIP_AUTH", "").strip()
|
||||
in ("1", "true", "yes"),
|
||||
init_data_max_age_seconds=int(os.getenv("INIT_DATA_MAX_AGE_SECONDS", "0")),
|
||||
cors_origins=cors,
|
||||
external_calendar_ics_url=os.getenv(
|
||||
"EXTERNAL_CALENDAR_ICS_URL", ""
|
||||
).strip(),
|
||||
duty_display_tz=os.getenv("DUTY_DISPLAY_TZ", "Europe/Moscow").strip()
|
||||
or "Europe/Moscow",
|
||||
)
|
||||
|
||||
|
||||
# Module-level vars: no validation on import; entry point must check BOT_TOKEN when needed.
|
||||
BOT_TOKEN = os.getenv("BOT_TOKEN") or ""
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///data/duty_teller.db")
|
||||
MINI_APP_BASE_URL = os.getenv("MINI_APP_BASE_URL", "").rstrip("/")
|
||||
HTTP_PORT = int(os.getenv("HTTP_PORT", "8080"))
|
||||
|
||||
_raw_allowed = os.getenv("ALLOWED_USERNAMES", "").strip()
|
||||
ALLOWED_USERNAMES = {
|
||||
s.strip().lstrip("@").lower() for s in _raw_allowed.split(",") if s.strip()
|
||||
}
|
||||
|
||||
_raw_admin = os.getenv("ADMIN_USERNAMES", "").strip()
|
||||
ADMIN_USERNAMES = {
|
||||
s.strip().lstrip("@").lower() for s in _raw_admin.split(",") if s.strip()
|
||||
}
|
||||
|
||||
MINI_APP_SKIP_AUTH = os.getenv("MINI_APP_SKIP_AUTH", "").strip() in ("1", "true", "yes")
|
||||
INIT_DATA_MAX_AGE_SECONDS = int(os.getenv("INIT_DATA_MAX_AGE_SECONDS", "0"))
|
||||
|
||||
_raw_cors = os.getenv("CORS_ORIGINS", "").strip()
|
||||
CORS_ORIGINS = (
|
||||
[_o.strip() for _o in _raw_cors.split(",") if _o.strip()]
|
||||
if _raw_cors and _raw_cors != "*"
|
||||
else ["*"]
|
||||
)
|
||||
|
||||
EXTERNAL_CALENDAR_ICS_URL = os.getenv("EXTERNAL_CALENDAR_ICS_URL", "").strip()
|
||||
DUTY_DISPLAY_TZ = (
|
||||
os.getenv("DUTY_DISPLAY_TZ", "Europe/Moscow").strip() or "Europe/Moscow"
|
||||
)
|
||||
|
||||
|
||||
def is_admin(username: str) -> bool:
|
||||
"""True if the given Telegram username (no @, any case) is in ADMIN_USERNAMES."""
|
||||
return (username or "").strip().lower() in ADMIN_USERNAMES
|
||||
|
||||
|
||||
def can_access_miniapp(username: str) -> bool:
|
||||
"""True if username is in ALLOWED_USERNAMES or ADMIN_USERNAMES."""
|
||||
u = (username or "").strip().lower()
|
||||
return u in ALLOWED_USERNAMES or u in ADMIN_USERNAMES
|
||||
|
||||
|
||||
def require_bot_token() -> None:
|
||||
"""Raise SystemExit with a clear message if BOT_TOKEN is not set. Call from entry point."""
|
||||
if not BOT_TOKEN:
|
||||
raise SystemExit(
|
||||
"BOT_TOKEN is not set. Copy .env.example to .env and set your token from @BotFather."
|
||||
)
|
||||
52
duty_teller/db/__init__.py
Normal file
52
duty_teller/db/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Database layer: SQLAlchemy models, Pydantic schemas, repository, init."""
|
||||
|
||||
from duty_teller.db.models import Base, User, Duty
|
||||
from duty_teller.db.schemas import (
|
||||
UserCreate,
|
||||
UserInDb,
|
||||
DutyCreate,
|
||||
DutyInDb,
|
||||
DutyWithUser,
|
||||
)
|
||||
from duty_teller.db.session import (
|
||||
get_engine,
|
||||
get_session_factory,
|
||||
get_session,
|
||||
session_scope,
|
||||
)
|
||||
from duty_teller.db.repository import (
|
||||
delete_duties_in_range,
|
||||
get_or_create_user,
|
||||
get_or_create_user_by_full_name,
|
||||
get_duties,
|
||||
insert_duty,
|
||||
set_user_phone,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"User",
|
||||
"Duty",
|
||||
"UserCreate",
|
||||
"UserInDb",
|
||||
"DutyCreate",
|
||||
"DutyInDb",
|
||||
"DutyWithUser",
|
||||
"get_engine",
|
||||
"get_session_factory",
|
||||
"get_session",
|
||||
"session_scope",
|
||||
"delete_duties_in_range",
|
||||
"get_or_create_user",
|
||||
"get_or_create_user_by_full_name",
|
||||
"get_duties",
|
||||
"insert_duty",
|
||||
"set_user_phone",
|
||||
"init_db",
|
||||
]
|
||||
|
||||
|
||||
def init_db(database_url: str) -> None:
|
||||
"""Create tables from metadata (Alembic migrations handle schema in production)."""
|
||||
engine = get_engine(database_url)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
51
duty_teller/db/models.py
Normal file
51
duty_teller/db/models.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""SQLAlchemy ORM models for users and duties."""
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, BigInteger, Text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Declarative base for all models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
telegram_user_id: Mapped[int | None] = mapped_column(
|
||||
BigInteger, unique=True, nullable=True
|
||||
)
|
||||
full_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
username: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
first_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
last_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
duties: Mapped[list["Duty"]] = relationship("Duty", back_populates="user")
|
||||
|
||||
|
||||
class Duty(Base):
|
||||
__tablename__ = "duties"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False
|
||||
)
|
||||
# UTC, ISO 8601 with Z suffix (e.g. 2025-01-15T09:00:00Z)
|
||||
start_at: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
end_at: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# duty | unavailable | vacation
|
||||
event_type: Mapped[str] = mapped_column(Text, nullable=False, server_default="duty")
|
||||
|
||||
user: Mapped["User"] = relationship("User", back_populates="duties")
|
||||
|
||||
|
||||
class GroupDutyPin(Base):
|
||||
"""Stores which message to update in each group for the pinned duty notice."""
|
||||
|
||||
__tablename__ = "group_duty_pins"
|
||||
|
||||
chat_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
message_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
213
duty_teller/db/repository.py
Normal file
213
duty_teller/db/repository.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Repository: get_or_create_user, get_duties, insert_duty, get_current_duty, group_duty_pins."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from duty_teller.db.models import User, Duty, GroupDutyPin
|
||||
|
||||
|
||||
def get_or_create_user(
|
||||
session: Session,
|
||||
telegram_user_id: int,
|
||||
full_name: str,
|
||||
username: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
) -> User:
|
||||
user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
|
||||
if user:
|
||||
user.full_name = full_name
|
||||
user.username = username
|
||||
user.first_name = first_name
|
||||
user.last_name = last_name
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
user = User(
|
||||
telegram_user_id=telegram_user_id,
|
||||
full_name=full_name,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
|
||||
"""Find user by exact full_name or create one with telegram_user_id=None (for duty-schedule import)."""
|
||||
user = session.query(User).filter(User.full_name == full_name).first()
|
||||
if user:
|
||||
return user
|
||||
user = User(
|
||||
telegram_user_id=None,
|
||||
full_name=full_name,
|
||||
username=None,
|
||||
first_name=None,
|
||||
last_name=None,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def delete_duties_in_range(
|
||||
session: Session,
|
||||
user_id: int,
|
||||
from_date: str,
|
||||
to_date: str,
|
||||
) -> int:
|
||||
"""Delete all duties of the user that overlap [from_date, to_date] (YYYY-MM-DD). Returns count deleted."""
|
||||
to_next = (
|
||||
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
||||
).strftime("%Y-%m-%d")
|
||||
q = session.query(Duty).filter(
|
||||
Duty.user_id == user_id,
|
||||
Duty.start_at < to_next,
|
||||
Duty.end_at >= from_date,
|
||||
)
|
||||
count = q.count()
|
||||
q.delete(synchronize_session=False)
|
||||
session.commit()
|
||||
return count
|
||||
|
||||
|
||||
def get_duties(
|
||||
session: Session,
|
||||
from_date: str,
|
||||
to_date: str,
|
||||
) -> list[tuple[Duty, str]]:
|
||||
"""Return list of (Duty, full_name) overlapping the given date range."""
|
||||
to_date_next = (
|
||||
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
||||
).strftime("%Y-%m-%d")
|
||||
q = (
|
||||
session.query(Duty, User.full_name)
|
||||
.join(User, Duty.user_id == User.id)
|
||||
.filter(Duty.start_at < to_date_next, Duty.end_at >= from_date)
|
||||
)
|
||||
return list(q.all())
|
||||
|
||||
|
||||
def insert_duty(
|
||||
session: Session,
|
||||
user_id: int,
|
||||
start_at: str,
|
||||
end_at: str,
|
||||
event_type: str = "duty",
|
||||
) -> Duty:
|
||||
"""Create a duty. start_at and end_at must be UTC, ISO 8601 with Z."""
|
||||
duty = Duty(
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
event_type=event_type,
|
||||
)
|
||||
session.add(duty)
|
||||
session.commit()
|
||||
session.refresh(duty)
|
||||
return duty
|
||||
|
||||
|
||||
def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None:
|
||||
"""Return the duty (and user) for which start_at <= at_utc < end_at, event_type='duty'."""
|
||||
from datetime import timezone
|
||||
|
||||
if at_utc.tzinfo is not None:
|
||||
at_utc = at_utc.astimezone(timezone.utc)
|
||||
now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
|
||||
row = (
|
||||
session.query(Duty, User)
|
||||
.join(User, Duty.user_id == User.id)
|
||||
.filter(
|
||||
Duty.event_type == "duty",
|
||||
Duty.start_at <= now_iso,
|
||||
Duty.end_at > now_iso,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return (row[0], row[1])
|
||||
|
||||
|
||||
def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None:
|
||||
"""Return the end_at of the current duty or of the next duty. Naive UTC."""
|
||||
from datetime import timezone
|
||||
|
||||
if after_utc.tzinfo is not None:
|
||||
after_utc = after_utc.astimezone(timezone.utc)
|
||||
after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
|
||||
current = (
|
||||
session.query(Duty)
|
||||
.filter(
|
||||
Duty.event_type == "duty",
|
||||
Duty.start_at <= after_iso,
|
||||
Duty.end_at > after_iso,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if current:
|
||||
return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace(
|
||||
tzinfo=None
|
||||
)
|
||||
next_duty = (
|
||||
session.query(Duty)
|
||||
.filter(Duty.event_type == "duty", Duty.start_at > after_iso)
|
||||
.order_by(Duty.start_at)
|
||||
.first()
|
||||
)
|
||||
if next_duty:
|
||||
return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace(
|
||||
tzinfo=None
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_group_duty_pin(session: Session, chat_id: int) -> GroupDutyPin | None:
|
||||
"""Get the pinned message record for a chat, if any."""
|
||||
return session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first()
|
||||
|
||||
|
||||
def save_group_duty_pin(
|
||||
session: Session, chat_id: int, message_id: int
|
||||
) -> GroupDutyPin:
|
||||
"""Save or update the pinned message for a chat."""
|
||||
pin = session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).first()
|
||||
if pin:
|
||||
pin.message_id = message_id
|
||||
else:
|
||||
pin = GroupDutyPin(chat_id=chat_id, message_id=message_id)
|
||||
session.add(pin)
|
||||
session.commit()
|
||||
session.refresh(pin)
|
||||
return pin
|
||||
|
||||
|
||||
def delete_group_duty_pin(session: Session, chat_id: int) -> None:
|
||||
"""Remove the pinned message record when the bot leaves the group."""
|
||||
session.query(GroupDutyPin).filter(GroupDutyPin.chat_id == chat_id).delete()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]:
|
||||
"""Return all chat_ids that have a pinned duty message (for restoring jobs on startup)."""
|
||||
rows = session.query(GroupDutyPin.chat_id).all()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
def set_user_phone(
|
||||
session: Session, telegram_user_id: int, phone: str | None
|
||||
) -> User | None:
|
||||
"""Set phone for user by telegram_user_id. Returns User or None if not found."""
|
||||
user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
|
||||
if not user:
|
||||
return None
|
||||
user.phone = phone
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
58
duty_teller/db/schemas.py
Normal file
58
duty_teller/db/schemas.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Pydantic schemas for API and validation."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
full_name: str
|
||||
username: str | None = None
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
telegram_user_id: int
|
||||
|
||||
|
||||
class UserInDb(UserBase):
|
||||
id: int
|
||||
telegram_user_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DutyBase(BaseModel):
|
||||
user_id: int
|
||||
start_at: str # UTC, ISO 8601 with Z
|
||||
end_at: str # UTC, ISO 8601 with Z
|
||||
|
||||
|
||||
class DutyCreate(DutyBase):
|
||||
pass
|
||||
|
||||
|
||||
class DutyInDb(DutyBase):
|
||||
id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DutyWithUser(DutyInDb):
|
||||
"""Duty with full_name and event_type for calendar display.
|
||||
|
||||
event_type: only these values are returned; unknown DB values are mapped to "duty" in the API.
|
||||
"""
|
||||
|
||||
full_name: str
|
||||
event_type: Literal["duty", "unavailable", "vacation"] = "duty"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
"""External calendar event (e.g. holiday) for a single day."""
|
||||
|
||||
date: str # YYYY-MM-DD
|
||||
summary: str
|
||||
57
duty_teller/db/session.py
Normal file
57
duty_teller/db/session.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""SQLAlchemy engine and session factory.
|
||||
|
||||
Engine and session factory are cached globally per process. Only one DATABASE_URL
|
||||
is effectively used for the process lifetime. Using a different URL later (e.g. in
|
||||
tests with in-memory SQLite) would still use the first engine. To use a different
|
||||
URL in tests, set env (e.g. DATABASE_URL) before the first import of this module, or
|
||||
clear _engine and _SessionLocal in test fixtures. Prefer session_scope() for all
|
||||
callers so sessions are always closed and rolled back on error.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope(database_url: str) -> Generator[Session, None, None]:
|
||||
"""Context manager: yields a session, rolls back on exception, closes on exit."""
|
||||
session = get_session(database_url)
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_engine(database_url: str):
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False}
|
||||
if "sqlite" in database_url
|
||||
else {},
|
||||
echo=False,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory(database_url: str) -> sessionmaker[Session]:
|
||||
global _SessionLocal
|
||||
if _SessionLocal is None:
|
||||
engine = get_engine(database_url)
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
def get_session(database_url: str) -> Session:
|
||||
return get_session_factory(database_url)()
|
||||
17
duty_teller/handlers/__init__.py
Normal file
17
duty_teller/handlers/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Expose a single register_handlers(app) that registers all handlers."""
|
||||
|
||||
from telegram.ext import Application
|
||||
|
||||
from . import commands, errors, group_duty_pin, import_duty_schedule
|
||||
|
||||
|
||||
def register_handlers(app: Application) -> None:
|
||||
app.add_handler(commands.start_handler)
|
||||
app.add_handler(commands.help_handler)
|
||||
app.add_handler(commands.set_phone_handler)
|
||||
app.add_handler(import_duty_schedule.import_duty_schedule_handler)
|
||||
app.add_handler(import_duty_schedule.handover_time_handler)
|
||||
app.add_handler(import_duty_schedule.duty_schedule_document_handler)
|
||||
app.add_handler(group_duty_pin.group_duty_pin_handler)
|
||||
app.add_handler(group_duty_pin.pin_duty_handler)
|
||||
app.add_error_handler(errors.error_handler)
|
||||
94
duty_teller/handlers/commands.py
Normal file
94
duty_teller/handlers/commands.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Command handlers: /start, /help; /start registers user."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import duty_teller.config as config
|
||||
from telegram import Update
|
||||
from telegram.ext import CommandHandler, ContextTypes
|
||||
|
||||
from duty_teller.db.session import session_scope
|
||||
from duty_teller.db.repository import get_or_create_user, set_user_phone
|
||||
from duty_teller.utils.user import build_full_name
|
||||
|
||||
|
||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.message:
|
||||
return
|
||||
user = update.effective_user
|
||||
if not user:
|
||||
return
|
||||
full_name = build_full_name(user.first_name, user.last_name)
|
||||
telegram_user_id = user.id
|
||||
username = user.username
|
||||
first_name = user.first_name
|
||||
last_name = user.last_name
|
||||
|
||||
def do_get_or_create() -> None:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
get_or_create_user(
|
||||
session,
|
||||
telegram_user_id=telegram_user_id,
|
||||
full_name=full_name,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
)
|
||||
|
||||
await asyncio.get_running_loop().run_in_executor(None, do_get_or_create)
|
||||
|
||||
text = "Привет! Я бот календаря дежурств. Используй /help для списка команд."
|
||||
await update.message.reply_text(text)
|
||||
|
||||
|
||||
async def set_phone(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
if update.effective_chat and update.effective_chat.type != "private":
|
||||
await update.message.reply_text("Команда /set_phone доступна только в личке.")
|
||||
return
|
||||
args = context.args or []
|
||||
phone = " ".join(args).strip() if args else None
|
||||
telegram_user_id = update.effective_user.id
|
||||
|
||||
def do_set_phone() -> str:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
full_name = build_full_name(
|
||||
update.effective_user.first_name, update.effective_user.last_name
|
||||
)
|
||||
get_or_create_user(
|
||||
session,
|
||||
telegram_user_id=telegram_user_id,
|
||||
full_name=full_name,
|
||||
username=update.effective_user.username,
|
||||
first_name=update.effective_user.first_name,
|
||||
last_name=update.effective_user.last_name,
|
||||
)
|
||||
user = set_user_phone(session, telegram_user_id, phone or None)
|
||||
if user is None:
|
||||
return "Ошибка сохранения."
|
||||
if phone:
|
||||
return f"Телефон сохранён: {phone}"
|
||||
return "Телефон очищен."
|
||||
|
||||
result = await asyncio.get_running_loop().run_in_executor(None, do_set_phone)
|
||||
await update.message.reply_text(result)
|
||||
|
||||
|
||||
async def help_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
lines = [
|
||||
"Доступные команды:",
|
||||
"/start — Начать",
|
||||
"/help — Показать эту справку",
|
||||
"/set_phone — Указать или очистить телефон для отображения в дежурстве",
|
||||
"/pin_duty — В группе: закрепить сообщение о дежурстве (нужны права админа у бота)",
|
||||
]
|
||||
if config.is_admin(update.effective_user.username or ""):
|
||||
lines.append("/import_duty_schedule — Импорт расписания дежурств (JSON)")
|
||||
await update.message.reply_text("\n".join(lines))
|
||||
|
||||
|
||||
start_handler = CommandHandler("start", start)
|
||||
help_handler = CommandHandler("help", help_cmd)
|
||||
set_phone_handler = CommandHandler("set_phone", set_phone)
|
||||
16
duty_teller/handlers/errors.py
Normal file
16
duty_teller/handlers/errors.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Global error handler: log exception and notify user."""
|
||||
|
||||
import logging
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def error_handler(
|
||||
update: Update | None, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
logger.exception("Exception while handling an update")
|
||||
if isinstance(update, Update) and update.effective_message:
|
||||
await update.effective_message.reply_text("Произошла ошибка. Попробуйте позже.")
|
||||
225
duty_teller/handlers/group_duty_pin.py
Normal file
225
duty_teller/handlers/group_duty_pin.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Pinned duty message in groups: handle bot add/remove, schedule updates at shift end."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import duty_teller.config as config
|
||||
from telegram import Update
|
||||
from telegram.constants import ChatMemberStatus
|
||||
from telegram.error import BadRequest, Forbidden
|
||||
from telegram.ext import ChatMemberHandler, CommandHandler, ContextTypes
|
||||
|
||||
from duty_teller.db.session import session_scope
|
||||
from duty_teller.services.group_duty_pin_service import (
|
||||
get_duty_message_text,
|
||||
get_next_shift_end_utc,
|
||||
save_pin,
|
||||
delete_pin,
|
||||
get_message_id,
|
||||
get_all_pin_chat_ids,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JOB_NAME_PREFIX = "duty_pin_"
|
||||
RETRY_WHEN_NO_DUTY_MINUTES = 15
|
||||
|
||||
|
||||
def _get_duty_message_text_sync() -> str:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return get_duty_message_text(session, config.DUTY_DISPLAY_TZ)
|
||||
|
||||
|
||||
def _get_next_shift_end_sync():
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return get_next_shift_end_utc(session)
|
||||
|
||||
|
||||
def _sync_save_pin(chat_id: int, message_id: int) -> None:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
save_pin(session, chat_id, message_id)
|
||||
|
||||
|
||||
def _sync_delete_pin(chat_id: int) -> None:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
delete_pin(session, chat_id)
|
||||
|
||||
|
||||
def _sync_get_message_id(chat_id: int) -> int | None:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return get_message_id(session, chat_id)
|
||||
|
||||
|
||||
async def _schedule_next_update(
|
||||
application, chat_id: int, when_utc: datetime | None
|
||||
) -> None:
|
||||
job_queue = application.job_queue
|
||||
if job_queue is None:
|
||||
logger.warning("Job queue not available, cannot schedule pin update")
|
||||
return
|
||||
name = f"{JOB_NAME_PREFIX}{chat_id}"
|
||||
for job in job_queue.get_jobs_by_name(name):
|
||||
job.schedule_removal()
|
||||
if when_utc is not None:
|
||||
now_utc = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
delay = when_utc - now_utc
|
||||
if delay.total_seconds() < 1:
|
||||
delay = 1
|
||||
job_queue.run_once(
|
||||
update_group_pin,
|
||||
when=delay,
|
||||
data={"chat_id": chat_id},
|
||||
name=name,
|
||||
)
|
||||
logger.info("Scheduled pin update for chat_id=%s at %s", chat_id, when_utc)
|
||||
else:
|
||||
from datetime import timedelta
|
||||
|
||||
job_queue.run_once(
|
||||
update_group_pin,
|
||||
when=timedelta(minutes=RETRY_WHEN_NO_DUTY_MINUTES),
|
||||
data={"chat_id": chat_id},
|
||||
name=name,
|
||||
)
|
||||
logger.info(
|
||||
"No next shift for chat_id=%s; scheduled retry in %s min",
|
||||
chat_id,
|
||||
RETRY_WHEN_NO_DUTY_MINUTES,
|
||||
)
|
||||
|
||||
|
||||
async def update_group_pin(context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
chat_id = context.job.data.get("chat_id")
|
||||
if chat_id is None:
|
||||
return
|
||||
loop = asyncio.get_running_loop()
|
||||
message_id = await loop.run_in_executor(None, _sync_get_message_id, chat_id)
|
||||
if message_id is None:
|
||||
logger.info("No pin record for chat_id=%s, skipping update", chat_id)
|
||||
return
|
||||
text = await loop.run_in_executor(None, _get_duty_message_text_sync)
|
||||
try:
|
||||
await context.bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
)
|
||||
except (BadRequest, Forbidden) as e:
|
||||
logger.warning("Failed to edit pinned message chat_id=%s: %s", chat_id, e)
|
||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
||||
await _schedule_next_update(context.application, chat_id, next_end)
|
||||
|
||||
|
||||
async def my_chat_member_handler(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
if not update.my_chat_member or not update.effective_user:
|
||||
return
|
||||
old = update.my_chat_member.old_chat_member
|
||||
new = update.my_chat_member.new_chat_member
|
||||
chat = update.effective_chat
|
||||
if not chat or chat.type not in ("group", "supergroup"):
|
||||
return
|
||||
if new.user.id != context.bot.id:
|
||||
return
|
||||
chat_id = chat.id
|
||||
|
||||
if new.status in (
|
||||
ChatMemberStatus.MEMBER,
|
||||
ChatMemberStatus.ADMINISTRATOR,
|
||||
) and old.status in (
|
||||
ChatMemberStatus.LEFT,
|
||||
ChatMemberStatus.BANNED,
|
||||
):
|
||||
loop = asyncio.get_running_loop()
|
||||
text = await loop.run_in_executor(None, _get_duty_message_text_sync)
|
||||
try:
|
||||
msg = await context.bot.send_message(chat_id=chat_id, text=text)
|
||||
except (BadRequest, Forbidden) as e:
|
||||
logger.warning("Failed to send duty message in chat_id=%s: %s", chat_id, e)
|
||||
return
|
||||
pinned = False
|
||||
try:
|
||||
await context.bot.pin_chat_message(
|
||||
chat_id=chat_id,
|
||||
message_id=msg.message_id,
|
||||
disable_notification=True,
|
||||
)
|
||||
pinned = True
|
||||
except (BadRequest, Forbidden) as e:
|
||||
logger.warning("Failed to pin message in chat_id=%s: %s", chat_id, e)
|
||||
await loop.run_in_executor(None, _sync_save_pin, chat_id, msg.message_id)
|
||||
if not pinned:
|
||||
try:
|
||||
await context.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text="Сообщение о дежурстве отправлено, но закрепить его не удалось. "
|
||||
"Сделайте бота администратором с правом «Закреплять сообщения» (Pin messages), "
|
||||
"затем отправьте в чат команду /pin_duty — текущее сообщение будет закреплено.",
|
||||
)
|
||||
except (BadRequest, Forbidden):
|
||||
pass
|
||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
||||
await _schedule_next_update(context.application, chat_id, next_end)
|
||||
return
|
||||
|
||||
if new.status in (ChatMemberStatus.LEFT, ChatMemberStatus.BANNED):
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
None, _sync_delete_pin, chat_id
|
||||
)
|
||||
name = f"{JOB_NAME_PREFIX}{chat_id}"
|
||||
if context.application.job_queue:
|
||||
for job in context.application.job_queue.get_jobs_by_name(name):
|
||||
job.schedule_removal()
|
||||
logger.info("Bot left chat_id=%s, removed pin record and jobs", chat_id)
|
||||
|
||||
|
||||
def _get_all_pin_chat_ids_sync() -> list[int]:
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return get_all_pin_chat_ids(session)
|
||||
|
||||
|
||||
async def restore_group_pin_jobs(application) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
chat_ids = await loop.run_in_executor(None, _get_all_pin_chat_ids_sync)
|
||||
for chat_id in chat_ids:
|
||||
next_end = await loop.run_in_executor(None, _get_next_shift_end_sync)
|
||||
await _schedule_next_update(application, chat_id, next_end)
|
||||
logger.info("Restored %s group pin jobs", len(chat_ids))
|
||||
|
||||
|
||||
async def pin_duty_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.message or not update.effective_chat:
|
||||
return
|
||||
chat = update.effective_chat
|
||||
if chat.type not in ("group", "supergroup"):
|
||||
await update.message.reply_text("Команда /pin_duty работает только в группах.")
|
||||
return
|
||||
chat_id = chat.id
|
||||
loop = asyncio.get_running_loop()
|
||||
message_id = await loop.run_in_executor(None, _sync_get_message_id, chat_id)
|
||||
if message_id is None:
|
||||
await update.message.reply_text(
|
||||
"В этом чате ещё нет сообщения о дежурстве. Добавьте бота в группу — оно создастся автоматически."
|
||||
)
|
||||
return
|
||||
try:
|
||||
await context.bot.pin_chat_message(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
disable_notification=True,
|
||||
)
|
||||
await update.message.reply_text("Сообщение о дежурстве закреплено.")
|
||||
except (BadRequest, Forbidden) as e:
|
||||
logger.warning("pin_duty failed chat_id=%s: %s", chat_id, e)
|
||||
await update.message.reply_text(
|
||||
"Не удалось закрепить. Убедитесь, что бот — администратор с правом «Закреплять сообщения»."
|
||||
)
|
||||
|
||||
|
||||
group_duty_pin_handler = ChatMemberHandler(
|
||||
my_chat_member_handler,
|
||||
ChatMemberHandler.MY_CHAT_MEMBER,
|
||||
)
|
||||
pin_duty_handler = CommandHandler("pin_duty", pin_duty_cmd)
|
||||
119
duty_teller/handlers/import_duty_schedule.py
Normal file
119
duty_teller/handlers/import_duty_schedule.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Import duty-schedule: /import_duty_schedule (admin only). Two steps: handover time -> JSON file."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import duty_teller.config as config
|
||||
from telegram import Update
|
||||
from telegram.ext import CommandHandler, ContextTypes, MessageHandler, filters
|
||||
|
||||
from duty_teller.db.session import session_scope
|
||||
from duty_teller.importers.duty_schedule import (
|
||||
DutyScheduleParseError,
|
||||
parse_duty_schedule,
|
||||
)
|
||||
from duty_teller.services.import_service import run_import
|
||||
from duty_teller.utils.handover import parse_handover_time
|
||||
|
||||
|
||||
async def import_duty_schedule_cmd(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
if not config.is_admin(update.effective_user.username or ""):
|
||||
await update.message.reply_text("Доступ только для администраторов.")
|
||||
return
|
||||
context.user_data["awaiting_handover_time"] = True
|
||||
await update.message.reply_text(
|
||||
"Укажите время пересменки в формате ЧЧ:ММ и часовой пояс, "
|
||||
"например 09:00 Europe/Moscow или 06:00 UTC."
|
||||
)
|
||||
|
||||
|
||||
async def handle_handover_time_text(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
if not update.message or not update.effective_user or not update.message.text:
|
||||
return
|
||||
if not context.user_data.get("awaiting_handover_time"):
|
||||
return
|
||||
if not config.is_admin(update.effective_user.username or ""):
|
||||
return
|
||||
text = update.message.text.strip()
|
||||
parsed = parse_handover_time(text)
|
||||
if parsed is None:
|
||||
await update.message.reply_text(
|
||||
"Не удалось разобрать время. Укажите, например: 09:00 Europe/Moscow"
|
||||
)
|
||||
return
|
||||
hour_utc, minute_utc = parsed
|
||||
context.user_data["handover_utc_time"] = (hour_utc, minute_utc)
|
||||
context.user_data["awaiting_handover_time"] = False
|
||||
context.user_data["awaiting_duty_schedule_file"] = True
|
||||
await update.message.reply_text("Отправьте файл в формате duty-schedule (JSON).")
|
||||
|
||||
|
||||
async def handle_duty_schedule_document(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
if not update.message or not update.message.document or not update.effective_user:
|
||||
return
|
||||
if not context.user_data.get("awaiting_duty_schedule_file"):
|
||||
return
|
||||
handover = context.user_data.get("handover_utc_time")
|
||||
if not handover or not config.is_admin(update.effective_user.username or ""):
|
||||
return
|
||||
if not (update.message.document.file_name or "").lower().endswith(".json"):
|
||||
await update.message.reply_text("Нужен файл с расширением .json")
|
||||
return
|
||||
|
||||
hour_utc, minute_utc = handover
|
||||
file_id = update.message.document.file_id
|
||||
|
||||
file = await context.bot.get_file(file_id)
|
||||
raw = bytes(await file.download_as_bytearray())
|
||||
try:
|
||||
result = parse_duty_schedule(raw)
|
||||
except DutyScheduleParseError as e:
|
||||
context.user_data.pop("awaiting_duty_schedule_file", None)
|
||||
context.user_data.pop("handover_utc_time", None)
|
||||
await update.message.reply_text(f"Ошибка разбора файла: {e}")
|
||||
return
|
||||
|
||||
def run_import_with_scope():
|
||||
with session_scope(config.DATABASE_URL) as session:
|
||||
return run_import(session, result, hour_utc, minute_utc)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
num_users, num_duty, num_unavailable, num_vacation = await loop.run_in_executor(
|
||||
None, run_import_with_scope
|
||||
)
|
||||
except Exception as e:
|
||||
await update.message.reply_text(f"Ошибка импорта: {e}")
|
||||
else:
|
||||
total = num_duty + num_unavailable + num_vacation
|
||||
parts = [f"{num_users} пользователей", f"{num_duty} дежурств"]
|
||||
if num_unavailable:
|
||||
parts.append(f"{num_unavailable} недоступностей")
|
||||
if num_vacation:
|
||||
parts.append(f"{num_vacation} отпусков")
|
||||
await update.message.reply_text(
|
||||
"Импорт выполнен: " + ", ".join(parts) + f" (всего {total} событий)."
|
||||
)
|
||||
finally:
|
||||
context.user_data.pop("awaiting_duty_schedule_file", None)
|
||||
context.user_data.pop("handover_utc_time", None)
|
||||
|
||||
|
||||
import_duty_schedule_handler = CommandHandler(
|
||||
"import_duty_schedule", import_duty_schedule_cmd
|
||||
)
|
||||
handover_time_handler = MessageHandler(
|
||||
filters.TEXT & ~filters.COMMAND,
|
||||
handle_handover_time_text,
|
||||
)
|
||||
duty_schedule_document_handler = MessageHandler(
|
||||
filters.Document.FileExtension("json"),
|
||||
handle_duty_schedule_document,
|
||||
)
|
||||
1
duty_teller/importers/__init__.py
Normal file
1
duty_teller/importers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Importers for duty data (e.g. duty-schedule JSON)."""
|
||||
114
duty_teller/importers/duty_schedule.py
Normal file
114
duty_teller/importers/duty_schedule.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Parser for duty-schedule JSON format. No DB access."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
|
||||
# Символы дежурства в ячейке duty (CSV с разделителем ;)
|
||||
DUTY_MARKERS = frozenset({"б", "Б", "в", "В"})
|
||||
UNAVAILABLE_MARKER = "Н"
|
||||
VACATION_MARKER = "О"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DutyScheduleEntry:
|
||||
"""One person's schedule: full_name and three lists of dates by event type."""
|
||||
|
||||
full_name: str
|
||||
duty_dates: list[date]
|
||||
unavailable_dates: list[date]
|
||||
vacation_dates: list[date]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DutyScheduleResult:
|
||||
"""Parsed duty schedule: start_date, end_date, and per-person entries."""
|
||||
|
||||
start_date: date
|
||||
end_date: date
|
||||
entries: list[DutyScheduleEntry]
|
||||
|
||||
|
||||
class DutyScheduleParseError(Exception):
|
||||
"""Invalid or missing fields in duty-schedule JSON."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def parse_duty_schedule(raw_bytes: bytes) -> DutyScheduleResult:
|
||||
"""Parse duty-schedule JSON. Returns start_date, end_date, and list of DutyScheduleEntry.
|
||||
|
||||
- meta.start_date (YYYY-MM-DD) and schedule (array) required.
|
||||
- meta.weeks optional; number of days from max duty string length (split by ';').
|
||||
- For each schedule item: name (required), duty = CSV with ';'; index i = start_date + i days.
|
||||
- Cell value after strip: в/В/б/Б => duty, Н => unavailable, О => vacation; rest ignored.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(raw_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise DutyScheduleParseError(f"Invalid JSON or encoding: {e}") from e
|
||||
|
||||
meta = data.get("meta")
|
||||
if not meta or not isinstance(meta, dict):
|
||||
raise DutyScheduleParseError("Missing or invalid 'meta'")
|
||||
|
||||
start_str = meta.get("start_date")
|
||||
if not start_str or not isinstance(start_str, str):
|
||||
raise DutyScheduleParseError("Missing or invalid meta.start_date")
|
||||
try:
|
||||
start_date = date.fromisoformat(start_str.strip())
|
||||
except ValueError as e:
|
||||
raise DutyScheduleParseError(f"Invalid meta.start_date: {start_str}") from e
|
||||
|
||||
schedule = data.get("schedule")
|
||||
if not isinstance(schedule, list):
|
||||
raise DutyScheduleParseError("Missing or invalid 'schedule' (must be array)")
|
||||
|
||||
max_days = 0
|
||||
entries: list[DutyScheduleEntry] = []
|
||||
|
||||
for row in schedule:
|
||||
if not isinstance(row, dict):
|
||||
raise DutyScheduleParseError("schedule item must be an object")
|
||||
name = row.get("name")
|
||||
if name is None or not isinstance(name, str):
|
||||
raise DutyScheduleParseError("schedule item must have 'name' (string)")
|
||||
full_name = name.strip()
|
||||
if not full_name:
|
||||
raise DutyScheduleParseError("schedule item 'name' cannot be empty")
|
||||
|
||||
duty_str = row.get("duty")
|
||||
if duty_str is None:
|
||||
duty_str = ""
|
||||
if not isinstance(duty_str, str):
|
||||
raise DutyScheduleParseError("schedule item 'duty' must be string")
|
||||
|
||||
cells = [c.strip() for c in duty_str.split(";")]
|
||||
max_days = max(max_days, len(cells))
|
||||
|
||||
duty_dates: list[date] = []
|
||||
unavailable_dates: list[date] = []
|
||||
vacation_dates: list[date] = []
|
||||
for i, cell in enumerate(cells):
|
||||
d = start_date + timedelta(days=i)
|
||||
if cell in DUTY_MARKERS:
|
||||
duty_dates.append(d)
|
||||
elif cell == UNAVAILABLE_MARKER:
|
||||
unavailable_dates.append(d)
|
||||
elif cell == VACATION_MARKER:
|
||||
vacation_dates.append(d)
|
||||
entries.append(
|
||||
DutyScheduleEntry(
|
||||
full_name=full_name,
|
||||
duty_dates=duty_dates,
|
||||
unavailable_dates=unavailable_dates,
|
||||
vacation_dates=vacation_dates,
|
||||
)
|
||||
)
|
||||
|
||||
if max_days == 0:
|
||||
end_date = start_date
|
||||
else:
|
||||
end_date = start_date + timedelta(days=max_days - 1)
|
||||
|
||||
return DutyScheduleResult(start_date=start_date, end_date=end_date, entries=entries)
|
||||
84
duty_teller/run.py
Normal file
84
duty_teller/run.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Application entry point: build bot Application, run HTTP server + polling."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import urllib.request
|
||||
|
||||
from telegram.ext import ApplicationBuilder
|
||||
|
||||
from duty_teller import config
|
||||
from duty_teller.config import require_bot_token
|
||||
from duty_teller.handlers import group_duty_pin, register_handlers
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _set_default_menu_button_webapp() -> None:
|
||||
if not (config.MINI_APP_BASE_URL and config.BOT_TOKEN):
|
||||
return
|
||||
menu_url = (config.MINI_APP_BASE_URL.rstrip("/") + "/app/").strip()
|
||||
if not menu_url.startswith("https://"):
|
||||
return
|
||||
payload = {
|
||||
"menu_button": {
|
||||
"type": "web_app",
|
||||
"text": "Календарь",
|
||||
"web_app": {"url": menu_url},
|
||||
}
|
||||
}
|
||||
req = urllib.request.Request(
|
||||
f"https://api.telegram.org/bot{config.BOT_TOKEN}/setChatMenuButton",
|
||||
data=json.dumps(payload).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
if resp.status == 200:
|
||||
logger.info("Default menu button set to Web App: %s", menu_url)
|
||||
else:
|
||||
logger.warning("setChatMenuButton returned %s", resp.status)
|
||||
except Exception as e:
|
||||
logger.warning("Could not set menu button: %s", e)
|
||||
|
||||
|
||||
def _run_uvicorn(web_app, port: int) -> None:
|
||||
import uvicorn
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server = uvicorn.Server(
|
||||
uvicorn.Config(web_app, host="0.0.0.0", port=port, log_level="info"),
|
||||
)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Build the bot and FastAPI, start uvicorn in a thread, run polling."""
|
||||
require_bot_token()
|
||||
# _set_default_menu_button_webapp()
|
||||
app = (
|
||||
ApplicationBuilder()
|
||||
.token(config.BOT_TOKEN)
|
||||
.post_init(group_duty_pin.restore_group_pin_jobs)
|
||||
.build()
|
||||
)
|
||||
register_handlers(app)
|
||||
|
||||
from duty_teller.api.app import app as web_app
|
||||
|
||||
t = threading.Thread(
|
||||
target=_run_uvicorn,
|
||||
args=(web_app, config.HTTP_PORT),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
|
||||
logger.info("Bot starting (polling)... HTTP API on port %s", config.HTTP_PORT)
|
||||
app.run_polling(allowed_updates=["message", "my_chat_member"])
|
||||
23
duty_teller/services/__init__.py
Normal file
23
duty_teller/services/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Service layer: business logic and orchestration."""
|
||||
|
||||
from duty_teller.services.group_duty_pin_service import (
|
||||
format_duty_message,
|
||||
get_duty_message_text,
|
||||
get_next_shift_end_utc,
|
||||
save_pin,
|
||||
delete_pin,
|
||||
get_message_id,
|
||||
get_all_pin_chat_ids,
|
||||
)
|
||||
from duty_teller.services.import_service import run_import
|
||||
|
||||
__all__ = [
|
||||
"format_duty_message",
|
||||
"get_duty_message_text",
|
||||
"get_next_shift_end_utc",
|
||||
"save_pin",
|
||||
"delete_pin",
|
||||
"get_message_id",
|
||||
"get_all_pin_chat_ids",
|
||||
"run_import",
|
||||
]
|
||||
86
duty_teller/services/group_duty_pin_service.py
Normal file
86
duty_teller/services/group_duty_pin_service.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Group duty pin: current duty message text, next shift end, pin CRUD. All accept session."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from duty_teller.db.repository import (
|
||||
get_current_duty,
|
||||
get_next_shift_end,
|
||||
get_group_duty_pin,
|
||||
save_group_duty_pin,
|
||||
delete_group_duty_pin,
|
||||
get_all_group_duty_pin_chat_ids,
|
||||
)
|
||||
|
||||
|
||||
def format_duty_message(duty, user, tz_name: str) -> str:
|
||||
"""Build the text for the pinned message. duty, user may be None."""
|
||||
if duty is None or user is None:
|
||||
return "Сейчас дежурства нет."
|
||||
try:
|
||||
tz = ZoneInfo(tz_name)
|
||||
except Exception:
|
||||
tz = ZoneInfo("Europe/Moscow")
|
||||
tz_name = "Europe/Moscow"
|
||||
start_dt = datetime.fromisoformat(duty.start_at.replace("Z", "+00:00"))
|
||||
end_dt = datetime.fromisoformat(duty.end_at.replace("Z", "+00:00"))
|
||||
start_local = start_dt.astimezone(tz)
|
||||
end_local = end_dt.astimezone(tz)
|
||||
offset_sec = (
|
||||
start_local.utcoffset().total_seconds() if start_local.utcoffset() else 0
|
||||
)
|
||||
sign = "+" if offset_sec >= 0 else "-"
|
||||
h, r = divmod(abs(int(offset_sec)), 3600)
|
||||
m = r // 60
|
||||
tz_hint = f"UTC{sign}{h:d}:{m:02d}, {tz_name}"
|
||||
time_range = (
|
||||
f"{start_local.strftime('%d.%m.%Y %H:%M')} — "
|
||||
f"{end_local.strftime('%d.%m.%Y %H:%M')} ({tz_hint})"
|
||||
)
|
||||
lines = [
|
||||
f"🕐 Дежурство: {time_range}",
|
||||
f"👤 {user.full_name}",
|
||||
]
|
||||
if user.phone:
|
||||
lines.append(f"📞 {user.phone}")
|
||||
if user.username:
|
||||
lines.append(f"@{user.username}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_duty_message_text(session: Session, tz_name: str) -> str:
|
||||
"""Get current duty from DB and return formatted message."""
|
||||
now = datetime.now(timezone.utc)
|
||||
result = get_current_duty(session, now)
|
||||
if result is None:
|
||||
return "Сейчас дежурства нет."
|
||||
duty, user = result
|
||||
return format_duty_message(duty, user, tz_name)
|
||||
|
||||
|
||||
def get_next_shift_end_utc(session: Session) -> datetime | None:
|
||||
"""Return next shift end as naive UTC datetime for job scheduling."""
|
||||
return get_next_shift_end(session, datetime.now(timezone.utc))
|
||||
|
||||
|
||||
def save_pin(session: Session, chat_id: int, message_id: int) -> None:
|
||||
"""Save or update the pinned message record for a chat."""
|
||||
save_group_duty_pin(session, chat_id, message_id)
|
||||
|
||||
|
||||
def delete_pin(session: Session, chat_id: int) -> None:
|
||||
"""Remove the pinned message record when the bot leaves the group."""
|
||||
delete_group_duty_pin(session, chat_id)
|
||||
|
||||
|
||||
def get_message_id(session: Session, chat_id: int) -> int | None:
|
||||
"""Return message_id for the pin in this chat, or None."""
|
||||
pin = get_group_duty_pin(session, chat_id)
|
||||
return pin.message_id if pin else None
|
||||
|
||||
|
||||
def get_all_pin_chat_ids(session: Session) -> list[int]:
|
||||
"""Return all chat_ids that have a pinned duty message (for restoring jobs on startup)."""
|
||||
return get_all_group_duty_pin_chat_ids(session)
|
||||
70
duty_teller/services/import_service.py
Normal file
70
duty_teller/services/import_service.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Import duty schedule: delete range, insert duties/unavailable/vacation. Accepts session."""
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from duty_teller.db.repository import (
|
||||
get_or_create_user_by_full_name,
|
||||
delete_duties_in_range,
|
||||
insert_duty,
|
||||
)
|
||||
from duty_teller.importers.duty_schedule import DutyScheduleResult
|
||||
from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso
|
||||
|
||||
|
||||
def _consecutive_date_ranges(dates: list[date]) -> list[tuple[date, date]]:
|
||||
"""Sort dates and merge consecutive ones into (first, last) ranges. Empty list -> []."""
|
||||
if not dates:
|
||||
return []
|
||||
sorted_dates = sorted(set(dates))
|
||||
ranges: list[tuple[date, date]] = []
|
||||
start_d = end_d = sorted_dates[0]
|
||||
for d in sorted_dates[1:]:
|
||||
if (d - end_d).days == 1:
|
||||
end_d = d
|
||||
else:
|
||||
ranges.append((start_d, end_d))
|
||||
start_d = end_d = d
|
||||
ranges.append((start_d, end_d))
|
||||
return ranges
|
||||
|
||||
|
||||
def run_import(
|
||||
session: Session,
|
||||
result: DutyScheduleResult,
|
||||
hour_utc: int,
|
||||
minute_utc: int,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""Run import: delete range per user, insert duty/unavailable/vacation. Returns (num_users, num_duty, num_unavailable, num_vacation)."""
|
||||
from_date_str = result.start_date.isoformat()
|
||||
to_date_str = result.end_date.isoformat()
|
||||
num_duty = num_unavailable = num_vacation = 0
|
||||
for entry in result.entries:
|
||||
user = get_or_create_user_by_full_name(session, entry.full_name)
|
||||
delete_duties_in_range(session, user.id, from_date_str, to_date_str)
|
||||
for d in entry.duty_dates:
|
||||
start_at = duty_to_iso(d, hour_utc, minute_utc)
|
||||
d_next = d + timedelta(days=1)
|
||||
end_at = duty_to_iso(d_next, hour_utc, minute_utc)
|
||||
insert_duty(session, user.id, start_at, end_at, event_type="duty")
|
||||
num_duty += 1
|
||||
for d in entry.unavailable_dates:
|
||||
insert_duty(
|
||||
session,
|
||||
user.id,
|
||||
day_start_iso(d),
|
||||
day_end_iso(d),
|
||||
event_type="unavailable",
|
||||
)
|
||||
num_unavailable += 1
|
||||
for start_d, end_d in _consecutive_date_ranges(entry.vacation_dates):
|
||||
insert_duty(
|
||||
session,
|
||||
user.id,
|
||||
day_start_iso(start_d),
|
||||
day_end_iso(end_d),
|
||||
event_type="vacation",
|
||||
)
|
||||
num_vacation += 1
|
||||
return (len(result.entries), num_duty, num_unavailable, num_vacation)
|
||||
21
duty_teller/utils/__init__.py
Normal file
21
duty_teller/utils/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Shared utilities: date/ISO helpers, user display names, handover time parsing."""
|
||||
|
||||
from duty_teller.utils.dates import (
|
||||
day_end_iso,
|
||||
day_start_iso,
|
||||
duty_to_iso,
|
||||
parse_iso_date,
|
||||
validate_date_range,
|
||||
)
|
||||
from duty_teller.utils.user import build_full_name
|
||||
from duty_teller.utils.handover import parse_handover_time
|
||||
|
||||
__all__ = [
|
||||
"day_start_iso",
|
||||
"day_end_iso",
|
||||
"duty_to_iso",
|
||||
"parse_iso_date",
|
||||
"validate_date_range",
|
||||
"build_full_name",
|
||||
"parse_handover_time",
|
||||
]
|
||||
41
duty_teller/utils/dates.py
Normal file
41
duty_teller/utils/dates.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Date and ISO helpers for duty ranges and API validation."""
|
||||
|
||||
import re
|
||||
from datetime import date, datetime, timezone
|
||||
|
||||
|
||||
def day_start_iso(d: date) -> str:
|
||||
"""ISO 8601 start of calendar day UTC: YYYY-MM-DDT00:00:00Z."""
|
||||
return d.isoformat() + "T00:00:00Z"
|
||||
|
||||
|
||||
def day_end_iso(d: date) -> str:
|
||||
"""ISO 8601 end of calendar day UTC: YYYY-MM-DDT23:59:59Z."""
|
||||
return d.isoformat() + "T23:59:59Z"
|
||||
|
||||
|
||||
def duty_to_iso(d: date, hour_utc: int, minute_utc: int) -> str:
|
||||
"""ISO 8601 with Z for start of duty on date d at given UTC time."""
|
||||
dt = datetime(d.year, d.month, d.day, hour_utc, minute_utc, 0, tzinfo=timezone.utc)
|
||||
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
_ISO_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
||||
|
||||
|
||||
def parse_iso_date(s: str) -> date | None:
|
||||
"""Parse YYYY-MM-DD string to date. Returns None if invalid."""
|
||||
if not s or not _ISO_DATE_RE.match(s.strip()):
|
||||
return None
|
||||
try:
|
||||
return date.fromisoformat(s.strip())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def validate_date_range(from_date: str, to_date: str) -> None:
|
||||
"""Validate from_date and to_date are YYYY-MM-DD and from_date <= to_date. Raises ValueError if invalid."""
|
||||
if not _ISO_DATE_RE.match(from_date or "") or not _ISO_DATE_RE.match(to_date or ""):
|
||||
raise ValueError("Параметры from и to должны быть в формате YYYY-MM-DD")
|
||||
if from_date > to_date:
|
||||
raise ValueError("Дата from не должна быть позже to")
|
||||
37
duty_teller/utils/handover.py
Normal file
37
duty_teller/utils/handover.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Handover time parsing for duty schedule import."""
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# HH:MM or HH:MM:SS, optional space + timezone (IANA or "UTC")
|
||||
HANDOVER_TIME_RE = re.compile(
|
||||
r"^\s*(\d{1,2}):(\d{2})(?::(\d{2}))?\s*(?:\s+(\S+))?\s*$", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
def parse_handover_time(text: str) -> tuple[int, int] | None:
|
||||
"""Parse handover time string to (hour_utc, minute_utc). Returns None on failure."""
|
||||
m = HANDOVER_TIME_RE.match(text)
|
||||
if not m:
|
||||
return None
|
||||
hour = int(m.group(1))
|
||||
minute = int(m.group(2))
|
||||
# second = m.group(3) ignored
|
||||
tz_str = (m.group(4) or "").strip()
|
||||
if not tz_str or tz_str.upper() == "UTC":
|
||||
return (hour % 24, minute)
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
except ImportError:
|
||||
try:
|
||||
from backports.zoneinfo import ZoneInfo # type: ignore
|
||||
except ImportError:
|
||||
return None
|
||||
try:
|
||||
tz = ZoneInfo(tz_str)
|
||||
except Exception:
|
||||
return None
|
||||
# Build datetime in that tz and convert to UTC
|
||||
dt = datetime(2000, 1, 1, hour, minute, 0, tzinfo=tz)
|
||||
utc = dt.astimezone(timezone.utc)
|
||||
return (utc.hour, utc.minute)
|
||||
8
duty_teller/utils/user.py
Normal file
8
duty_teller/utils/user.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""User display name helpers."""
|
||||
|
||||
|
||||
def build_full_name(first_name: str | None, last_name: str | None) -> str:
|
||||
"""Build display full name from first and last name. Returns 'User' if both empty."""
|
||||
parts = [first_name or "", last_name or ""]
|
||||
full = " ".join(filter(None, parts)).strip()
|
||||
return full or "User"
|
||||
Reference in New Issue
Block a user