Compare commits
2 commits
e942ecc34a
...
b2d00dd301
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2d00dd301 | ||
|
|
2e3d21f774 |
3 changed files with 167 additions and 2 deletions
|
|
@ -753,8 +753,54 @@ def _get_child_summaries(dir_path, cache):
|
||||||
return "\n".join(parts) if parts else "(none — this is a leaf directory)"
|
return "\n".join(parts) if parts else "(none — this is a leaf directory)"
|
||||||
|
|
||||||
|
|
||||||
|
_SURVEY_CONFIDENCE_THRESHOLD = 0.5
|
||||||
|
_PROTECTED_DIR_TOOLS = {"submit_report"}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_survey_block(survey):
|
||||||
|
"""Render survey output as a labeled text block for the dir prompt."""
|
||||||
|
if not survey:
|
||||||
|
return "(no survey available)"
|
||||||
|
lines = [
|
||||||
|
f"Description: {survey.get('description', '')}",
|
||||||
|
f"Approach: {survey.get('approach', '')}",
|
||||||
|
]
|
||||||
|
notes = survey.get("domain_notes", "")
|
||||||
|
if notes:
|
||||||
|
lines.append(f"Domain notes: {notes}")
|
||||||
|
relevant = survey.get("relevant_tools") or []
|
||||||
|
if relevant:
|
||||||
|
lines.append(f"Relevant tools (lean on these): {', '.join(relevant)}")
|
||||||
|
skip = survey.get("skip_tools") or []
|
||||||
|
if skip:
|
||||||
|
lines.append(f"Skip tools (already removed from your toolbox): "
|
||||||
|
f"{', '.join(skip)}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_dir_tools(survey):
|
||||||
|
"""Return _DIR_TOOLS with skip_tools removed, gated on confidence.
|
||||||
|
|
||||||
|
- Returns full list if survey is None or confidence < threshold.
|
||||||
|
- Always preserves control-flow tools in _PROTECTED_DIR_TOOLS.
|
||||||
|
- Tool names in skip_tools that don't match anything are silently ignored.
|
||||||
|
"""
|
||||||
|
if not survey:
|
||||||
|
return list(_DIR_TOOLS)
|
||||||
|
try:
|
||||||
|
confidence = float(survey.get("confidence", 0.0) or 0.0)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
confidence = 0.0
|
||||||
|
if confidence < _SURVEY_CONFIDENCE_THRESHOLD:
|
||||||
|
return list(_DIR_TOOLS)
|
||||||
|
skip = set(survey.get("skip_tools") or []) - _PROTECTED_DIR_TOOLS
|
||||||
|
if not skip:
|
||||||
|
return list(_DIR_TOOLS)
|
||||||
|
return [t for t in _DIR_TOOLS if t["name"] not in skip]
|
||||||
|
|
||||||
|
|
||||||
def _run_dir_loop(client, target, cache, tracker, dir_path, max_turns=14,
|
def _run_dir_loop(client, target, cache, tracker, dir_path, max_turns=14,
|
||||||
verbose=False):
|
verbose=False, survey=None):
|
||||||
"""Run an isolated agent loop for a single directory."""
|
"""Run an isolated agent loop for a single directory."""
|
||||||
dir_rel = os.path.relpath(dir_path, target)
|
dir_rel = os.path.relpath(dir_path, target)
|
||||||
if dir_rel == ".":
|
if dir_rel == ".":
|
||||||
|
|
@ -762,6 +808,8 @@ def _run_dir_loop(client, target, cache, tracker, dir_path, max_turns=14,
|
||||||
|
|
||||||
context = _build_dir_context(dir_path)
|
context = _build_dir_context(dir_path)
|
||||||
child_summaries = _get_child_summaries(dir_path, cache)
|
child_summaries = _get_child_summaries(dir_path, cache)
|
||||||
|
survey_context = _format_survey_block(survey)
|
||||||
|
dir_tools = _filter_dir_tools(survey)
|
||||||
|
|
||||||
system = _DIR_SYSTEM_PROMPT.format(
|
system = _DIR_SYSTEM_PROMPT.format(
|
||||||
dir_path=dir_path,
|
dir_path=dir_path,
|
||||||
|
|
@ -769,6 +817,7 @@ def _run_dir_loop(client, target, cache, tracker, dir_path, max_turns=14,
|
||||||
max_turns=max_turns,
|
max_turns=max_turns,
|
||||||
context=context,
|
context=context,
|
||||||
child_summaries=child_summaries,
|
child_summaries=child_summaries,
|
||||||
|
survey_context=survey_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
|
|
@ -844,7 +893,7 @@ def _run_dir_loop(client, target, cache, tracker, dir_path, max_turns=14,
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content_blocks, usage = _call_api_streaming(
|
content_blocks, usage = _call_api_streaming(
|
||||||
client, system, messages, _DIR_TOOLS, tracker,
|
client, system, messages, dir_tools, tracker,
|
||||||
)
|
)
|
||||||
except anthropic.APIError as e:
|
except anthropic.APIError as e:
|
||||||
print(f" [AI] API error: {e}", file=sys.stderr)
|
print(f" [AI] API error: {e}", file=sys.stderr)
|
||||||
|
|
@ -1229,6 +1278,7 @@ def _run_investigation(client, target, report, show_hidden=False,
|
||||||
|
|
||||||
summary = _run_dir_loop(
|
summary = _run_dir_loop(
|
||||||
client, target, cache, tracker, dir_path, verbose=verbose,
|
client, target, cache, tracker, dir_path, verbose=verbose,
|
||||||
|
survey=survey,
|
||||||
)
|
)
|
||||||
|
|
||||||
if summary and not cache.has_entry("dir", dir_path):
|
if summary and not cache.has_entry("dir", dir_path):
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,9 @@ why you are uncertain (e.g. "binary file, content not readable",
|
||||||
"file truncated at max_bytes"). Do NOT set confidence_reason when
|
"file truncated at max_bytes"). Do NOT set confidence_reason when
|
||||||
confidence is 0.7 or above.
|
confidence is 0.7 or above.
|
||||||
|
|
||||||
|
## Survey Context
|
||||||
|
{survey_context}
|
||||||
|
|
||||||
## Context
|
## Context
|
||||||
{context}
|
{context}
|
||||||
|
|
||||||
|
|
|
||||||
112
tests/test_ai_filter.py
Normal file
112
tests/test_ai_filter.py
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
"""Tests for the pure helpers in ai.py that don't require a live API."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ai():
|
||||||
|
# Stub heavy/optional deps so ai.py imports cleanly in unit tests.
|
||||||
|
for mod in ("anthropic", "magic"):
|
||||||
|
if mod not in sys.modules:
|
||||||
|
sys.modules[mod] = MagicMock()
|
||||||
|
if "luminos_lib.ast_parser" not in sys.modules:
|
||||||
|
stub = MagicMock()
|
||||||
|
stub.parse_structure = MagicMock()
|
||||||
|
sys.modules["luminos_lib.ast_parser"] = stub
|
||||||
|
from luminos_lib import ai
|
||||||
|
return ai
|
||||||
|
|
||||||
|
|
||||||
|
ai = _import_ai()
|
||||||
|
|
||||||
|
|
||||||
|
class FilterDirToolsTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.all_names = {t["name"] for t in ai._DIR_TOOLS}
|
||||||
|
|
||||||
|
def _names(self, tools):
|
||||||
|
return {t["name"] for t in tools}
|
||||||
|
|
||||||
|
def test_none_survey_returns_full_list(self):
|
||||||
|
self.assertEqual(self._names(ai._filter_dir_tools(None)), self.all_names)
|
||||||
|
|
||||||
|
def test_low_confidence_returns_full_list(self):
|
||||||
|
survey = {"confidence": 0.3, "skip_tools": ["run_command"]}
|
||||||
|
self.assertEqual(self._names(ai._filter_dir_tools(survey)), self.all_names)
|
||||||
|
|
||||||
|
def test_high_confidence_drops_skip_tools(self):
|
||||||
|
survey = {"confidence": 0.9, "skip_tools": ["run_command"]}
|
||||||
|
result = self._names(ai._filter_dir_tools(survey))
|
||||||
|
self.assertNotIn("run_command", result)
|
||||||
|
self.assertEqual(result, self.all_names - {"run_command"})
|
||||||
|
|
||||||
|
def test_threshold_boundary_inclusive(self):
|
||||||
|
survey = {"confidence": 0.5, "skip_tools": ["run_command"]}
|
||||||
|
result = self._names(ai._filter_dir_tools(survey))
|
||||||
|
self.assertNotIn("run_command", result)
|
||||||
|
|
||||||
|
def test_protected_tool_never_dropped(self):
|
||||||
|
survey = {"confidence": 1.0, "skip_tools": ["submit_report", "run_command"]}
|
||||||
|
result = self._names(ai._filter_dir_tools(survey))
|
||||||
|
self.assertIn("submit_report", result)
|
||||||
|
self.assertNotIn("run_command", result)
|
||||||
|
|
||||||
|
def test_unknown_tool_in_skip_is_ignored(self):
|
||||||
|
survey = {"confidence": 0.9, "skip_tools": ["nonexistent_tool"]}
|
||||||
|
self.assertEqual(self._names(ai._filter_dir_tools(survey)), self.all_names)
|
||||||
|
|
||||||
|
def test_empty_skip_tools_returns_full_list(self):
|
||||||
|
survey = {"confidence": 0.9, "skip_tools": []}
|
||||||
|
self.assertEqual(self._names(ai._filter_dir_tools(survey)), self.all_names)
|
||||||
|
|
||||||
|
def test_missing_confidence_treated_as_zero(self):
|
||||||
|
survey = {"skip_tools": ["run_command"]}
|
||||||
|
self.assertEqual(self._names(ai._filter_dir_tools(survey)), self.all_names)
|
||||||
|
|
||||||
|
def test_garbage_confidence_treated_as_zero(self):
|
||||||
|
survey = {"confidence": "not a number", "skip_tools": ["run_command"]}
|
||||||
|
self.assertEqual(self._names(ai._filter_dir_tools(survey)), self.all_names)
|
||||||
|
|
||||||
|
def test_multiple_skip_tools(self):
|
||||||
|
survey = {
|
||||||
|
"confidence": 0.9,
|
||||||
|
"skip_tools": ["run_command", "parse_structure"],
|
||||||
|
}
|
||||||
|
result = self._names(ai._filter_dir_tools(survey))
|
||||||
|
self.assertNotIn("run_command", result)
|
||||||
|
self.assertNotIn("parse_structure", result)
|
||||||
|
|
||||||
|
|
||||||
|
class FormatSurveyBlockTests(unittest.TestCase):
|
||||||
|
def test_none_returns_placeholder(self):
|
||||||
|
self.assertIn("no survey", ai._format_survey_block(None).lower())
|
||||||
|
|
||||||
|
def test_includes_description_and_approach(self):
|
||||||
|
block = ai._format_survey_block({
|
||||||
|
"description": "A Python lib", "approach": "read modules",
|
||||||
|
"confidence": 0.9,
|
||||||
|
})
|
||||||
|
self.assertIn("A Python lib", block)
|
||||||
|
self.assertIn("read modules", block)
|
||||||
|
|
||||||
|
def test_includes_skip_tools_when_present(self):
|
||||||
|
block = ai._format_survey_block({
|
||||||
|
"description": "x", "approach": "y",
|
||||||
|
"skip_tools": ["run_command"], "confidence": 0.9,
|
||||||
|
})
|
||||||
|
self.assertIn("run_command", block)
|
||||||
|
|
||||||
|
def test_omits_empty_optional_fields(self):
|
||||||
|
block = ai._format_survey_block({
|
||||||
|
"description": "x", "approach": "y",
|
||||||
|
"domain_notes": "", "relevant_tools": [], "skip_tools": [],
|
||||||
|
"confidence": 0.9,
|
||||||
|
})
|
||||||
|
self.assertNotIn("Domain notes:", block)
|
||||||
|
self.assertNotIn("Relevant tools", block)
|
||||||
|
self.assertNotIn("Skip tools", block)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in a new issue