Compare commits
No commits in common. "851fed6a5f66793f644772245a34d9cad3cf4721" and "8930f4486ad30d8759df48b0bca594777b0c87da" have entirely different histories.
851fed6a5f
...
8930f4486a
2 changed files with 0 additions and 426 deletions
|
|
@ -1,181 +0,0 @@
|
||||||
"""Search and fetch tools for the web researcher.
|
|
||||||
|
|
||||||
Two tools:
|
|
||||||
- tavily_search: Web search via Tavily API, returns structured results
|
|
||||||
- fetch_url: Direct URL fetch with content extraction and hashing
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from tavily import TavilyClient
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SearchResult:
|
|
||||||
"""A single result from a Tavily search."""
|
|
||||||
|
|
||||||
url: str
|
|
||||||
title: str
|
|
||||||
content: str # Short extracted summary from Tavily
|
|
||||||
raw_content: Optional[str] # Full page text (may be None)
|
|
||||||
score: float # Tavily relevance score (0.0-1.0)
|
|
||||||
content_hash: str # SHA-256 of the best available content
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FetchResult:
|
|
||||||
"""Result of fetching and extracting content from a URL."""
|
|
||||||
|
|
||||||
url: str
|
|
||||||
text: str # Extracted clean text
|
|
||||||
content_hash: str # SHA-256 of the fetched content
|
|
||||||
content_length: int
|
|
||||||
success: bool
|
|
||||||
error: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _sha256(text: str) -> str:
|
|
||||||
"""Compute SHA-256 hash of text content."""
|
|
||||||
return f"sha256:{hashlib.sha256(text.encode('utf-8')).hexdigest()}"
|
|
||||||
|
|
||||||
|
|
||||||
def tavily_search(
|
|
||||||
api_key: str,
|
|
||||||
query: str,
|
|
||||||
max_results: int = 5,
|
|
||||||
include_raw_content: bool = True,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Search the web via Tavily API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: Tavily API key.
|
|
||||||
query: Search query string.
|
|
||||||
max_results: Maximum number of results to return.
|
|
||||||
include_raw_content: Whether to request full page text from Tavily.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of SearchResult objects, sorted by relevance score.
|
|
||||||
"""
|
|
||||||
client = TavilyClient(api_key=api_key)
|
|
||||||
response = client.search(
|
|
||||||
query,
|
|
||||||
max_results=max_results,
|
|
||||||
include_raw_content=include_raw_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for item in response.get("results", []):
|
|
||||||
raw = item.get("raw_content") or ""
|
|
||||||
content = item.get("content", "")
|
|
||||||
# Hash the best available content (raw if present, else summary)
|
|
||||||
hashable = raw if raw else content
|
|
||||||
results.append(
|
|
||||||
SearchResult(
|
|
||||||
url=item.get("url", ""),
|
|
||||||
title=item.get("title", ""),
|
|
||||||
content=content,
|
|
||||||
raw_content=raw if raw else None,
|
|
||||||
score=item.get("score", 0.0),
|
|
||||||
content_hash=_sha256(hashable) if hashable else _sha256(""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_url(
|
|
||||||
url: str,
|
|
||||||
timeout: float = 15.0,
|
|
||||||
max_length: int = 100_000,
|
|
||||||
) -> FetchResult:
|
|
||||||
"""Fetch a URL and extract clean text content.
|
|
||||||
|
|
||||||
Used for URLs where Tavily didn't return raw_content, or for
|
|
||||||
URLs discovered during research that weren't in the search results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The URL to fetch.
|
|
||||||
timeout: Request timeout in seconds.
|
|
||||||
max_length: Maximum response body length in characters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
FetchResult with extracted text and content hash.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(
|
|
||||||
follow_redirects=True,
|
|
||||||
timeout=timeout,
|
|
||||||
) as client:
|
|
||||||
response = await client.get(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"User-Agent": "Marchwarden/0.1 (research agent)",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
raw_text = response.text[:max_length]
|
|
||||||
clean_text = _extract_text(raw_text)
|
|
||||||
|
|
||||||
return FetchResult(
|
|
||||||
url=url,
|
|
||||||
text=clean_text,
|
|
||||||
content_hash=_sha256(clean_text),
|
|
||||||
content_length=len(clean_text),
|
|
||||||
success=True,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
return FetchResult(
|
|
||||||
url=url,
|
|
||||||
text="",
|
|
||||||
content_hash=_sha256(""),
|
|
||||||
content_length=0,
|
|
||||||
success=False,
|
|
||||||
error=f"Timeout after {timeout}s",
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
return FetchResult(
|
|
||||||
url=url,
|
|
||||||
text="",
|
|
||||||
content_hash=_sha256(""),
|
|
||||||
content_length=0,
|
|
||||||
success=False,
|
|
||||||
error=f"HTTP {e.response.status_code}",
|
|
||||||
)
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
return FetchResult(
|
|
||||||
url=url,
|
|
||||||
text="",
|
|
||||||
content_hash=_sha256(""),
|
|
||||||
content_length=0,
|
|
||||||
success=False,
|
|
||||||
error=str(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text(html: str) -> str:
|
|
||||||
"""Extract readable text from HTML.
|
|
||||||
|
|
||||||
Simple extraction: strip tags, collapse whitespace. For V1 this is
|
|
||||||
sufficient. If quality is poor, swap in trafilatura or readability-lxml.
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
# Remove script and style blocks
|
|
||||||
text = re.sub(r"<script[^>]*>.*?</script>", " ", html, flags=re.DOTALL)
|
|
||||||
text = re.sub(r"<style[^>]*>.*?</style>", " ", text, flags=re.DOTALL)
|
|
||||||
# Remove HTML tags
|
|
||||||
text = re.sub(r"<[^>]+>", " ", text)
|
|
||||||
# Decode common HTML entities
|
|
||||||
text = text.replace("&", "&")
|
|
||||||
text = text.replace("<", "<")
|
|
||||||
text = text.replace(">", ">")
|
|
||||||
text = text.replace(""", '"')
|
|
||||||
text = text.replace("'", "'")
|
|
||||||
text = text.replace(" ", " ")
|
|
||||||
# Collapse whitespace
|
|
||||||
text = re.sub(r"\s+", " ", text).strip()
|
|
||||||
return text
|
|
||||||
|
|
@ -1,245 +0,0 @@
|
||||||
"""Tests for web researcher search and fetch tools."""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from researchers.web.tools import (
|
|
||||||
FetchResult,
|
|
||||||
SearchResult,
|
|
||||||
_extract_text,
|
|
||||||
_sha256,
|
|
||||||
fetch_url,
|
|
||||||
tavily_search,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _sha256
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestSha256:
|
|
||||||
def test_consistent_hash(self):
|
|
||||||
assert _sha256("hello") == _sha256("hello")
|
|
||||||
|
|
||||||
def test_different_input_different_hash(self):
|
|
||||||
assert _sha256("hello") != _sha256("world")
|
|
||||||
|
|
||||||
def test_format(self):
|
|
||||||
h = _sha256("test")
|
|
||||||
assert h.startswith("sha256:")
|
|
||||||
# SHA-256 hex digest is 64 chars
|
|
||||||
assert len(h.split(":")[1]) == 64
|
|
||||||
|
|
||||||
def test_matches_hashlib(self):
|
|
||||||
text = "marchwarden test"
|
|
||||||
expected = "sha256:" + hashlib.sha256(text.encode("utf-8")).hexdigest()
|
|
||||||
assert _sha256(text) == expected
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _extract_text
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractText:
|
|
||||||
def test_strips_tags(self):
|
|
||||||
assert _extract_text("<p>hello</p>") == "hello"
|
|
||||||
|
|
||||||
def test_strips_script(self):
|
|
||||||
html = "<script>var x=1;</script><p>content</p>"
|
|
||||||
assert "var x" not in _extract_text(html)
|
|
||||||
assert "content" in _extract_text(html)
|
|
||||||
|
|
||||||
def test_strips_style(self):
|
|
||||||
html = "<style>.foo{color:red}</style><p>visible</p>"
|
|
||||||
assert "color" not in _extract_text(html)
|
|
||||||
assert "visible" in _extract_text(html)
|
|
||||||
|
|
||||||
def test_collapses_whitespace(self):
|
|
||||||
html = "<p>hello</p> <p>world</p>"
|
|
||||||
assert _extract_text(html) == "hello world"
|
|
||||||
|
|
||||||
def test_decodes_entities(self):
|
|
||||||
html = "& < > " ' "
|
|
||||||
result = _extract_text(html)
|
|
||||||
assert "&" in result
|
|
||||||
assert "<" in result
|
|
||||||
assert ">" in result
|
|
||||||
assert '"' in result
|
|
||||||
assert "'" in result
|
|
||||||
|
|
||||||
def test_empty_input(self):
|
|
||||||
assert _extract_text("") == ""
|
|
||||||
|
|
||||||
def test_plain_text_passthrough(self):
|
|
||||||
assert _extract_text("just plain text") == "just plain text"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# tavily_search
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestTavilySearch:
|
|
||||||
@patch("researchers.web.tools.TavilyClient")
|
|
||||||
def test_returns_search_results(self, mock_client_class):
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_client_class.return_value = mock_client
|
|
||||||
mock_client.search.return_value = {
|
|
||||||
"results": [
|
|
||||||
{
|
|
||||||
"url": "https://example.com/page1",
|
|
||||||
"title": "Page One",
|
|
||||||
"content": "Short summary of page one.",
|
|
||||||
"raw_content": "Full text of page one with lots of detail.",
|
|
||||||
"score": 0.95,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"url": "https://example.com/page2",
|
|
||||||
"title": "Page Two",
|
|
||||||
"content": "Short summary of page two.",
|
|
||||||
"raw_content": "",
|
|
||||||
"score": 0.80,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
results = tavily_search("fake-key", "test query", max_results=2)
|
|
||||||
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], SearchResult)
|
|
||||||
|
|
||||||
# First result has raw_content
|
|
||||||
assert results[0].url == "https://example.com/page1"
|
|
||||||
assert results[0].title == "Page One"
|
|
||||||
assert results[0].raw_content == "Full text of page one with lots of detail."
|
|
||||||
assert results[0].score == 0.95
|
|
||||||
assert results[0].content_hash.startswith("sha256:")
|
|
||||||
# Hash should be of raw_content (best available)
|
|
||||||
assert results[0].content_hash == _sha256(
|
|
||||||
"Full text of page one with lots of detail."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Second result has empty raw_content → falls back to content
|
|
||||||
assert results[1].raw_content is None
|
|
||||||
assert results[1].content_hash == _sha256("Short summary of page two.")
|
|
||||||
|
|
||||||
@patch("researchers.web.tools.TavilyClient")
|
|
||||||
def test_empty_results(self, mock_client_class):
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_client_class.return_value = mock_client
|
|
||||||
mock_client.search.return_value = {"results": []}
|
|
||||||
|
|
||||||
results = tavily_search("fake-key", "obscure query")
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
@patch("researchers.web.tools.TavilyClient")
|
|
||||||
def test_passes_params_to_client(self, mock_client_class):
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_client_class.return_value = mock_client
|
|
||||||
mock_client.search.return_value = {"results": []}
|
|
||||||
|
|
||||||
tavily_search("my-key", "query", max_results=3, include_raw_content=False)
|
|
||||||
|
|
||||||
mock_client_class.assert_called_once_with(api_key="my-key")
|
|
||||||
mock_client.search.assert_called_once_with(
|
|
||||||
"query", max_results=3, include_raw_content=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# fetch_url
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestFetchUrl:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_successful_fetch(self):
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.text = "<html><body><p>Hello world</p></body></html>"
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
|
|
||||||
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_client_class.return_value.__aenter__ = AsyncMock(
|
|
||||||
return_value=mock_client
|
|
||||||
)
|
|
||||||
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_client.get.return_value = mock_response
|
|
||||||
|
|
||||||
result = await fetch_url("https://example.com")
|
|
||||||
|
|
||||||
assert isinstance(result, FetchResult)
|
|
||||||
assert result.success is True
|
|
||||||
assert result.error is None
|
|
||||||
assert "Hello world" in result.text
|
|
||||||
assert result.content_hash.startswith("sha256:")
|
|
||||||
assert result.content_length > 0
|
|
||||||
assert result.url == "https://example.com"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_timeout_error(self):
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_client_class.return_value.__aenter__ = AsyncMock(
|
|
||||||
return_value=mock_client
|
|
||||||
)
|
|
||||||
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_client.get.side_effect = httpx.TimeoutException("timed out")
|
|
||||||
|
|
||||||
result = await fetch_url("https://slow.example.com", timeout=5.0)
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert "Timeout" in result.error
|
|
||||||
assert result.text == ""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_http_error(self):
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 404
|
|
||||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
|
||||||
"Not Found",
|
|
||||||
request=MagicMock(),
|
|
||||||
response=mock_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_client_class.return_value.__aenter__ = AsyncMock(
|
|
||||||
return_value=mock_client
|
|
||||||
)
|
|
||||||
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_client.get.return_value = mock_response
|
|
||||||
|
|
||||||
result = await fetch_url("https://example.com/missing")
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert "404" in result.error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_hash_consistency(self):
|
|
||||||
"""Same content → same hash."""
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.text = "<p>consistent content</p>"
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
|
|
||||||
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_client_class.return_value.__aenter__ = AsyncMock(
|
|
||||||
return_value=mock_client
|
|
||||||
)
|
|
||||||
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_client.get.return_value = mock_response
|
|
||||||
|
|
||||||
r1 = await fetch_url("https://example.com/a")
|
|
||||||
r2 = await fetch_url("https://example.com/b")
|
|
||||||
|
|
||||||
assert r1.content_hash == r2.content_hash
|
|
||||||
Loading…
Reference in a new issue