diff --git a/researchers/web/tools.py b/researchers/web/tools.py new file mode 100644 index 0000000..12e83c7 --- /dev/null +++ b/researchers/web/tools.py @@ -0,0 +1,181 @@ +"""Search and fetch tools for the web researcher. + +Two tools: +- tavily_search: Web search via Tavily API, returns structured results +- fetch_url: Direct URL fetch with content extraction and hashing +""" + +import hashlib +from dataclasses import dataclass +from typing import Optional + +import httpx +from tavily import TavilyClient + + +@dataclass +class SearchResult: + """A single result from a Tavily search.""" + + url: str + title: str + content: str # Short extracted summary from Tavily + raw_content: Optional[str] # Full page text (may be None) + score: float # Tavily relevance score (0.0-1.0) + content_hash: str # SHA-256 of the best available content + + +@dataclass +class FetchResult: + """Result of fetching and extracting content from a URL.""" + + url: str + text: str # Extracted clean text + content_hash: str # SHA-256 of the fetched content + content_length: int + success: bool + error: Optional[str] = None + + +def _sha256(text: str) -> str: + """Compute SHA-256 hash of text content.""" + return f"sha256:{hashlib.sha256(text.encode('utf-8')).hexdigest()}" + + +def tavily_search( + api_key: str, + query: str, + max_results: int = 5, + include_raw_content: bool = True, +) -> list[SearchResult]: + """Search the web via Tavily API. + + Args: + api_key: Tavily API key. + query: Search query string. + max_results: Maximum number of results to return. + include_raw_content: Whether to request full page text from Tavily. + + Returns: + List of SearchResult objects, sorted by relevance score. + """ + client = TavilyClient(api_key=api_key) + response = client.search( + query, + max_results=max_results, + include_raw_content=include_raw_content, + ) + + results = [] + for item in response.get("results", []): + raw = item.get("raw_content") or "" + content = item.get("content", "") + # Hash the best available content (raw if present, else summary) + hashable = raw if raw else content + results.append( + SearchResult( + url=item.get("url", ""), + title=item.get("title", ""), + content=content, + raw_content=raw if raw else None, + score=item.get("score", 0.0), + content_hash=_sha256(hashable) if hashable else _sha256(""), + ) + ) + + return results + + +async def fetch_url( + url: str, + timeout: float = 15.0, + max_length: int = 100_000, +) -> FetchResult: + """Fetch a URL and extract clean text content. + + Used for URLs where Tavily didn't return raw_content, or for + URLs discovered during research that weren't in the search results. + + Args: + url: The URL to fetch. + timeout: Request timeout in seconds. + max_length: Maximum response body length in characters. + + Returns: + FetchResult with extracted text and content hash. + """ + try: + async with httpx.AsyncClient( + follow_redirects=True, + timeout=timeout, + ) as client: + response = await client.get( + url, + headers={ + "User-Agent": "Marchwarden/0.1 (research agent)", + }, + ) + response.raise_for_status() + + raw_text = response.text[:max_length] + clean_text = _extract_text(raw_text) + + return FetchResult( + url=url, + text=clean_text, + content_hash=_sha256(clean_text), + content_length=len(clean_text), + success=True, + ) + except httpx.TimeoutException: + return FetchResult( + url=url, + text="", + content_hash=_sha256(""), + content_length=0, + success=False, + error=f"Timeout after {timeout}s", + ) + except httpx.HTTPStatusError as e: + return FetchResult( + url=url, + text="", + content_hash=_sha256(""), + content_length=0, + success=False, + error=f"HTTP {e.response.status_code}", + ) + except httpx.HTTPError as e: + return FetchResult( + url=url, + text="", + content_hash=_sha256(""), + content_length=0, + success=False, + error=str(e), + ) + + +def _extract_text(html: str) -> str: + """Extract readable text from HTML. + + Simple extraction: strip tags, collapse whitespace. For V1 this is + sufficient. If quality is poor, swap in trafilatura or readability-lxml. + """ + import re + + # Remove script and style blocks + text = re.sub(r"]*>.*?", " ", html, flags=re.DOTALL) + text = re.sub(r"]*>.*?", " ", text, flags=re.DOTALL) + # Remove HTML tags + text = re.sub(r"<[^>]+>", " ", text) + # Decode common HTML entities + text = text.replace("&", "&") + text = text.replace("<", "<") + text = text.replace(">", ">") + text = text.replace(""", '"') + text = text.replace("'", "'") + text = text.replace(" ", " ") + # Collapse whitespace + text = re.sub(r"\s+", " ", text).strip() + return text diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..e32328a --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,245 @@ +"""Tests for web researcher search and fetch tools.""" + +import hashlib +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from researchers.web.tools import ( + FetchResult, + SearchResult, + _extract_text, + _sha256, + fetch_url, + tavily_search, +) + + +# --------------------------------------------------------------------------- +# _sha256 +# --------------------------------------------------------------------------- + + +class TestSha256: + def test_consistent_hash(self): + assert _sha256("hello") == _sha256("hello") + + def test_different_input_different_hash(self): + assert _sha256("hello") != _sha256("world") + + def test_format(self): + h = _sha256("test") + assert h.startswith("sha256:") + # SHA-256 hex digest is 64 chars + assert len(h.split(":")[1]) == 64 + + def test_matches_hashlib(self): + text = "marchwarden test" + expected = "sha256:" + hashlib.sha256(text.encode("utf-8")).hexdigest() + assert _sha256(text) == expected + + +# --------------------------------------------------------------------------- +# _extract_text +# --------------------------------------------------------------------------- + + +class TestExtractText: + def test_strips_tags(self): + assert _extract_text("

