feat: enhance error handling and configuration validation
Some checks failed
CI / lint-and-test (push) Failing after 27s

- Added a global exception handler to log unhandled exceptions and return a generic 500 JSON response without exposing details to the client.
- Updated the configuration to validate the `DATABASE_URL` format, ensuring it starts with `sqlite://` or `postgresql://`, and log warnings for invalid formats.
- Introduced safe parsing for numeric environment variables (`HTTP_PORT`, `INIT_DATA_MAX_AGE_SECONDS`) with defaults on invalid values, including logging warnings for out-of-range values.
- Enhanced the duty schedule parser to enforce limits on the number of schedule rows and the length of full names and duty strings, raising appropriate errors when exceeded.
- Updated internationalization messages to include generic error responses for import failures and parsing issues, improving user experience.
- Added unit tests to verify the new error handling and configuration validation behaviors.
This commit is contained in:
2026-03-02 23:36:03 +03:00
parent 43386b15fa
commit 7ffa727832
20 changed files with 451 additions and 70 deletions

View File

@@ -8,7 +8,7 @@ import duty_teller.config as config
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from fastapi.responses import JSONResponse, Response
from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
@@ -42,6 +42,16 @@ def _is_valid_calendar_token(token: str) -> bool:
app = FastAPI(title="Duty Teller API")
@app.exception_handler(Exception)
def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Log unhandled exceptions and return 500 without exposing details to the client."""
log.exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content={"detail": "Internal server error"},
)
@app.get("/health", summary="Health check")
def health() -> dict:
"""Return 200 when the app is up. Used by Docker HEALTHCHECK."""
@@ -106,6 +116,18 @@ class NoCacheStaticMiddleware:
app.add_middleware(NoCacheStaticMiddleware)
# Allowed values for config.js to prevent script injection.
_VALID_LANGS = frozenset({"en", "ru"})
_VALID_LOG_LEVELS = frozenset({"debug", "info", "warning", "error"})
def _safe_js_string(value: str, allowed: frozenset[str], default: str) -> str:
"""Return value if it is in allowed set, else default. Prevents injection in config.js."""
if value in allowed:
return value
return default
@app.get(
"/app/config.js",
summary="Mini App config (language, log level)",
@@ -115,8 +137,8 @@ app.add_middleware(NoCacheStaticMiddleware)
)
def app_config_js() -> Response:
"""Return JS assigning window.__DT_LANG and window.__DT_LOG_LEVEL for the webapp. No caching."""
lang = config.DEFAULT_LANGUAGE
log_level = config.LOG_LEVEL_STR.lower()
lang = _safe_js_string(config.DEFAULT_LANGUAGE, _VALID_LANGS, "en")
log_level = _safe_js_string(config.LOG_LEVEL_STR.lower(), _VALID_LOG_LEVELS, "info")
body = f'window.__DT_LANG = "{lang}";\nwindow.__DT_LOG_LEVEL = "{log_level}";'
return Response(
content=body,
@@ -183,10 +205,10 @@ def get_team_calendar_ical(
) -> Response:
"""Return ICS calendar with all duties (event_type duty only). Token validates user."""
if not _is_valid_calendar_token(token):
return Response(status_code=404, content="Not found")
return JSONResponse(status_code=404, content={"detail": "Not found"})
user = get_user_by_calendar_token(session, token)
if user is None:
return Response(status_code=404, content="Not found")
return JSONResponse(status_code=404, content={"detail": "Not found"})
cache_key = ("team_ics",)
ics_bytes, found = ics_calendar_cache.get(cache_key)
if not found:
@@ -224,10 +246,10 @@ def get_personal_calendar_ical(
No Telegram auth; access is by secret token in the URL.
"""
if not _is_valid_calendar_token(token):
return Response(status_code=404, content="Not found")
return JSONResponse(status_code=404, content={"detail": "Not found"})
user = get_user_by_calendar_token(session, token)
if user is None:
return Response(status_code=404, content="Not found")
return JSONResponse(status_code=404, content={"detail": "Not found"})
cache_key = ("personal_ics", user.id)
ics_bytes, found = ics_calendar_cache.get(cache_key)
if not found:

