"""Chromadb wrapper for the arxiv-rag researcher. The store lives at ``~/.marchwarden/arxiv-rag/`` and contains: - ``papers.json`` — manifest mapping arxiv_id -> metadata - ``pdfs/.pdf`` — cached PDFs - ``chroma/`` — chromadb persistent collection of embedded chunks This module is intentionally narrow: it exposes the persistent state and the operations the ingest + retrieval layers need (add chunks, fetch manifest, list papers). The retrieval layer (#39) will add a query API on top of the same collection. Chunk IDs are deterministic and include the embedding model name in their hash so re-ingesting a paper with a different embedder creates a new ID space rather than silently overwriting old citations (ArxivRagProposal §1, decision 4). """ from __future__ import annotations import hashlib import json import os from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any, Iterable, Optional DEFAULT_ROOT = Path(os.path.expanduser("~/.marchwarden/arxiv-rag")) DEFAULT_COLLECTION = "arxiv_chunks" @dataclass class PaperRecord: """Manifest entry for one indexed paper.""" arxiv_id: str version: str title: str authors: list[str] year: Optional[int] category: Optional[str] chunks_indexed: int embedding_model: str added_at: str = field( default_factory=lambda: datetime.now(timezone.utc) .isoformat(timespec="seconds") .replace("+00:00", "Z") ) def to_dict(self) -> dict: return { "version": self.version, "title": self.title, "authors": list(self.authors), "year": self.year, "category": self.category, "chunks_indexed": self.chunks_indexed, "embedding_model": self.embedding_model, "added_at": self.added_at, } @classmethod def from_dict(cls, arxiv_id: str, data: dict) -> "PaperRecord": return cls( arxiv_id=arxiv_id, version=data.get("version", ""), title=data.get("title", ""), authors=list(data.get("authors", [])), year=data.get("year"), category=data.get("category"), chunks_indexed=int(data.get("chunks_indexed", 0)), embedding_model=data.get("embedding_model", ""), added_at=data.get("added_at", ""), ) def make_chunk_id(arxiv_id: str, section_index: int, embedding_model: str) -> str: """Deterministic chunk id, scoped by embedding model. Format: ``::::``. The model hash slice keeps the id readable while making it unique across embedding models. See ArxivRagProposal §1 decision 4 — re-ingesting with a different model must not collide with prior chunks. """ model_hash = hashlib.sha1(embedding_model.encode("utf-8")).hexdigest()[:8] return f"{arxiv_id}::{section_index:04d}::{model_hash}" class ArxivStore: """File-backed manifest + chromadb collection for indexed papers.""" def __init__( self, root: Optional[Path] = None, collection_name: str = DEFAULT_COLLECTION, ): self.root = Path(root) if root else DEFAULT_ROOT self.pdfs_dir = self.root / "pdfs" self.chroma_dir = self.root / "chroma" self.manifest_path = self.root / "papers.json" self.collection_name = collection_name self.root.mkdir(parents=True, exist_ok=True) self.pdfs_dir.mkdir(parents=True, exist_ok=True) self.chroma_dir.mkdir(parents=True, exist_ok=True) self._client = None # lazy self._collection = None # lazy # ------------------------------------------------------------------ # Chroma — lazy because importing chromadb is slow # ------------------------------------------------------------------ @property def collection(self): """Lazy chromadb collection handle.""" if self._collection is None: import chromadb self._client = chromadb.PersistentClient(path=str(self.chroma_dir)) self._collection = self._client.get_or_create_collection( name=self.collection_name, # Cosine distance — typical for sentence-transformer # embeddings normalized to unit length. metadata={"hnsw:space": "cosine"}, ) return self._collection def add_chunks( self, ids: list[str], documents: list[str], embeddings: list[list[float]], metadatas: list[dict[str, Any]], ) -> None: """Add a batch of embedded chunks to the collection.""" if not ids: return if not (len(ids) == len(documents) == len(embeddings) == len(metadatas)): raise ValueError( "ids/documents/embeddings/metadatas must all have the same length" ) self.collection.upsert( ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas, ) def chunk_count_for(self, arxiv_id: str) -> int: """Number of chunks currently stored for one paper.""" # chromadb's get() with a where filter returns the matching docs; # we just need the count. try: res = self.collection.get(where={"arxiv_id": arxiv_id}) except Exception: return 0 return len(res.get("ids", [])) def delete_paper(self, arxiv_id: str) -> int: """Remove all chunks for one paper. Returns number deleted.""" before = self.chunk_count_for(arxiv_id) if before == 0: return 0 self.collection.delete(where={"arxiv_id": arxiv_id}) return before # ------------------------------------------------------------------ # Manifest — plain JSON, atomic write # ------------------------------------------------------------------ def load_manifest(self) -> dict[str, PaperRecord]: """Read papers.json. Returns {} if missing.""" if not self.manifest_path.exists(): return {} data = json.loads(self.manifest_path.read_text(encoding="utf-8")) return { arxiv_id: PaperRecord.from_dict(arxiv_id, entry) for arxiv_id, entry in data.items() } def save_manifest(self, manifest: dict[str, PaperRecord]) -> None: """Write papers.json atomically.""" payload = {arxiv_id: rec.to_dict() for arxiv_id, rec in manifest.items()} tmp = self.manifest_path.with_suffix(".json.tmp") tmp.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") tmp.replace(self.manifest_path) def upsert_paper(self, record: PaperRecord) -> None: """Insert or replace one entry in the manifest.""" manifest = self.load_manifest() manifest[record.arxiv_id] = record self.save_manifest(manifest) def remove_paper(self, arxiv_id: str) -> bool: """Drop one entry from the manifest. Returns True if removed.""" manifest = self.load_manifest() if arxiv_id not in manifest: return False del manifest[arxiv_id] self.save_manifest(manifest) return True def list_papers(self) -> list[PaperRecord]: """All manifest entries, sorted by added_at descending (newest first).""" manifest = self.load_manifest() return sorted(manifest.values(), key=lambda r: r.added_at, reverse=True) def get_paper(self, arxiv_id: str) -> Optional[PaperRecord]: """Manifest entry for one paper, or None.""" return self.load_manifest().get(arxiv_id)