367 lines
13 KiB
Python
367 lines
13 KiB
Python
|
|
"""Tests for the web researcher agent."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import tempfile
|
||
|
|
from types import SimpleNamespace
|
||
|
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from researchers.web.agent import WebResearcher, _format_search_results
|
||
|
|
from researchers.web.models import ResearchConstraints, ResearchResult
|
||
|
|
from researchers.web.tools import SearchResult
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Helpers
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def _make_anthropic_response(content_blocks, input_tokens=100, output_tokens=200):
|
||
|
|
"""Build a mock Anthropic messages.create response."""
|
||
|
|
resp = MagicMock()
|
||
|
|
resp.content = content_blocks
|
||
|
|
resp.usage = SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens)
|
||
|
|
return resp
|
||
|
|
|
||
|
|
|
||
|
|
def _text_block(text):
|
||
|
|
block = MagicMock()
|
||
|
|
block.type = "text"
|
||
|
|
block.text = text
|
||
|
|
return block
|
||
|
|
|
||
|
|
|
||
|
|
def _tool_use_block(name, tool_input, tool_id="tool_1"):
|
||
|
|
block = MagicMock()
|
||
|
|
block.type = "tool_use"
|
||
|
|
block.name = name
|
||
|
|
block.input = tool_input
|
||
|
|
block.id = tool_id
|
||
|
|
return block
|
||
|
|
|
||
|
|
|
||
|
|
VALID_SYNTHESIS_JSON = json.dumps(
|
||
|
|
{
|
||
|
|
"answer": "Utah is ideal for cool-season crops at high elevation.",
|
||
|
|
"citations": [
|
||
|
|
{
|
||
|
|
"source": "web",
|
||
|
|
"locator": "https://example.com/utah-crops",
|
||
|
|
"title": "Utah Crop Guide",
|
||
|
|
"snippet": "Cool-season crops thrive above 7000 ft.",
|
||
|
|
"raw_excerpt": "In Utah's high-elevation gardens, cool-season vegetables such as peas, lettuce, and potatoes consistently outperform warm-season crops.",
|
||
|
|
"confidence": 0.9,
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"gaps": [
|
||
|
|
{
|
||
|
|
"topic": "pest management",
|
||
|
|
"category": "source_not_found",
|
||
|
|
"detail": "No pest data found.",
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"discovery_events": [
|
||
|
|
{
|
||
|
|
"type": "related_research",
|
||
|
|
"suggested_researcher": "database",
|
||
|
|
"query": "Utah soil salinity data",
|
||
|
|
"reason": "Multiple sources reference USU studies",
|
||
|
|
"source_locator": "https://example.com/ref",
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"confidence": 0.82,
|
||
|
|
"confidence_factors": {
|
||
|
|
"num_corroborating_sources": 3,
|
||
|
|
"source_authority": "high",
|
||
|
|
"contradiction_detected": False,
|
||
|
|
"query_specificity_match": 0.85,
|
||
|
|
"budget_exhausted": False,
|
||
|
|
"recency": "current",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# _format_search_results
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestFormatSearchResults:
|
||
|
|
def test_formats_results(self):
|
||
|
|
results = [
|
||
|
|
SearchResult(
|
||
|
|
url="https://example.com",
|
||
|
|
title="Test",
|
||
|
|
content="Short summary",
|
||
|
|
raw_content="Full text here",
|
||
|
|
score=0.95,
|
||
|
|
content_hash="sha256:abc",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
text = _format_search_results(results)
|
||
|
|
assert "Test" in text
|
||
|
|
assert "https://example.com" in text
|
||
|
|
assert "0.95" in text
|
||
|
|
assert "Full text here" in text
|
||
|
|
|
||
|
|
def test_prefers_raw_content(self):
|
||
|
|
results = [
|
||
|
|
SearchResult(
|
||
|
|
url="https://example.com",
|
||
|
|
title="Test",
|
||
|
|
content="Short",
|
||
|
|
raw_content="Much longer raw content",
|
||
|
|
score=0.9,
|
||
|
|
content_hash="sha256:abc",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
text = _format_search_results(results)
|
||
|
|
assert "Much longer raw content" in text
|
||
|
|
|
||
|
|
def test_falls_back_to_content(self):
|
||
|
|
results = [
|
||
|
|
SearchResult(
|
||
|
|
url="https://example.com",
|
||
|
|
title="Test",
|
||
|
|
content="Only short content",
|
||
|
|
raw_content=None,
|
||
|
|
score=0.9,
|
||
|
|
content_hash="sha256:abc",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
text = _format_search_results(results)
|
||
|
|
assert "Only short content" in text
|
||
|
|
|
||
|
|
def test_empty_results(self):
|
||
|
|
assert "No results" in _format_search_results([])
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# WebResearcher — mocked tool loop
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestWebResearcher:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_simple_research_loop(self):
|
||
|
|
"""Test a complete loop: one search → LLM stops → synthesis."""
|
||
|
|
with tempfile.TemporaryDirectory() as tmp:
|
||
|
|
researcher = WebResearcher(
|
||
|
|
anthropic_api_key="fake",
|
||
|
|
tavily_api_key="fake",
|
||
|
|
model_id="claude-test",
|
||
|
|
trace_dir=tmp,
|
||
|
|
)
|
||
|
|
|
||
|
|
# First call: LLM requests a web_search
|
||
|
|
search_response = _make_anthropic_response(
|
||
|
|
[_tool_use_block("web_search", {"query": "Utah crops"})],
|
||
|
|
)
|
||
|
|
# Second call: LLM is done (text only, no tools)
|
||
|
|
done_response = _make_anthropic_response(
|
||
|
|
[_text_block("I have enough information.")],
|
||
|
|
)
|
||
|
|
# Third call: synthesis
|
||
|
|
synthesis_response = _make_anthropic_response(
|
||
|
|
[_text_block(VALID_SYNTHESIS_JSON)],
|
||
|
|
)
|
||
|
|
|
||
|
|
researcher.client.messages.create = MagicMock(
|
||
|
|
side_effect=[search_response, done_response, synthesis_response]
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch("researchers.web.agent.tavily_search") as mock_search:
|
||
|
|
mock_search.return_value = [
|
||
|
|
SearchResult(
|
||
|
|
url="https://example.com/utah",
|
||
|
|
title="Utah Gardening",
|
||
|
|
content="Cool-season crops work well.",
|
||
|
|
raw_content="Full content about Utah gardening.",
|
||
|
|
score=0.95,
|
||
|
|
content_hash="sha256:abc123",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
|
||
|
|
result = await researcher.research(
|
||
|
|
"What are ideal crops for Utah?",
|
||
|
|
constraints=ResearchConstraints(max_iterations=3),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert isinstance(result, ResearchResult)
|
||
|
|
assert "Utah" in result.answer
|
||
|
|
assert len(result.citations) == 1
|
||
|
|
assert result.citations[0].locator == "https://example.com/utah-crops"
|
||
|
|
assert result.citations[0].raw_excerpt.startswith("In Utah")
|
||
|
|
assert len(result.gaps) == 1
|
||
|
|
assert result.gaps[0].category == "source_not_found"
|
||
|
|
assert len(result.discovery_events) == 1
|
||
|
|
assert result.confidence == 0.82
|
||
|
|
assert result.confidence_factors.num_corroborating_sources == 3
|
||
|
|
assert result.cost_metadata.model_id == "claude-test"
|
||
|
|
assert result.cost_metadata.tokens_used > 0
|
||
|
|
assert result.trace_id is not None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_budget_exhaustion(self):
|
||
|
|
"""Test that the loop stops when token budget is reached."""
|
||
|
|
with tempfile.TemporaryDirectory() as tmp:
|
||
|
|
researcher = WebResearcher(
|
||
|
|
anthropic_api_key="fake",
|
||
|
|
tavily_api_key="fake",
|
||
|
|
model_id="claude-test",
|
||
|
|
trace_dir=tmp,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Each response uses 600 tokens — budget is 1000
|
||
|
|
search_response = _make_anthropic_response(
|
||
|
|
[_tool_use_block("web_search", {"query": "test"}, "t1")],
|
||
|
|
input_tokens=400,
|
||
|
|
output_tokens=200,
|
||
|
|
)
|
||
|
|
# Second search pushes over budget (600 + 600 = 1200 > 1000)
|
||
|
|
search_response_2 = _make_anthropic_response(
|
||
|
|
[_tool_use_block("web_search", {"query": "test2"}, "t2")],
|
||
|
|
input_tokens=400,
|
||
|
|
output_tokens=200,
|
||
|
|
)
|
||
|
|
synthesis_response = _make_anthropic_response(
|
||
|
|
[_text_block(VALID_SYNTHESIS_JSON)],
|
||
|
|
input_tokens=200,
|
||
|
|
output_tokens=100,
|
||
|
|
)
|
||
|
|
|
||
|
|
researcher.client.messages.create = MagicMock(
|
||
|
|
side_effect=[search_response, search_response_2, synthesis_response]
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch("researchers.web.agent.tavily_search") as mock_search:
|
||
|
|
mock_search.return_value = [
|
||
|
|
SearchResult(
|
||
|
|
url="https://example.com",
|
||
|
|
title="Test",
|
||
|
|
content="Content",
|
||
|
|
raw_content=None,
|
||
|
|
score=0.9,
|
||
|
|
content_hash="sha256:abc",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
|
||
|
|
result = await researcher.research(
|
||
|
|
"test question",
|
||
|
|
constraints=ResearchConstraints(
|
||
|
|
max_iterations=5,
|
||
|
|
token_budget=1000,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.cost_metadata.budget_exhausted is True
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_synthesis_failure_returns_fallback(self):
|
||
|
|
"""If synthesis JSON is unparseable, return a valid fallback."""
|
||
|
|
with tempfile.TemporaryDirectory() as tmp:
|
||
|
|
researcher = WebResearcher(
|
||
|
|
anthropic_api_key="fake",
|
||
|
|
tavily_api_key="fake",
|
||
|
|
model_id="claude-test",
|
||
|
|
trace_dir=tmp,
|
||
|
|
)
|
||
|
|
|
||
|
|
# LLM immediately stops (no tools)
|
||
|
|
done_response = _make_anthropic_response(
|
||
|
|
[_text_block("Nothing to search.")],
|
||
|
|
)
|
||
|
|
# Synthesis returns garbage
|
||
|
|
bad_synthesis = _make_anthropic_response(
|
||
|
|
[_text_block("This is not valid JSON at all!!!")],
|
||
|
|
)
|
||
|
|
|
||
|
|
researcher.client.messages.create = MagicMock(
|
||
|
|
side_effect=[done_response, bad_synthesis]
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await researcher.research("test question")
|
||
|
|
|
||
|
|
assert isinstance(result, ResearchResult)
|
||
|
|
assert "synthesis failed" in result.answer.lower()
|
||
|
|
assert result.confidence == 0.1
|
||
|
|
assert len(result.gaps) == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_trace_file_created(self):
|
||
|
|
"""Verify trace file is created and has entries."""
|
||
|
|
with tempfile.TemporaryDirectory() as tmp:
|
||
|
|
researcher = WebResearcher(
|
||
|
|
anthropic_api_key="fake",
|
||
|
|
tavily_api_key="fake",
|
||
|
|
model_id="claude-test",
|
||
|
|
trace_dir=tmp,
|
||
|
|
)
|
||
|
|
|
||
|
|
done_response = _make_anthropic_response(
|
||
|
|
[_text_block("Done.")],
|
||
|
|
)
|
||
|
|
synthesis_response = _make_anthropic_response(
|
||
|
|
[_text_block(VALID_SYNTHESIS_JSON)],
|
||
|
|
)
|
||
|
|
|
||
|
|
researcher.client.messages.create = MagicMock(
|
||
|
|
side_effect=[done_response, synthesis_response]
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await researcher.research("test")
|
||
|
|
|
||
|
|
# Check trace file exists
|
||
|
|
from researchers.web.trace import TraceLogger
|
||
|
|
|
||
|
|
trace = TraceLogger(trace_id=result.trace_id, trace_dir=tmp)
|
||
|
|
entries = trace.read_entries()
|
||
|
|
assert len(entries) >= 3 # start, iteration_start, synthesis, complete
|
||
|
|
assert entries[0]["action"] == "start"
|
||
|
|
actions = [e["action"] for e in entries]
|
||
|
|
assert "complete" in actions
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_fetch_url_tool(self):
|
||
|
|
"""Test that fetch_url tool calls work in the loop."""
|
||
|
|
with tempfile.TemporaryDirectory() as tmp:
|
||
|
|
researcher = WebResearcher(
|
||
|
|
anthropic_api_key="fake",
|
||
|
|
tavily_api_key="fake",
|
||
|
|
model_id="claude-test",
|
||
|
|
trace_dir=tmp,
|
||
|
|
)
|
||
|
|
|
||
|
|
# LLM requests fetch_url
|
||
|
|
fetch_response = _make_anthropic_response(
|
||
|
|
[_tool_use_block("fetch_url", {"url": "https://example.com/page"})],
|
||
|
|
)
|
||
|
|
done_response = _make_anthropic_response(
|
||
|
|
[_text_block("Got it.")],
|
||
|
|
)
|
||
|
|
synthesis_response = _make_anthropic_response(
|
||
|
|
[_text_block(VALID_SYNTHESIS_JSON)],
|
||
|
|
)
|
||
|
|
|
||
|
|
researcher.client.messages.create = MagicMock(
|
||
|
|
side_effect=[fetch_response, done_response, synthesis_response]
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch("researchers.web.agent.fetch_url") as mock_fetch:
|
||
|
|
from researchers.web.tools import FetchResult
|
||
|
|
|
||
|
|
mock_fetch.return_value = FetchResult(
|
||
|
|
url="https://example.com/page",
|
||
|
|
text="Fetched page content about Utah gardening.",
|
||
|
|
content_hash="sha256:def456",
|
||
|
|
content_length=42,
|
||
|
|
success=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await researcher.research("test question")
|
||
|
|
|
||
|
|
assert isinstance(result, ResearchResult)
|
||
|
|
mock_fetch.assert_called_once_with("https://example.com/page")
|