diff --git a/src/quartermaster/month_service.py b/src/quartermaster/month_service.py new file mode 100644 index 0000000..7c96387 --- /dev/null +++ b/src/quartermaster/month_service.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from datetime import date +from decimal import Decimal +from enum import Enum + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from quartermaster.models import ( + DebtTarget, + Entry, + Month, + MonthDebtTarget, + MonthEntry, + Section, +) + +YEAR_MONTH_RE = re.compile(r"^\d{4}-(0[1-9]|1[0-2])$") + + +class DeviationState(str, Enum): + unchanged = "unchanged" + edited = "edited" + new_in_month = "new_in_month" + + +@dataclass(frozen=True) +class MonthRow: + entry: MonthEntry + state: DeviationState + + +@dataclass(frozen=True) +class MonthSectionView: + section: Section + label: str + rows: list[MonthRow] + total_planned: Decimal + total_applied: Decimal + + +def valid_year_month(year_month: str) -> bool: + return bool(YEAR_MONTH_RE.match(year_month)) + + +def current_year_month() -> str: + today = date.today() + return f"{today.year:04d}-{today.month:02d}" + + +def shift_year_month(year_month: str, delta: int) -> str: + year, month = (int(part) for part in year_month.split("-")) + index = (year * 12 + (month - 1)) + delta + new_year, new_month0 = divmod(index, 12) + return f"{new_year:04d}-{new_month0 + 1:02d}" + + +def get_month(db: Session, year_month: str) -> Month | None: + stmt = select(Month).where(Month.year_month == year_month) + return db.scalar(stmt) + + +def list_months(db: Session) -> list[Month]: + stmt = select(Month).order_by(Month.year_month) + return list(db.scalars(stmt)) + + +def create_month(db: Session, year_month: str) -> Month: + if not valid_year_month(year_month): + raise ValueError("year_month must be formatted as YYYY-MM") + existing = get_month(db, year_month) + if existing is not None: + return existing + + month = Month(year_month=year_month) + db.add(month) + db.flush() + + budget_entries = list(db.scalars(select(Entry).order_by(Entry.id))) + source_to_month_entry: dict[int, MonthEntry] = {} + for e in budget_entries: + month_entry = MonthEntry( + month_id=month.id, + section=e.section, + name=e.name, + planned=e.amount, + applied=Decimal("0.00"), + origin_name=e.name, + origin_planned=e.amount, + source_entry_id=e.id, + ) + db.add(month_entry) + source_to_month_entry[e.id] = month_entry + + db.flush() + + budget_target = db.get(DebtTarget, 1) + target_entry_id: int | None = None + if budget_target is not None and budget_target.debt_minimum_id is not None: + mapped = source_to_month_entry.get(budget_target.debt_minimum_id) + if mapped is not None: + target_entry_id = mapped.id + + db.add( + MonthDebtTarget(month_id=month.id, month_entry_id=target_entry_id) + ) + db.commit() + db.refresh(month) + return month + + +def deviation_state(entry: MonthEntry) -> DeviationState: + if entry.origin_name is None or entry.origin_planned is None: + return DeviationState.new_in_month + if entry.name != entry.origin_name or entry.planned != entry.origin_planned: + return DeviationState.edited + return DeviationState.unchanged + + +def _rows(entries: list[MonthEntry]) -> list[MonthRow]: + return [MonthRow(entry=e, state=deviation_state(e)) for e in entries] + + +def section_view(month: Month, section: Section, label: str) -> MonthSectionView: + entries = [e for e in month.entries if e.section == section] + entries.sort(key=lambda e: e.id) + rows = _rows(entries) + total_planned = sum((e.planned for e in entries), Decimal("0")) + total_applied = sum((e.applied for e in entries), Decimal("0")) + return MonthSectionView( + section=section, + label=label, + rows=rows, + total_planned=total_planned, + total_applied=total_applied, + ) + + +def add_month_entry( + db: Session, + month: Month, + section: Section, + name: str, + planned: Decimal, +) -> MonthEntry: + entry = MonthEntry( + month_id=month.id, + section=section, + name=name.strip(), + planned=planned, + applied=Decimal("0.00"), + origin_name=None, + origin_planned=None, + source_entry_id=None, + ) + db.add(entry) + db.commit() + db.refresh(entry) + return entry + + +def get_month_entry(db: Session, month: Month, entry_id: int) -> MonthEntry | None: + entry = db.get(MonthEntry, entry_id) + if entry is None or entry.month_id != month.id: + return None + return entry + + +def delete_month_entry(db: Session, month: Month, entry_id: int) -> Section | None: + entry = get_month_entry(db, month, entry_id) + if entry is None: + return None + section = entry.section + db.delete(entry) + db.commit() + return section + + +def update_month_entry( + db: Session, + month: Month, + entry_id: int, + *, + name: str | None = None, + planned: Decimal | None = None, + applied: Decimal | None = None, +) -> MonthEntry | None: + entry = get_month_entry(db, month, entry_id) + if entry is None: + return None + if name is not None: + entry.name = name.strip() + if planned is not None: + entry.planned = planned + if applied is not None: + entry.applied = applied + db.commit() + db.refresh(entry) + return entry + + +def get_month_target(db: Session, month: Month) -> MonthDebtTarget: + if month.target is not None: + return month.target + target = MonthDebtTarget(month_id=month.id, month_entry_id=None) + db.add(target) + db.commit() + db.refresh(target) + return target + + +def set_month_target( + db: Session, month: Month, month_entry_id: int | None +) -> MonthDebtTarget: + target = get_month_target(db, month) + if month_entry_id is not None: + candidate = get_month_entry(db, month, month_entry_id) + if candidate is None or candidate.section != Section.debt_minimum: + raise ValueError( + "month_entry_id must reference a debt minimum entry in this month" + ) + target.month_entry_id = month_entry_id + db.commit() + db.refresh(target) + return target