Initialize git, add Apache-2.0 LICENSE, .gitattributes (LF line endings), AGENTS.md (entry points, stack, discovery order, baseline checks), RUNBOOK.md (dev boot, prod deploy with overlay, ingestion, failures, rollback, scaling notes), .env.prod.example with rotated credential placeholders, and dev-only warnings on .env.example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
385 lines
14 KiB
Python
385 lines
14 KiB
Python
"""Per-document end-to-end pipeline: OCR -> Docling -> chunk -> persist -> index.
|
|
|
|
Called by the Celery worker. Idempotent: re-running on the same document deletes
|
|
existing chunks for that document and re-creates them, then re-indexes in
|
|
OpenSearch and Qdrant.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from sqlalchemy import delete, select
|
|
|
|
from app.config import settings
|
|
from app.db.models import (
|
|
ArtifactType,
|
|
Chunk,
|
|
Document,
|
|
DocumentArtifact,
|
|
DocumentStatus,
|
|
Page,
|
|
ProcessingEvent,
|
|
)
|
|
from app.db.session import session_scope
|
|
from app.indexing import opensearch_client, qdrant_client
|
|
from app.indexing.embeddings import get_embedder
|
|
from app.ingestion.chunker import ChunkRecord, chunk_extraction
|
|
from app.ingestion.docling_extractor import ExtractionResult, extract
|
|
from app.ingestion.figure_processor import persist_figures
|
|
from app.ingestion.ocr import run_ocr
|
|
from app.ingestion.table_processor import persist_tables
|
|
from app.logging_config import get_logger
|
|
from app.storage.local_paths import (
|
|
key_docling_json,
|
|
key_markdown,
|
|
key_ocr_pdf,
|
|
work_dir_for,
|
|
)
|
|
from app.storage.minio_client import get_storage
|
|
from app.utils.language import detect_language
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def process_document_id(document_id: uuid.UUID, run_id: uuid.UUID | None = None) -> dict[str, Any]:
|
|
"""Top-level entry called by the Celery task. Wraps the pipeline in
|
|
error handling so the task always either succeeds or marks the document FAILED.
|
|
"""
|
|
storage = get_storage()
|
|
storage.ensure_buckets()
|
|
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is None:
|
|
logger.warning("pipeline.document_missing", document_id=str(document_id))
|
|
return {"status": "missing"}
|
|
|
|
source_path = Path(doc.source_path)
|
|
sha = doc.sha256
|
|
original_artifact = db.execute(
|
|
select(DocumentArtifact).where(
|
|
DocumentArtifact.document_id == doc.id,
|
|
DocumentArtifact.artifact_type == ArtifactType.ORIGINAL_PDF,
|
|
)
|
|
).scalar_one_or_none()
|
|
|
|
work_dir = work_dir_for(document_id)
|
|
local_pdf = work_dir / f"{sha}.pdf"
|
|
if not local_pdf.exists():
|
|
if source_path.exists():
|
|
local_pdf.write_bytes(source_path.read_bytes())
|
|
elif original_artifact:
|
|
storage.get_to_path(original_artifact.storage_bucket, original_artifact.storage_key, local_pdf)
|
|
else:
|
|
return _fail(document_id, run_id, "OCR_FAILED", "Original PDF not available locally or in MinIO")
|
|
|
|
# ---------------- OCR ----------------
|
|
ocr_pdf = work_dir / "ocr.pdf"
|
|
try:
|
|
_emit_event(document_id, run_id, DocumentStatus.OCR_STARTED, "OCR started")
|
|
ocr_result = run_ocr(local_pdf, ocr_pdf, languages=settings.ocr_languages)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("pipeline.ocr_failed", document_id=str(document_id))
|
|
return _fail(document_id, run_id, DocumentStatus.OCR_FAILED, f"OCR failed: {exc}")
|
|
|
|
# Upload OCR PDF (even if we 'skipped' it - OCR PDF is the canonical input to Docling).
|
|
ocr_key = key_ocr_pdf(document_id)
|
|
storage.put_file(
|
|
bucket=storage.derived_bucket,
|
|
key=ocr_key,
|
|
path=ocr_result.output_path,
|
|
content_type="application/pdf",
|
|
)
|
|
with session_scope() as db:
|
|
_ensure_artifact(db, document_id, ArtifactType.OCR_PDF, storage.derived_bucket, ocr_key)
|
|
doc = db.get(Document, document_id)
|
|
if doc is not None:
|
|
doc.status = DocumentStatus.OCR_COMPLETED
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=DocumentStatus.OCR_COMPLETED,
|
|
level="INFO",
|
|
message=f"OCR finished ({ocr_result.reason})",
|
|
data={"skipped": ocr_result.skipped, "languages": ocr_result.languages},
|
|
)
|
|
)
|
|
|
|
# ---------------- Docling ----------------
|
|
try:
|
|
_emit_event(document_id, run_id, DocumentStatus.EXTRACTION_STARTED, "Docling extraction started")
|
|
extraction = extract(ocr_result.output_path)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("pipeline.docling_failed", document_id=str(document_id))
|
|
return _fail(document_id, run_id, DocumentStatus.EXTRACTION_FAILED, f"Docling failed: {exc}")
|
|
|
|
# Persist Markdown + JSON to MinIO.
|
|
md_key = key_markdown(document_id)
|
|
json_key = key_docling_json(document_id)
|
|
storage.put_bytes(
|
|
bucket=storage.derived_bucket,
|
|
key=md_key,
|
|
data=extraction.markdown.encode("utf-8"),
|
|
content_type="text/markdown",
|
|
)
|
|
storage.put_bytes(
|
|
bucket=storage.derived_bucket,
|
|
key=json_key,
|
|
data=json.dumps(extraction.json_payload, ensure_ascii=False).encode("utf-8"),
|
|
content_type="application/json",
|
|
)
|
|
|
|
# ---------------- Persist pages, chunks, tables, figures ----------------
|
|
chunk_records = chunk_extraction(extraction)
|
|
sample_text = "\n".join(p.text for p in extraction.pages[:3] if p.text)
|
|
lang = detect_language(sample_text)
|
|
|
|
with session_scope() as db:
|
|
_ensure_artifact(db, document_id, ArtifactType.MARKDOWN, storage.derived_bucket, md_key)
|
|
_ensure_artifact(db, document_id, ArtifactType.DOCLING_JSON, storage.derived_bucket, json_key)
|
|
|
|
doc = db.get(Document, document_id)
|
|
if doc is None:
|
|
return {"status": "missing"}
|
|
doc.status = DocumentStatus.EXTRACTION_COMPLETED
|
|
if lang and not doc.language_hint:
|
|
doc.language_hint = lang
|
|
|
|
page_id_by_number = _upsert_pages(db, document_id, extraction)
|
|
persist_tables(db, storage, document_id, extraction.tables, page_id_by_number)
|
|
persist_figures(db, storage, document_id, extraction.figures, page_id_by_number)
|
|
|
|
# Replace chunks idempotently: drop all and re-insert.
|
|
db.execute(delete(Chunk).where(Chunk.document_id == document_id))
|
|
for cr in chunk_records:
|
|
db.add(_to_chunk_row(document_id, page_id_by_number, cr))
|
|
|
|
doc.status = DocumentStatus.CHUNKING_COMPLETED
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=DocumentStatus.CHUNKING_COMPLETED,
|
|
level="INFO",
|
|
message="Chunking complete",
|
|
data={"chunks": len(chunk_records)},
|
|
)
|
|
)
|
|
|
|
# ---------------- Indexing (OpenSearch + Qdrant) ----------------
|
|
try:
|
|
opensearch_client.ensure_index()
|
|
qdrant_client.ensure_collection()
|
|
opensearch_client.delete_by_document(str(document_id))
|
|
qdrant_client.delete_by_document(str(document_id))
|
|
|
|
os_docs, qdrant_points = _build_index_payloads(document_id, chunk_records, extraction, lang)
|
|
if os_docs:
|
|
opensearch_client.index_chunks(os_docs)
|
|
if qdrant_points:
|
|
embedder = get_embedder()
|
|
texts_to_embed = [text for _, text, _ in qdrant_points]
|
|
vectors = embedder.encode(texts_to_embed)
|
|
triples = [
|
|
(chunk_id, vec, payload)
|
|
for (chunk_id, _text, payload), vec in zip(qdrant_points, vectors, strict=True)
|
|
]
|
|
qdrant_client.upsert_chunks(triples)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("pipeline.indexing_failed", document_id=str(document_id))
|
|
return _fail(document_id, run_id, DocumentStatus.FAILED, f"Indexing failed: {exc}")
|
|
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is not None:
|
|
doc.status = DocumentStatus.INDEXING_COMPLETED
|
|
doc.error_message = None
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=DocumentStatus.INDEXING_COMPLETED,
|
|
level="INFO",
|
|
message="Indexing complete",
|
|
data={"chunks": len(chunk_records)},
|
|
)
|
|
)
|
|
|
|
return {"status": DocumentStatus.INDEXING_COMPLETED, "chunks": len(chunk_records)}
|
|
|
|
|
|
# ---------------- helpers ----------------
|
|
|
|
def _to_chunk_row(
|
|
document_id: uuid.UUID, page_id_by_number: dict[int, uuid.UUID], cr: ChunkRecord
|
|
) -> Chunk:
|
|
return Chunk(
|
|
document_id=document_id,
|
|
page_id=page_id_by_number.get(cr.page_number),
|
|
page_number=cr.page_number,
|
|
block_id=cr.block_id,
|
|
chunk_index=cr.chunk_index,
|
|
block_type=cr.block_type,
|
|
text=cr.text,
|
|
normalized_text=cr.normalized_text,
|
|
token_count=cr.token_count,
|
|
ocr_confidence=None,
|
|
quality_flags=cr.quality_flags,
|
|
chunk_metadata=cr.metadata,
|
|
)
|
|
|
|
|
|
def _upsert_pages(db, document_id: uuid.UUID, extraction: ExtractionResult) -> dict[int, uuid.UUID]:
|
|
existing = {
|
|
p.page_number: p
|
|
for p in db.execute(select(Page).where(Page.document_id == document_id)).scalars()
|
|
}
|
|
out: dict[int, uuid.UUID] = {}
|
|
for ep in extraction.pages:
|
|
page = existing.get(ep.page_number)
|
|
if page is None:
|
|
page = Page(
|
|
document_id=document_id,
|
|
page_number=ep.page_number,
|
|
text=ep.text,
|
|
ocr_confidence=ep.ocr_confidence,
|
|
has_tables=ep.has_tables,
|
|
has_figures=ep.has_figures,
|
|
has_handwriting=ep.has_handwriting,
|
|
)
|
|
db.add(page)
|
|
db.flush()
|
|
else:
|
|
page.text = ep.text
|
|
page.has_tables = ep.has_tables
|
|
page.has_figures = ep.has_figures
|
|
page.has_handwriting = ep.has_handwriting
|
|
out[ep.page_number] = page.id
|
|
return out
|
|
|
|
|
|
def _build_index_payloads(
|
|
document_id: uuid.UUID,
|
|
chunks: list[ChunkRecord],
|
|
extraction: ExtractionResult,
|
|
language_hint: str | None,
|
|
) -> tuple[list[dict[str, Any]], list[tuple[str, str, dict[str, Any]]]]:
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is None:
|
|
return [], []
|
|
original_file_name = doc.original_file_name
|
|
source_path = doc.source_path
|
|
|
|
chunk_rows = (
|
|
db.execute(select(Chunk).where(Chunk.document_id == document_id))
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
os_docs: list[dict[str, Any]] = []
|
|
qdrant: list[tuple[str, str, dict[str, Any]]] = []
|
|
|
|
for row in chunk_rows:
|
|
chunk_id = str(row.id)
|
|
text = row.text or ""
|
|
os_docs.append(
|
|
{
|
|
"chunk_id": chunk_id,
|
|
"document_id": str(document_id),
|
|
"source_path": source_path,
|
|
"original_file_name": original_file_name,
|
|
"page_number": row.page_number,
|
|
"block_type": row.block_type,
|
|
"block_id": row.block_id,
|
|
"text": text,
|
|
"normalized_text": row.normalized_text,
|
|
"ocr_confidence": row.ocr_confidence,
|
|
"language_hint": language_hint,
|
|
"metadata": row.chunk_metadata or {},
|
|
"quality_flags": row.quality_flags or {},
|
|
"created_at": (row.created_at or datetime.now(tz=timezone.utc)).isoformat(),
|
|
}
|
|
)
|
|
text_preview = text[:512]
|
|
qdrant.append(
|
|
(
|
|
chunk_id,
|
|
text,
|
|
{
|
|
"document_id": str(document_id),
|
|
"source_path": source_path,
|
|
"original_file_name": original_file_name,
|
|
"page_number": row.page_number,
|
|
"block_type": row.block_type,
|
|
"block_id": row.block_id,
|
|
"text_preview": text_preview,
|
|
"ocr_confidence": row.ocr_confidence,
|
|
"quality_flags": row.quality_flags or {},
|
|
"metadata": row.chunk_metadata or {},
|
|
},
|
|
)
|
|
)
|
|
return os_docs, qdrant
|
|
|
|
|
|
def _ensure_artifact(db, document_id: uuid.UUID, artifact_type: str, bucket: str, key: str) -> None:
|
|
existing = db.execute(
|
|
select(DocumentArtifact).where(
|
|
DocumentArtifact.document_id == document_id,
|
|
DocumentArtifact.storage_key == key,
|
|
)
|
|
).scalar_one_or_none()
|
|
if existing:
|
|
return
|
|
db.add(
|
|
DocumentArtifact(
|
|
document_id=document_id,
|
|
artifact_type=artifact_type,
|
|
storage_bucket=bucket,
|
|
storage_key=key,
|
|
)
|
|
)
|
|
|
|
|
|
def _emit_event(document_id: uuid.UUID, run_id: uuid.UUID | None, stage: str, message: str) -> None:
|
|
with session_scope() as db:
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=stage,
|
|
level="INFO",
|
|
message=message,
|
|
data={},
|
|
)
|
|
)
|
|
|
|
|
|
def _fail(
|
|
document_id: uuid.UUID, run_id: uuid.UUID | None, stage: str, message: str
|
|
) -> dict[str, Any]:
|
|
with session_scope() as db:
|
|
doc = db.get(Document, document_id)
|
|
if doc is not None:
|
|
doc.status = stage
|
|
doc.error_message = message[:2000]
|
|
db.add(
|
|
ProcessingEvent(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
stage=stage,
|
|
level="ERROR",
|
|
message=message,
|
|
data={},
|
|
)
|
|
)
|
|
logger.error("pipeline.failed", document_id=str(document_id), stage=stage, message=message)
|
|
return {"status": stage, "error": message}
|