"""Tests for the Marchwarden Research Contract v1 models.""" import json import uuid import pytest from pydantic import ValidationError from researchers.web.models import ( Citation, ConfidenceFactors, CostMetadata, DiscoveryEvent, Gap, GapCategory, ResearchConstraints, ResearchResult, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def make_citation(**overrides) -> Citation: defaults = { "source": "web", "locator": "https://example.com/article", "title": "Example Article", "snippet": "Relevant summary of the content.", "raw_excerpt": "Verbatim text copied directly from the source document.", "confidence": 0.85, } defaults.update(overrides) return Citation(**defaults) def make_gap(**overrides) -> Gap: defaults = { "topic": "pest management", "category": GapCategory.SOURCE_NOT_FOUND, "detail": "No pest data found in general web sources.", } defaults.update(overrides) return Gap(**defaults) def make_discovery_event(**overrides) -> DiscoveryEvent: defaults = { "type": "related_research", "suggested_researcher": "arxiv", "query": "soil salinity studies Utah 2024-2026", "reason": "Multiple web sources reference USU study data", "source_locator": "https://example.com/reference", } defaults.update(overrides) return DiscoveryEvent(**defaults) def make_confidence_factors(**overrides) -> ConfidenceFactors: defaults = { "num_corroborating_sources": 3, "source_authority": "high", "contradiction_detected": False, "query_specificity_match": 0.85, "budget_exhausted": False, "recency": "current", } defaults.update(overrides) return ConfidenceFactors(**defaults) def make_cost_metadata(**overrides) -> CostMetadata: defaults = { "tokens_used": 8452, "iterations_run": 3, "wall_time_sec": 42.5, "budget_exhausted": False, "model_id": "claude-sonnet-4-6", } defaults.update(overrides) return CostMetadata(**defaults) def make_research_result(**overrides) -> ResearchResult: defaults = { "answer": "Utah is ideal for cool-season crops at high elevation.", "citations": [make_citation()], "gaps": [make_gap()], "discovery_events": [make_discovery_event()], "confidence": 0.82, "confidence_factors": make_confidence_factors(), "cost_metadata": make_cost_metadata(), "trace_id": str(uuid.uuid4()), } defaults.update(overrides) return ResearchResult(**defaults) # --------------------------------------------------------------------------- # ResearchConstraints # --------------------------------------------------------------------------- class TestResearchConstraints: def test_defaults(self): c = ResearchConstraints() assert c.max_iterations == 5 assert c.token_budget == 20_000 assert c.max_sources == 10 assert c.source_filter is None def test_custom_values(self): c = ResearchConstraints( max_iterations=3, token_budget=5000, max_sources=5 ) assert c.max_iterations == 3 assert c.token_budget == 5000 assert c.max_sources == 5 def test_invalid_iterations(self): with pytest.raises(ValidationError): ResearchConstraints(max_iterations=0) def test_invalid_token_budget(self): with pytest.raises(ValidationError): ResearchConstraints(token_budget=500) def test_serialization_roundtrip(self): c = ResearchConstraints(max_iterations=3, token_budget=10000) data = c.model_dump() c2 = ResearchConstraints(**data) assert c == c2 # --------------------------------------------------------------------------- # Citation # --------------------------------------------------------------------------- class TestCitation: def test_full_citation(self): c = make_citation() assert c.source == "web" assert c.raw_excerpt.startswith("Verbatim") assert 0.0 <= c.confidence <= 1.0 def test_minimal_citation(self): c = Citation( source="web", locator="https://example.com", raw_excerpt="Some text.", confidence=0.5, ) assert c.title is None assert c.snippet is None def test_confidence_bounds(self): with pytest.raises(ValidationError): make_citation(confidence=1.5) with pytest.raises(ValidationError): make_citation(confidence=-0.1) def test_raw_excerpt_required(self): with pytest.raises(ValidationError): Citation(source="web", locator="https://example.com", confidence=0.5) def test_serialization_roundtrip(self): c = make_citation() data = c.model_dump() c2 = Citation(**data) assert c == c2 # --------------------------------------------------------------------------- # GapCategory # --------------------------------------------------------------------------- class TestGapCategory: def test_all_categories_exist(self): expected = { "source_not_found", "access_denied", "budget_exhausted", "contradictory_sources", "scope_exceeded", } actual = {cat.value for cat in GapCategory} assert actual == expected def test_string_enum(self): assert GapCategory.SOURCE_NOT_FOUND == "source_not_found" assert isinstance(GapCategory.ACCESS_DENIED, str) # --------------------------------------------------------------------------- # Gap # --------------------------------------------------------------------------- class TestGap: def test_gap_creation(self): g = make_gap() assert g.category == GapCategory.SOURCE_NOT_FOUND assert g.topic == "pest management" def test_all_categories_accepted(self): for cat in GapCategory: g = make_gap(category=cat) assert g.category == cat def test_serialization_roundtrip(self): g = make_gap() data = g.model_dump() g2 = Gap(**data) assert g == g2 def test_json_uses_string_category(self): g = make_gap(category=GapCategory.BUDGET_EXHAUSTED) data = json.loads(g.model_dump_json()) assert data["category"] == "budget_exhausted" # --------------------------------------------------------------------------- # DiscoveryEvent # --------------------------------------------------------------------------- class TestDiscoveryEvent: def test_full_event(self): e = make_discovery_event() assert e.type == "related_research" assert e.suggested_researcher == "arxiv" def test_minimal_event(self): e = DiscoveryEvent( type="contradiction", query="conflicting data on topic X", reason="Two sources disagree", ) assert e.suggested_researcher is None assert e.source_locator is None def test_serialization_roundtrip(self): e = make_discovery_event() data = e.model_dump() e2 = DiscoveryEvent(**data) assert e == e2 # --------------------------------------------------------------------------- # ConfidenceFactors # --------------------------------------------------------------------------- class TestConfidenceFactors: def test_creation(self): cf = make_confidence_factors() assert cf.num_corroborating_sources == 3 assert cf.source_authority == "high" assert cf.contradiction_detected is False assert cf.recency == "current" def test_recency_none(self): cf = make_confidence_factors(recency=None) assert cf.recency is None def test_query_specificity_bounds(self): with pytest.raises(ValidationError): make_confidence_factors(query_specificity_match=1.5) with pytest.raises(ValidationError): make_confidence_factors(query_specificity_match=-0.1) def test_serialization_roundtrip(self): cf = make_confidence_factors() data = cf.model_dump() cf2 = ConfidenceFactors(**data) assert cf == cf2 # --------------------------------------------------------------------------- # CostMetadata # --------------------------------------------------------------------------- class TestCostMetadata: def test_creation(self): cm = make_cost_metadata() assert cm.tokens_used == 8452 assert cm.model_id == "claude-sonnet-4-6" def test_model_id_required(self): with pytest.raises(ValidationError): CostMetadata( tokens_used=100, iterations_run=1, wall_time_sec=1.0, budget_exhausted=False, ) def test_non_negative_constraints(self): with pytest.raises(ValidationError): make_cost_metadata(tokens_used=-1) with pytest.raises(ValidationError): make_cost_metadata(wall_time_sec=-0.5) def test_serialization_roundtrip(self): cm = make_cost_metadata() data = cm.model_dump() cm2 = CostMetadata(**data) assert cm == cm2 # --------------------------------------------------------------------------- # ResearchResult (full contract) # --------------------------------------------------------------------------- class TestResearchResult: def test_full_result(self): r = make_research_result() assert r.answer.startswith("Utah") assert len(r.citations) == 1 assert len(r.gaps) == 1 assert len(r.discovery_events) == 1 assert 0.0 <= r.confidence <= 1.0 assert r.cost_metadata.model_id == "claude-sonnet-4-6" def test_empty_lists_allowed(self): r = make_research_result( citations=[], gaps=[], discovery_events=[] ) assert r.citations == [] assert r.gaps == [] assert r.discovery_events == [] def test_confidence_bounds(self): with pytest.raises(ValidationError): make_research_result(confidence=1.5) def test_full_json_roundtrip(self): r = make_research_result() json_str = r.model_dump_json() data = json.loads(json_str) r2 = ResearchResult(**data) assert r == r2 def test_json_structure(self): """Verify the JSON output matches the contract schema.""" r = make_research_result() data = json.loads(r.model_dump_json()) # Top-level keys expected_keys = { "answer", "citations", "gaps", "discovery_events", "confidence", "confidence_factors", "cost_metadata", "trace_id", } assert set(data.keys()) == expected_keys # Citation keys citation_keys = { "source", "locator", "title", "snippet", "raw_excerpt", "confidence", } assert set(data["citations"][0].keys()) == citation_keys # Gap keys gap_keys = {"topic", "category", "detail"} assert set(data["gaps"][0].keys()) == gap_keys # Gap category is a string value assert data["gaps"][0]["category"] == "source_not_found" # CostMetadata includes model_id assert "model_id" in data["cost_metadata"] # ConfidenceFactors keys cf_keys = { "num_corroborating_sources", "source_authority", "contradiction_detected", "query_specificity_match", "budget_exhausted", "recency", } assert set(data["confidence_factors"].keys()) == cf_keys