Closes #38. First sub-milestone of M5.1 (Researcher #2: arxiv-rag). New package researchers/arxiv/ with three modules: - store.py — ArxivStore wraps a persistent chromadb collection at ~/.marchwarden/arxiv-rag/chroma/ plus a papers.json manifest. Chunk ids are deterministic and embedding-model-scoped (per ArxivRagProposal decision 4) so re-ingesting with a different embedder doesn't collide with prior chunks. - ingest.py — three-phase pipeline: download_pdf (arxiv API), extract_sections (pymupdf with heuristic heading detection + whole-paper fallback), and embed_and_store (sentence-transformers, configurable via MARCHWARDEN_ARXIV_EMBED_MODEL). Top-level ingest() chains them and upserts the manifest entry. Re-ingest is idempotent — chunks for the same paper are dropped before re-adding. - CLI subgroup `marchwarden arxiv add|list|info|remove`. Lazy-imports the heavy chromadb / torch deps so non-arxiv commands stay fast. The heavy ML deps (pymupdf, chromadb, sentence-transformers, arxiv) are gated behind an optional `[arxiv]` extra so the base install stays slim for users who only want the web researcher. Tests: 14 added (141 total passing). Real pymupdf against synthetic PDFs generated at test time covers extract_sections; chromadb and the embedder are stubbed via dependency injection so the tests stay fast, deterministic, and network-free. End-to-end ingest() is exercised with a mocked arxiv.Search that produces synthetic PDFs. Out of scope for #38 (covered by later sub-milestones): - Retrieval / search API (#39) - ArxivResearcher agent loop (#40) - MCP server (#41) - ask --researcher arxiv flag (#42) - Cost ledger embedding_calls field (#43) Notes: - pip install pulled in CUDA torch wheel (~2GB nvidia libs); harmless on CPU-only WSL but a future optimization would pin the CPU torch index. - Live smoke against a real arxiv id deferred so we don't block the M3.3 collection runner currently using the venv. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
214 lines
7.6 KiB
Python
214 lines
7.6 KiB
Python
"""Chromadb wrapper for the arxiv-rag researcher.
|
|
|
|
The store lives at ``~/.marchwarden/arxiv-rag/`` and contains:
|
|
|
|
- ``papers.json`` — manifest mapping arxiv_id -> metadata
|
|
- ``pdfs/<id>.pdf`` — cached PDFs
|
|
- ``chroma/`` — chromadb persistent collection of embedded chunks
|
|
|
|
This module is intentionally narrow: it exposes the persistent state and
|
|
the operations the ingest + retrieval layers need (add chunks, fetch
|
|
manifest, list papers). The retrieval layer (#39) will add a query API
|
|
on top of the same collection.
|
|
|
|
Chunk IDs are deterministic and include the embedding model name in
|
|
their hash so re-ingesting a paper with a different embedder creates a
|
|
new ID space rather than silently overwriting old citations
|
|
(ArxivRagProposal §1, decision 4).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Iterable, Optional
|
|
|
|
|
|
DEFAULT_ROOT = Path(os.path.expanduser("~/.marchwarden/arxiv-rag"))
|
|
DEFAULT_COLLECTION = "arxiv_chunks"
|
|
|
|
|
|
@dataclass
|
|
class PaperRecord:
|
|
"""Manifest entry for one indexed paper."""
|
|
|
|
arxiv_id: str
|
|
version: str
|
|
title: str
|
|
authors: list[str]
|
|
year: Optional[int]
|
|
category: Optional[str]
|
|
chunks_indexed: int
|
|
embedding_model: str
|
|
added_at: str = field(
|
|
default_factory=lambda: datetime.now(timezone.utc)
|
|
.isoformat(timespec="seconds")
|
|
.replace("+00:00", "Z")
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"version": self.version,
|
|
"title": self.title,
|
|
"authors": list(self.authors),
|
|
"year": self.year,
|
|
"category": self.category,
|
|
"chunks_indexed": self.chunks_indexed,
|
|
"embedding_model": self.embedding_model,
|
|
"added_at": self.added_at,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, arxiv_id: str, data: dict) -> "PaperRecord":
|
|
return cls(
|
|
arxiv_id=arxiv_id,
|
|
version=data.get("version", ""),
|
|
title=data.get("title", ""),
|
|
authors=list(data.get("authors", [])),
|
|
year=data.get("year"),
|
|
category=data.get("category"),
|
|
chunks_indexed=int(data.get("chunks_indexed", 0)),
|
|
embedding_model=data.get("embedding_model", ""),
|
|
added_at=data.get("added_at", ""),
|
|
)
|
|
|
|
|
|
def make_chunk_id(arxiv_id: str, section_index: int, embedding_model: str) -> str:
|
|
"""Deterministic chunk id, scoped by embedding model.
|
|
|
|
Format: ``<arxiv_id>::<section_index>::<sha1(model)[0:8]>``. The
|
|
model hash slice keeps the id readable while making it unique across
|
|
embedding models. See ArxivRagProposal §1 decision 4 — re-ingesting
|
|
with a different model must not collide with prior chunks.
|
|
"""
|
|
model_hash = hashlib.sha1(embedding_model.encode("utf-8")).hexdigest()[:8]
|
|
return f"{arxiv_id}::{section_index:04d}::{model_hash}"
|
|
|
|
|
|
class ArxivStore:
|
|
"""File-backed manifest + chromadb collection for indexed papers."""
|
|
|
|
def __init__(
|
|
self,
|
|
root: Optional[Path] = None,
|
|
collection_name: str = DEFAULT_COLLECTION,
|
|
):
|
|
self.root = Path(root) if root else DEFAULT_ROOT
|
|
self.pdfs_dir = self.root / "pdfs"
|
|
self.chroma_dir = self.root / "chroma"
|
|
self.manifest_path = self.root / "papers.json"
|
|
self.collection_name = collection_name
|
|
|
|
self.root.mkdir(parents=True, exist_ok=True)
|
|
self.pdfs_dir.mkdir(parents=True, exist_ok=True)
|
|
self.chroma_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._client = None # lazy
|
|
self._collection = None # lazy
|
|
|
|
# ------------------------------------------------------------------
|
|
# Chroma — lazy because importing chromadb is slow
|
|
# ------------------------------------------------------------------
|
|
|
|
@property
|
|
def collection(self):
|
|
"""Lazy chromadb collection handle."""
|
|
if self._collection is None:
|
|
import chromadb
|
|
|
|
self._client = chromadb.PersistentClient(path=str(self.chroma_dir))
|
|
self._collection = self._client.get_or_create_collection(
|
|
name=self.collection_name,
|
|
# Cosine distance — typical for sentence-transformer
|
|
# embeddings normalized to unit length.
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
return self._collection
|
|
|
|
def add_chunks(
|
|
self,
|
|
ids: list[str],
|
|
documents: list[str],
|
|
embeddings: list[list[float]],
|
|
metadatas: list[dict[str, Any]],
|
|
) -> None:
|
|
"""Add a batch of embedded chunks to the collection."""
|
|
if not ids:
|
|
return
|
|
if not (len(ids) == len(documents) == len(embeddings) == len(metadatas)):
|
|
raise ValueError(
|
|
"ids/documents/embeddings/metadatas must all have the same length"
|
|
)
|
|
self.collection.upsert(
|
|
ids=ids,
|
|
documents=documents,
|
|
embeddings=embeddings,
|
|
metadatas=metadatas,
|
|
)
|
|
|
|
def chunk_count_for(self, arxiv_id: str) -> int:
|
|
"""Number of chunks currently stored for one paper."""
|
|
# chromadb's get() with a where filter returns the matching docs;
|
|
# we just need the count.
|
|
try:
|
|
res = self.collection.get(where={"arxiv_id": arxiv_id})
|
|
except Exception:
|
|
return 0
|
|
return len(res.get("ids", []))
|
|
|
|
def delete_paper(self, arxiv_id: str) -> int:
|
|
"""Remove all chunks for one paper. Returns number deleted."""
|
|
before = self.chunk_count_for(arxiv_id)
|
|
if before == 0:
|
|
return 0
|
|
self.collection.delete(where={"arxiv_id": arxiv_id})
|
|
return before
|
|
|
|
# ------------------------------------------------------------------
|
|
# Manifest — plain JSON, atomic write
|
|
# ------------------------------------------------------------------
|
|
|
|
def load_manifest(self) -> dict[str, PaperRecord]:
|
|
"""Read papers.json. Returns {} if missing."""
|
|
if not self.manifest_path.exists():
|
|
return {}
|
|
data = json.loads(self.manifest_path.read_text(encoding="utf-8"))
|
|
return {
|
|
arxiv_id: PaperRecord.from_dict(arxiv_id, entry)
|
|
for arxiv_id, entry in data.items()
|
|
}
|
|
|
|
def save_manifest(self, manifest: dict[str, PaperRecord]) -> None:
|
|
"""Write papers.json atomically."""
|
|
payload = {arxiv_id: rec.to_dict() for arxiv_id, rec in manifest.items()}
|
|
tmp = self.manifest_path.with_suffix(".json.tmp")
|
|
tmp.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
|
tmp.replace(self.manifest_path)
|
|
|
|
def upsert_paper(self, record: PaperRecord) -> None:
|
|
"""Insert or replace one entry in the manifest."""
|
|
manifest = self.load_manifest()
|
|
manifest[record.arxiv_id] = record
|
|
self.save_manifest(manifest)
|
|
|
|
def remove_paper(self, arxiv_id: str) -> bool:
|
|
"""Drop one entry from the manifest. Returns True if removed."""
|
|
manifest = self.load_manifest()
|
|
if arxiv_id not in manifest:
|
|
return False
|
|
del manifest[arxiv_id]
|
|
self.save_manifest(manifest)
|
|
return True
|
|
|
|
def list_papers(self) -> list[PaperRecord]:
|
|
"""All manifest entries, sorted by added_at descending (newest first)."""
|
|
manifest = self.load_manifest()
|
|
return sorted(manifest.values(), key=lambda r: r.added_at, reverse=True)
|
|
|
|
def get_paper(self, arxiv_id: str) -> Optional[PaperRecord]:
|
|
"""Manifest entry for one paper, or None."""
|
|
return self.load_manifest().get(arxiv_id)
|