"""Per-document end-to-end pipeline: OCR -> Docling -> chunk -> persist -> index. Called by the Celery worker. Idempotent: re-running on the same document deletes existing chunks for that document and re-creates them, then re-indexes in OpenSearch and Qdrant. """ from __future__ import annotations import json import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any from sqlalchemy import delete, select from app.config import settings from app.db.models import ( ArtifactType, Chunk, Document, DocumentArtifact, DocumentStatus, Page, ProcessingEvent, ) from app.db.session import session_scope from app.indexing import opensearch_client, qdrant_client from app.indexing.embeddings import get_embedder from app.ingestion.chunker import ChunkRecord, chunk_extraction from app.ingestion.docling_extractor import ExtractionResult, extract from app.ingestion.figure_processor import persist_figures from app.ingestion.ocr import run_ocr from app.ingestion.table_processor import persist_tables from app.logging_config import get_logger from app.storage.local_paths import ( key_docling_json, key_markdown, key_ocr_pdf, work_dir_for, ) from app.storage.minio_client import get_storage from app.utils.language import detect_language logger = get_logger(__name__) def process_document_id(document_id: uuid.UUID, run_id: uuid.UUID | None = None) -> dict[str, Any]: """Top-level entry called by the Celery task. Wraps the pipeline in error handling so the task always either succeeds or marks the document FAILED. """ storage = get_storage() storage.ensure_buckets() with session_scope() as db: doc = db.get(Document, document_id) if doc is None: logger.warning("pipeline.document_missing", document_id=str(document_id)) return {"status": "missing"} source_path = Path(doc.source_path) sha = doc.sha256 original_artifact = db.execute( select(DocumentArtifact).where( DocumentArtifact.document_id == doc.id, DocumentArtifact.artifact_type == ArtifactType.ORIGINAL_PDF, ) ).scalar_one_or_none() work_dir = work_dir_for(document_id) local_pdf = work_dir / f"{sha}.pdf" if not local_pdf.exists(): if source_path.exists(): local_pdf.write_bytes(source_path.read_bytes()) elif original_artifact: storage.get_to_path(original_artifact.storage_bucket, original_artifact.storage_key, local_pdf) else: return _fail(document_id, run_id, "OCR_FAILED", "Original PDF not available locally or in MinIO") # ---------------- OCR ---------------- ocr_pdf = work_dir / "ocr.pdf" try: _emit_event(document_id, run_id, DocumentStatus.OCR_STARTED, "OCR started") ocr_result = run_ocr(local_pdf, ocr_pdf, languages=settings.ocr_languages) except Exception as exc: # noqa: BLE001 logger.exception("pipeline.ocr_failed", document_id=str(document_id)) return _fail(document_id, run_id, DocumentStatus.OCR_FAILED, f"OCR failed: {exc}") # Upload OCR PDF (even if we 'skipped' it - OCR PDF is the canonical input to Docling). ocr_key = key_ocr_pdf(document_id) storage.put_file( bucket=storage.derived_bucket, key=ocr_key, path=ocr_result.output_path, content_type="application/pdf", ) with session_scope() as db: _ensure_artifact(db, document_id, ArtifactType.OCR_PDF, storage.derived_bucket, ocr_key) doc = db.get(Document, document_id) if doc is not None: doc.status = DocumentStatus.OCR_COMPLETED db.add( ProcessingEvent( run_id=run_id, document_id=document_id, stage=DocumentStatus.OCR_COMPLETED, level="INFO", message=f"OCR finished ({ocr_result.reason})", data={"skipped": ocr_result.skipped, "languages": ocr_result.languages}, ) ) # ---------------- Docling ---------------- try: _emit_event(document_id, run_id, DocumentStatus.EXTRACTION_STARTED, "Docling extraction started") extraction = extract(ocr_result.output_path) except Exception as exc: # noqa: BLE001 logger.exception("pipeline.docling_failed", document_id=str(document_id)) return _fail(document_id, run_id, DocumentStatus.EXTRACTION_FAILED, f"Docling failed: {exc}") # Persist Markdown + JSON to MinIO. md_key = key_markdown(document_id) json_key = key_docling_json(document_id) storage.put_bytes( bucket=storage.derived_bucket, key=md_key, data=extraction.markdown.encode("utf-8"), content_type="text/markdown", ) storage.put_bytes( bucket=storage.derived_bucket, key=json_key, data=json.dumps(extraction.json_payload, ensure_ascii=False).encode("utf-8"), content_type="application/json", ) # ---------------- Persist pages, chunks, tables, figures ---------------- chunk_records = chunk_extraction(extraction) sample_text = "\n".join(p.text for p in extraction.pages[:3] if p.text) lang = detect_language(sample_text) with session_scope() as db: _ensure_artifact(db, document_id, ArtifactType.MARKDOWN, storage.derived_bucket, md_key) _ensure_artifact(db, document_id, ArtifactType.DOCLING_JSON, storage.derived_bucket, json_key) doc = db.get(Document, document_id) if doc is None: return {"status": "missing"} doc.status = DocumentStatus.EXTRACTION_COMPLETED if lang and not doc.language_hint: doc.language_hint = lang page_id_by_number = _upsert_pages(db, document_id, extraction) persist_tables(db, storage, document_id, extraction.tables, page_id_by_number) persist_figures(db, storage, document_id, extraction.figures, page_id_by_number) # Replace chunks idempotently: drop all and re-insert. db.execute(delete(Chunk).where(Chunk.document_id == document_id)) for cr in chunk_records: db.add(_to_chunk_row(document_id, page_id_by_number, cr)) doc.status = DocumentStatus.CHUNKING_COMPLETED db.add( ProcessingEvent( run_id=run_id, document_id=document_id, stage=DocumentStatus.CHUNKING_COMPLETED, level="INFO", message="Chunking complete", data={"chunks": len(chunk_records)}, ) ) # ---------------- Indexing (OpenSearch + Qdrant) ---------------- try: opensearch_client.ensure_index() qdrant_client.ensure_collection() opensearch_client.delete_by_document(str(document_id)) qdrant_client.delete_by_document(str(document_id)) os_docs, qdrant_points = _build_index_payloads(document_id, chunk_records, extraction, lang) if os_docs: opensearch_client.index_chunks(os_docs) if qdrant_points: embedder = get_embedder() texts_to_embed = [text for _, text, _ in qdrant_points] vectors = embedder.encode(texts_to_embed) triples = [ (chunk_id, vec, payload) for (chunk_id, _text, payload), vec in zip(qdrant_points, vectors, strict=True) ] qdrant_client.upsert_chunks(triples) except Exception as exc: # noqa: BLE001 logger.exception("pipeline.indexing_failed", document_id=str(document_id)) return _fail(document_id, run_id, DocumentStatus.FAILED, f"Indexing failed: {exc}") with session_scope() as db: doc = db.get(Document, document_id) if doc is not None: doc.status = DocumentStatus.INDEXING_COMPLETED doc.error_message = None db.add( ProcessingEvent( run_id=run_id, document_id=document_id, stage=DocumentStatus.INDEXING_COMPLETED, level="INFO", message="Indexing complete", data={"chunks": len(chunk_records)}, ) ) return {"status": DocumentStatus.INDEXING_COMPLETED, "chunks": len(chunk_records)} # ---------------- helpers ---------------- def _to_chunk_row( document_id: uuid.UUID, page_id_by_number: dict[int, uuid.UUID], cr: ChunkRecord ) -> Chunk: return Chunk( document_id=document_id, page_id=page_id_by_number.get(cr.page_number), page_number=cr.page_number, block_id=cr.block_id, chunk_index=cr.chunk_index, block_type=cr.block_type, text=cr.text, normalized_text=cr.normalized_text, token_count=cr.token_count, ocr_confidence=None, quality_flags=cr.quality_flags, chunk_metadata=cr.metadata, ) def _upsert_pages(db, document_id: uuid.UUID, extraction: ExtractionResult) -> dict[int, uuid.UUID]: existing = { p.page_number: p for p in db.execute(select(Page).where(Page.document_id == document_id)).scalars() } out: dict[int, uuid.UUID] = {} for ep in extraction.pages: page = existing.get(ep.page_number) if page is None: page = Page( document_id=document_id, page_number=ep.page_number, text=ep.text, ocr_confidence=ep.ocr_confidence, has_tables=ep.has_tables, has_figures=ep.has_figures, has_handwriting=ep.has_handwriting, ) db.add(page) db.flush() else: page.text = ep.text page.has_tables = ep.has_tables page.has_figures = ep.has_figures page.has_handwriting = ep.has_handwriting out[ep.page_number] = page.id return out def _build_index_payloads( document_id: uuid.UUID, chunks: list[ChunkRecord], extraction: ExtractionResult, language_hint: str | None, ) -> tuple[list[dict[str, Any]], list[tuple[str, str, dict[str, Any]]]]: with session_scope() as db: doc = db.get(Document, document_id) if doc is None: return [], [] original_file_name = doc.original_file_name source_path = doc.source_path chunk_rows = ( db.execute(select(Chunk).where(Chunk.document_id == document_id)) .scalars() .all() ) os_docs: list[dict[str, Any]] = [] qdrant: list[tuple[str, str, dict[str, Any]]] = [] for row in chunk_rows: chunk_id = str(row.id) text = row.text or "" os_docs.append( { "chunk_id": chunk_id, "document_id": str(document_id), "source_path": source_path, "original_file_name": original_file_name, "page_number": row.page_number, "block_type": row.block_type, "block_id": row.block_id, "text": text, "normalized_text": row.normalized_text, "ocr_confidence": row.ocr_confidence, "language_hint": language_hint, "metadata": row.chunk_metadata or {}, "quality_flags": row.quality_flags or {}, "created_at": (row.created_at or datetime.now(tz=timezone.utc)).isoformat(), } ) text_preview = text[:512] qdrant.append( ( chunk_id, text, { "document_id": str(document_id), "source_path": source_path, "original_file_name": original_file_name, "page_number": row.page_number, "block_type": row.block_type, "block_id": row.block_id, "text_preview": text_preview, "ocr_confidence": row.ocr_confidence, "quality_flags": row.quality_flags or {}, "metadata": row.chunk_metadata or {}, }, ) ) return os_docs, qdrant def _ensure_artifact(db, document_id: uuid.UUID, artifact_type: str, bucket: str, key: str) -> None: existing = db.execute( select(DocumentArtifact).where( DocumentArtifact.document_id == document_id, DocumentArtifact.storage_key == key, ) ).scalar_one_or_none() if existing: return db.add( DocumentArtifact( document_id=document_id, artifact_type=artifact_type, storage_bucket=bucket, storage_key=key, ) ) def _emit_event(document_id: uuid.UUID, run_id: uuid.UUID | None, stage: str, message: str) -> None: with session_scope() as db: db.add( ProcessingEvent( run_id=run_id, document_id=document_id, stage=stage, level="INFO", message=message, data={}, ) ) def _fail( document_id: uuid.UUID, run_id: uuid.UUID | None, stage: str, message: str ) -> dict[str, Any]: with session_scope() as db: doc = db.get(Document, document_id) if doc is not None: doc.status = stage doc.error_message = message[:2000] db.add( ProcessingEvent( run_id=run_id, document_id=document_id, stage=stage, level="ERROR", message=message, data={}, ) ) logger.error("pipeline.failed", document_id=str(document_id), stage=stage, message=message) return {"status": stage, "error": message}