Initialize git, add Apache-2.0 LICENSE, .gitattributes (LF line endings), AGENTS.md (entry points, stack, discovery order, baseline checks), RUNBOOK.md (dev boot, prod deploy with overlay, ingestion, failures, rollback, scaling notes), .env.prod.example with rotated credential placeholders, and dev-only warnings on .env.example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
76 lines
2.7 KiB
Python
76 lines
2.7 KiB
Python
"""BGE reranker - cross-encoder style scoring of (query, passage) pairs.
|
|
|
|
Designed to degrade gracefully:
|
|
- If the model fails to load, ``rerank`` returns inputs unchanged with the
|
|
``reranked`` flag set to False so the API can report the truth to clients.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
from typing import Sequence
|
|
|
|
from app.config import settings
|
|
from app.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class Reranker:
|
|
def __init__(self, model_name: str, device: str, batch_size: int) -> None:
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self.batch_size = batch_size
|
|
self._impl: str | None = None
|
|
self._model = None
|
|
self._load()
|
|
|
|
def _load(self) -> None:
|
|
try:
|
|
from FlagEmbedding import FlagReranker # type: ignore
|
|
use_fp16 = self.device != "cpu"
|
|
self._model = FlagReranker(self.model_name, use_fp16=use_fp16, devices=self.device)
|
|
self._impl = "flagembedding"
|
|
logger.info("reranker.loaded", impl="flagembedding", model=self.model_name, device=self.device)
|
|
return
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("reranker.flagembedding_failed", error=str(exc))
|
|
|
|
try:
|
|
from sentence_transformers import CrossEncoder
|
|
self._model = CrossEncoder(self.model_name, device=self.device)
|
|
self._impl = "sentence-transformers"
|
|
logger.info("reranker.loaded", impl="sentence-transformers", model=self.model_name)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("reranker.disabled", error=str(exc))
|
|
self._impl = None
|
|
self._model = None
|
|
|
|
@property
|
|
def available(self) -> bool:
|
|
return self._impl is not None and self._model is not None
|
|
|
|
def score(self, query: str, passages: Sequence[str]) -> list[float]:
|
|
if not self.available or not passages:
|
|
return [0.0] * len(passages)
|
|
pairs = [(query, p) for p in passages]
|
|
if self._impl == "flagembedding":
|
|
scores = self._model.compute_score(pairs, batch_size=self.batch_size, normalize=True) # type: ignore[union-attr]
|
|
else:
|
|
scores = self._model.predict(pairs, batch_size=self.batch_size) # type: ignore[union-attr]
|
|
if not isinstance(scores, list):
|
|
try:
|
|
scores = list(scores)
|
|
except TypeError:
|
|
scores = [float(scores)]
|
|
return [float(s) for s in scores]
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_reranker() -> Reranker:
|
|
return Reranker(
|
|
model_name=settings.reranker_model,
|
|
device=settings.reranker_device,
|
|
batch_size=settings.reranker_batch_size,
|
|
)
|