marchwarden/tests/test_models.py

399 lines
12 KiB
Python
Raw Normal View History

"""Tests for the Marchwarden Research Contract v1 models."""
import json
import uuid
import pytest
from pydantic import ValidationError
from researchers.web.models import (
Citation,
ConfidenceFactors,
CostMetadata,
DiscoveryEvent,
Gap,
GapCategory,
ResearchConstraints,
ResearchResult,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def make_citation(**overrides) -> Citation:
defaults = {
"source": "web",
"locator": "https://example.com/article",
"title": "Example Article",
"snippet": "Relevant summary of the content.",
"raw_excerpt": "Verbatim text copied directly from the source document.",
"confidence": 0.85,
}
defaults.update(overrides)
return Citation(**defaults)
def make_gap(**overrides) -> Gap:
defaults = {
"topic": "pest management",
"category": GapCategory.SOURCE_NOT_FOUND,
"detail": "No pest data found in general web sources.",
}
defaults.update(overrides)
return Gap(**defaults)
def make_discovery_event(**overrides) -> DiscoveryEvent:
defaults = {
"type": "related_research",
"suggested_researcher": "arxiv",
"query": "soil salinity studies Utah 2024-2026",
"reason": "Multiple web sources reference USU study data",
"source_locator": "https://example.com/reference",
}
defaults.update(overrides)
return DiscoveryEvent(**defaults)
def make_confidence_factors(**overrides) -> ConfidenceFactors:
defaults = {
"num_corroborating_sources": 3,
"source_authority": "high",
"contradiction_detected": False,
"query_specificity_match": 0.85,
"budget_exhausted": False,
"recency": "current",
}
defaults.update(overrides)
return ConfidenceFactors(**defaults)
def make_cost_metadata(**overrides) -> CostMetadata:
defaults = {
"tokens_used": 8452,
"iterations_run": 3,
"wall_time_sec": 42.5,
"budget_exhausted": False,
"model_id": "claude-sonnet-4-6",
}
defaults.update(overrides)
return CostMetadata(**defaults)
def make_research_result(**overrides) -> ResearchResult:
defaults = {
"answer": "Utah is ideal for cool-season crops at high elevation.",
"citations": [make_citation()],
"gaps": [make_gap()],
"discovery_events": [make_discovery_event()],
"confidence": 0.82,
"confidence_factors": make_confidence_factors(),
"cost_metadata": make_cost_metadata(),
"trace_id": str(uuid.uuid4()),
}
defaults.update(overrides)
return ResearchResult(**defaults)
# ---------------------------------------------------------------------------
# ResearchConstraints
# ---------------------------------------------------------------------------
class TestResearchConstraints:
def test_defaults(self):
c = ResearchConstraints()
assert c.max_iterations == 5
assert c.token_budget == 20_000
assert c.max_sources == 10
assert c.source_filter is None
def test_custom_values(self):
c = ResearchConstraints(
max_iterations=3, token_budget=5000, max_sources=5
)
assert c.max_iterations == 3
assert c.token_budget == 5000
assert c.max_sources == 5
def test_invalid_iterations(self):
with pytest.raises(ValidationError):
ResearchConstraints(max_iterations=0)
def test_invalid_token_budget(self):
with pytest.raises(ValidationError):
ResearchConstraints(token_budget=500)
def test_serialization_roundtrip(self):
c = ResearchConstraints(max_iterations=3, token_budget=10000)
data = c.model_dump()
c2 = ResearchConstraints(**data)
assert c == c2
# ---------------------------------------------------------------------------
# Citation
# ---------------------------------------------------------------------------
class TestCitation:
def test_full_citation(self):
c = make_citation()
assert c.source == "web"
assert c.raw_excerpt.startswith("Verbatim")
assert 0.0 <= c.confidence <= 1.0
def test_minimal_citation(self):
c = Citation(
source="web",
locator="https://example.com",
raw_excerpt="Some text.",
confidence=0.5,
)
assert c.title is None
assert c.snippet is None
def test_confidence_bounds(self):
with pytest.raises(ValidationError):
make_citation(confidence=1.5)
with pytest.raises(ValidationError):
make_citation(confidence=-0.1)
def test_raw_excerpt_required(self):
with pytest.raises(ValidationError):
Citation(source="web", locator="https://example.com", confidence=0.5)
def test_serialization_roundtrip(self):
c = make_citation()
data = c.model_dump()
c2 = Citation(**data)
assert c == c2
# ---------------------------------------------------------------------------
# GapCategory
# ---------------------------------------------------------------------------
class TestGapCategory:
def test_all_categories_exist(self):
expected = {
"source_not_found",
"access_denied",
"budget_exhausted",
"contradictory_sources",
"scope_exceeded",
}
actual = {cat.value for cat in GapCategory}
assert actual == expected
def test_string_enum(self):
assert GapCategory.SOURCE_NOT_FOUND == "source_not_found"
assert isinstance(GapCategory.ACCESS_DENIED, str)
# ---------------------------------------------------------------------------
# Gap
# ---------------------------------------------------------------------------
class TestGap:
def test_gap_creation(self):
g = make_gap()
assert g.category == GapCategory.SOURCE_NOT_FOUND
assert g.topic == "pest management"
def test_all_categories_accepted(self):
for cat in GapCategory:
g = make_gap(category=cat)
assert g.category == cat
def test_serialization_roundtrip(self):
g = make_gap()
data = g.model_dump()
g2 = Gap(**data)
assert g == g2
def test_json_uses_string_category(self):
g = make_gap(category=GapCategory.BUDGET_EXHAUSTED)
data = json.loads(g.model_dump_json())
assert data["category"] == "budget_exhausted"
# ---------------------------------------------------------------------------
# DiscoveryEvent
# ---------------------------------------------------------------------------
class TestDiscoveryEvent:
def test_full_event(self):
e = make_discovery_event()
assert e.type == "related_research"
assert e.suggested_researcher == "arxiv"
def test_minimal_event(self):
e = DiscoveryEvent(
type="contradiction",
query="conflicting data on topic X",
reason="Two sources disagree",
)
assert e.suggested_researcher is None
assert e.source_locator is None
def test_serialization_roundtrip(self):
e = make_discovery_event()
data = e.model_dump()
e2 = DiscoveryEvent(**data)
assert e == e2
# ---------------------------------------------------------------------------
# ConfidenceFactors
# ---------------------------------------------------------------------------
class TestConfidenceFactors:
def test_creation(self):
cf = make_confidence_factors()
assert cf.num_corroborating_sources == 3
assert cf.source_authority == "high"
assert cf.contradiction_detected is False
assert cf.recency == "current"
def test_recency_none(self):
cf = make_confidence_factors(recency=None)
assert cf.recency is None
def test_query_specificity_bounds(self):
with pytest.raises(ValidationError):
make_confidence_factors(query_specificity_match=1.5)
with pytest.raises(ValidationError):
make_confidence_factors(query_specificity_match=-0.1)
def test_serialization_roundtrip(self):
cf = make_confidence_factors()
data = cf.model_dump()
cf2 = ConfidenceFactors(**data)
assert cf == cf2
# ---------------------------------------------------------------------------
# CostMetadata
# ---------------------------------------------------------------------------
class TestCostMetadata:
def test_creation(self):
cm = make_cost_metadata()
assert cm.tokens_used == 8452
assert cm.model_id == "claude-sonnet-4-6"
def test_model_id_required(self):
with pytest.raises(ValidationError):
CostMetadata(
tokens_used=100,
iterations_run=1,
wall_time_sec=1.0,
budget_exhausted=False,
)
def test_non_negative_constraints(self):
with pytest.raises(ValidationError):
make_cost_metadata(tokens_used=-1)
with pytest.raises(ValidationError):
make_cost_metadata(wall_time_sec=-0.5)
def test_serialization_roundtrip(self):
cm = make_cost_metadata()
data = cm.model_dump()
cm2 = CostMetadata(**data)
assert cm == cm2
# ---------------------------------------------------------------------------
# ResearchResult (full contract)
# ---------------------------------------------------------------------------
class TestResearchResult:
def test_full_result(self):
r = make_research_result()
assert r.answer.startswith("Utah")
assert len(r.citations) == 1
assert len(r.gaps) == 1
assert len(r.discovery_events) == 1
assert 0.0 <= r.confidence <= 1.0
assert r.cost_metadata.model_id == "claude-sonnet-4-6"
def test_empty_lists_allowed(self):
r = make_research_result(
citations=[], gaps=[], discovery_events=[]
)
assert r.citations == []
assert r.gaps == []
assert r.discovery_events == []
def test_confidence_bounds(self):
with pytest.raises(ValidationError):
make_research_result(confidence=1.5)
def test_full_json_roundtrip(self):
r = make_research_result()
json_str = r.model_dump_json()
data = json.loads(json_str)
r2 = ResearchResult(**data)
assert r == r2
def test_json_structure(self):
"""Verify the JSON output matches the contract schema."""
r = make_research_result()
data = json.loads(r.model_dump_json())
# Top-level keys
expected_keys = {
"answer",
"citations",
"gaps",
"discovery_events",
"confidence",
"confidence_factors",
"cost_metadata",
"trace_id",
}
assert set(data.keys()) == expected_keys
# Citation keys
citation_keys = {
"source",
"locator",
"title",
"snippet",
"raw_excerpt",
"confidence",
}
assert set(data["citations"][0].keys()) == citation_keys
# Gap keys
gap_keys = {"topic", "category", "detail"}
assert set(data["gaps"][0].keys()) == gap_keys
# Gap category is a string value
assert data["gaps"][0]["category"] == "source_not_found"
# CostMetadata includes model_id
assert "model_id" in data["cost_metadata"]
# ConfidenceFactors keys
cf_keys = {
"num_corroborating_sources",
"source_authority",
"contradiction_detected",
"query_specificity_match",
"budget_exhausted",
"recency",
}
assert set(data["confidence_factors"].keys()) == cf_keys