marchwarden/tests/test_models.py

399 lines
12 KiB
Python
Raw Permalink Normal View History

"""Tests for the Marchwarden Research Contract v1 models."""
import json
import uuid
import pytest
from pydantic import ValidationError
from researchers.web.models import (
Citation,
ConfidenceFactors,
CostMetadata,
DiscoveryEvent,
Gap,
GapCategory,
ResearchConstraints,
ResearchResult,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def make_citation(**overrides) -> Citation:
defaults = {
"source": "web",
"locator": "https://example.com/article",
"title": "Example Article",
"snippet": "Relevant summary of the content.",
"raw_excerpt": "Verbatim text copied directly from the source document.",
"confidence": 0.85,
}
defaults.update(overrides)
return Citation(**defaults)
def make_gap(**overrides) -> Gap:
defaults = {
"topic": "pest management",
"category": GapCategory.SOURCE_NOT_FOUND,
"detail": "No pest data found in general web sources.",
}
defaults.update(overrides)
return Gap(**defaults)
def make_discovery_event(**overrides) -> DiscoveryEvent:
defaults = {
"type": "related_research",
"suggested_researcher": "arxiv",
"query": "soil salinity studies Utah 2024-2026",
"reason": "Multiple web sources reference USU study data",
"source_locator": "https://example.com/reference",
}
defaults.update(overrides)
return DiscoveryEvent(**defaults)
def make_confidence_factors(**overrides) -> ConfidenceFactors:
defaults = {
"num_corroborating_sources": 3,
"source_authority": "high",
"contradiction_detected": False,
"query_specificity_match": 0.85,
"budget_exhausted": False,
"recency": "current",
}
defaults.update(overrides)
return ConfidenceFactors(**defaults)
def make_cost_metadata(**overrides) -> CostMetadata:
defaults = {
"tokens_used": 8452,
"iterations_run": 3,
"wall_time_sec": 42.5,
"budget_exhausted": False,
"model_id": "claude-sonnet-4-6",
}
defaults.update(overrides)
return CostMetadata(**defaults)
def make_research_result(**overrides) -> ResearchResult:
defaults = {
"answer": "Utah is ideal for cool-season crops at high elevation.",
"citations": [make_citation()],
"gaps": [make_gap()],
"discovery_events": [make_discovery_event()],
"confidence": 0.82,
"confidence_factors": make_confidence_factors(),
"cost_metadata": make_cost_metadata(),
"trace_id": str(uuid.uuid4()),
}
defaults.update(overrides)
return ResearchResult(**defaults)
# ---------------------------------------------------------------------------
# ResearchConstraints
# ---------------------------------------------------------------------------
class TestResearchConstraints:
def test_defaults(self):
c = ResearchConstraints()
assert c.max_iterations == 5
assert c.token_budget == 20_000
assert c.max_sources == 10
assert c.source_filter is None
def test_custom_values(self):
c = ResearchConstraints(
max_iterations=3, token_budget=5000, max_sources=5
)
assert c.max_iterations == 3
assert c.token_budget == 5000
assert c.max_sources == 5
def test_invalid_iterations(self):
with pytest.raises(ValidationError):
ResearchConstraints(max_iterations=0)
def test_invalid_token_budget(self):
with pytest.raises(ValidationError):
ResearchConstraints(token_budget=500)
def test_serialization_roundtrip(self):
c = ResearchConstraints(max_iterations=3, token_budget=10000)
data = c.model_dump()
c2 = ResearchConstraints(**data)
assert c == c2
# ---------------------------------------------------------------------------
# Citation
# ---------------------------------------------------------------------------
class TestCitation:
def test_full_citation(self):
c = make_citation()
assert c.source == "web"
assert c.raw_excerpt.startswith("Verbatim")
assert 0.0 <= c.confidence <= 1.0
def test_minimal_citation(self):
c = Citation(
source="web",
locator="https://example.com",
raw_excerpt="Some text.",
confidence=0.5,
)
assert c.title is None
assert c.snippet is None
def test_confidence_bounds(self):
with pytest.raises(ValidationError):
make_citation(confidence=1.5)
with pytest.raises(ValidationError):
make_citation(confidence=-0.1)
def test_raw_excerpt_required(self):
with pytest.raises(ValidationError):
Citation(source="web", locator="https://example.com", confidence=0.5)
def test_serialization_roundtrip(self):
c = make_citation()
data = c.model_dump()
c2 = Citation(**data)
assert c == c2
# ---------------------------------------------------------------------------
# GapCategory
# ---------------------------------------------------------------------------
class TestGapCategory:
def test_all_categories_exist(self):
expected = {
"source_not_found",
"access_denied",
"budget_exhausted",
"contradictory_sources",
"scope_exceeded",
}
actual = {cat.value for cat in GapCategory}
assert actual == expected
def test_string_enum(self):
assert GapCategory.SOURCE_NOT_FOUND == "source_not_found"
assert isinstance(GapCategory.ACCESS_DENIED, str)
# ---------------------------------------------------------------------------
# Gap
# ---------------------------------------------------------------------------
class TestGap:
def test_gap_creation(self):
g = make_gap()
assert g.category == GapCategory.SOURCE_NOT_FOUND
assert g.topic == "pest management"
def test_all_categories_accepted(self):
for cat in GapCategory:
g = make_gap(category=cat)
assert g.category == cat
def test_serialization_roundtrip(self):
g = make_gap()
data = g.model_dump()
g2 = Gap(**data)
assert g == g2
def test_json_uses_string_category(self):
g = make_gap(category=GapCategory.BUDGET_EXHAUSTED)
data = json.loads(g.model_dump_json())
assert data["category"] == "budget_exhausted"
# ---------------------------------------------------------------------------
# DiscoveryEvent
# ---------------------------------------------------------------------------
class TestDiscoveryEvent:
def test_full_event(self):
e = make_discovery_event()
assert e.type == "related_research"
assert e.suggested_researcher == "arxiv"
def test_minimal_event(self):
e = DiscoveryEvent(
type="contradiction",
query="conflicting data on topic X",
reason="Two sources disagree",
)
assert e.suggested_researcher is None
assert e.source_locator is None
def test_serialization_roundtrip(self):
e = make_discovery_event()
data = e.model_dump()
e2 = DiscoveryEvent(**data)
assert e == e2
# ---------------------------------------------------------------------------
# ConfidenceFactors
# ---------------------------------------------------------------------------
class TestConfidenceFactors:
def test_creation(self):
cf = make_confidence_factors()
assert cf.num_corroborating_sources == 3
assert cf.source_authority == "high"
assert cf.contradiction_detected is False
assert cf.recency == "current"
def test_recency_none(self):
cf = make_confidence_factors(recency=None)
assert cf.recency is None
def test_query_specificity_bounds(self):
with pytest.raises(ValidationError):
make_confidence_factors(query_specificity_match=1.5)
with pytest.raises(ValidationError):
make_confidence_factors(query_specificity_match=-0.1)
def test_serialization_roundtrip(self):
cf = make_confidence_factors()
data = cf.model_dump()
cf2 = ConfidenceFactors(**data)
assert cf == cf2
# ---------------------------------------------------------------------------
# CostMetadata
# ---------------------------------------------------------------------------
class TestCostMetadata:
def test_creation(self):
cm = make_cost_metadata()
assert cm.tokens_used == 8452
assert cm.model_id == "claude-sonnet-4-6"
def test_model_id_required(self):
with pytest.raises(ValidationError):
CostMetadata(
tokens_used=100,
iterations_run=1,
wall_time_sec=1.0,
budget_exhausted=False,
)
def test_non_negative_constraints(self):
with pytest.raises(ValidationError):
make_cost_metadata(tokens_used=-1)
with pytest.raises(ValidationError):
make_cost_metadata(wall_time_sec=-0.5)
def test_serialization_roundtrip(self):
cm = make_cost_metadata()
data = cm.model_dump()
cm2 = CostMetadata(**data)
assert cm == cm2
# ---------------------------------------------------------------------------
# ResearchResult (full contract)
# ---------------------------------------------------------------------------
class TestResearchResult:
def test_full_result(self):
r = make_research_result()
assert r.answer.startswith("Utah")
assert len(r.citations) == 1
assert len(r.gaps) == 1
assert len(r.discovery_events) == 1
assert 0.0 <= r.confidence <= 1.0
assert r.cost_metadata.model_id == "claude-sonnet-4-6"
def test_empty_lists_allowed(self):
r = make_research_result(
citations=[], gaps=[], discovery_events=[]
)
assert r.citations == []
assert r.gaps == []
assert r.discovery_events == []
def test_confidence_bounds(self):
with pytest.raises(ValidationError):
make_research_result(confidence=1.5)
def test_full_json_roundtrip(self):
r = make_research_result()
json_str = r.model_dump_json()
data = json.loads(json_str)
r2 = ResearchResult(**data)
assert r == r2
def test_json_structure(self):
"""Verify the JSON output matches the contract schema."""
r = make_research_result()
data = json.loads(r.model_dump_json())
# Top-level keys
expected_keys = {
"answer",
"citations",
"gaps",
"discovery_events",
"confidence",
"confidence_factors",
"cost_metadata",
"trace_id",
}
assert set(data.keys()) == expected_keys
# Citation keys
citation_keys = {
"source",
"locator",
"title",
"snippet",
"raw_excerpt",
"confidence",
}
assert set(data["citations"][0].keys()) == citation_keys
# Gap keys
gap_keys = {"topic", "category", "detail"}
assert set(data["gaps"][0].keys()) == gap_keys
# Gap category is a string value
assert data["gaps"][0]["category"] == "source_not_found"
# CostMetadata includes model_id
assert "model_id" in data["cost_metadata"]
# ConfidenceFactors keys
cf_keys = {
"num_corroborating_sources",
"source_authority",
"contradiction_detected",
"query_specificity_match",
"budget_exhausted",
"recency",
}
assert set(data["confidence_factors"].keys()) == cf_keys