Files
LegacyHUB/app/indexing/reranker.py
Vadim Malanov 349f4ea838 perf(reranker): add benchmark harness and passage clipping
- scripts/benchmark_reranker.py exercises the configured reranker
  with synthetic queries or live OpenSearch samples and prints
  p50/p95/p99 latency, mean latency, and pairs/sec throughput.
  Supports --warmup, --candidates, --passage-length, --source, and a
  --json-only mode for CI.
- app/indexing/reranker.py clips passages to 2048 characters before
  scoring so a runaway chunk cannot starve the cross-encoder beyond
  bge-reranker-v2-m3's training window.
- RUNBOOK.md gains a Reranker benchmark section with CPU/GPU SLO
  targets and a remediation ladder (lower top-K, raise batch size,
  switch device, disable reranker) when measured p95 exceeds budget.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 17:08:04 +03:00

82 lines
3.0 KiB
Python

"""BGE reranker - cross-encoder style scoring of (query, passage) pairs.
Designed to degrade gracefully:
- If the model fails to load, ``rerank`` returns inputs unchanged with the
``reranked`` flag set to False so the API can report the truth to clients.
"""
from __future__ import annotations
from functools import lru_cache
from typing import Sequence
from app.config import settings
from app.logging_config import get_logger
logger = get_logger(__name__)
class Reranker:
def __init__(self, model_name: str, device: str, batch_size: int) -> None:
self.model_name = model_name
self.device = device
self.batch_size = batch_size
self._impl: str | None = None
self._model = None
self._load()
def _load(self) -> None:
try:
from FlagEmbedding import FlagReranker # type: ignore
use_fp16 = self.device != "cpu"
self._model = FlagReranker(self.model_name, use_fp16=use_fp16, devices=self.device)
self._impl = "flagembedding"
logger.info("reranker.loaded", impl="flagembedding", model=self.model_name, device=self.device)
return
except Exception as exc: # noqa: BLE001
logger.warning("reranker.flagembedding_failed", error=str(exc))
try:
from sentence_transformers import CrossEncoder
self._model = CrossEncoder(self.model_name, device=self.device)
self._impl = "sentence-transformers"
logger.info("reranker.loaded", impl="sentence-transformers", model=self.model_name)
except Exception as exc: # noqa: BLE001
logger.error("reranker.disabled", error=str(exc))
self._impl = None
self._model = None
@property
def available(self) -> bool:
return self._impl is not None and self._model is not None
# bge-reranker-v2-m3 is trained at 512 tokens; we truncate by chars so the
# reranker stays inside its budget even when callers forget to limit the
# candidate text length.
_MAX_PASSAGE_CHARS = 2048
def score(self, query: str, passages: Sequence[str]) -> list[float]:
if not self.available or not passages:
return [0.0] * len(passages)
clipped = [p[: self._MAX_PASSAGE_CHARS] for p in passages]
pairs = [(query, p) for p in clipped]
if self._impl == "flagembedding":
scores = self._model.compute_score(pairs, batch_size=self.batch_size, normalize=True) # type: ignore[union-attr]
else:
scores = self._model.predict(pairs, batch_size=self.batch_size) # type: ignore[union-attr]
if not isinstance(scores, list):
try:
scores = list(scores)
except TypeError:
scores = [float(scores)]
return [float(s) for s in scores]
@lru_cache(maxsize=1)
def get_reranker() -> Reranker:
return Reranker(
model_name=settings.reranker_model,
device=settings.reranker_device,
batch_size=settings.reranker_batch_size,
)