Files
LegacyHUB/app/indexing/reranker.py
Vadim Malanov 7f72171572 chore: bootstrap repository with governance docs
Initialize git, add Apache-2.0 LICENSE, .gitattributes (LF line
endings), AGENTS.md (entry points, stack, discovery order, baseline
checks), RUNBOOK.md (dev boot, prod deploy with overlay, ingestion,
failures, rollback, scaling notes), .env.prod.example with rotated
credential placeholders, and dev-only warnings on .env.example.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 16:41:50 +03:00

76 lines
2.7 KiB
Python

"""BGE reranker - cross-encoder style scoring of (query, passage) pairs.
Designed to degrade gracefully:
- If the model fails to load, ``rerank`` returns inputs unchanged with the
``reranked`` flag set to False so the API can report the truth to clients.
"""
from __future__ import annotations
from functools import lru_cache
from typing import Sequence
from app.config import settings
from app.logging_config import get_logger
logger = get_logger(__name__)
class Reranker:
def __init__(self, model_name: str, device: str, batch_size: int) -> None:
self.model_name = model_name
self.device = device
self.batch_size = batch_size
self._impl: str | None = None
self._model = None
self._load()
def _load(self) -> None:
try:
from FlagEmbedding import FlagReranker # type: ignore
use_fp16 = self.device != "cpu"
self._model = FlagReranker(self.model_name, use_fp16=use_fp16, devices=self.device)
self._impl = "flagembedding"
logger.info("reranker.loaded", impl="flagembedding", model=self.model_name, device=self.device)
return
except Exception as exc: # noqa: BLE001
logger.warning("reranker.flagembedding_failed", error=str(exc))
try:
from sentence_transformers import CrossEncoder
self._model = CrossEncoder(self.model_name, device=self.device)
self._impl = "sentence-transformers"
logger.info("reranker.loaded", impl="sentence-transformers", model=self.model_name)
except Exception as exc: # noqa: BLE001
logger.error("reranker.disabled", error=str(exc))
self._impl = None
self._model = None
@property
def available(self) -> bool:
return self._impl is not None and self._model is not None
def score(self, query: str, passages: Sequence[str]) -> list[float]:
if not self.available or not passages:
return [0.0] * len(passages)
pairs = [(query, p) for p in passages]
if self._impl == "flagembedding":
scores = self._model.compute_score(pairs, batch_size=self.batch_size, normalize=True) # type: ignore[union-attr]
else:
scores = self._model.predict(pairs, batch_size=self.batch_size) # type: ignore[union-attr]
if not isinstance(scores, list):
try:
scores = list(scores)
except TypeError:
scores = [float(scores)]
return [float(s) for s in scores]
@lru_cache(maxsize=1)
def get_reranker() -> Reranker:
return Reranker(
model_name=settings.reranker_model,
device=settings.reranker_device,
batch_size=settings.reranker_batch_size,
)