marchwarden/tests/test_tools.py

246 lines
8.6 KiB
Python
Raw Permalink Normal View History

"""Tests for web researcher search and fetch tools."""
import hashlib
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from researchers.web.tools import (
FetchResult,
SearchResult,
_extract_text,
_sha256,
fetch_url,
tavily_search,
)
# ---------------------------------------------------------------------------
# _sha256
# ---------------------------------------------------------------------------
class TestSha256:
def test_consistent_hash(self):
assert _sha256("hello") == _sha256("hello")
def test_different_input_different_hash(self):
assert _sha256("hello") != _sha256("world")
def test_format(self):
h = _sha256("test")
assert h.startswith("sha256:")
# SHA-256 hex digest is 64 chars
assert len(h.split(":")[1]) == 64
def test_matches_hashlib(self):
text = "marchwarden test"
expected = "sha256:" + hashlib.sha256(text.encode("utf-8")).hexdigest()
assert _sha256(text) == expected
# ---------------------------------------------------------------------------
# _extract_text
# ---------------------------------------------------------------------------
class TestExtractText:
def test_strips_tags(self):
assert _extract_text("<p>hello</p>") == "hello"
def test_strips_script(self):
html = "<script>var x=1;</script><p>content</p>"
assert "var x" not in _extract_text(html)
assert "content" in _extract_text(html)
def test_strips_style(self):
html = "<style>.foo{color:red}</style><p>visible</p>"
assert "color" not in _extract_text(html)
assert "visible" in _extract_text(html)
def test_collapses_whitespace(self):
html = "<p>hello</p> <p>world</p>"
assert _extract_text(html) == "hello world"
def test_decodes_entities(self):
html = "&amp; &lt; &gt; &quot; &#39; &nbsp;"
result = _extract_text(html)
assert "&" in result
assert "<" in result
assert ">" in result
assert '"' in result
assert "'" in result
def test_empty_input(self):
assert _extract_text("") == ""
def test_plain_text_passthrough(self):
assert _extract_text("just plain text") == "just plain text"
# ---------------------------------------------------------------------------
# tavily_search
# ---------------------------------------------------------------------------
class TestTavilySearch:
@patch("researchers.web.tools.TavilyClient")
def test_returns_search_results(self, mock_client_class):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.search.return_value = {
"results": [
{
"url": "https://example.com/page1",
"title": "Page One",
"content": "Short summary of page one.",
"raw_content": "Full text of page one with lots of detail.",
"score": 0.95,
},
{
"url": "https://example.com/page2",
"title": "Page Two",
"content": "Short summary of page two.",
"raw_content": "",
"score": 0.80,
},
]
}
results = tavily_search("fake-key", "test query", max_results=2)
assert len(results) == 2
assert isinstance(results[0], SearchResult)
# First result has raw_content
assert results[0].url == "https://example.com/page1"
assert results[0].title == "Page One"
assert results[0].raw_content == "Full text of page one with lots of detail."
assert results[0].score == 0.95
assert results[0].content_hash.startswith("sha256:")
# Hash should be of raw_content (best available)
assert results[0].content_hash == _sha256(
"Full text of page one with lots of detail."
)
# Second result has empty raw_content → falls back to content
assert results[1].raw_content is None
assert results[1].content_hash == _sha256("Short summary of page two.")
@patch("researchers.web.tools.TavilyClient")
def test_empty_results(self, mock_client_class):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.search.return_value = {"results": []}
results = tavily_search("fake-key", "obscure query")
assert results == []
@patch("researchers.web.tools.TavilyClient")
def test_passes_params_to_client(self, mock_client_class):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.search.return_value = {"results": []}
tavily_search("my-key", "query", max_results=3, include_raw_content=False)
mock_client_class.assert_called_once_with(api_key="my-key")
mock_client.search.assert_called_once_with(
"query", max_results=3, include_raw_content=False
)
# ---------------------------------------------------------------------------
# fetch_url
# ---------------------------------------------------------------------------
class TestFetchUrl:
@pytest.mark.asyncio
async def test_successful_fetch(self):
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.text = "<html><body><p>Hello world</p></body></html>"
mock_response.raise_for_status = MagicMock()
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__ = AsyncMock(
return_value=mock_client
)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
mock_client.get.return_value = mock_response
result = await fetch_url("https://example.com")
assert isinstance(result, FetchResult)
assert result.success is True
assert result.error is None
assert "Hello world" in result.text
assert result.content_hash.startswith("sha256:")
assert result.content_length > 0
assert result.url == "https://example.com"
@pytest.mark.asyncio
async def test_timeout_error(self):
import httpx
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__ = AsyncMock(
return_value=mock_client
)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
mock_client.get.side_effect = httpx.TimeoutException("timed out")
result = await fetch_url("https://slow.example.com", timeout=5.0)
assert result.success is False
assert "Timeout" in result.error
assert result.text == ""
@pytest.mark.asyncio
async def test_http_error(self):
import httpx
mock_response = MagicMock()
mock_response.status_code = 404
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Not Found",
request=MagicMock(),
response=mock_response,
)
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__ = AsyncMock(
return_value=mock_client
)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
mock_client.get.return_value = mock_response
result = await fetch_url("https://example.com/missing")
assert result.success is False
assert "404" in result.error
@pytest.mark.asyncio
async def test_content_hash_consistency(self):
"""Same content → same hash."""
mock_response = AsyncMock()
mock_response.text = "<p>consistent content</p>"
mock_response.raise_for_status = MagicMock()
with patch("researchers.web.tools.httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__ = AsyncMock(
return_value=mock_client
)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=False)
mock_client.get.return_value = mock_response
r1 = await fetch_url("https://example.com/a")
r2 = await fetch_url("https://example.com/b")
assert r1.content_hash == r2.content_hash