180 lines
6 KiB
Python
180 lines
6 KiB
Python
|
|
"""Tests for the obs.costs cost ledger and price table."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from obs.costs import (
|
||
|
|
DEFAULT_PRICES_PATH,
|
||
|
|
SEED_PRICES_TOML,
|
||
|
|
CostLedger,
|
||
|
|
PriceTable,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestPriceTable:
|
||
|
|
def test_seeds_missing_file(self, tmp_path):
|
||
|
|
prices_path = tmp_path / "prices.toml"
|
||
|
|
assert not prices_path.exists()
|
||
|
|
|
||
|
|
table = PriceTable(path=str(prices_path))
|
||
|
|
|
||
|
|
assert prices_path.exists()
|
||
|
|
assert "claude-sonnet-4-6" in prices_path.read_text()
|
||
|
|
# Loaded into memory
|
||
|
|
assert table._data["models"]["claude-sonnet-4-6"]["input_per_mtok_usd"] == 3.00
|
||
|
|
|
||
|
|
def test_does_not_overwrite_existing_file(self, tmp_path):
|
||
|
|
prices_path = tmp_path / "prices.toml"
|
||
|
|
prices_path.write_text(
|
||
|
|
'[models."custom-model"]\n'
|
||
|
|
'input_per_mtok_usd = 1.23\n'
|
||
|
|
'output_per_mtok_usd = 4.56\n'
|
||
|
|
)
|
||
|
|
table = PriceTable(path=str(prices_path))
|
||
|
|
assert table._data["models"]["custom-model"]["input_per_mtok_usd"] == 1.23
|
||
|
|
assert "claude-sonnet-4-6" not in table._data.get("models", {})
|
||
|
|
|
||
|
|
def test_estimates_known_model(self, tmp_path):
|
||
|
|
table = PriceTable(path=str(tmp_path / "prices.toml"))
|
||
|
|
# 1M input @ $3 + 1M output @ $15 = $18, no tavily
|
||
|
|
cost = table.estimate_call_usd(
|
||
|
|
model_id="claude-sonnet-4-6",
|
||
|
|
tokens_input=1_000_000,
|
||
|
|
tokens_output=1_000_000,
|
||
|
|
tavily_searches=0,
|
||
|
|
)
|
||
|
|
assert cost == 18.00
|
||
|
|
|
||
|
|
def test_estimates_with_tavily(self, tmp_path):
|
||
|
|
table = PriceTable(path=str(tmp_path / "prices.toml"))
|
||
|
|
cost = table.estimate_call_usd(
|
||
|
|
model_id="claude-sonnet-4-6",
|
||
|
|
tokens_input=0,
|
||
|
|
tokens_output=0,
|
||
|
|
tavily_searches=10,
|
||
|
|
)
|
||
|
|
# 10 * $0.005 = $0.05
|
||
|
|
assert cost == 0.05
|
||
|
|
|
||
|
|
def test_unknown_model_returns_none(self, tmp_path):
|
||
|
|
table = PriceTable(path=str(tmp_path / "prices.toml"))
|
||
|
|
cost = table.estimate_call_usd(
|
||
|
|
model_id="some-future-model",
|
||
|
|
tokens_input=1000,
|
||
|
|
tokens_output=1000,
|
||
|
|
tavily_searches=0,
|
||
|
|
)
|
||
|
|
assert cost is None
|
||
|
|
|
||
|
|
|
||
|
|
class TestCostLedger:
|
||
|
|
def _ledger(self, tmp_path):
|
||
|
|
return CostLedger(
|
||
|
|
ledger_path=str(tmp_path / "costs.jsonl"),
|
||
|
|
price_table=PriceTable(path=str(tmp_path / "prices.toml")),
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_record_writes_jsonl(self, tmp_path):
|
||
|
|
ledger = self._ledger(tmp_path)
|
||
|
|
entry = ledger.record(
|
||
|
|
trace_id="abc-123",
|
||
|
|
question="What grows in Utah?",
|
||
|
|
model_id="claude-sonnet-4-6",
|
||
|
|
tokens_used=10_000,
|
||
|
|
tokens_input=8_000,
|
||
|
|
tokens_output=2_000,
|
||
|
|
iterations_run=3,
|
||
|
|
wall_time_sec=42.5,
|
||
|
|
tavily_searches=4,
|
||
|
|
budget_exhausted=False,
|
||
|
|
confidence=0.9,
|
||
|
|
)
|
||
|
|
|
||
|
|
# File contains one JSON line
|
||
|
|
lines = (tmp_path / "costs.jsonl").read_text().strip().splitlines()
|
||
|
|
assert len(lines) == 1
|
||
|
|
on_disk = json.loads(lines[0])
|
||
|
|
assert on_disk == entry
|
||
|
|
|
||
|
|
# All required fields present and shaped correctly
|
||
|
|
assert on_disk["trace_id"] == "abc-123"
|
||
|
|
assert on_disk["question"] == "What grows in Utah?"
|
||
|
|
assert on_disk["model_id"] == "claude-sonnet-4-6"
|
||
|
|
assert on_disk["tokens_used"] == 10_000
|
||
|
|
assert on_disk["tokens_input"] == 8_000
|
||
|
|
assert on_disk["tokens_output"] == 2_000
|
||
|
|
assert on_disk["iterations_run"] == 3
|
||
|
|
assert on_disk["wall_time_sec"] == 42.5
|
||
|
|
assert on_disk["tavily_searches"] == 4
|
||
|
|
assert on_disk["budget_exhausted"] is False
|
||
|
|
assert on_disk["confidence"] == 0.9
|
||
|
|
assert "timestamp" in on_disk
|
||
|
|
# 8000 input @ $3/Mtok + 2000 output @ $15/Mtok + 4 * $0.005 = $0.074
|
||
|
|
assert on_disk["estimated_cost_usd"] == pytest.approx(0.074, abs=1e-6)
|
||
|
|
|
||
|
|
def test_record_appends(self, tmp_path):
|
||
|
|
ledger = self._ledger(tmp_path)
|
||
|
|
for i in range(3):
|
||
|
|
ledger.record(
|
||
|
|
trace_id=f"trace-{i}",
|
||
|
|
question=f"q{i}",
|
||
|
|
model_id="claude-sonnet-4-6",
|
||
|
|
tokens_used=100,
|
||
|
|
tokens_input=80,
|
||
|
|
tokens_output=20,
|
||
|
|
iterations_run=1,
|
||
|
|
wall_time_sec=1.0,
|
||
|
|
tavily_searches=0,
|
||
|
|
budget_exhausted=False,
|
||
|
|
confidence=0.5,
|
||
|
|
)
|
||
|
|
lines = (tmp_path / "costs.jsonl").read_text().strip().splitlines()
|
||
|
|
assert len(lines) == 3
|
||
|
|
assert json.loads(lines[0])["trace_id"] == "trace-0"
|
||
|
|
assert json.loads(lines[2])["trace_id"] == "trace-2"
|
||
|
|
|
||
|
|
def test_unknown_model_records_null_cost(self, tmp_path):
|
||
|
|
ledger = self._ledger(tmp_path)
|
||
|
|
entry = ledger.record(
|
||
|
|
trace_id="abc",
|
||
|
|
question="q",
|
||
|
|
model_id="some-future-model",
|
||
|
|
tokens_used=1000,
|
||
|
|
tokens_input=500,
|
||
|
|
tokens_output=500,
|
||
|
|
iterations_run=1,
|
||
|
|
wall_time_sec=1.0,
|
||
|
|
tavily_searches=0,
|
||
|
|
budget_exhausted=False,
|
||
|
|
confidence=0.5,
|
||
|
|
)
|
||
|
|
assert entry["estimated_cost_usd"] is None
|
||
|
|
|
||
|
|
def test_question_is_truncated(self, tmp_path):
|
||
|
|
ledger = self._ledger(tmp_path)
|
||
|
|
long_q = "x" * 1000
|
||
|
|
entry = ledger.record(
|
||
|
|
trace_id="abc",
|
||
|
|
question=long_q,
|
||
|
|
model_id="claude-sonnet-4-6",
|
||
|
|
tokens_used=10,
|
||
|
|
tokens_input=5,
|
||
|
|
tokens_output=5,
|
||
|
|
iterations_run=1,
|
||
|
|
wall_time_sec=0.1,
|
||
|
|
tavily_searches=0,
|
||
|
|
budget_exhausted=False,
|
||
|
|
confidence=0.5,
|
||
|
|
)
|
||
|
|
assert len(entry["question"]) == 200
|
||
|
|
|
||
|
|
def test_env_var_override(self, tmp_path, monkeypatch):
|
||
|
|
custom = tmp_path / "custom-ledger.jsonl"
|
||
|
|
monkeypatch.setenv("MARCHWARDEN_COST_LEDGER", str(custom))
|
||
|
|
ledger = CostLedger(
|
||
|
|
price_table=PriceTable(path=str(tmp_path / "prices.toml")),
|
||
|
|
)
|
||
|
|
assert ledger.path == custom
|