"""Tests for the trace logger.""" import json import os import tempfile import uuid from researchers.web.trace import TraceLogger class TestTraceLogger: def _make_logger(self, tmp_dir, **kwargs): return TraceLogger(trace_dir=tmp_dir, **kwargs) def test_generates_trace_id(self): with tempfile.TemporaryDirectory() as tmp: logger = self._make_logger(tmp) # Should be a valid UUID uuid.UUID(logger.trace_id) def test_uses_provided_trace_id(self): with tempfile.TemporaryDirectory() as tmp: logger = self._make_logger(tmp, trace_id="my-custom-id") assert logger.trace_id == "my-custom-id" def test_creates_trace_dir(self): with tempfile.TemporaryDirectory() as tmp: nested = os.path.join(tmp, "deep", "nested", "dir") logger = TraceLogger(trace_dir=nested) assert os.path.isdir(nested) logger.close() def test_file_path(self): with tempfile.TemporaryDirectory() as tmp: logger = self._make_logger(tmp, trace_id="test-123") assert str(logger.file_path).endswith("test-123.jsonl") assert str(logger.result_path).endswith("test-123.result.json") def test_write_result_persists_pydantic_model(self): with tempfile.TemporaryDirectory() as tmp: logger = self._make_logger(tmp, trace_id="result-test") class Stub: def model_dump_json(self, indent=None): return '{"answer": "hi", "gaps": []}' logger.write_result(Stub()) assert logger.result_path.exists() data = json.loads(logger.result_path.read_text()) assert data["answer"] == "hi" assert data["gaps"] == [] def test_write_result_accepts_dict(self): with tempfile.TemporaryDirectory() as tmp: logger = self._make_logger(tmp, trace_id="dict-test") logger.write_result({"foo": "bar"}) assert json.loads(logger.result_path.read_text()) == {"foo": "bar"} def test_log_step_creates_file(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: logger.log_step("search", query="test") assert logger.file_path.exists() def test_log_step_increments_counter(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: e1 = logger.log_step("search") e2 = logger.log_step("fetch") e3 = logger.log_step("synthesize") assert e1["step"] == 1 assert e2["step"] == 2 assert e3["step"] == 3 def test_log_step_required_fields(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: entry = logger.log_step("search", decision="relevant query") assert entry["action"] == "search" assert entry["decision"] == "relevant query" assert "timestamp" in entry assert "step" in entry def test_log_step_extra_kwargs(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: entry = logger.log_step( "fetch_url", url="https://example.com", content_hash="sha256:abc123", content_length=5000, ) assert entry["url"] == "https://example.com" assert entry["content_hash"] == "sha256:abc123" assert entry["content_length"] == 5000 def test_entries_are_valid_jsonl(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: logger.log_step("search", query="test") logger.log_step("fetch_url", url="https://example.com") logger.log_step("synthesize", decision="done") with open(logger.file_path) as f: lines = [l.strip() for l in f if l.strip()] assert len(lines) == 3 for line in lines: parsed = json.loads(line) assert "step" in parsed assert "action" in parsed assert "timestamp" in parsed def test_read_entries(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: logger.log_step("search", query="q1") logger.log_step("fetch_url", url="https://example.com") entries = logger.read_entries() assert len(entries) == 2 assert entries[0]["action"] == "search" assert entries[0]["query"] == "q1" assert entries[1]["action"] == "fetch_url" def test_read_entries_empty_file(self): with tempfile.TemporaryDirectory() as tmp: logger = self._make_logger(tmp) entries = logger.read_entries() assert entries == [] def test_context_manager_closes_file(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: logger.log_step("search") assert logger._file is None def test_timestamp_format(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp) as logger: entry = logger.log_step("search") # ISO 8601 UTC format ts = entry["timestamp"] assert ts.endswith("Z") assert "T" in ts def test_flush_after_each_write(self): """Entries should be readable immediately, not buffered.""" with tempfile.TemporaryDirectory() as tmp: logger = self._make_logger(tmp) logger.log_step("search", query="test") # Read without closing entries = logger.read_entries() assert len(entries) == 1 logger.close() def test_multiple_loggers_different_files(self): with tempfile.TemporaryDirectory() as tmp: with self._make_logger(tmp, trace_id="trace-a") as a: a.log_step("search", query="a") with self._make_logger(tmp, trace_id="trace-b") as b: b.log_step("search", query="b") assert a.read_entries()[0]["query"] == "a" assert b.read_entries()[0]["query"] == "b" assert a.file_path != b.file_path # --------------------------------------------------------------------------- # Step duration tracking (Issue #35) # --------------------------------------------------------------------------- import time as _time class TestStepDurations: def test_web_search_pair_records_duration_ms(self, tmp_path): logger = TraceLogger(trace_dir=str(tmp_path)) logger.log_step("web_search", query="utah crops") _time.sleep(0.02) entry = logger.log_step("web_search_complete", result_count=5) logger.close() assert "duration_ms" in entry assert entry["duration_ms"] >= 15 def test_fetch_url_pair_records_duration_ms(self, tmp_path): logger = TraceLogger(trace_dir=str(tmp_path)) logger.log_step("fetch_url", url="https://example.com") _time.sleep(0.01) entry = logger.log_step("fetch_url_complete", success=True) logger.close() assert "duration_ms" in entry def test_synthesis_complete_records_duration_ms(self, tmp_path): logger = TraceLogger(trace_dir=str(tmp_path)) logger.log_step("synthesis_start") _time.sleep(0.01) entry = logger.log_step("synthesis_complete") logger.close() assert "duration_ms" in entry def test_synthesis_error_also_records_duration_ms(self, tmp_path): logger = TraceLogger(trace_dir=str(tmp_path)) logger.log_step("synthesis_start") _time.sleep(0.01) entry = logger.log_step("synthesis_error", parse_error="boom") logger.close() assert "duration_ms" in entry def test_complete_records_total_duration_sec(self, tmp_path): logger = TraceLogger(trace_dir=str(tmp_path)) logger.log_step("start", question="q") _time.sleep(0.02) entry = logger.log_step("complete", confidence=0.9) logger.close() assert "total_duration_sec" in entry assert entry["total_duration_sec"] >= 0.015 # Sec precision, not ms assert "duration_ms" not in entry def test_unpaired_completer_does_not_crash(self, tmp_path): logger = TraceLogger(trace_dir=str(tmp_path)) # No matching web_search starter entry = logger.log_step("web_search_complete", result_count=0) logger.close() assert "duration_ms" not in entry def test_existing_fields_preserved(self, tmp_path): logger = TraceLogger(trace_dir=str(tmp_path)) logger.log_step("web_search", query="x") entry = logger.log_step("web_search_complete", result_count=3, urls=["u1"]) logger.close() assert entry["result_count"] == 3 assert entry["urls"] == ["u1"] assert "step" in entry assert "timestamp" in entry