- scripts/benchmark_reranker.py exercises the configured reranker with synthetic queries or live OpenSearch samples and prints p50/p95/p99 latency, mean latency, and pairs/sec throughput. Supports --warmup, --candidates, --passage-length, --source, and a --json-only mode for CI. - app/indexing/reranker.py clips passages to 2048 characters before scoring so a runaway chunk cannot starve the cross-encoder beyond bge-reranker-v2-m3's training window. - RUNBOOK.md gains a Reranker benchmark section with CPU/GPU SLO targets and a remediation ladder (lower top-K, raise batch size, switch device, disable reranker) when measured p95 exceeds budget. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
"""BGE reranker - cross-encoder style scoring of (query, passage) pairs.
|
|
|
|
Designed to degrade gracefully:
|
|
- If the model fails to load, ``rerank`` returns inputs unchanged with the
|
|
``reranked`` flag set to False so the API can report the truth to clients.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
from typing import Sequence
|
|
|
|
from app.config import settings
|
|
from app.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class Reranker:
|
|
def __init__(self, model_name: str, device: str, batch_size: int) -> None:
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self.batch_size = batch_size
|
|
self._impl: str | None = None
|
|
self._model = None
|
|
self._load()
|
|
|
|
def _load(self) -> None:
|
|
try:
|
|
from FlagEmbedding import FlagReranker # type: ignore
|
|
use_fp16 = self.device != "cpu"
|
|
self._model = FlagReranker(self.model_name, use_fp16=use_fp16, devices=self.device)
|
|
self._impl = "flagembedding"
|
|
logger.info("reranker.loaded", impl="flagembedding", model=self.model_name, device=self.device)
|
|
return
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("reranker.flagembedding_failed", error=str(exc))
|
|
|
|
try:
|
|
from sentence_transformers import CrossEncoder
|
|
self._model = CrossEncoder(self.model_name, device=self.device)
|
|
self._impl = "sentence-transformers"
|
|
logger.info("reranker.loaded", impl="sentence-transformers", model=self.model_name)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("reranker.disabled", error=str(exc))
|
|
self._impl = None
|
|
self._model = None
|
|
|
|
@property
|
|
def available(self) -> bool:
|
|
return self._impl is not None and self._model is not None
|
|
|
|
# bge-reranker-v2-m3 is trained at 512 tokens; we truncate by chars so the
|
|
# reranker stays inside its budget even when callers forget to limit the
|
|
# candidate text length.
|
|
_MAX_PASSAGE_CHARS = 2048
|
|
|
|
def score(self, query: str, passages: Sequence[str]) -> list[float]:
|
|
if not self.available or not passages:
|
|
return [0.0] * len(passages)
|
|
clipped = [p[: self._MAX_PASSAGE_CHARS] for p in passages]
|
|
pairs = [(query, p) for p in clipped]
|
|
if self._impl == "flagembedding":
|
|
scores = self._model.compute_score(pairs, batch_size=self.batch_size, normalize=True) # type: ignore[union-attr]
|
|
else:
|
|
scores = self._model.predict(pairs, batch_size=self.batch_size) # type: ignore[union-attr]
|
|
if not isinstance(scores, list):
|
|
try:
|
|
scores = list(scores)
|
|
except TypeError:
|
|
scores = [float(scores)]
|
|
return [float(s) for s in scores]
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_reranker() -> Reranker:
|
|
return Reranker(
|
|
model_name=settings.reranker_model,
|
|
device=settings.reranker_device,
|
|
batch_size=settings.reranker_batch_size,
|
|
)
|