153 lines
5.3 KiB
Python
153 lines
5.3 KiB
Python
|
|
"""Tests for the MCP server."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from researchers.web.server import _read_secret, research
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# _read_secret
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestReadSecret:
|
||
|
|
def test_reads_key(self, tmp_path):
|
||
|
|
secrets = tmp_path / "secrets"
|
||
|
|
secrets.write_text("FOO=bar\nBAZ=qux\n")
|
||
|
|
with patch("researchers.web.server.os.path.expanduser", return_value=str(secrets)):
|
||
|
|
assert _read_secret("FOO") == "bar"
|
||
|
|
assert _read_secret("BAZ") == "qux"
|
||
|
|
|
||
|
|
def test_missing_key_raises(self, tmp_path):
|
||
|
|
secrets = tmp_path / "secrets"
|
||
|
|
secrets.write_text("FOO=bar\n")
|
||
|
|
with patch("researchers.web.server.os.path.expanduser", return_value=str(secrets)):
|
||
|
|
with pytest.raises(ValueError, match="MISSING"):
|
||
|
|
_read_secret("MISSING")
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# research tool
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestResearchTool:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_returns_valid_json(self):
|
||
|
|
"""The research tool should return a JSON string with all contract fields."""
|
||
|
|
from researchers.web.models import (
|
||
|
|
ResearchResult,
|
||
|
|
ConfidenceFactors,
|
||
|
|
CostMetadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
mock_result = ResearchResult(
|
||
|
|
answer="Test answer.",
|
||
|
|
citations=[],
|
||
|
|
gaps=[],
|
||
|
|
discovery_events=[],
|
||
|
|
open_questions=[],
|
||
|
|
confidence=0.8,
|
||
|
|
confidence_factors=ConfidenceFactors(
|
||
|
|
num_corroborating_sources=1,
|
||
|
|
source_authority="medium",
|
||
|
|
contradiction_detected=False,
|
||
|
|
query_specificity_match=0.7,
|
||
|
|
budget_exhausted=False,
|
||
|
|
recency="current",
|
||
|
|
),
|
||
|
|
cost_metadata=CostMetadata(
|
||
|
|
tokens_used=500,
|
||
|
|
iterations_run=1,
|
||
|
|
wall_time_sec=5.0,
|
||
|
|
budget_exhausted=False,
|
||
|
|
model_id="claude-test",
|
||
|
|
),
|
||
|
|
trace_id="test-trace-id",
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch("researchers.web.server._get_researcher") as mock_get:
|
||
|
|
mock_researcher = AsyncMock()
|
||
|
|
mock_researcher.research.return_value = mock_result
|
||
|
|
mock_get.return_value = mock_researcher
|
||
|
|
|
||
|
|
result_json = await research(
|
||
|
|
question="test question",
|
||
|
|
context="some context",
|
||
|
|
depth="shallow",
|
||
|
|
max_iterations=2,
|
||
|
|
token_budget=5000,
|
||
|
|
)
|
||
|
|
|
||
|
|
data = json.loads(result_json)
|
||
|
|
assert data["answer"] == "Test answer."
|
||
|
|
assert data["confidence"] == 0.8
|
||
|
|
assert data["trace_id"] == "test-trace-id"
|
||
|
|
assert "citations" in data
|
||
|
|
assert "gaps" in data
|
||
|
|
assert "discovery_events" in data
|
||
|
|
assert "open_questions" in data
|
||
|
|
assert "confidence_factors" in data
|
||
|
|
assert "cost_metadata" in data
|
||
|
|
|
||
|
|
# Verify researcher was called with correct args
|
||
|
|
mock_researcher.research.assert_called_once()
|
||
|
|
call_kwargs = mock_researcher.research.call_args[1]
|
||
|
|
assert call_kwargs["question"] == "test question"
|
||
|
|
assert call_kwargs["context"] == "some context"
|
||
|
|
assert call_kwargs["depth"] == "shallow"
|
||
|
|
assert call_kwargs["constraints"].max_iterations == 2
|
||
|
|
assert call_kwargs["constraints"].token_budget == 5000
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_defaults(self):
|
||
|
|
"""Test that defaults work when optional args are omitted."""
|
||
|
|
from researchers.web.models import (
|
||
|
|
ResearchResult,
|
||
|
|
ConfidenceFactors,
|
||
|
|
CostMetadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
mock_result = ResearchResult(
|
||
|
|
answer="Default test.",
|
||
|
|
citations=[],
|
||
|
|
gaps=[],
|
||
|
|
discovery_events=[],
|
||
|
|
open_questions=[],
|
||
|
|
confidence=0.5,
|
||
|
|
confidence_factors=ConfidenceFactors(
|
||
|
|
num_corroborating_sources=0,
|
||
|
|
source_authority="low",
|
||
|
|
contradiction_detected=False,
|
||
|
|
query_specificity_match=0.5,
|
||
|
|
budget_exhausted=False,
|
||
|
|
),
|
||
|
|
cost_metadata=CostMetadata(
|
||
|
|
tokens_used=100,
|
||
|
|
iterations_run=1,
|
||
|
|
wall_time_sec=1.0,
|
||
|
|
budget_exhausted=False,
|
||
|
|
model_id="claude-test",
|
||
|
|
),
|
||
|
|
trace_id="test-id",
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch("researchers.web.server._get_researcher") as mock_get:
|
||
|
|
mock_researcher = AsyncMock()
|
||
|
|
mock_researcher.research.return_value = mock_result
|
||
|
|
mock_get.return_value = mock_researcher
|
||
|
|
|
||
|
|
result_json = await research(question="just a question")
|
||
|
|
|
||
|
|
data = json.loads(result_json)
|
||
|
|
assert data["answer"] == "Default test."
|
||
|
|
|
||
|
|
call_kwargs = mock_researcher.research.call_args[1]
|
||
|
|
assert call_kwargs["context"] is None
|
||
|
|
assert call_kwargs["depth"] == "balanced"
|
||
|
|
assert call_kwargs["constraints"].max_iterations == 5
|
||
|
|
assert call_kwargs["constraints"].token_budget == 20000
|