From 1b0f86399ae88517edb31c3499f45611679e441c Mon Sep 17 00:00:00 2001 From: Jeff Smith Date: Wed, 8 Apr 2026 14:00:45 -0600 Subject: [PATCH] M0.3: Implement contract v1 Pydantic models with tests All Research Contract types as Pydantic models: - ResearchConstraints (input) - Citation with raw_excerpt (output) - GapCategory enum (5 categories) - Gap with structured category (output) - DiscoveryEvent (lateral findings) - ConfidenceFactors (auditable scoring inputs) - CostMetadata with model_id (resource tracking) - ResearchResult (top-level contract) 32 tests: validation, bounds checking, serialization roundtrips, JSON structure verification against contract spec. Refs: archeious/marchwarden#1 Co-Authored-By: Claude Haiku 4.5 --- researchers/web/models.py | 231 ++++++++++++++++++++++ tests/__init__.py | 0 tests/test_models.py | 398 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 629 insertions(+) create mode 100644 researchers/web/models.py create mode 100644 tests/__init__.py create mode 100644 tests/test_models.py diff --git a/researchers/web/models.py b/researchers/web/models.py new file mode 100644 index 0000000..56f95f8 --- /dev/null +++ b/researchers/web/models.py @@ -0,0 +1,231 @@ +"""Marchwarden Research Contract v1 — Pydantic models. + +These models define the stable contract between a researcher MCP server +and its caller (PI agent or CLI shim). Changes to required fields or +types require a contract version bump. +""" + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +# --------------------------------------------------------------------------- +# Input types +# --------------------------------------------------------------------------- + + +class ResearchConstraints(BaseModel): + """Fine-grained control over researcher behavior.""" + + max_iterations: int = Field( + default=5, + ge=1, + le=20, + description="Stop after N iterations, regardless of progress.", + ) + token_budget: int = Field( + default=20_000, + ge=1_000, + description="Soft limit on total tokens consumed by the research loop.", + ) + max_sources: int = Field( + default=10, + ge=1, + description="Maximum number of sources to fetch and extract.", + ) + source_filter: Optional[str] = Field( + default=None, + description="Restrict search to specific domains (V2). E.g. '.gov,.edu'.", + ) + + +# --------------------------------------------------------------------------- +# Output types — Citation +# --------------------------------------------------------------------------- + + +class Citation(BaseModel): + """A single source used by the researcher, with raw evidence.""" + + source: str = Field( + description="Source type: 'web', 'file', 'database', etc.", + ) + locator: str = Field( + description="URL, file path, row ID, or unique identifier.", + ) + title: Optional[str] = Field( + default=None, + description="Human-readable title (for web sources).", + ) + snippet: Optional[str] = Field( + default=None, + description="Researcher's summary of relevant content (50-200 chars).", + ) + raw_excerpt: str = Field( + description=( + "Verbatim text from the source (up to 500 chars). " + "Bypasses researcher synthesis to prevent the Synthesis Paradox." + ), + ) + confidence: float = Field( + ge=0.0, + le=1.0, + description="Researcher's confidence in this source's accuracy.", + ) + + +# --------------------------------------------------------------------------- +# Output types — Gap +# --------------------------------------------------------------------------- + + +class GapCategory(str, Enum): + """Categorized reason a gap exists. Drives PI decision-making.""" + + SOURCE_NOT_FOUND = "source_not_found" + ACCESS_DENIED = "access_denied" + BUDGET_EXHAUSTED = "budget_exhausted" + CONTRADICTORY_SOURCES = "contradictory_sources" + SCOPE_EXCEEDED = "scope_exceeded" + + +class Gap(BaseModel): + """An unresolved aspect of the research question.""" + + topic: str = Field( + description="What aspect wasn't resolved.", + ) + category: GapCategory = Field( + description="Structured reason category.", + ) + detail: str = Field( + description="Human-readable explanation of why this gap exists.", + ) + + +# --------------------------------------------------------------------------- +# Output types — DiscoveryEvent +# --------------------------------------------------------------------------- + + +class DiscoveryEvent(BaseModel): + """A lateral finding relevant to another researcher's domain.""" + + type: str = Field( + description="Event type: 'related_research', 'new_source', 'contradiction'.", + ) + suggested_researcher: Optional[str] = Field( + default=None, + description="Target researcher type: 'arxiv', 'database', 'legal', etc.", + ) + query: str = Field( + description="Suggested query for the target researcher.", + ) + reason: str = Field( + description="Why this is relevant to the overall investigation.", + ) + source_locator: Optional[str] = Field( + default=None, + description="Where the discovery was found (URL, DOI, etc.).", + ) + + +# --------------------------------------------------------------------------- +# Output types — Confidence +# --------------------------------------------------------------------------- + + +class ConfidenceFactors(BaseModel): + """Inputs that fed the confidence score. Enables auditability and future calibration.""" + + num_corroborating_sources: int = Field( + ge=0, + description="How many sources agree on the core claims.", + ) + source_authority: str = Field( + description="'high' (.gov, .edu, peer-reviewed), 'medium' (established orgs), 'low' (blogs, forums).", + ) + contradiction_detected: bool = Field( + description="Were conflicting claims found across sources?", + ) + query_specificity_match: float = Field( + ge=0.0, + le=1.0, + description="How well the results address the actual question (0.0-1.0).", + ) + budget_exhausted: bool = Field( + description="True if the researcher hit its iteration or token cap.", + ) + recency: Optional[str] = Field( + default=None, + description="'current' (< 1yr), 'recent' (1-3yr), 'dated' (> 3yr), None if unknown.", + ) + + +# --------------------------------------------------------------------------- +# Output types — CostMetadata +# --------------------------------------------------------------------------- + + +class CostMetadata(BaseModel): + """Resource usage for a single research call.""" + + tokens_used: int = Field( + ge=0, + description="Total tokens consumed (Claude + search API calls).", + ) + iterations_run: int = Field( + ge=0, + description="Number of inner-loop iterations completed.", + ) + wall_time_sec: float = Field( + ge=0.0, + description="Actual elapsed wall-clock time in seconds.", + ) + budget_exhausted: bool = Field( + description="True if the researcher hit its iteration or token cap.", + ) + model_id: str = Field( + description="Model used for the research loop (e.g. 'claude-sonnet-4-6').", + ) + + +# --------------------------------------------------------------------------- +# Top-level output +# --------------------------------------------------------------------------- + + +class ResearchResult(BaseModel): + """Complete result from a single research() call. This is the contract.""" + + answer: str = Field( + description="The synthesized answer. Every claim must trace to a citation.", + ) + citations: list[Citation] = Field( + default_factory=list, + description="Sources used, with raw evidence.", + ) + gaps: list[Gap] = Field( + default_factory=list, + description="What couldn't be resolved, categorized by cause.", + ) + discovery_events: list[DiscoveryEvent] = Field( + default_factory=list, + description="Lateral findings for other researchers.", + ) + confidence: float = Field( + ge=0.0, + le=1.0, + description="Overall confidence in the answer (0.0-1.0).", + ) + confidence_factors: ConfidenceFactors = Field( + description="What fed the confidence score.", + ) + cost_metadata: CostMetadata = Field( + description="Resource usage for this research call.", + ) + trace_id: str = Field( + description="UUID linking to the JSONL trace log.", + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..2580140 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,398 @@ +"""Tests for the Marchwarden Research Contract v1 models.""" + +import json +import uuid + +import pytest +from pydantic import ValidationError + +from researchers.web.models import ( + Citation, + ConfidenceFactors, + CostMetadata, + DiscoveryEvent, + Gap, + GapCategory, + ResearchConstraints, + ResearchResult, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def make_citation(**overrides) -> Citation: + defaults = { + "source": "web", + "locator": "https://example.com/article", + "title": "Example Article", + "snippet": "Relevant summary of the content.", + "raw_excerpt": "Verbatim text copied directly from the source document.", + "confidence": 0.85, + } + defaults.update(overrides) + return Citation(**defaults) + + +def make_gap(**overrides) -> Gap: + defaults = { + "topic": "pest management", + "category": GapCategory.SOURCE_NOT_FOUND, + "detail": "No pest data found in general web sources.", + } + defaults.update(overrides) + return Gap(**defaults) + + +def make_discovery_event(**overrides) -> DiscoveryEvent: + defaults = { + "type": "related_research", + "suggested_researcher": "arxiv", + "query": "soil salinity studies Utah 2024-2026", + "reason": "Multiple web sources reference USU study data", + "source_locator": "https://example.com/reference", + } + defaults.update(overrides) + return DiscoveryEvent(**defaults) + + +def make_confidence_factors(**overrides) -> ConfidenceFactors: + defaults = { + "num_corroborating_sources": 3, + "source_authority": "high", + "contradiction_detected": False, + "query_specificity_match": 0.85, + "budget_exhausted": False, + "recency": "current", + } + defaults.update(overrides) + return ConfidenceFactors(**defaults) + + +def make_cost_metadata(**overrides) -> CostMetadata: + defaults = { + "tokens_used": 8452, + "iterations_run": 3, + "wall_time_sec": 42.5, + "budget_exhausted": False, + "model_id": "claude-sonnet-4-6", + } + defaults.update(overrides) + return CostMetadata(**defaults) + + +def make_research_result(**overrides) -> ResearchResult: + defaults = { + "answer": "Utah is ideal for cool-season crops at high elevation.", + "citations": [make_citation()], + "gaps": [make_gap()], + "discovery_events": [make_discovery_event()], + "confidence": 0.82, + "confidence_factors": make_confidence_factors(), + "cost_metadata": make_cost_metadata(), + "trace_id": str(uuid.uuid4()), + } + defaults.update(overrides) + return ResearchResult(**defaults) + + +# --------------------------------------------------------------------------- +# ResearchConstraints +# --------------------------------------------------------------------------- + + +class TestResearchConstraints: + def test_defaults(self): + c = ResearchConstraints() + assert c.max_iterations == 5 + assert c.token_budget == 20_000 + assert c.max_sources == 10 + assert c.source_filter is None + + def test_custom_values(self): + c = ResearchConstraints( + max_iterations=3, token_budget=5000, max_sources=5 + ) + assert c.max_iterations == 3 + assert c.token_budget == 5000 + assert c.max_sources == 5 + + def test_invalid_iterations(self): + with pytest.raises(ValidationError): + ResearchConstraints(max_iterations=0) + + def test_invalid_token_budget(self): + with pytest.raises(ValidationError): + ResearchConstraints(token_budget=500) + + def test_serialization_roundtrip(self): + c = ResearchConstraints(max_iterations=3, token_budget=10000) + data = c.model_dump() + c2 = ResearchConstraints(**data) + assert c == c2 + + +# --------------------------------------------------------------------------- +# Citation +# --------------------------------------------------------------------------- + + +class TestCitation: + def test_full_citation(self): + c = make_citation() + assert c.source == "web" + assert c.raw_excerpt.startswith("Verbatim") + assert 0.0 <= c.confidence <= 1.0 + + def test_minimal_citation(self): + c = Citation( + source="web", + locator="https://example.com", + raw_excerpt="Some text.", + confidence=0.5, + ) + assert c.title is None + assert c.snippet is None + + def test_confidence_bounds(self): + with pytest.raises(ValidationError): + make_citation(confidence=1.5) + with pytest.raises(ValidationError): + make_citation(confidence=-0.1) + + def test_raw_excerpt_required(self): + with pytest.raises(ValidationError): + Citation(source="web", locator="https://example.com", confidence=0.5) + + def test_serialization_roundtrip(self): + c = make_citation() + data = c.model_dump() + c2 = Citation(**data) + assert c == c2 + + +# --------------------------------------------------------------------------- +# GapCategory +# --------------------------------------------------------------------------- + + +class TestGapCategory: + def test_all_categories_exist(self): + expected = { + "source_not_found", + "access_denied", + "budget_exhausted", + "contradictory_sources", + "scope_exceeded", + } + actual = {cat.value for cat in GapCategory} + assert actual == expected + + def test_string_enum(self): + assert GapCategory.SOURCE_NOT_FOUND == "source_not_found" + assert isinstance(GapCategory.ACCESS_DENIED, str) + + +# --------------------------------------------------------------------------- +# Gap +# --------------------------------------------------------------------------- + + +class TestGap: + def test_gap_creation(self): + g = make_gap() + assert g.category == GapCategory.SOURCE_NOT_FOUND + assert g.topic == "pest management" + + def test_all_categories_accepted(self): + for cat in GapCategory: + g = make_gap(category=cat) + assert g.category == cat + + def test_serialization_roundtrip(self): + g = make_gap() + data = g.model_dump() + g2 = Gap(**data) + assert g == g2 + + def test_json_uses_string_category(self): + g = make_gap(category=GapCategory.BUDGET_EXHAUSTED) + data = json.loads(g.model_dump_json()) + assert data["category"] == "budget_exhausted" + + +# --------------------------------------------------------------------------- +# DiscoveryEvent +# --------------------------------------------------------------------------- + + +class TestDiscoveryEvent: + def test_full_event(self): + e = make_discovery_event() + assert e.type == "related_research" + assert e.suggested_researcher == "arxiv" + + def test_minimal_event(self): + e = DiscoveryEvent( + type="contradiction", + query="conflicting data on topic X", + reason="Two sources disagree", + ) + assert e.suggested_researcher is None + assert e.source_locator is None + + def test_serialization_roundtrip(self): + e = make_discovery_event() + data = e.model_dump() + e2 = DiscoveryEvent(**data) + assert e == e2 + + +# --------------------------------------------------------------------------- +# ConfidenceFactors +# --------------------------------------------------------------------------- + + +class TestConfidenceFactors: + def test_creation(self): + cf = make_confidence_factors() + assert cf.num_corroborating_sources == 3 + assert cf.source_authority == "high" + assert cf.contradiction_detected is False + assert cf.recency == "current" + + def test_recency_none(self): + cf = make_confidence_factors(recency=None) + assert cf.recency is None + + def test_query_specificity_bounds(self): + with pytest.raises(ValidationError): + make_confidence_factors(query_specificity_match=1.5) + with pytest.raises(ValidationError): + make_confidence_factors(query_specificity_match=-0.1) + + def test_serialization_roundtrip(self): + cf = make_confidence_factors() + data = cf.model_dump() + cf2 = ConfidenceFactors(**data) + assert cf == cf2 + + +# --------------------------------------------------------------------------- +# CostMetadata +# --------------------------------------------------------------------------- + + +class TestCostMetadata: + def test_creation(self): + cm = make_cost_metadata() + assert cm.tokens_used == 8452 + assert cm.model_id == "claude-sonnet-4-6" + + def test_model_id_required(self): + with pytest.raises(ValidationError): + CostMetadata( + tokens_used=100, + iterations_run=1, + wall_time_sec=1.0, + budget_exhausted=False, + ) + + def test_non_negative_constraints(self): + with pytest.raises(ValidationError): + make_cost_metadata(tokens_used=-1) + with pytest.raises(ValidationError): + make_cost_metadata(wall_time_sec=-0.5) + + def test_serialization_roundtrip(self): + cm = make_cost_metadata() + data = cm.model_dump() + cm2 = CostMetadata(**data) + assert cm == cm2 + + +# --------------------------------------------------------------------------- +# ResearchResult (full contract) +# --------------------------------------------------------------------------- + + +class TestResearchResult: + def test_full_result(self): + r = make_research_result() + assert r.answer.startswith("Utah") + assert len(r.citations) == 1 + assert len(r.gaps) == 1 + assert len(r.discovery_events) == 1 + assert 0.0 <= r.confidence <= 1.0 + assert r.cost_metadata.model_id == "claude-sonnet-4-6" + + def test_empty_lists_allowed(self): + r = make_research_result( + citations=[], gaps=[], discovery_events=[] + ) + assert r.citations == [] + assert r.gaps == [] + assert r.discovery_events == [] + + def test_confidence_bounds(self): + with pytest.raises(ValidationError): + make_research_result(confidence=1.5) + + def test_full_json_roundtrip(self): + r = make_research_result() + json_str = r.model_dump_json() + data = json.loads(json_str) + r2 = ResearchResult(**data) + assert r == r2 + + def test_json_structure(self): + """Verify the JSON output matches the contract schema.""" + r = make_research_result() + data = json.loads(r.model_dump_json()) + + # Top-level keys + expected_keys = { + "answer", + "citations", + "gaps", + "discovery_events", + "confidence", + "confidence_factors", + "cost_metadata", + "trace_id", + } + assert set(data.keys()) == expected_keys + + # Citation keys + citation_keys = { + "source", + "locator", + "title", + "snippet", + "raw_excerpt", + "confidence", + } + assert set(data["citations"][0].keys()) == citation_keys + + # Gap keys + gap_keys = {"topic", "category", "detail"} + assert set(data["gaps"][0].keys()) == gap_keys + + # Gap category is a string value + assert data["gaps"][0]["category"] == "source_not_found" + + # CostMetadata includes model_id + assert "model_id" in data["cost_metadata"] + + # ConfidenceFactors keys + cf_keys = { + "num_corroborating_sources", + "source_authority", + "contradiction_detected", + "query_specificity_match", + "budget_exhausted", + "recency", + } + assert set(data["confidence_factors"].keys()) == cf_keys -- 2.45.2