"""Hybrid search: lexical (OpenSearch BM25) + semantic (Qdrant) + RRF + reranker. Always returns ``SearchResponse`` (never throws on missing index/collection - empty results are valid). """ from __future__ import annotations import uuid from collections import defaultdict from dataclasses import dataclass from typing import Any from qdrant_client.http import models as qm from app.api.schemas import ( Citation, SearchFilters, SearchHit, SearchMode, SearchRequest, SearchResponse, ) from app.config import settings from app.indexing.embeddings import get_embedder from app.indexing.opensearch_client import get_opensearch from app.indexing.qdrant_client import DENSE_VECTOR_NAME, get_qdrant from app.indexing.reranker import get_reranker from app.logging_config import get_logger from app.utils.text_cleaning import normalize_for_search logger = get_logger(__name__) @dataclass class _Candidate: chunk_id: str document_id: str page_number: int block_type: str block_id: str | None text: str source_path: str original_file_name: str quality_flags: dict[str, Any] metadata: dict[str, Any] bm25_score: float | None = None bm25_rank: int | None = None dense_score: float | None = None dense_rank: int | None = None def run_search(req: SearchRequest) -> SearchResponse: mode: SearchMode = req.search_mode filters = req.filters lexical: list[_Candidate] = [] semantic: list[_Candidate] = [] if mode in ("lexical", "hybrid"): try: lexical = _lexical_search(req.query, filters, settings.hybrid_opensearch_top_k) except Exception as exc: # noqa: BLE001 logger.warning("search.lexical_failed", error=str(exc)) if mode in ("semantic", "hybrid"): try: semantic = _semantic_search(req.query, filters, settings.hybrid_qdrant_top_k) except Exception as exc: # noqa: BLE001 logger.warning("search.semantic_failed", error=str(exc)) merged = _merge(lexical, semantic, mode) candidates = merged[: settings.rerank_candidates] reranker = get_reranker() reranked_flag = False if settings.reranker_enabled and reranker.available and candidates: scores = reranker.score(req.query, [c.text for c in candidates]) for c, s in zip(candidates, scores, strict=True): c.dense_score = s candidates.sort(key=lambda c: (c.dense_score or 0.0), reverse=True) reranked_flag = True final = candidates[: req.limit] hits: list[SearchHit] = [] for rank, c in enumerate(final, start=1): score = ( c.dense_score if reranked_flag else (c.dense_score if mode == "semantic" else c.bm25_score) or 0.0 ) hits.append( SearchHit( rank=rank, score=float(score), document_id=uuid.UUID(c.document_id), chunk_id=uuid.UUID(c.chunk_id), original_file_name=c.original_file_name, source_path=c.source_path, page_number=c.page_number, block_type=c.block_type, text=c.text, citation=Citation( pdf=c.original_file_name, page=c.page_number, block_id=c.block_id, table_id=str(c.metadata.get("table_index")) if c.metadata.get("table_index") is not None else None, figure_id=str(c.metadata.get("figure_index")) if c.metadata.get("figure_index") is not None else None, ), quality_flags=c.quality_flags, metadata=c.metadata, ) ) return SearchResponse( query=req.query, mode=mode, total_candidates=len(merged), reranked=reranked_flag, results=hits, ) # ---------------- lexical ---------------- def _lexical_search(query: str, filters: SearchFilters, top_k: int) -> list[_Candidate]: client = get_opensearch() if not client.indices.exists(index=settings.opensearch_index_chunks): return [] must = [ { "multi_match": { "query": query, "fields": ["text^1.0", "text.ru^1.5", "text.en^1.5", "normalized_text^0.7"], "type": "best_fields", "operator": "or", } } ] norm = normalize_for_search(query) if norm and norm != query.lower(): must.append({"match": {"normalized_text": {"query": norm, "boost": 0.5}}}) filter_clauses = _opensearch_filters(filters) body = { "size": top_k, "query": {"bool": {"must": must, "filter": filter_clauses}}, "_source": [ "chunk_id", "document_id", "source_path", "original_file_name", "page_number", "block_type", "block_id", "text", "quality_flags", "metadata", ], } res = client.search(index=settings.opensearch_index_chunks, body=body, request_timeout=30) out: list[_Candidate] = [] for rank, hit in enumerate(res.get("hits", {}).get("hits", []), start=1): s = hit.get("_source", {}) out.append( _Candidate( chunk_id=s["chunk_id"], document_id=s["document_id"], page_number=int(s.get("page_number", 0)), block_type=s.get("block_type", "paragraph"), block_id=s.get("block_id"), text=s.get("text", ""), source_path=s.get("source_path", ""), original_file_name=s.get("original_file_name", ""), quality_flags=s.get("quality_flags") or {}, metadata=s.get("metadata") or {}, bm25_score=float(hit.get("_score") or 0.0), bm25_rank=rank, ) ) return out def _opensearch_filters(filters: SearchFilters) -> list[dict[str, Any]]: clauses: list[dict[str, Any]] = [] if filters.document_id: clauses.append({"term": {"document_id": str(filters.document_id)}}) if filters.source_path: clauses.append({"term": {"source_path": filters.source_path}}) if filters.block_type: clauses.append({"term": {"block_type": filters.block_type}}) if filters.min_ocr_confidence is not None: clauses.append({"range": {"ocr_confidence": {"gte": filters.min_ocr_confidence}}}) return clauses # ---------------- semantic ---------------- def _semantic_search(query: str, filters: SearchFilters, top_k: int) -> list[_Candidate]: embedder = get_embedder() vector = embedder.encode_one(query) qf = _qdrant_filter(filters) client = get_qdrant() try: results = client.query_points( collection_name=settings.qdrant_collection_chunks, query=vector, using=DENSE_VECTOR_NAME, limit=top_k, with_payload=True, query_filter=qf, ).points except Exception as exc: # noqa: BLE001 logger.debug("qdrant.query_points_fallback", error=str(exc)) results = client.search( collection_name=settings.qdrant_collection_chunks, query_vector=(DENSE_VECTOR_NAME, vector), query_filter=qf, limit=top_k, with_payload=True, ) out: list[_Candidate] = [] for rank, p in enumerate(results, start=1): payload = p.payload or {} chunk_id = payload.get("chunk_id") or str(p.id) out.append( _Candidate( chunk_id=str(chunk_id), document_id=str(payload.get("document_id", "")), page_number=int(payload.get("page_number") or 0), block_type=payload.get("block_type", "paragraph"), block_id=payload.get("block_id"), text=payload.get("text_preview", ""), source_path=payload.get("source_path", ""), original_file_name=payload.get("original_file_name", ""), quality_flags=payload.get("quality_flags") or {}, metadata=payload.get("metadata") or {}, dense_score=float(p.score or 0.0), dense_rank=rank, ) ) return out def _qdrant_filter(filters: SearchFilters) -> qm.Filter | None: must: list[qm.FieldCondition | qm.Range] = [] if filters.document_id: must.append(qm.FieldCondition(key="document_id", match=qm.MatchValue(value=str(filters.document_id)))) if filters.source_path: must.append(qm.FieldCondition(key="source_path", match=qm.MatchValue(value=filters.source_path))) if filters.block_type: must.append(qm.FieldCondition(key="block_type", match=qm.MatchValue(value=filters.block_type))) if filters.min_ocr_confidence is not None: must.append(qm.FieldCondition(key="ocr_confidence", range=qm.Range(gte=filters.min_ocr_confidence))) if not must: return None return qm.Filter(must=must) # ---------------- merge ---------------- def _merge(lexical: list[_Candidate], semantic: list[_Candidate], mode: SearchMode) -> list[_Candidate]: if mode == "lexical": return lexical if mode == "semantic": return _hydrate_semantic_text(semantic) by_id: dict[str, _Candidate] = {} for c in lexical: by_id[c.chunk_id] = c for c in semantic: if c.chunk_id in by_id: by_id[c.chunk_id].dense_score = c.dense_score by_id[c.chunk_id].dense_rank = c.dense_rank if not by_id[c.chunk_id].text: by_id[c.chunk_id].text = c.text else: by_id[c.chunk_id] = c rrf: dict[str, float] = defaultdict(float) k = settings.hybrid_rrf_k for c in lexical: if c.bm25_rank is not None: rrf[c.chunk_id] += 1.0 / (k + c.bm25_rank) for c in semantic: if c.dense_rank is not None: rrf[c.chunk_id] += 1.0 / (k + c.dense_rank) items = sorted(by_id.values(), key=lambda c: rrf.get(c.chunk_id, 0.0), reverse=True) return _hydrate_full_text(items) def _hydrate_full_text(candidates: list[_Candidate]) -> list[_Candidate]: """For candidates whose text came only from Qdrant payload (preview), pull the full chunk text from OpenSearch by id so the reranker sees full content. """ missing = [c for c in candidates if len(c.text) <= 512] if not missing: return candidates client = get_opensearch() ids = [c.chunk_id for c in missing] try: res = client.mget(index=settings.opensearch_index_chunks, body={"ids": ids}) except Exception: return candidates by_id = {d["_id"]: d.get("_source", {}) for d in res.get("docs", []) if d.get("found")} for c in missing: s = by_id.get(c.chunk_id) if s and s.get("text"): c.text = s["text"] if not c.original_file_name: c.original_file_name = s.get("original_file_name", "") if not c.source_path: c.source_path = s.get("source_path", "") if not c.metadata: c.metadata = s.get("metadata") or {} if not c.quality_flags: c.quality_flags = s.get("quality_flags") or {} return candidates def _hydrate_semantic_text(candidates: list[_Candidate]) -> list[_Candidate]: return _hydrate_full_text(candidates)