The artifact-upsert helper was duplicated four times (scanner.py, table_processor.py, figure_processor.py, pipeline.py) with slightly different signatures. Consolidates into a single keyword-only function keyed on (document_id, storage_key) - the identity the schema already enforces - so re-running the pipeline never creates duplicate rows. scanner / table_processor / figure_processor now import the shared helper directly. pipeline.py keeps a thin local wrapper to preserve the positional call sites at three artifact upsert points (OCR_PDF, MARKDOWN, DOCLING_JSON). Tests: 24 passed (5 health + 19 original). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
379 lines
14 KiB
Python
379 lines
14 KiB
Python
"""Per-document end-to-end pipeline: OCR -> Docling -> chunk -> persist -> index.
|
|
|
|
Called by the Celery worker. Idempotent: re-running on the same document deletes
|
|
existing chunks for that document and re-creates them, then re-indexes in
|
|
OpenSearch and Qdrant.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from sqlalchemy import delete, select
|
|
|
|
from app.config import settings
|
|
from app.db.models import (
|
|
ArtifactType,
|
|
Chunk,
|
|
Document,
|
|
DocumentArtifact,
|
|
DocumentStatus,
|
|
Page,
|
|
ProcessingEvent,
|
|
)
|
|
from app.storage.artifacts import ensure_artifact
|
|
from app.db.session import session_scope
|
|
from app.indexing import opensearch_client, qdrant_client
|
|
from app.indexing.embeddings import get_embedder
|
|
from app.ingestion.chunker import ChunkRecord, chunk_extraction
|
|
from app.ingestion.docling_extractor import ExtractionResult, extract
|
|
from app.ingestion.figure_processor import persist_figures
|
|
from app.ingestion.ocr import run_ocr
|
|
from app.ingestion.table_processor import persist_tables
|
|
from app.logging_config import get_logger
|
|
from app.storage.local_paths import (
|
|
key_docling_json,
|
|
key_markdown,
|
|
key_ocr_pdf,
|
|
work_dir_for,
|
|
)
|
|
from app.storage.minio_client import get_storage
|
|
from app.utils.language import detect_language
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def process_document_id(document_id: uuid.UUID, run_id: uuid.UUID | None = None) -> dict[str, Any]:
|
|
"""Top-level entry called by the Celery task. Wraps the pipeline in
|
|
error handling so the task always either succeeds or marks the document FAILED.
|
|
"""
|
|
storage = get_storage()
|
|
storage.ensure_buckets()
|
|
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is None:
|
|
logger.warning("pipeline.document_missing", document_id=str(document_id))
|
|
return {"status": "missing"}
|
|
|
|
source_path = Path(doc.source_path)
|
|
sha = doc.sha256
|
|
original_artifact = db.execute(
|
|
select(DocumentArtifact).where(
|
|
DocumentArtifact.document_id == doc.id,
|
|
DocumentArtifact.artifact_type == ArtifactType.ORIGINAL_PDF,
|
|
)
|
|
).scalar_one_or_none()
|
|
|
|
work_dir = work_dir_for(document_id)
|
|
local_pdf = work_dir / f"{sha}.pdf"
|
|
if not local_pdf.exists():
|
|
if source_path.exists():
|
|
local_pdf.write_bytes(source_path.read_bytes())
|
|
elif original_artifact:
|
|
storage.get_to_path(original_artifact.storage_bucket, original_artifact.storage_key, local_pdf)
|
|
else:
|
|
return _fail(document_id, run_id, "OCR_FAILED", "Original PDF not available locally or in MinIO")
|
|
|
|
# ---------------- OCR ----------------
|
|
ocr_pdf = work_dir / "ocr.pdf"
|
|
try:
|
|
_emit_event(document_id, run_id, DocumentStatus.OCR_STARTED, "OCR started")
|
|
ocr_result = run_ocr(local_pdf, ocr_pdf, languages=settings.ocr_languages)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("pipeline.ocr_failed", document_id=str(document_id))
|
|
return _fail(document_id, run_id, DocumentStatus.OCR_FAILED, f"OCR failed: {exc}")
|
|
|
|
# Upload OCR PDF (even if we 'skipped' it - OCR PDF is the canonical input to Docling).
|
|
ocr_key = key_ocr_pdf(document_id)
|
|
storage.put_file(
|
|
bucket=storage.derived_bucket,
|
|
key=ocr_key,
|
|
path=ocr_result.output_path,
|
|
content_type="application/pdf",
|
|
)
|
|
with session_scope() as db:
|
|
_ensure_artifact(db, document_id, ArtifactType.OCR_PDF, storage.derived_bucket, ocr_key)
|
|
doc = db.get(Document, document_id)
|
|
if doc is not None:
|
|
doc.status = DocumentStatus.OCR_COMPLETED
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=DocumentStatus.OCR_COMPLETED,
|
|
level="INFO",
|
|
message=f"OCR finished ({ocr_result.reason})",
|
|
data={"skipped": ocr_result.skipped, "languages": ocr_result.languages},
|
|
)
|
|
)
|
|
|
|
# ---------------- Docling ----------------
|
|
try:
|
|
_emit_event(document_id, run_id, DocumentStatus.EXTRACTION_STARTED, "Docling extraction started")
|
|
extraction = extract(ocr_result.output_path)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("pipeline.docling_failed", document_id=str(document_id))
|
|
return _fail(document_id, run_id, DocumentStatus.EXTRACTION_FAILED, f"Docling failed: {exc}")
|
|
|
|
# Persist Markdown + JSON to MinIO.
|
|
md_key = key_markdown(document_id)
|
|
json_key = key_docling_json(document_id)
|
|
storage.put_bytes(
|
|
bucket=storage.derived_bucket,
|
|
key=md_key,
|
|
data=extraction.markdown.encode("utf-8"),
|
|
content_type="text/markdown",
|
|
)
|
|
storage.put_bytes(
|
|
bucket=storage.derived_bucket,
|
|
key=json_key,
|
|
data=json.dumps(extraction.json_payload, ensure_ascii=False).encode("utf-8"),
|
|
content_type="application/json",
|
|
)
|
|
|
|
# ---------------- Persist pages, chunks, tables, figures ----------------
|
|
chunk_records = chunk_extraction(extraction)
|
|
sample_text = "\n".join(p.text for p in extraction.pages[:3] if p.text)
|
|
lang = detect_language(sample_text)
|
|
|
|
with session_scope() as db:
|
|
_ensure_artifact(db, document_id, ArtifactType.MARKDOWN, storage.derived_bucket, md_key)
|
|
_ensure_artifact(db, document_id, ArtifactType.DOCLING_JSON, storage.derived_bucket, json_key)
|
|
|
|
doc = db.get(Document, document_id)
|
|
if doc is None:
|
|
return {"status": "missing"}
|
|
doc.status = DocumentStatus.EXTRACTION_COMPLETED
|
|
if lang and not doc.language_hint:
|
|
doc.language_hint = lang
|
|
|
|
page_id_by_number = _upsert_pages(db, document_id, extraction)
|
|
persist_tables(db, storage, document_id, extraction.tables, page_id_by_number)
|
|
persist_figures(db, storage, document_id, extraction.figures, page_id_by_number)
|
|
|
|
# Replace chunks idempotently: drop all and re-insert.
|
|
db.execute(delete(Chunk).where(Chunk.document_id == document_id))
|
|
for cr in chunk_records:
|
|
db.add(_to_chunk_row(document_id, page_id_by_number, cr))
|
|
|
|
doc.status = DocumentStatus.CHUNKING_COMPLETED
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=DocumentStatus.CHUNKING_COMPLETED,
|
|
level="INFO",
|
|
message="Chunking complete",
|
|
data={"chunks": len(chunk_records)},
|
|
)
|
|
)
|
|
|
|
# ---------------- Indexing (OpenSearch + Qdrant) ----------------
|
|
try:
|
|
opensearch_client.ensure_index()
|
|
qdrant_client.ensure_collection()
|
|
opensearch_client.delete_by_document(str(document_id))
|
|
qdrant_client.delete_by_document(str(document_id))
|
|
|
|
os_docs, qdrant_points = _build_index_payloads(document_id, chunk_records, extraction, lang)
|
|
if os_docs:
|
|
opensearch_client.index_chunks(os_docs)
|
|
if qdrant_points:
|
|
embedder = get_embedder()
|
|
texts_to_embed = [text for _, text, _ in qdrant_points]
|
|
vectors = embedder.encode(texts_to_embed)
|
|
triples = [
|
|
(chunk_id, vec, payload)
|
|
for (chunk_id, _text, payload), vec in zip(qdrant_points, vectors, strict=True)
|
|
]
|
|
qdrant_client.upsert_chunks(triples)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("pipeline.indexing_failed", document_id=str(document_id))
|
|
return _fail(document_id, run_id, DocumentStatus.FAILED, f"Indexing failed: {exc}")
|
|
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is not None:
|
|
doc.status = DocumentStatus.INDEXING_COMPLETED
|
|
doc.error_message = None
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=DocumentStatus.INDEXING_COMPLETED,
|
|
level="INFO",
|
|
message="Indexing complete",
|
|
data={"chunks": len(chunk_records)},
|
|
)
|
|
)
|
|
|
|
return {"status": DocumentStatus.INDEXING_COMPLETED, "chunks": len(chunk_records)}
|
|
|
|
|
|
# ---------------- helpers ----------------
|
|
|
|
def _to_chunk_row(
|
|
document_id: uuid.UUID, page_id_by_number: dict[int, uuid.UUID], cr: ChunkRecord
|
|
) -> Chunk:
|
|
return Chunk(
|
|
document_id=document_id,
|
|
page_id=page_id_by_number.get(cr.page_number),
|
|
page_number=cr.page_number,
|
|
block_id=cr.block_id,
|
|
chunk_index=cr.chunk_index,
|
|
block_type=cr.block_type,
|
|
text=cr.text,
|
|
normalized_text=cr.normalized_text,
|
|
token_count=cr.token_count,
|
|
ocr_confidence=None,
|
|
quality_flags=cr.quality_flags,
|
|
chunk_metadata=cr.metadata,
|
|
)
|
|
|
|
|
|
def _upsert_pages(db, document_id: uuid.UUID, extraction: ExtractionResult) -> dict[int, uuid.UUID]:
|
|
existing = {
|
|
p.page_number: p
|
|
for p in db.execute(select(Page).where(Page.document_id == document_id)).scalars()
|
|
}
|
|
out: dict[int, uuid.UUID] = {}
|
|
for ep in extraction.pages:
|
|
page = existing.get(ep.page_number)
|
|
if page is None:
|
|
page = Page(
|
|
document_id=document_id,
|
|
page_number=ep.page_number,
|
|
text=ep.text,
|
|
ocr_confidence=ep.ocr_confidence,
|
|
has_tables=ep.has_tables,
|
|
has_figures=ep.has_figures,
|
|
has_handwriting=ep.has_handwriting,
|
|
)
|
|
db.add(page)
|
|
db.flush()
|
|
else:
|
|
page.text = ep.text
|
|
page.has_tables = ep.has_tables
|
|
page.has_figures = ep.has_figures
|
|
page.has_handwriting = ep.has_handwriting
|
|
out[ep.page_number] = page.id
|
|
return out
|
|
|
|
|
|
def _build_index_payloads(
|
|
document_id: uuid.UUID,
|
|
chunks: list[ChunkRecord],
|
|
extraction: ExtractionResult,
|
|
language_hint: str | None,
|
|
) -> tuple[list[dict[str, Any]], list[tuple[str, str, dict[str, Any]]]]:
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is None:
|
|
return [], []
|
|
original_file_name = doc.original_file_name
|
|
source_path = doc.source_path
|
|
|
|
chunk_rows = (
|
|
db.execute(select(Chunk).where(Chunk.document_id == document_id))
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
os_docs: list[dict[str, Any]] = []
|
|
qdrant: list[tuple[str, str, dict[str, Any]]] = []
|
|
|
|
for row in chunk_rows:
|
|
chunk_id = str(row.id)
|
|
text = row.text or ""
|
|
os_docs.append(
|
|
{
|
|
"chunk_id": chunk_id,
|
|
"document_id": str(document_id),
|
|
"source_path": source_path,
|
|
"original_file_name": original_file_name,
|
|
"page_number": row.page_number,
|
|
"block_type": row.block_type,
|
|
"block_id": row.block_id,
|
|
"text": text,
|
|
"normalized_text": row.normalized_text,
|
|
"ocr_confidence": row.ocr_confidence,
|
|
"language_hint": language_hint,
|
|
"metadata": row.chunk_metadata or {},
|
|
"quality_flags": row.quality_flags or {},
|
|
"created_at": (row.created_at or datetime.now(tz=timezone.utc)).isoformat(),
|
|
}
|
|
)
|
|
text_preview = text[:512]
|
|
qdrant.append(
|
|
(
|
|
chunk_id,
|
|
text,
|
|
{
|
|
"document_id": str(document_id),
|
|
"source_path": source_path,
|
|
"original_file_name": original_file_name,
|
|
"page_number": row.page_number,
|
|
"block_type": row.block_type,
|
|
"block_id": row.block_id,
|
|
"text_preview": text_preview,
|
|
"ocr_confidence": row.ocr_confidence,
|
|
"quality_flags": row.quality_flags or {},
|
|
"metadata": row.chunk_metadata or {},
|
|
},
|
|
)
|
|
)
|
|
return os_docs, qdrant
|
|
|
|
|
|
def _ensure_artifact(db, document_id: uuid.UUID, artifact_type: str, bucket: str, key: str) -> None:
|
|
"""Thin wrapper preserving the local positional signature used inside this
|
|
module while delegating to the shared helper."""
|
|
ensure_artifact(
|
|
db,
|
|
document_id=document_id,
|
|
artifact_type=artifact_type,
|
|
bucket=bucket,
|
|
key=key,
|
|
)
|
|
|
|
|
|
def _emit_event(document_id: uuid.UUID, run_id: uuid.UUID | None, stage: str, message: str) -> None:
|
|
with session_scope() as db:
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=stage,
|
|
level="INFO",
|
|
message=message,
|
|
data={},
|
|
)
|
|
)
|
|
|
|
|
|
def _fail(
|
|
document_id: uuid.UUID, run_id: uuid.UUID | None, stage: str, message: str
|
|
) -> dict[str, Any]:
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is not None:
|
|
doc.status = stage
|
|
doc.error_message = message[:2000]
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=stage,
|
|
level="ERROR",
|
|
message=message,
|
|
data={},
|
|
)
|
|
)
|
|
logger.error("pipeline.failed", document_id=str(document_id), stage=stage, message=message)
|
|
return {"status": stage, "error": message}
|