View File

@@ -42,7 +42,12 @@ def _validate_duty_dates(from_date: str, to_date: str, lang: str) -> None:
try:
validate_date_range(from_date, to_date)
except DateRangeValidationError as e:
key = "dates.bad_format" if e.kind == "bad_format" else "dates.from_after_to"
key_map = {
"bad_format": "dates.bad_format",
"from_after_to": "dates.from_after_to",
"range_too_large": "dates.range_too_large",
}
key = key_map.get(e.kind, "dates.bad_format")
raise HTTPException(status_code=400, detail=t(lang, key)) from e
except ValueError as e:
# Backward compatibility if something else raises ValueError.

View File

@@ -1,7 +1,8 @@
"""Load configuration from environment (e.g. .env via python-dotenv).
BOT_TOKEN is not validated on import; call require_bot_token() in the entry point
when running the bot.
when running the bot. Numeric env vars (HTTP_PORT, INIT_DATA_MAX_AGE_SECONDS) use
safe parsing with defaults on invalid values.
"""
import logging
@@ -16,6 +17,11 @@ from duty_teller.i18n.lang import normalize_lang
load_dotenv()
logger = logging.getLogger(__name__)
# Valid port range for HTTP_PORT.
HTTP_PORT_MIN, HTTP_PORT_MAX = 1, 65535
# Project root (parent of duty_teller package). Used for webapp path, etc.
PROJECT_ROOT = Path(__file__).resolve().parent.parent
@@ -55,6 +61,48 @@ def _normalize_log_level(raw: str) -> str:
return "INFO"
def _parse_int_env(
name: str, default: int, min_val: int | None = None, max_val: int | None = None
) -> int:
"""Parse an integer from os.environ; use default on invalid or out-of-range. Log on fallback."""
raw = os.getenv(name)
if raw is None or raw == "":
return default
try:
value = int(raw.strip())
except ValueError:
logger.warning(
"Invalid %s=%r (expected integer); using default %s",
name,
raw,
default,
)
return default
if min_val is not None and value < min_val:
logger.warning(
"%s=%s is below minimum %s; using %s", name, value, min_val, min_val
)
return min_val
if max_val is not None and value > max_val:
logger.warning(
"%s=%s is above maximum %s; using %s", name, value, max_val, max_val
)
return max_val
return value
def _validate_database_url(url: str) -> bool:
"""Return True if URL looks like a supported SQLAlchemy URL (sqlite or postgres)."""
if not url or not isinstance(url, str):
return False
u = url.strip().split("?", 1)[0].lower()
return (
u.startswith("sqlite://")
or u.startswith("postgresql://")
or u.startswith("postgres://")
)
@dataclass(frozen=True)
class Settings:
"""Injectable settings built from environment. Used in tests or when env is overridden."""
@@ -105,20 +153,30 @@ class Settings:
raw_host = (os.getenv("HTTP_HOST") or "127.0.0.1").strip()
http_host = raw_host if raw_host else "127.0.0.1"
bot_username = (os.getenv("BOT_USERNAME", "") or "").strip().lstrip("@").lower()
database_url = os.getenv("DATABASE_URL", "sqlite:///data/duty_teller.db")
if not _validate_database_url(database_url):
logger.warning(
"DATABASE_URL does not look like a supported URL (sqlite:// or postgresql://); "
"DB connection may fail."
)
http_port = _parse_int_env(
"HTTP_PORT", 8080, min_val=HTTP_PORT_MIN, max_val=HTTP_PORT_MAX
)
init_data_max_age = _parse_int_env("INIT_DATA_MAX_AGE_SECONDS", 0, min_val=0)
return cls(
bot_token=bot_token,
database_url=os.getenv("DATABASE_URL", "sqlite:///data/duty_teller.db"),
database_url=database_url,
bot_username=bot_username,
mini_app_base_url=os.getenv("MINI_APP_BASE_URL", "").rstrip("/"),
http_host=http_host,
http_port=int(os.getenv("HTTP_PORT", "8080")),
http_port=http_port,
allowed_usernames=allowed,
admin_usernames=admin,
allowed_phones=allowed_phones,
admin_phones=admin_phones,
mini_app_skip_auth=os.getenv("MINI_APP_SKIP_AUTH", "").strip()
in ("1", "true", "yes"),
init_data_max_age_seconds=int(os.getenv("INIT_DATA_MAX_AGE_SECONDS", "0")),
init_data_max_age_seconds=init_data_max_age,
cors_origins=cors,
external_calendar_ics_url=os.getenv(
"EXTERNAL_CALENDAR_ICS_URL", ""

View File

@@ -7,6 +7,7 @@ from datetime import datetime, timezone
from sqlalchemy.orm import Session
import duty_teller.config as config
from duty_teller.db.schemas import DUTY_EVENT_TYPES
from duty_teller.db.models import (
User,
Duty,
@@ -201,14 +202,19 @@ def get_or_create_user(
return user
def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
def get_or_create_user_by_full_name(
session: Session, full_name: str, *, commit: bool = True
) -> User:
"""Find user by exact full_name or create one (for duty-schedule import).
New users have telegram_user_id=None and name_manually_edited=True.
When commit=False, caller is responsible for committing (e.g. single commit
per import in run_import).
Args:
session: DB session.
full_name: Exact full name to match or set.
commit: If True, commit immediately. If False, caller commits.
Returns:
User instance (existing or newly created).
@@ -225,8 +231,11 @@ def get_or_create_user_by_full_name(session: Session, full_name: str) -> User:
name_manually_edited=True,
)
session.add(user)
session.commit()
session.refresh(user)
if commit:
session.commit()
session.refresh(user)
else:
session.flush() # Assign id so caller can use user.id before commit
return user
@@ -447,11 +456,13 @@ def insert_duty(
user_id: User id.
start_at: Start time UTC, ISO 8601 with Z (e.g. 2025-01-15T09:00:00Z).
end_at: End time UTC, ISO 8601 with Z.
event_type: One of "duty", "unavailable", "vacation". Default "duty".
event_type: One of "duty", "unavailable", "vacation". Invalid values are stored as "duty".
Returns:
Created Duty instance.
"""
if event_type not in DUTY_EVENT_TYPES:
event_type = "duty"
duty = Duty(
user_id=user_id,
start_at=start_at,

View File

@@ -67,7 +67,8 @@ async def set_phone(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
phone = " ".join(args).strip() if args else None
telegram_user_id = update.effective_user.id
def do_set_phone() -> str | None:
def do_set_phone() -> tuple[str, str | None]:
"""Returns (status, display_phone). status is 'error'|'saved'|'cleared'. display_phone for 'saved'."""
with session_scope(config.DATABASE_URL) as session:
full_name = build_full_name(
update.effective_user.first_name, update.effective_user.last_name
@@ -82,16 +83,20 @@ async def set_phone(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
)
user = set_user_phone(session, telegram_user_id, phone or None)
if user is None:
return "error"
return ("error", None)
if phone:
return "saved"
return "cleared"
return ("saved", user.phone or config.normalize_phone(phone))
return ("cleared", None)
result = await asyncio.get_running_loop().run_in_executor(None, do_set_phone)
result, display_phone = await asyncio.get_running_loop().run_in_executor(
None, do_set_phone
)
if result == "error":
await update.message.reply_text(t(lang, "set_phone.error"))
elif result == "saved":
await update.message.reply_text(t(lang, "set_phone.saved", phone=phone or ""))
await update.message.reply_text(
t(lang, "set_phone.saved", phone=display_phone or "")
)
else:
await update.message.reply_text(t(lang, "set_phone.cleared"))

View File

@@ -1,6 +1,7 @@
"""Import duty-schedule: /import_duty_schedule (admin only). Two steps: handover time -> JSON file."""
import asyncio
import logging
import duty_teller.config as config
from telegram import Update
@@ -16,6 +17,8 @@ from duty_teller.importers.duty_schedule import (
from duty_teller.services.import_service import run_import
from duty_teller.utils.handover import parse_handover_time
logger = logging.getLogger(__name__)
async def import_duty_schedule_cmd(
update: Update, context: ContextTypes.DEFAULT_TYPE
@@ -80,9 +83,10 @@ async def handle_duty_schedule_document(
try:
result = parse_duty_schedule(raw)
except DutyScheduleParseError as e:
logger.warning("Duty schedule parse error: %s", e, exc_info=True)
context.user_data.pop("awaiting_duty_schedule_file", None)
context.user_data.pop("handover_utc_time", None)
await update.message.reply_text(t(lang, "import.parse_error", error=str(e)))
await update.message.reply_text(t(lang, "import.parse_error_generic"))
return
def run_import_with_scope():
@@ -95,7 +99,8 @@ async def handle_duty_schedule_document(
None, run_import_with_scope
)
except Exception as e:
await update.message.reply_text(t(lang, "import.import_error", error=str(e)))
logger.exception("Import failed: %s", e)
await update.message.reply_text(t(lang, "import.import_error_generic"))
else:
total = num_duty + num_unavailable + num_vacation
unavailable_suffix = (

View File

@@ -72,7 +72,9 @@ MESSAGES: dict[str, dict[str, str]] = {
"import.send_json": "Send the duty-schedule file (JSON).",
"import.need_json": "File must have .json extension.",
"import.parse_error": "File parse error: {error}",
"import.parse_error_generic": "The file could not be parsed. Check the format and try again.",
"import.import_error": "Import error: {error}",
"import.import_error_generic": "Import failed. Please try again or contact an administrator.",
"import.done": (
"Import done: {users} users, {duties} duties{unavailable}{vacation} "
"({total} events total)."
@@ -88,6 +90,7 @@ MESSAGES: dict[str, dict[str, str]] = {
"api.access_denied": "Access denied",
"dates.bad_format": "Parameters from and to must be in YYYY-MM-DD format",
"dates.from_after_to": "from date must not be after to",
"dates.range_too_large": "Date range is too large. Request a shorter period.",
"contact.show": "Contacts",
"contact.back": "Back",
"current_duty.title": "Current Duty",
@@ -160,7 +163,9 @@ MESSAGES: dict[str, dict[str, str]] = {
"import.send_json": "Отправьте файл в формате duty-schedule (JSON).",
"import.need_json": "Нужен файл с расширением .json",
"import.parse_error": "Ошибка разбора файла: {error}",
"import.parse_error_generic": "Не удалось разобрать файл. Проверьте формат и попробуйте снова.",
"import.import_error": "Ошибка импорта: {error}",
"import.import_error_generic": "Импорт не выполнен. Попробуйте снова или обратитесь к администратору.",
"import.done": "Импорт выполнен: {users} пользователей, {duties} дежурств{unavailable}{vacation} (всего {total} событий).",
"import.done_unavailable": ", {count} недоступностей",
"import.done_vacation": ", {count} отпусков",
@@ -171,6 +176,7 @@ MESSAGES: dict[str, dict[str, str]] = {
"api.access_denied": "Доступ запрещён",
"dates.bad_format": "Параметры from и to должны быть в формате YYYY-MM-DD",
"dates.from_after_to": "Дата from не должна быть позже to",
"dates.range_too_large": "Диапазон дат слишком большой. Запросите более короткий период.",
"contact.show": "Контакты",
"contact.back": "Назад",
"current_duty.title": "Текущее дежурство",

View File

@@ -9,6 +9,11 @@ DUTY_MARKERS = frozenset({"б", "Б", "в", "В"})
UNAVAILABLE_MARKER = "Н"
VACATION_MARKER = "О"
# Limits to avoid abuse and unreasonable input.
MAX_SCHEDULE_ROWS = 500
MAX_FULL_NAME_LENGTH = 200
MAX_DUTY_STRING_LENGTH = 10000
@dataclass
class DutyScheduleEntry:
@@ -69,10 +74,24 @@ def parse_duty_schedule(raw_bytes: bytes) -> DutyScheduleResult:
except ValueError as e:
raise DutyScheduleParseError(f"Invalid meta.start_date: {start_str}") from e
# Reject dates outside current year ± 1.
today = date.today()
min_year = today.year - 1
max_year = today.year + 1
if not (min_year <= start_date.year <= max_year):
raise DutyScheduleParseError(
f"meta.start_date year must be between {min_year} and {max_year}"
)
schedule = data.get("schedule")
if not isinstance(schedule, list):
raise DutyScheduleParseError("Missing or invalid 'schedule' (must be array)")
if len(schedule) > MAX_SCHEDULE_ROWS:
raise DutyScheduleParseError(
f"schedule has too many rows (max {MAX_SCHEDULE_ROWS})"
)
max_days = 0
entries: list[DutyScheduleEntry] = []
@@ -85,12 +104,20 @@ def parse_duty_schedule(raw_bytes: bytes) -> DutyScheduleResult:
full_name = name.strip()
if not full_name:
raise DutyScheduleParseError("schedule item 'name' cannot be empty")
if len(full_name) > MAX_FULL_NAME_LENGTH:
raise DutyScheduleParseError(
f"schedule item 'name' must not exceed {MAX_FULL_NAME_LENGTH} characters"
)
duty_str = row.get("duty")
if duty_str is None:
duty_str = ""
if not isinstance(duty_str, str):
raise DutyScheduleParseError("schedule item 'duty' must be string")
if len(duty_str) > MAX_DUTY_STRING_LENGTH:
raise DutyScheduleParseError(
f"schedule item 'duty' must not exceed {MAX_DUTY_STRING_LENGTH} characters"
)
cells = [c.strip() for c in duty_str.split(";")]
max_days = max(max_days, len(cells))
@@ -120,4 +147,9 @@ def parse_duty_schedule(raw_bytes: bytes) -> DutyScheduleResult:
else:
end_date = start_date + timedelta(days=max_days - 1)
if not (min_year <= end_date.year <= max_year):
raise DutyScheduleParseError(
f"Computed end_date year must be between {min_year} and {max_year}"
)
return DutyScheduleResult(start_date=start_date, end_date=end_date, entries=entries)

View File

@@ -3,7 +3,9 @@
import asyncio
import json
import logging
import sys
import threading
import time
import urllib.request
from telegram.ext import ApplicationBuilder
@@ -13,6 +15,9 @@ from duty_teller.config import require_bot_token
from duty_teller.handlers import group_duty_pin, register_handlers
from duty_teller.utils.http_client import safe_urlopen
# Seconds to wait for HTTP server to bind before health check.
_HTTP_STARTUP_WAIT_SEC = 3
async def _resolve_bot_username(application) -> None:
"""If BOT_USERNAME is not set from env, resolve it via get_me()."""
@@ -69,6 +74,25 @@ def _run_uvicorn(web_app, port: int) -> None:
loop.run_until_complete(server.serve())
def _wait_for_http_ready(port: int) -> bool:
"""Return True if /health responds successfully within _HTTP_STARTUP_WAIT_SEC."""
host = config.HTTP_HOST
if host == "0.0.0.0":
host = "127.0.0.1"
url = f"http://{host}:{port}/health"
deadline = time.monotonic() + _HTTP_STARTUP_WAIT_SEC
while time.monotonic() < deadline:
try:
req = urllib.request.Request(url)
with safe_urlopen(req, timeout=2) as resp:
if resp.status == 200:
return True
except Exception as e:
logger.debug("Health check not ready yet: %s", e)
time.sleep(0.5)
return False
def main() -> None:
"""Build the bot and FastAPI, start uvicorn in a thread, run polling."""
require_bot_token()
@@ -85,16 +109,30 @@ def main() -> None:
from duty_teller.api.app import app as web_app
t = threading.Thread(
target=_run_uvicorn,
args=(web_app, config.HTTP_PORT),
daemon=True,
)
t.start()
if config.MINI_APP_SKIP_AUTH:
logger.warning(
"MINI_APP_SKIP_AUTH is set — API auth disabled (insecure); use only for dev"
)
if config.HTTP_HOST not in ("127.0.0.1", "localhost", ""):
print(
"ERROR: MINI_APP_SKIP_AUTH must not be used in production (non-localhost).",
file=sys.stderr,
)
sys.exit(1)
t = threading.Thread(
target=_run_uvicorn,
args=(web_app, config.HTTP_PORT),
daemon=False,
)
t.start()
if not _wait_for_http_ready(config.HTTP_PORT):
logger.error(
"HTTP server did not become ready on port %s within %s s; check port and permissions.",
config.HTTP_PORT,
_HTTP_STARTUP_WAIT_SEC,
)
sys.exit(1)
logger.info("Bot starting (polling)... HTTP API on port %s", config.HTTP_PORT)
app.run_polling(allowed_updates=["message", "my_chat_member"])

View File

@@ -1,5 +1,6 @@
"""Import duty schedule: delete range, insert duties/unavailable/vacation. Accepts session."""
import logging
from datetime import date, timedelta
from sqlalchemy.orm import Session
@@ -14,6 +15,8 @@ from duty_teller.db.repository import (
from duty_teller.importers.duty_schedule import DutyScheduleResult
from duty_teller.utils.dates import day_start_iso, day_end_iso, duty_to_iso
logger = logging.getLogger(__name__)
def _consecutive_date_ranges(dates: list[date]) -> list[tuple[date, date]]:
"""Sort dates and merge consecutive ones into (first, last) ranges. Empty list -> []."""
@@ -53,16 +56,24 @@ def run_import(
Returns:
Tuple (num_users, num_duty, num_unavailable, num_vacation).
"""
logger.info(
"Import started: range %s..%s, %d entries",
result.start_date,
result.end_date,
len(result.entries),
)
from_date_str = result.start_date.isoformat()
to_date_str = result.end_date.isoformat()
num_duty = num_unavailable = num_vacation = 0
# Batch: get all users by full_name, create missing
# Batch: get all users by full_name, create missing (no commit until end)
names = [e.full_name for e in result.entries]
users_map = get_users_by_full_names(session, names)
for name in names:
if name not in users_map:
users_map[name] = get_or_create_user_by_full_name(session, name)
users_map[name] = get_or_create_user_by_full_name(
session, name, commit=False
)
# Delete range per user (no commit)
for entry in result.entries:
@@ -113,4 +124,11 @@ def run_import(
session.bulk_insert_mappings(Duty, duty_rows)
session.commit()
invalidate_duty_related_caches()
logger.info(
"Import done: %d users, %d duty, %d unavailable, %d vacation",
len(result.entries),
num_duty,
num_unavailable,
num_vacation,
)
return (len(result.entries), num_duty, num_unavailable, num_vacation)

View File

@@ -24,10 +24,17 @@ def duty_to_iso(d: date, hour_utc: int, minute_utc: int) -> str:
_ISO_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
# Maximum allowed date range in days (e.g. 731 = 2 years).
MAX_DATE_RANGE_DAYS = 731
class DateRangeValidationError(ValueError):
"""Raised when from_date/to_date validation fails. API uses kind for i18n key."""
def __init__(self, kind: Literal["bad_format", "from_after_to"]) -> None:
def __init__(
self,
kind: Literal["bad_format", "from_after_to", "range_too_large"],
) -> None:
self.kind = kind
super().__init__(kind)
@@ -86,12 +93,20 @@ def parse_iso_date(s: str) -> date | None:
def validate_date_range(from_date: str, to_date: str) -> None:
"""Validate from_date and to_date are YYYY-MM-DD and from_date <= to_date.
"""Validate from_date and to_date are YYYY-MM-DD, from_date <= to_date, and range <= MAX_DATE_RANGE_DAYS.
Raises:
DateRangeValidationError: bad_format if format invalid, from_after_to if from > to.
DateRangeValidationError: bad_format if format invalid, from_after_to if from > to,
range_too_large if (to_date - from_date) > MAX_DATE_RANGE_DAYS.
"""
if not _ISO_DATE_RE.match(from_date or "") or not _ISO_DATE_RE.match(to_date or ""):
raise DateRangeValidationError("bad_format")
if from_date > to_date:
raise DateRangeValidationError("from_after_to")
try:
from_d = date.fromisoformat(from_date)
to_d = date.fromisoformat(to_date)
except ValueError:
raise DateRangeValidationError("bad_format") from None
if (to_d - from_d).days > MAX_DATE_RANGE_DAYS:
raise DateRangeValidationError("range_too_large")