marchwarden/researchers/arxiv/store.py

215 lines
7.6 KiB
Python
Raw Normal View History

feat(arxiv): ingest pipeline (M5.1.1) Closes #38. First sub-milestone of M5.1 (Researcher #2: arxiv-rag). New package researchers/arxiv/ with three modules: - store.py — ArxivStore wraps a persistent chromadb collection at ~/.marchwarden/arxiv-rag/chroma/ plus a papers.json manifest. Chunk ids are deterministic and embedding-model-scoped (per ArxivRagProposal decision 4) so re-ingesting with a different embedder doesn't collide with prior chunks. - ingest.py — three-phase pipeline: download_pdf (arxiv API), extract_sections (pymupdf with heuristic heading detection + whole-paper fallback), and embed_and_store (sentence-transformers, configurable via MARCHWARDEN_ARXIV_EMBED_MODEL). Top-level ingest() chains them and upserts the manifest entry. Re-ingest is idempotent — chunks for the same paper are dropped before re-adding. - CLI subgroup `marchwarden arxiv add|list|info|remove`. Lazy-imports the heavy chromadb / torch deps so non-arxiv commands stay fast. The heavy ML deps (pymupdf, chromadb, sentence-transformers, arxiv) are gated behind an optional `[arxiv]` extra so the base install stays slim for users who only want the web researcher. Tests: 14 added (141 total passing). Real pymupdf against synthetic PDFs generated at test time covers extract_sections; chromadb and the embedder are stubbed via dependency injection so the tests stay fast, deterministic, and network-free. End-to-end ingest() is exercised with a mocked arxiv.Search that produces synthetic PDFs. Out of scope for #38 (covered by later sub-milestones): - Retrieval / search API (#39) - ArxivResearcher agent loop (#40) - MCP server (#41) - ask --researcher arxiv flag (#42) - Cost ledger embedding_calls field (#43) Notes: - pip install pulled in CUDA torch wheel (~2GB nvidia libs); harmless on CPU-only WSL but a future optimization would pin the CPU torch index. - Live smoke against a real arxiv id deferred so we don't block the M3.3 collection runner currently using the venv. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 02:03:42 +00:00
"""Chromadb wrapper for the arxiv-rag researcher.
The store lives at ``~/.marchwarden/arxiv-rag/`` and contains:
- ``papers.json`` manifest mapping arxiv_id -> metadata
- ``pdfs/<id>.pdf`` cached PDFs
- ``chroma/`` chromadb persistent collection of embedded chunks
This module is intentionally narrow: it exposes the persistent state and
the operations the ingest + retrieval layers need (add chunks, fetch
manifest, list papers). The retrieval layer (#39) will add a query API
on top of the same collection.
Chunk IDs are deterministic and include the embedding model name in
their hash so re-ingesting a paper with a different embedder creates a
new ID space rather than silently overwriting old citations
(ArxivRagProposal §1, decision 4).
"""
from __future__ import annotations
import hashlib
import json
import os
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterable, Optional
DEFAULT_ROOT = Path(os.path.expanduser("~/.marchwarden/arxiv-rag"))
DEFAULT_COLLECTION = "arxiv_chunks"
@dataclass
class PaperRecord:
"""Manifest entry for one indexed paper."""
arxiv_id: str
version: str
title: str
authors: list[str]
year: Optional[int]
category: Optional[str]
chunks_indexed: int
embedding_model: str
added_at: str = field(
default_factory=lambda: datetime.now(timezone.utc)
.isoformat(timespec="seconds")
.replace("+00:00", "Z")
)
def to_dict(self) -> dict:
return {
"version": self.version,
"title": self.title,
"authors": list(self.authors),
"year": self.year,
"category": self.category,
"chunks_indexed": self.chunks_indexed,
"embedding_model": self.embedding_model,
"added_at": self.added_at,
}
@classmethod
def from_dict(cls, arxiv_id: str, data: dict) -> "PaperRecord":
return cls(
arxiv_id=arxiv_id,
version=data.get("version", ""),
title=data.get("title", ""),
authors=list(data.get("authors", [])),
year=data.get("year"),
category=data.get("category"),
chunks_indexed=int(data.get("chunks_indexed", 0)),
embedding_model=data.get("embedding_model", ""),
added_at=data.get("added_at", ""),
)
def make_chunk_id(arxiv_id: str, section_index: int, embedding_model: str) -> str:
"""Deterministic chunk id, scoped by embedding model.
Format: ``<arxiv_id>::<section_index>::<sha1(model)[0:8]>``. The
model hash slice keeps the id readable while making it unique across
embedding models. See ArxivRagProposal §1 decision 4 re-ingesting
with a different model must not collide with prior chunks.
"""
model_hash = hashlib.sha1(embedding_model.encode("utf-8")).hexdigest()[:8]
return f"{arxiv_id}::{section_index:04d}::{model_hash}"
class ArxivStore:
"""File-backed manifest + chromadb collection for indexed papers."""
def __init__(
self,
root: Optional[Path] = None,
collection_name: str = DEFAULT_COLLECTION,
):
self.root = Path(root) if root else DEFAULT_ROOT
self.pdfs_dir = self.root / "pdfs"
self.chroma_dir = self.root / "chroma"
self.manifest_path = self.root / "papers.json"
self.collection_name = collection_name
self.root.mkdir(parents=True, exist_ok=True)
self.pdfs_dir.mkdir(parents=True, exist_ok=True)
self.chroma_dir.mkdir(parents=True, exist_ok=True)
self._client = None # lazy
self._collection = None # lazy
# ------------------------------------------------------------------
# Chroma — lazy because importing chromadb is slow
# ------------------------------------------------------------------
@property
def collection(self):
"""Lazy chromadb collection handle."""
if self._collection is None:
import chromadb
self._client = chromadb.PersistentClient(path=str(self.chroma_dir))
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
# Cosine distance — typical for sentence-transformer
# embeddings normalized to unit length.
metadata={"hnsw:space": "cosine"},
)
return self._collection
def add_chunks(
self,
ids: list[str],
documents: list[str],
embeddings: list[list[float]],
metadatas: list[dict[str, Any]],
) -> None:
"""Add a batch of embedded chunks to the collection."""
if not ids:
return
if not (len(ids) == len(documents) == len(embeddings) == len(metadatas)):
raise ValueError(
"ids/documents/embeddings/metadatas must all have the same length"
)
self.collection.upsert(
ids=ids,
documents=documents,
embeddings=embeddings,
metadatas=metadatas,
)
def chunk_count_for(self, arxiv_id: str) -> int:
"""Number of chunks currently stored for one paper."""
# chromadb's get() with a where filter returns the matching docs;
# we just need the count.
try:
res = self.collection.get(where={"arxiv_id": arxiv_id})
except Exception:
return 0
return len(res.get("ids", []))
def delete_paper(self, arxiv_id: str) -> int:
"""Remove all chunks for one paper. Returns number deleted."""
before = self.chunk_count_for(arxiv_id)
if before == 0:
return 0
self.collection.delete(where={"arxiv_id": arxiv_id})
return before
# ------------------------------------------------------------------
# Manifest — plain JSON, atomic write
# ------------------------------------------------------------------
def load_manifest(self) -> dict[str, PaperRecord]:
"""Read papers.json. Returns {} if missing."""
if not self.manifest_path.exists():
return {}
data = json.loads(self.manifest_path.read_text(encoding="utf-8"))
return {
arxiv_id: PaperRecord.from_dict(arxiv_id, entry)
for arxiv_id, entry in data.items()
}
def save_manifest(self, manifest: dict[str, PaperRecord]) -> None:
"""Write papers.json atomically."""
payload = {arxiv_id: rec.to_dict() for arxiv_id, rec in manifest.items()}
tmp = self.manifest_path.with_suffix(".json.tmp")
tmp.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
tmp.replace(self.manifest_path)
def upsert_paper(self, record: PaperRecord) -> None:
"""Insert or replace one entry in the manifest."""
manifest = self.load_manifest()
manifest[record.arxiv_id] = record
self.save_manifest(manifest)
def remove_paper(self, arxiv_id: str) -> bool:
"""Drop one entry from the manifest. Returns True if removed."""
manifest = self.load_manifest()
if arxiv_id not in manifest:
return False
del manifest[arxiv_id]
self.save_manifest(manifest)
return True
def list_papers(self) -> list[PaperRecord]:
"""All manifest entries, sorted by added_at descending (newest first)."""
manifest = self.load_manifest()
return sorted(manifest.values(), key=lambda r: r.added_at, reverse=True)
def get_paper(self, arxiv_id: str) -> Optional[PaperRecord]:
"""Manifest entry for one paper, or None."""
return self.load_manifest().get(arxiv_id)