"""Structure-aware chunking. Rules (per spec): - Chunk by document structure first, fixed-size second. - Hierarchy: title > heading > paragraph > list > table > figure caption. - Target 500-900 tokens (configurable). - Overlap 80-120 tokens for long narrative text only. - Never split tables - one table = one chunk (or one chunk per row group if huge). - Every chunk carries citation metadata. We use a deliberately simple ``len(text.split())`` token estimator. The downstream embedding model has its own tokenizer; this estimator is only a budget proxy. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any from app.config import settings from app.ingestion.docling_extractor import ( ExtractedBlock, ExtractedFigure, ExtractedTable, ExtractionResult, ) from app.ingestion.normalizer import normalize_block from app.ingestion.quality import compute_quality_flags @dataclass class ChunkRecord: chunk_index: int page_number: int block_type: str text: str normalized_text: str token_count: int block_id: str | None = None quality_flags: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) def _estimate_tokens(text: str) -> int: return max(1, len(text.split())) def chunk_extraction( extraction: ExtractionResult, *, document_ocr_confidence: float | None = None, ) -> list[ChunkRecord]: target = settings.chunk_target_tokens minimum = settings.chunk_min_tokens maximum = settings.chunk_max_tokens overlap = settings.chunk_overlap_tokens chunks: list[ChunkRecord] = [] idx = 0 # 1) Tables first - one chunk per table, never split. for t in extraction.tables: body = (t.markdown or "").strip() if not body: continue summary = _summarize_table(t) text = body if summary: text = f"{summary}\n\n{body}" display, norm = normalize_block(text) flags = compute_quality_flags( text=display, block_type="table", ocr_confidence=document_ocr_confidence, ) chunks.append( ChunkRecord( chunk_index=idx, page_number=t.page_number, block_type="table", text=display, normalized_text=norm, token_count=_estimate_tokens(display), block_id=t.block_id or f"table:{t.table_index}", quality_flags=flags, metadata={"table_index": t.table_index, "summary": summary or ""}, ) ) idx += 1 # 2) Figures - caption + placeholder description. for f in extraction.figures: text_parts: list[str] = [] if f.caption: text_parts.append(f"Caption: {f.caption}") text_parts.append(f"Figure detected on page {f.page_number}.") text = "\n".join(text_parts) block_type = "figure_caption" if f.caption else "figure_description" display, norm = normalize_block(text) flags = compute_quality_flags( text=display, block_type=block_type, ocr_confidence=document_ocr_confidence, ) chunks.append( ChunkRecord( chunk_index=idx, page_number=f.page_number, block_type=block_type, text=display, normalized_text=norm, token_count=_estimate_tokens(display), block_id=f.block_id or f"figure:{f.figure_index}", quality_flags=flags, metadata={"figure_index": f.figure_index}, ) ) idx += 1 # 3) Narrative blocks grouped per page, packed by structure. by_page: dict[int, list[ExtractedBlock]] = {} for b in extraction.blocks: by_page.setdefault(b.page_number, []).append(b) for page_no in sorted(by_page): blocks = by_page[page_no] groups = _group_by_section(blocks) for group in groups: packed = _pack_group(group, target=target, maximum=maximum, minimum=minimum) for piece in packed: text = piece["text"] btype = piece["block_type"] display, norm = normalize_block(text) flags = compute_quality_flags( text=display, block_type=btype, ocr_confidence=document_ocr_confidence, ) chunks.append( ChunkRecord( chunk_index=idx, page_number=page_no, block_type=btype, text=display, normalized_text=norm, token_count=_estimate_tokens(display), block_id=piece.get("block_id"), quality_flags=flags, metadata={"section_heading": piece.get("section") or ""}, ) ) idx += 1 # Optional overlap: only if the last piece is long narrative if overlap > 0 and packed and packed[-1]["block_type"] == "paragraph": tail = _tail_tokens(packed[-1]["text"], overlap) if tail and len(tail.split()) >= max(20, overlap // 2): # Overlap is already represented by next-group adjacency in # most legacy docs; we do not emit duplicate overlap chunks # to avoid index bloat. This is intentional per spec note # ("only for long narrative text") - left here for future tuning. pass return chunks # ---------------- Helpers ---------------- def _group_by_section(blocks: list[ExtractedBlock]) -> list[list[ExtractedBlock]]: groups: list[list[ExtractedBlock]] = [] current: list[ExtractedBlock] = [] for b in blocks: if b.block_type in ("title", "heading") and current: groups.append(current) current = [b] else: current.append(b) if current: groups.append(current) return groups def _pack_group( group: list[ExtractedBlock], *, target: int, maximum: int, minimum: int ) -> list[dict[str, Any]]: """Pack a section's blocks into chunks at most ``maximum`` tokens. Headings / titles attach to the next chunk as a section anchor. """ if not group: return [] section_heading = "" body_blocks: list[ExtractedBlock] = [] for b in group: if b.block_type in ("title", "heading"): section_heading = (section_heading + " > " + b.text).strip(" >") if section_heading else b.text else: body_blocks.append(b) if not body_blocks: # Heading-only group: emit as a single ``heading`` chunk so the title is searchable. text = section_heading or group[0].text return [ { "text": text, "block_type": "heading", "block_id": group[0].block_id, "section": section_heading, } ] out: list[dict[str, Any]] = [] buffer: list[str] = [] buffer_block_ids: list[str] = [] buffer_block_type = "paragraph" buffer_tokens = 0 def flush(): nonlocal buffer, buffer_block_ids, buffer_block_type, buffer_tokens if not buffer: return text = "\n\n".join(buffer).strip() if not text: buffer = [] buffer_block_ids = [] buffer_tokens = 0 return # Prepend section heading for context (kept short). if section_heading and len(section_heading) < 200: text = f"# {section_heading}\n\n{text}" out.append( { "text": text, "block_type": buffer_block_type, "block_id": buffer_block_ids[0] if buffer_block_ids else None, "section": section_heading, } ) buffer = [] buffer_block_ids = [] buffer_tokens = 0 for b in body_blocks: tokens = _estimate_tokens(b.text) if tokens >= maximum: # Hard split a giant block into sub-chunks of ~target tokens. flush() for sub in _split_long_text(b.text, target=target, maximum=maximum): out.append( { "text": sub, "block_type": b.block_type if b.block_type != "list" else "list", "block_id": b.block_id, "section": section_heading, } ) continue if buffer_tokens + tokens > maximum and buffer_tokens >= minimum: flush() if not buffer: buffer_block_type = b.block_type if b.block_type != "list" else "list" buffer.append(b.text) if b.block_id: buffer_block_ids.append(b.block_id) buffer_tokens += tokens if buffer_tokens >= target: flush() flush() return out def _split_long_text(text: str, *, target: int, maximum: int) -> list[str]: words = text.split() if not words: return [] pieces: list[str] = [] step = target if step <= 0: step = 500 i = 0 while i < len(words): end = min(len(words), i + maximum) # Aim for ``target`` words but extend up to ``maximum`` to reach a sentence boundary. piece = " ".join(words[i : i + step]) pieces.append(piece) i += step if end - i < target // 4 and end - i > 0: pieces[-1] = " ".join(words[i - step : end]) break return pieces def _tail_tokens(text: str, n: int) -> str: words = text.split() if len(words) <= n: return text return " ".join(words[-n:]) def _summarize_table(t: ExtractedTable) -> str: """Heuristic one-line summary for index recall.""" md = t.markdown or "" first = next((line for line in md.splitlines() if line.startswith("|")), "") header_cells = [c.strip() for c in first.strip("|").split("|") if c.strip()] n_cols = len(header_cells) n_rows = max(0, sum(1 for ln in md.splitlines() if ln.startswith("|")) - 2) header_preview = ", ".join(header_cells[:6]) return ( f"Table on page {t.page_number}: {n_rows} rows x {n_cols} cols. " f"Columns: {header_preview}." if header_cells else f"Table on page {t.page_number}." )