hello

") == "hello" + + def test_strips_script(self): + html = "

content

" + assert "var x" not in _extract_text(html) + assert "content" in _extract_text(html) + + def test_strips_style(self): + html = "

visible

" + assert "color" not in _extract_text(html) + assert "visible" in _extract_text(html) + + def test_collapses_whitespace(self): + html = "

hello

world

" + assert _extract_text(html) == "hello world" + + def test_decodes_entities(self): + html = "& < > " '  " + result = _extract_text(html) + assert "&" in result + assert "<" in result + assert ">" in result + assert '"' in result + assert "'" in result + + def test_empty_input(self): + assert _extract_text("") == "" + + def test_plain_text_passthrough(self): + assert _extract_text("just plain text") == "just plain text" + + +# --------------------------------------------------------------------------- +# tavily_search +# --------------------------------------------------------------------------- + + +class TestTavilySearch: + @patch("researchers.web.tools.TavilyClient") + def test_returns_search_results(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.search.return_value = { + "results": [ + { + "url": "https://example.com/page1", + "title": "Page One", + "content": "Short summary of page one.", + "raw_content": "Full text of page one with lots of detail.", + "score": 0.95, + }, + { + "url": "https://example.com/page2", + "title": "Page Two", + "content": "Short summary of page two.", + "raw_content": "", + "score": 0.80, + }, + ] + } + + results = tavily_search("fake-key", "test query", max_results=2) + + assert len(results) == 2 + assert isinstance(results[0], SearchResult) + + # First result has raw_content + assert results[0].url == "https://example.com/page1" + assert results[0].title == "Page One" + assert results[0].raw_content == "Full text of page one with lots of detail." + assert results[0].score == 0.95 + assert results[0].content_hash.startswith("sha256:") + # Hash should be of raw_content (best available) + assert results[0].content_hash == _sha256( + "Full text of page one with lots of detail." + ) + + # Second result has empty raw_content → falls back to content + assert results[1].raw_content is None + assert results[1].content_hash == _sha256("Short summary of page two.") + + @patch("researchers.web.tools.TavilyClient") + def test_empty_results(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.search.return_value = {"results": []} + + results = tavily_search("fake-key", "obscure query") + assert results == [] + + @patch("researchers.web.tools.TavilyClient") + def test_passes_params_to_client(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.search.return_value = {"results": []} + + tavily_search("my-key", "query", max_results=3, include_raw_content=False) + + mock_client_class.assert_called_once_with(api_key="my-key") + mock_client.search.assert_called_once_with( + "query", max_results=3, include_raw_content=False + ) + + +# --------------------------------------------------------------------------- +# fetch_url +# --------------------------------------------------------------------------- + + +class TestFetchUrl: + @pytest.mark.asyncio + async def test_successful_fetch(self): + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.text = "

Hello world

" + mock_response.raise_for_status = MagicMock() + + with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.get.return_value = mock_response + + result = await fetch_url("https://example.com") + + assert isinstance(result, FetchResult) + assert result.success is True + assert result.error is None + assert "Hello world" in result.text + assert result.content_hash.startswith("sha256:") + assert result.content_length > 0 + assert result.url == "https://example.com" + + @pytest.mark.asyncio + async def test_timeout_error(self): + import httpx + + with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.get.side_effect = httpx.TimeoutException("timed out") + + result = await fetch_url("https://slow.example.com", timeout=5.0) + + assert result.success is False + assert "Timeout" in result.error + assert result.text == "" + + @pytest.mark.asyncio + async def test_http_error(self): + import httpx + + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", + request=MagicMock(), + response=mock_response, + ) + + with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.get.return_value = mock_response + + result = await fetch_url("https://example.com/missing") + + assert result.success is False + assert "404" in result.error + + @pytest.mark.asyncio + async def test_content_hash_consistency(self): + """Same content → same hash.""" + mock_response = AsyncMock() + mock_response.text = "

consistent content

" + mock_response.raise_for_status = MagicMock() + + with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__ = AsyncMock( + return_value=mock_client + ) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.get.return_value = mock_response + + r1 = await fetch_url("https://example.com/a") + r2 = await fetch_url("https://example.com/b") + + assert r1.content_hash == r2.content_hash