diff --git a/researchers/web/trace.py b/researchers/web/trace.py index 1c8e5fb..97c81ae 100644 --- a/researchers/web/trace.py +++ b/researchers/web/trace.py @@ -35,6 +35,22 @@ _INFO_ACTIONS = frozenset( _log = get_logger("marchwarden.researcher.trace") +# Action pairings for duration tracking. When a starter action fires +# we record a monotonic start time keyed by the starter name. When the +# matching completer fires we compute the elapsed duration and attach +# it as a field on the completer's entry, then clear the start. +# +# Synthesis has two possible completers (success or error), both +# pointing back to synthesis_start. +_DURATION_PAIRS: dict[str, str] = { + "web_search_complete": "web_search", + "fetch_url_complete": "fetch_url", + "synthesis_complete": "synthesis_start", + "synthesis_error": "synthesis_start", + "complete": "start", +} +_STARTER_ACTIONS = frozenset(_DURATION_PAIRS.values()) + class TraceLogger: """Logs research steps to a JSONL file. @@ -65,6 +81,9 @@ class TraceLogger: self.file_path = self.trace_dir / f"{self.trace_id}.jsonl" self._step_counter = 0 self._file = None + # action_name -> monotonic start time, populated by starter + # actions and consumed by their matching completer (Issue #35). + self._pending_starts: dict[str, float] = {} @property def _writer(self): @@ -93,6 +112,24 @@ class TraceLogger: } entry.update(kwargs) + # Duration tracking (Issue #35). Record start times for starter + # actions; when the matching completer fires, attach elapsed time + # to both the trace entry and the operational log line. + now = time.monotonic() + if action in _STARTER_ACTIONS: + self._pending_starts[action] = now + duration_extras: dict[str, Any] = {} + if action in _DURATION_PAIRS: + starter = _DURATION_PAIRS[action] + start = self._pending_starts.pop(starter, None) + if start is not None: + elapsed = now - start + if action == "complete": + duration_extras["total_duration_sec"] = round(elapsed, 3) + else: + duration_extras["duration_ms"] = int(elapsed * 1000) + entry.update(duration_extras) + self._writer.write(json.dumps(entry, default=str) + "\n") self._writer.flush() @@ -101,7 +138,13 @@ class TraceLogger: # already bound in contextvars by WebResearcher.research, so # they automatically appear on every line. log_method = _log.info if action in _INFO_ACTIONS else _log.debug - log_method(action, step=self._step_counter, decision=decision, **kwargs) + log_method( + action, + step=self._step_counter, + decision=decision, + **kwargs, + **duration_extras, + ) return entry diff --git a/tests/test_trace.py b/tests/test_trace.py index 8a42c39..3c77699 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -143,3 +143,74 @@ class TestTraceLogger: assert a.read_entries()[0]["query"] == "a" assert b.read_entries()[0]["query"] == "b" assert a.file_path != b.file_path + + +# --------------------------------------------------------------------------- +# Step duration tracking (Issue #35) +# --------------------------------------------------------------------------- + + +import time as _time + + +class TestStepDurations: + def test_web_search_pair_records_duration_ms(self, tmp_path): + logger = TraceLogger(trace_dir=str(tmp_path)) + logger.log_step("web_search", query="utah crops") + _time.sleep(0.02) + entry = logger.log_step("web_search_complete", result_count=5) + logger.close() + assert "duration_ms" in entry + assert entry["duration_ms"] >= 15 + + def test_fetch_url_pair_records_duration_ms(self, tmp_path): + logger = TraceLogger(trace_dir=str(tmp_path)) + logger.log_step("fetch_url", url="https://example.com") + _time.sleep(0.01) + entry = logger.log_step("fetch_url_complete", success=True) + logger.close() + assert "duration_ms" in entry + + def test_synthesis_complete_records_duration_ms(self, tmp_path): + logger = TraceLogger(trace_dir=str(tmp_path)) + logger.log_step("synthesis_start") + _time.sleep(0.01) + entry = logger.log_step("synthesis_complete") + logger.close() + assert "duration_ms" in entry + + def test_synthesis_error_also_records_duration_ms(self, tmp_path): + logger = TraceLogger(trace_dir=str(tmp_path)) + logger.log_step("synthesis_start") + _time.sleep(0.01) + entry = logger.log_step("synthesis_error", parse_error="boom") + logger.close() + assert "duration_ms" in entry + + def test_complete_records_total_duration_sec(self, tmp_path): + logger = TraceLogger(trace_dir=str(tmp_path)) + logger.log_step("start", question="q") + _time.sleep(0.02) + entry = logger.log_step("complete", confidence=0.9) + logger.close() + assert "total_duration_sec" in entry + assert entry["total_duration_sec"] >= 0.015 + # Sec precision, not ms + assert "duration_ms" not in entry + + def test_unpaired_completer_does_not_crash(self, tmp_path): + logger = TraceLogger(trace_dir=str(tmp_path)) + # No matching web_search starter + entry = logger.log_step("web_search_complete", result_count=0) + logger.close() + assert "duration_ms" not in entry + + def test_existing_fields_preserved(self, tmp_path): + logger = TraceLogger(trace_dir=str(tmp_path)) + logger.log_step("web_search", query="x") + entry = logger.log_step("web_search_complete", result_count=3, urls=["u1"]) + logger.close() + assert entry["result_count"] == 3 + assert entry["urls"] == ["u1"] + assert "step" in entry + assert "timestamp" in entry