Add configuration rules, refactor settings management, and enhance import functionality
- Introduced a new configuration file `.cursorrules` to define coding standards, error handling, testing requirements, and project-specific guidelines. - Refactored `config.py` to implement a `Settings` dataclass for better management of environment variables, improving testability and maintainability. - Updated the import duty schedule handler to utilize session management with `session_scope`, ensuring proper database session handling. - Enhanced the import service to streamline the duty schedule import process, improving code organization and readability. - Added new service layer functions to encapsulate business logic related to group duty pinning and duty schedule imports. - Updated README documentation to reflect the new configuration structure and improved import functionality.
This commit is contained in:
@@ -2,13 +2,14 @@
|
||||
|
||||
from db.models import Base, User, Duty
|
||||
from db.schemas import UserCreate, UserInDb, DutyCreate, DutyInDb, DutyWithUser
|
||||
from db.session import get_engine, get_session_factory, get_session
|
||||
from db.session import get_engine, get_session_factory, get_session, session_scope
|
||||
from db.repository import (
|
||||
delete_duties_in_range,
|
||||
get_or_create_user,
|
||||
get_or_create_user_by_full_name,
|
||||
get_duties,
|
||||
insert_duty,
|
||||
set_user_phone,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -23,11 +24,13 @@ __all__ = [
|
||||
"get_engine",
|
||||
"get_session_factory",
|
||||
"get_session",
|
||||
"session_scope",
|
||||
"delete_duties_in_range",
|
||||
"get_or_create_user",
|
||||
"get_or_create_user_by_full_name",
|
||||
"get_duties",
|
||||
"insert_duty",
|
||||
"set_user_phone",
|
||||
"init_db",
|
||||
]
|
||||
|
||||
|
||||
@@ -63,14 +63,13 @@ def delete_duties_in_range(
|
||||
) -> int:
|
||||
"""Delete all duties of the user that overlap [from_date, to_date] (YYYY-MM-DD). Returns count deleted."""
|
||||
# start_at < to_date + 1 day so duties starting on to_date are included (start_at is ISO with T)
|
||||
to_next = (datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
q = (
|
||||
session.query(Duty)
|
||||
.filter(
|
||||
Duty.user_id == user_id,
|
||||
Duty.start_at < to_next,
|
||||
Duty.end_at >= from_date,
|
||||
)
|
||||
to_next = (
|
||||
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
||||
).strftime("%Y-%m-%d")
|
||||
q = session.query(Duty).filter(
|
||||
Duty.user_id == user_id,
|
||||
Duty.start_at < to_next,
|
||||
Duty.end_at >= from_date,
|
||||
)
|
||||
count = q.count()
|
||||
q.delete(synchronize_session=False)
|
||||
@@ -89,7 +88,9 @@ def get_duties(
|
||||
in UTC (ISO 8601 with Z). Use to_date_next so duties starting on to_date are included
|
||||
(start_at like 2025-01-31T09:00:00Z is > "2025-01-31" lexicographically).
|
||||
"""
|
||||
to_date_next = (datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
to_date_next = (
|
||||
datetime.fromisoformat(to_date + "T00:00:00") + timedelta(days=1)
|
||||
).strftime("%Y-%m-%d")
|
||||
q = (
|
||||
session.query(Duty, User.full_name)
|
||||
.join(User, Duty.user_id == User.id)
|
||||
@@ -119,12 +120,11 @@ def insert_duty(
|
||||
return duty
|
||||
|
||||
|
||||
def get_current_duty(
|
||||
session: Session, at_utc: datetime
|
||||
) -> tuple[Duty, User] | None:
|
||||
def get_current_duty(session: Session, at_utc: datetime) -> tuple[Duty, User] | None:
|
||||
"""Return the duty (and user) for which start_at <= at_utc < end_at, event_type='duty'.
|
||||
at_utc is in UTC (naive or aware); comparison uses ISO strings."""
|
||||
from datetime import timezone
|
||||
|
||||
if at_utc.tzinfo is not None:
|
||||
at_utc = at_utc.astimezone(timezone.utc)
|
||||
now_iso = at_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
|
||||
@@ -147,6 +147,7 @@ def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None
|
||||
"""Return the end_at of the current duty (if after_utc is inside one) or of the next duty.
|
||||
For scheduling the next pin update. Returns naive UTC datetime."""
|
||||
from datetime import timezone
|
||||
|
||||
if after_utc.tzinfo is not None:
|
||||
after_utc = after_utc.astimezone(timezone.utc)
|
||||
after_iso = after_utc.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
|
||||
@@ -161,7 +162,9 @@ def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None
|
||||
.first()
|
||||
)
|
||||
if current:
|
||||
return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
return datetime.fromisoformat(current.end_at.replace("Z", "+00:00")).replace(
|
||||
tzinfo=None
|
||||
)
|
||||
# Next future duty: start_at > after_iso, order by start_at
|
||||
next_duty = (
|
||||
session.query(Duty)
|
||||
@@ -170,7 +173,9 @@ def get_next_shift_end(session: Session, after_utc: datetime) -> datetime | None
|
||||
.first()
|
||||
)
|
||||
if next_duty:
|
||||
return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
return datetime.fromisoformat(next_duty.end_at.replace("Z", "+00:00")).replace(
|
||||
tzinfo=None
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -206,7 +211,9 @@ def get_all_group_duty_pin_chat_ids(session: Session) -> list[int]:
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
def set_user_phone(session: Session, telegram_user_id: int, phone: str | None) -> User | None:
|
||||
def set_user_phone(
|
||||
session: Session, telegram_user_id: int, phone: str | None
|
||||
) -> User | None:
|
||||
"""Set phone for user by telegram_user_id. Returns User or None if not found."""
|
||||
user = session.query(User).filter(User.telegram_user_id == telegram_user_id).first()
|
||||
if not user:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""SQLAlchemy engine and session factory.
|
||||
|
||||
Note: Engine and session factory are cached globally per process. Only one
|
||||
DATABASE_URL is effectively used for the process lifetime. Using a different
|
||||
URL later (e.g. in tests with in-memory SQLite) would still use the first
|
||||
engine. To support multiple URLs, cache by database_url (e.g. a dict keyed by URL).
|
||||
Engine and session factory are cached globally per process. Only one DATABASE_URL
|
||||
is effectively used for the process lifetime. Using a different URL later (e.g. in
|
||||
tests with in-memory SQLite) would still use the first engine. To use a different
|
||||
URL in tests, set env (e.g. DATABASE_URL) before the first import of this module, or
|
||||
clear _engine and _SessionLocal in test fixtures. Prefer session_scope() for all
|
||||
callers so sessions are always closed and rolled back on error.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
Reference in New Issue
Block a user