M1.1: Search and fetch tools with tests
- tavily_search(): Tavily API wrapper returning SearchResult dataclasses with content hashing (raw_content preferred, falls back to summary) - fetch_url(): async URL fetch with HTML text extraction, content hashing, and graceful error handling (timeout, HTTP errors, connection errors) - _extract_text(): simple HTML → clean text (strip scripts/styles/tags, decode entities, collapse whitespace) - _sha256(): SHA-256 content hashing with 'sha256:' prefix for traces 18 tests: hashing, HTML extraction, mocked Tavily search, mocked async fetch (success, timeout, HTTP error, hash consistency). Refs: archeious/marchwarden#1 Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
8930f4486a
commit
a5bc93e275
2 changed files with 426 additions and 0 deletions
181
researchers/web/tools.py
Normal file
181
researchers/web/tools.py
Normal file
|
|
@ -0,0 +1,181 @@
|
||||||
|
"""Search and fetch tools for the web researcher.
|
||||||
|
|
||||||
|
Two tools:
|
||||||
|
- tavily_search: Web search via Tavily API, returns structured results
|
||||||
|
- fetch_url: Direct URL fetch with content extraction and hashing
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from tavily import TavilyClient
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
"""A single result from a Tavily search."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
title: str
|
||||||
|
content: str # Short extracted summary from Tavily
|
||||||
|
raw_content: Optional[str] # Full page text (may be None)
|
||||||
|
score: float # Tavily relevance score (0.0-1.0)
|
||||||
|
content_hash: str # SHA-256 of the best available content
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FetchResult:
|
||||||
|
"""Result of fetching and extracting content from a URL."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
text: str # Extracted clean text
|
||||||
|
content_hash: str # SHA-256 of the fetched content
|
||||||
|
content_length: int
|
||||||
|
success: bool
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256(text: str) -> str:
|
||||||
|
"""Compute SHA-256 hash of text content."""
|
||||||
|
return f"sha256:{hashlib.sha256(text.encode('utf-8')).hexdigest()}"
|
||||||
|
|
||||||
|
|
||||||
|
def tavily_search(
|
||||||
|
api_key: str,
|
||||||
|
query: str,
|
||||||
|
max_results: int = 5,
|
||||||
|
include_raw_content: bool = True,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search the web via Tavily API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Tavily API key.
|
||||||
|
query: Search query string.
|
||||||
|
max_results: Maximum number of results to return.
|
||||||
|
include_raw_content: Whether to request full page text from Tavily.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects, sorted by relevance score.
|
||||||
|
"""
|
||||||
|
client = TavilyClient(api_key=api_key)
|
||||||
|
response = client.search(
|
||||||
|
query,
|
||||||
|
max_results=max_results,
|
||||||
|
include_raw_content=include_raw_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in response.get("results", []):
|
||||||
|
raw = item.get("raw_content") or ""
|
||||||
|
content = item.get("content", "")
|
||||||
|
# Hash the best available content (raw if present, else summary)
|
||||||
|
hashable = raw if raw else content
|
||||||
|
results.append(
|
||||||
|
SearchResult(
|
||||||
|
url=item.get("url", ""),
|
||||||
|
title=item.get("title", ""),
|
||||||
|
content=content,
|
||||||
|
raw_content=raw if raw else None,
|
||||||
|
score=item.get("score", 0.0),
|
||||||
|
content_hash=_sha256(hashable) if hashable else _sha256(""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_url(
|
||||||
|
url: str,
|
||||||
|
timeout: float = 15.0,
|
||||||
|
max_length: int = 100_000,
|
||||||
|
) -> FetchResult:
|
||||||
|
"""Fetch a URL and extract clean text content.
|
||||||
|
|
||||||
|
Used for URLs where Tavily didn't return raw_content, or for
|
||||||
|
URLs discovered during research that weren't in the search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to fetch.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
max_length: Maximum response body length in characters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FetchResult with extracted text and content hash.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
follow_redirects=True,
|
||||||
|
timeout=timeout,
|
||||||
|
) as client:
|
||||||
|
response = await client.get(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"User-Agent": "Marchwarden/0.1 (research agent)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
raw_text = response.text[:max_length]
|
||||||
|
clean_text = _extract_text(raw_text)
|
||||||
|
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
text=clean_text,
|
||||||
|
content_hash=_sha256(clean_text),
|
||||||
|
content_length=len(clean_text),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
text="",
|
||||||
|
content_hash=_sha256(""),
|
||||||
|
content_length=0,
|
||||||
|
success=False,
|
||||||
|
error=f"Timeout after {timeout}s",
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
text="",
|
||||||
|
content_hash=_sha256(""),
|
||||||
|
content_length=0,
|
||||||
|
success=False,
|
||||||
|
error=f"HTTP {e.response.status_code}",
|
||||||
|
)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
return FetchResult(
|
||||||
|
url=url,
|
||||||
|
text="",
|
||||||
|
content_hash=_sha256(""),
|
||||||
|
content_length=0,
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text(html: str) -> str:
|
||||||
|
"""Extract readable text from HTML.
|
||||||
|
|
||||||
|
Simple extraction: strip tags, collapse whitespace. For V1 this is
|
||||||
|
sufficient. If quality is poor, swap in trafilatura or readability-lxml.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Remove script and style blocks
|
||||||
|
text = re.sub(r"<script[^>]*>.*?</script>", " ", html, flags=re.DOTALL)
|
||||||
|
text = re.sub(r"<style[^>]*>.*?</style>", " ", text, flags=re.DOTALL)
|
||||||
|
# Remove HTML tags
|
||||||
|
text = re.sub(r"<[^>]+>", " ", text)
|
||||||
|
# Decode common HTML entities
|
||||||
|
text = text.replace("&", "&")
|
||||||
|
text = text.replace("<", "<")
|
||||||
|
text = text.replace(">", ">")
|
||||||
|
text = text.replace(""", '"')
|
||||||
|
text = text.replace("'", "'")
|
||||||
|
text = text.replace(" ", " ")
|
||||||
|
# Collapse whitespace
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
return text
|
||||||
245
tests/test_tools.py
Normal file
245
tests/test_tools.py
Normal file
|
|
@ -0,0 +1,245 @@
|
||||||
|
"""Tests for web researcher search and fetch tools."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from researchers.web.tools import (
|
||||||
|
FetchResult,
|
||||||
|
SearchResult,
|
||||||
|
_extract_text,
|
||||||
|
_sha256,
|
||||||
|
fetch_url,
|
||||||
|
tavily_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _sha256
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSha256:
|
||||||
|
def test_consistent_hash(self):
|
||||||
|
assert _sha256("hello") == _sha256("hello")
|
||||||
|
|
||||||
|
def test_different_input_different_hash(self):
|
||||||
|
assert _sha256("hello") != _sha256("world")
|
||||||
|
|
||||||
|
def test_format(self):
|
||||||
|
h = _sha256("test")
|
||||||
|
assert h.startswith("sha256:")
|
||||||
|
# SHA-256 hex digest is 64 chars
|
||||||
|
assert len(h.split(":")[1]) == 64
|
||||||
|
|
||||||
|
def test_matches_hashlib(self):
|
||||||
|
text = "marchwarden test"
|
||||||
|
expected = "sha256:" + hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||||
|
assert _sha256(text) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractText:
|
||||||
|
def test_strips_tags(self):
|
||||||
|
assert _extract_text("<p>hello</p>") == "hello"
|
||||||
|
|
||||||
|
def test_strips_script(self):
|
||||||
|
html = "<script>var x=1;</script><p>content</p>"
|
||||||
|
assert "var x" not in _extract_text(html)
|
||||||
|
assert "content" in _extract_text(html)
|
||||||
|
|
||||||
|
def test_strips_style(self):
|
||||||
|
html = "<style>.foo{color:red}</style><p>visible</p>"
|
||||||
|
assert "color" not in _extract_text(html)
|
||||||
|
assert "visible" in _extract_text(html)
|
||||||
|
|
||||||
|
def test_collapses_whitespace(self):
|
||||||
|
html = "<p>hello</p> <p>world</p>"
|
||||||
|
assert _extract_text(html) == "hello world"
|
||||||
|
|
||||||
|
def test_decodes_entities(self):
|
||||||
|
html = "& < > " ' "
|
||||||
|
result = _extract_text(html)
|
||||||
|
assert "&" in result
|
||||||
|
assert "<" in result
|
||||||
|
assert ">" in result
|
||||||
|
assert '"' in result
|
||||||
|
assert "'" in result
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
assert _extract_text("") == ""
|
||||||
|
|
||||||
|
def test_plain_text_passthrough(self):
|
||||||
|
assert _extract_text("just plain text") == "just plain text"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# tavily_search
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTavilySearch:
|
||||||
|
@patch("researchers.web.tools.TavilyClient")
|
||||||
|
def test_returns_search_results(self, mock_client_class):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.search.return_value = {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"url": "https://example.com/page1",
|
||||||
|
"title": "Page One",
|
||||||
|
"content": "Short summary of page one.",
|
||||||
|
"raw_content": "Full text of page one with lots of detail.",
|
||||||
|
"score": 0.95,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://example.com/page2",
|
||||||
|
"title": "Page Two",
|
||||||
|
"content": "Short summary of page two.",
|
||||||
|
"raw_content": "",
|
||||||
|
"score": 0.80,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
results = tavily_search("fake-key", "test query", max_results=2)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], SearchResult)
|
||||||
|
|
||||||
|
# First result has raw_content
|
||||||
|
assert results[0].url == "https://example.com/page1"
|
||||||
|
assert results[0].title == "Page One"
|
||||||
|
assert results[0].raw_content == "Full text of page one with lots of detail."
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert results[0].content_hash.startswith("sha256:")
|
||||||
|
# Hash should be of raw_content (best available)
|
||||||
|
assert results[0].content_hash == _sha256(
|
||||||
|
"Full text of page one with lots of detail."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second result has empty raw_content → falls back to content
|
||||||
|
assert results[1].raw_content is None
|
||||||
|
assert results[1].content_hash == _sha256("Short summary of page two.")
|
||||||
|
|
||||||
|
@patch("researchers.web.tools.TavilyClient")
|
||||||
|
def test_empty_results(self, mock_client_class):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.search.return_value = {"results": []}
|
||||||
|
|
||||||
|
results = tavily_search("fake-key", "obscure query")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@patch("researchers.web.tools.TavilyClient")
|
||||||
|
def test_passes_params_to_client(self, mock_client_class):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.search.return_value = {"results": []}
|
||||||
|
|
||||||
|
tavily_search("my-key", "query", max_results=3, include_raw_content=False)
|
||||||
|
|
||||||
|
mock_client_class.assert_called_once_with(api_key="my-key")
|
||||||
|
mock_client.search.assert_called_once_with(
|
||||||
|
"query", max_results=3, include_raw_content=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# fetch_url
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchUrl:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_fetch(self):
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.text = "<html><body><p>Hello world</p></body></html>"
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value.__aenter__ = AsyncMock(
|
||||||
|
return_value=mock_client
|
||||||
|
)
|
||||||
|
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
result = await fetch_url("https://example.com")
|
||||||
|
|
||||||
|
assert isinstance(result, FetchResult)
|
||||||
|
assert result.success is True
|
||||||
|
assert result.error is None
|
||||||
|
assert "Hello world" in result.text
|
||||||
|
assert result.content_hash.startswith("sha256:")
|
||||||
|
assert result.content_length > 0
|
||||||
|
assert result.url == "https://example.com"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_error(self):
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value.__aenter__ = AsyncMock(
|
||||||
|
return_value=mock_client
|
||||||
|
)
|
||||||
|
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.get.side_effect = httpx.TimeoutException("timed out")
|
||||||
|
|
||||||
|
result = await fetch_url("https://slow.example.com", timeout=5.0)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Timeout" in result.error
|
||||||
|
assert result.text == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_error(self):
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 404
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"Not Found",
|
||||||
|
request=MagicMock(),
|
||||||
|
response=mock_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value.__aenter__ = AsyncMock(
|
||||||
|
return_value=mock_client
|
||||||
|
)
|
||||||
|
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
result = await fetch_url("https://example.com/missing")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "404" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_content_hash_consistency(self):
|
||||||
|
"""Same content → same hash."""
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.text = "<p>consistent content</p>"
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value.__aenter__ = AsyncMock(
|
||||||
|
return_value=mock_client
|
||||||
|
)
|
||||||
|
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
r1 = await fetch_url("https://example.com/a")
|
||||||
|
r2 = await fetch_url("https://example.com/b")
|
||||||
|
|
||||||
|
assert r1.content_hash == r2.content_hash
|
||||||
Loading…
Reference in a new issue