merge: feat/issue-44-context-budget (#44)

This commit is contained in:
Jeff Smith 2026-04-06 22:49:44 -06:00
commit 40af515fb2
2 changed files with 80 additions and 5 deletions

View file

@ -31,8 +31,12 @@ from luminos_lib.tree import build_tree, render_tree
MODEL = "claude-sonnet-4-20250514" MODEL = "claude-sonnet-4-20250514"
# Context budget: trigger early exit at 70% of Sonnet's context window. # Context budget: trigger early exit when a single API call's input_tokens
MAX_CONTEXT = 180_000 # (the actual size of the context window in use, NOT the cumulative sum
# across turns) approaches the model's real context limit. Sonnet 4 has
# a 200k context window; we leave a 30% safety margin for the response
# and any tool result we're about to append.
MAX_CONTEXT = 200_000
CONTEXT_BUDGET = int(MAX_CONTEXT * 0.70) CONTEXT_BUDGET = int(MAX_CONTEXT * 0.70)
# Pricing per 1M tokens (Claude Sonnet). # Pricing per 1M tokens (Claude Sonnet).
@ -88,13 +92,25 @@ def _should_skip_dir(name):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class _TokenTracker: class _TokenTracker:
"""Track cumulative token usage across API calls.""" """Track token usage across API calls.
Two distinct quantities are tracked:
- cumulative totals (total_*, loop_*) for cost reporting
- last_input the size of the context window on the most recent
call, used to detect approaching the model's context limit
Cumulative input is NOT a meaningful proxy for context size: each
turn's input_tokens already includes the full message history, so
summing across turns double-counts everything. Use last_input for
budget decisions, totals for billing. (See #44.)
"""
def __init__(self): def __init__(self):
self.total_input = 0 self.total_input = 0
self.total_output = 0 self.total_output = 0
self.loop_input = 0 self.loop_input = 0
self.loop_output = 0 self.loop_output = 0
self.last_input = 0
def record(self, usage): def record(self, usage):
"""Record usage from a single API call.""" """Record usage from a single API call."""
@ -104,18 +120,21 @@ class _TokenTracker:
self.total_output += out self.total_output += out
self.loop_input += inp self.loop_input += inp
self.loop_output += out self.loop_output += out
self.last_input = inp
def reset_loop(self): def reset_loop(self):
"""Reset per-loop counters (called between directory loops).""" """Reset per-loop counters (called between directory loops)."""
self.loop_input = 0 self.loop_input = 0
self.loop_output = 0 self.loop_output = 0
self.last_input = 0
@property @property
def loop_total(self): def loop_total(self):
return self.loop_input + self.loop_output return self.loop_input + self.loop_output
def budget_exceeded(self): def budget_exceeded(self):
return self.loop_total > CONTEXT_BUDGET """True when the most recent call's context exceeded the budget."""
return self.last_input > CONTEXT_BUDGET
def summary(self): def summary(self):
cost_in = self.total_input * INPUT_PRICE_PER_M / 1_000_000 cost_in = self.total_input * INPUT_PRICE_PER_M / 1_000_000
@ -862,7 +881,10 @@ def _run_dir_loop(client, target, cache, tracker, dir_path, max_turns=14,
# Check context budget # Check context budget
if tracker.budget_exceeded(): if tracker.budget_exceeded():
print(f" [AI] Context budget reached — exiting early " print(f" [AI] Context budget reached — exiting early "
f"({tracker.loop_total:,} tokens used)", file=sys.stderr) f"(context size {tracker.last_input:,} > "
f"{CONTEXT_BUDGET:,} budget; "
f"loop spend {tracker.loop_total:,} tokens)",
file=sys.stderr)
# Flush a partial directory summary from cached file entries # Flush a partial directory summary from cached file entries
if not cache.has_entry("dir", dir_path): if not cache.has_entry("dir", dir_path):
dir_real = os.path.realpath(dir_path) dir_real = os.path.realpath(dir_path)

View file

@ -108,6 +108,59 @@ class FormatSurveyBlockTests(unittest.TestCase):
self.assertNotIn("Skip tools", block) self.assertNotIn("Skip tools", block)
class TokenTrackerTests(unittest.TestCase):
def _usage(self, inp, out):
u = MagicMock()
u.input_tokens = inp
u.output_tokens = out
return u
def test_record_updates_cumulative_and_last(self):
t = ai._TokenTracker()
t.record(self._usage(100, 20))
t.record(self._usage(200, 30))
self.assertEqual(t.total_input, 300)
self.assertEqual(t.total_output, 50)
self.assertEqual(t.loop_input, 300)
self.assertEqual(t.loop_output, 50)
self.assertEqual(t.last_input, 200) # last call only
def test_budget_uses_last_input_not_sum(self):
t = ai._TokenTracker()
# Many small calls whose sum exceeds the budget but whose
# last input is well under the budget should NOT trip.
for _ in range(20):
t.record(self._usage(10_000, 100))
self.assertGreater(t.loop_input, ai.CONTEXT_BUDGET)
self.assertLess(t.last_input, ai.CONTEXT_BUDGET)
self.assertFalse(t.budget_exceeded())
def test_budget_trips_when_last_input_over_threshold(self):
t = ai._TokenTracker()
t.record(self._usage(ai.CONTEXT_BUDGET + 1, 100))
self.assertTrue(t.budget_exceeded())
def test_reset_loop_clears_loop_and_last(self):
t = ai._TokenTracker()
t.record(self._usage(500, 50))
t.reset_loop()
self.assertEqual(t.loop_input, 0)
self.assertEqual(t.loop_output, 0)
self.assertEqual(t.last_input, 0)
# Cumulative totals are NOT reset
self.assertEqual(t.total_input, 500)
self.assertEqual(t.total_output, 50)
def test_loop_total_property_still_works(self):
t = ai._TokenTracker()
t.record(self._usage(100, 25))
t.record(self._usage(200, 50))
self.assertEqual(t.loop_total, 375)
def test_max_context_is_sonnet_real_window(self):
self.assertEqual(ai.MAX_CONTEXT, 200_000)
class DefaultSurveyTests(unittest.TestCase): class DefaultSurveyTests(unittest.TestCase):
def test_has_all_required_keys(self): def test_has_all_required_keys(self):
survey = ai._default_survey() survey = ai._default_survey()