"""Ingest pipeline for the arxiv-rag researcher.
Public surface:
download_pdf(arxiv_id, store) -> Path
extract_sections(pdf_path) -> list[Section]
embed_and_store(arxiv_id, sections, store, model_name, metadata) -> int
ingest(arxiv_id, store=None, model_name=...) -> PaperRecord # one-shot
The split exists so unit tests can mock each phase independently. The
top-level ``ingest()`` is what the CLI calls.
Section detection is heuristic: we walk the PDF page by page, look for
short lines that match a small set of canonical academic headings
(introduction, methods, results, discussion, conclusion, references,
etc.), and use those as section boundaries. If nothing matches, we fall
back to one Section containing the entire paper text — citations to
that section will still be valid, just less precise.
"""
from __future__ import annotations
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
from .store import ArxivStore, PaperRecord, make_chunk_id
# ---------------------------------------------------------------------------
# Defaults
# ---------------------------------------------------------------------------
DEFAULT_EMBEDDING_MODEL = os.environ.get(
"MARCHWARDEN_ARXIV_EMBED_MODEL",
"nomic-ai/nomic-embed-text-v1.5",
)
# Headings considered "section starters" for the heuristic. Order
# matters only for documentation; matching is case-insensitive and
# whole-line.
_SECTION_HEADINGS = [
"abstract",
"introduction",
"background",
"related work",
"preliminaries",
"methods",
"method",
"methodology",
"approach",
"model",
"experiments",
"experimental setup",
"evaluation",
"results",
"discussion",
"analysis",
"limitations",
"conclusion",
"conclusions",
"future work",
"references",
"acknowledgments",
"appendix",
]
# Compiled match: optional leading number ("3", "3.1", "III"), optional
# trailing punctuation, the heading word, end of line.
_HEADING_RE = re.compile(
r"^\s*(?:[0-9IVX]+\.?[0-9.]*)?\s*(?P
" + "|".join(_SECTION_HEADINGS) + r")\s*$",
re.IGNORECASE,
)
@dataclass
class Section:
"""One section of a paper."""
index: int
title: str
text: str
page_start: int
page_end: int
@dataclass
class PaperMetadata:
"""Lightweight metadata extracted from arxiv at download time."""
arxiv_id: str
version: str
title: str
authors: list[str] = field(default_factory=list)
year: Optional[int] = None
category: Optional[str] = None
# ---------------------------------------------------------------------------
# Phase 1 — download
# ---------------------------------------------------------------------------
def download_pdf(
arxiv_id: str,
store: ArxivStore,
*,
arxiv_search: Optional[Callable] = None,
) -> tuple[Path, PaperMetadata]:
"""Download a paper PDF and return its cached path + arxiv metadata.
``arxiv_search`` is injectable for tests so we can avoid hitting the
real arxiv API. The default uses the ``arxiv`` package.
"""
target = store.pdfs_dir / f"{arxiv_id}.pdf"
if arxiv_search is None:
import arxiv as arxiv_pkg
search = arxiv_pkg.Search(id_list=[arxiv_id])
results = list(search.results())
else:
results = list(arxiv_search(arxiv_id))
if not results:
raise ValueError(f"arxiv id not found: {arxiv_id}")
paper = results[0]
# Download the PDF if we don't already have it cached.
if not target.exists():
# Both the real arxiv.Result and our test stub expose
# download_pdf(dirpath, filename). Test stubs may also accept a
# destination Path directly — try that first, fall back.
try:
paper.download_pdf(
dirpath=str(store.pdfs_dir),
filename=f"{arxiv_id}.pdf",
)
except TypeError:
paper.download_pdf(str(target))
metadata = PaperMetadata(
arxiv_id=arxiv_id,
version=getattr(paper, "entry_id", "").rsplit("v", 1)[-1] if "v" in getattr(paper, "entry_id", "") else "",
title=getattr(paper, "title", "") or "",
authors=[
getattr(a, "name", str(a))
for a in (getattr(paper, "authors", []) or [])
],
year=(
getattr(paper, "published", None).year
if getattr(paper, "published", None) is not None
else None
),
category=getattr(paper, "primary_category", None),
)
return target, metadata
# ---------------------------------------------------------------------------
# Phase 2 — extract sections
# ---------------------------------------------------------------------------
def extract_sections(pdf_path: Path) -> list[Section]:
"""Extract section-level chunks from a PDF.
Heuristic: walk pages, split on lines that match a known section
heading. If no headings are detected, return one Section containing
the whole document.
"""
import pymupdf
doc = pymupdf.open(str(pdf_path))
try:
# Build a flat list of (page_num, line) tuples for the whole doc.
lines: list[tuple[int, str]] = []
for page_num, page in enumerate(doc, start=1):
text = page.get_text("text") or ""
for raw_line in text.splitlines():
stripped = raw_line.strip()
if stripped:
lines.append((page_num, stripped))
finally:
doc.close()
# Find heading boundaries.
boundaries: list[tuple[int, str, int]] = [] # (line_index, title, page_num)
for i, (page_num, line) in enumerate(lines):
if len(line) > 80:
# Section headings are short. Skip likely body text.
continue
m = _HEADING_RE.match(line)
if m:
boundaries.append((i, m.group("title").strip().title(), page_num))
sections: list[Section] = []
if not boundaries:
# Fallback: whole paper as one section.
full_text = "\n".join(line for _, line in lines)
if not full_text.strip():
return []
first_page = lines[0][0] if lines else 1
last_page = lines[-1][0] if lines else 1
return [
Section(
index=0,
title="Full Paper",
text=full_text,
page_start=first_page,
page_end=last_page,
)
]
# Build sections between consecutive boundaries.
for idx, (start_line, title, page_start) in enumerate(boundaries):
end_line = (
boundaries[idx + 1][0] if idx + 1 < len(boundaries) else len(lines)
)
body_lines = lines[start_line + 1 : end_line]
text = "\n".join(line for _, line in body_lines).strip()
if not text:
continue
page_end = body_lines[-1][0] if body_lines else page_start
sections.append(
Section(
index=idx,
title=title,
text=text,
page_start=page_start,
page_end=page_end,
)
)
if not sections:
# Headings detected but every section was empty — fall back to
# whole paper rather than dropping the document.
full_text = "\n".join(line for _, line in lines)
return [
Section(
index=0,
title="Full Paper",
text=full_text,
page_start=lines[0][0],
page_end=lines[-1][0],
)
]
return sections
# ---------------------------------------------------------------------------
# Phase 3 — embed and store
# ---------------------------------------------------------------------------
def _load_embedder(model_name: str):
"""Load a sentence-transformers embedder. Cached at module level so
repeated ingests in the same process don't re-download / re-load.
"""
cache = _load_embedder._cache # type: ignore[attr-defined]
if model_name in cache:
return cache[model_name]
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer(model_name, trust_remote_code=True)
cache[model_name] = embedder
return embedder
_load_embedder._cache = {} # type: ignore[attr-defined]
def embed_and_store(
arxiv_id: str,
sections: list[Section],
store: ArxivStore,
model_name: str,
metadata: PaperMetadata,
*,
embedder: Optional[object] = None,
) -> int:
"""Embed each section and write to the chromadb collection.
``embedder`` is injectable for tests so we don't have to load
sentence-transformers. It must expose ``encode(list[str]) -> list[list[float]]``.
Returns the number of chunks written.
"""
if not sections:
return 0
if embedder is None:
embedder = _load_embedder(model_name)
texts = [s.text for s in sections]
raw_vectors = embedder.encode(texts)
# sentence-transformers returns a numpy.ndarray; chromadb wants
# plain lists. Handle both shapes.
embeddings: list[list[float]] = []
for vec in raw_vectors:
if hasattr(vec, "tolist"):
embeddings.append(vec.tolist())
else:
embeddings.append(list(vec))
ids = [make_chunk_id(arxiv_id, s.index, model_name) for s in sections]
metadatas = [
{
"arxiv_id": arxiv_id,
"section_index": s.index,
"section_title": s.title,
"page_start": s.page_start,
"page_end": s.page_end,
"title": metadata.title,
"embedding_model": model_name,
}
for s in sections
]
# Replace any prior chunks for this paper under this embedding model
# before re-adding. Idempotency: re-ingest with the same model is a
# no-op in observable state.
store.delete_paper(arxiv_id)
store.add_chunks(ids=ids, documents=texts, embeddings=embeddings, metadatas=metadatas)
return len(ids)
# ---------------------------------------------------------------------------
# Top-level orchestrator
# ---------------------------------------------------------------------------
def ingest(
arxiv_id: str,
store: Optional[ArxivStore] = None,
*,
model_name: str = DEFAULT_EMBEDDING_MODEL,
arxiv_search: Optional[Callable] = None,
embedder: Optional[object] = None,
) -> PaperRecord:
"""End-to-end ingest: download → extract → embed → store → manifest."""
store = store or ArxivStore()
pdf_path, metadata = download_pdf(arxiv_id, store, arxiv_search=arxiv_search)
sections = extract_sections(pdf_path)
chunk_count = embed_and_store(
arxiv_id=arxiv_id,
sections=sections,
store=store,
model_name=model_name,
metadata=metadata,
embedder=embedder,
)
record = PaperRecord(
arxiv_id=arxiv_id,
version=metadata.version,
title=metadata.title,
authors=metadata.authors,
year=metadata.year,
category=metadata.category,
chunks_indexed=chunk_count,
embedding_model=model_name,
)
store.upsert_paper(record)
return record