"""BGE reranker - cross-encoder style scoring of (query, passage) pairs. Designed to degrade gracefully: - If the model fails to load, ``rerank`` returns inputs unchanged with the ``reranked`` flag set to False so the API can report the truth to clients. """ from __future__ import annotations from functools import lru_cache from typing import Sequence from app.config import settings from app.logging_config import get_logger logger = get_logger(__name__) class Reranker: def __init__(self, model_name: str, device: str, batch_size: int) -> None: self.model_name = model_name self.device = device self.batch_size = batch_size self._impl: str | None = None self._model = None self._load() def _load(self) -> None: try: from FlagEmbedding import FlagReranker # type: ignore use_fp16 = self.device != "cpu" self._model = FlagReranker(self.model_name, use_fp16=use_fp16, devices=self.device) self._impl = "flagembedding" logger.info("reranker.loaded", impl="flagembedding", model=self.model_name, device=self.device) return except Exception as exc: # noqa: BLE001 logger.warning("reranker.flagembedding_failed", error=str(exc)) try: from sentence_transformers import CrossEncoder self._model = CrossEncoder(self.model_name, device=self.device) self._impl = "sentence-transformers" logger.info("reranker.loaded", impl="sentence-transformers", model=self.model_name) except Exception as exc: # noqa: BLE001 logger.error("reranker.disabled", error=str(exc)) self._impl = None self._model = None @property def available(self) -> bool: return self._impl is not None and self._model is not None # bge-reranker-v2-m3 is trained at 512 tokens; we truncate by chars so the # reranker stays inside its budget even when callers forget to limit the # candidate text length. _MAX_PASSAGE_CHARS = 2048 def score(self, query: str, passages: Sequence[str]) -> list[float]: if not self.available or not passages: return [0.0] * len(passages) clipped = [p[: self._MAX_PASSAGE_CHARS] for p in passages] pairs = [(query, p) for p in clipped] if self._impl == "flagembedding": scores = self._model.compute_score(pairs, batch_size=self.batch_size, normalize=True) # type: ignore[union-attr] else: scores = self._model.predict(pairs, batch_size=self.batch_size) # type: ignore[union-attr] if not isinstance(scores, list): try: scores = list(scores) except TypeError: scores = [float(scores)] return [float(s) for s in scores] @lru_cache(maxsize=1) def get_reranker() -> Reranker: return Reranker( model_name=settings.reranker_model, device=settings.reranker_device, batch_size=settings.reranker_batch_size, )