Files
LegacyHUB/scripts/benchmark_reranker.py
Vadim Malanov 349f4ea838 perf(reranker): add benchmark harness and passage clipping
- scripts/benchmark_reranker.py exercises the configured reranker
  with synthetic queries or live OpenSearch samples and prints
  p50/p95/p99 latency, mean latency, and pairs/sec throughput.
  Supports --warmup, --candidates, --passage-length, --source, and a
  --json-only mode for CI.
- app/indexing/reranker.py clips passages to 2048 characters before
  scoring so a runaway chunk cannot starve the cross-encoder beyond
  bge-reranker-v2-m3's training window.
- RUNBOOK.md gains a Reranker benchmark section with CPU/GPU SLO
  targets and a remediation ladder (lower top-K, raise batch size,
  switch device, disable reranker) when measured p95 exceeds budget.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 17:08:04 +03:00

196 lines
6.8 KiB
Python

"""Reranker latency / throughput benchmark.
Measures BGE-reranker-v2-m3 (or whatever ``RERANKER_MODEL`` resolves to)
against synthetic or live corpus passages and prints the standard set of
percentiles plus throughput. Use this on staging hardware to verify whether
the configured device meets the latency budget before committing to a target
top-K.
Usage:
# 1) synthetic warm-up (no DB / OpenSearch needed)
python scripts/benchmark_reranker.py --queries 32 --candidates 40 \
--passage-length 700 --warmup 4
# 2) live corpus pull (samples real chunks from OpenSearch)
python scripts/benchmark_reranker.py --source opensearch \
--query "ГОСТ 21.501-93" \
--candidates 40
Outputs JSON to stdout and a markdown summary table.
"""
from __future__ import annotations
import argparse
import json
import statistics
import sys
import time
from dataclasses import asdict, dataclass
from app.config import settings
from app.indexing.reranker import get_reranker
from app.logging_config import configure_logging, get_logger
configure_logging()
logger = get_logger(__name__)
@dataclass
class BenchResult:
model: str
device: str
queries: int
candidates_per_query: int
passage_chars: int
warmup: int
p50_ms: float
p95_ms: float
p99_ms: float
mean_ms: float
pairs_per_sec: float
wall_seconds: float
def percentile(values: list[float], q: float) -> float:
if not values:
return 0.0
s = sorted(values)
idx = max(0, min(len(s) - 1, int(round((q / 100.0) * (len(s) - 1)))))
return s[idx]
def synthetic_passages(n: int, chars: int) -> list[str]:
seed = "ГОСТ 21.501-93 определяет правила выполнения архитектурно-строительных рабочих чертежей. "
base = (seed * ((chars // len(seed)) + 2))[:chars]
return [f"[{i}] {base}" for i in range(n)]
def synthetic_queries(n: int) -> list[str]:
samples = [
"ГОСТ 21.501-93 рабочие чертежи",
"класс бетона B25",
"журнал ремонтов узлов",
"правила производства земляных работ",
"схема электропитания корпус 3",
"контроль качества сварных соединений",
"регламент технического обслуживания",
]
return [samples[i % len(samples)] for i in range(n)]
def passages_from_opensearch(query: str, top_k: int) -> list[str]:
from app.indexing.opensearch_client import get_opensearch
res = get_opensearch().search(
index=settings.opensearch_index_chunks,
body={
"size": top_k,
"query": {"multi_match": {"query": query, "fields": ["text", "text.ru", "text.en"]}},
"_source": ["text"],
},
request_timeout=30,
)
return [h["_source"]["text"] for h in res["hits"]["hits"] if h["_source"].get("text")]
def run(
queries: list[str],
candidates_per_query: int,
passage_chars: int,
warmup: int,
source: str,
) -> BenchResult:
reranker = get_reranker()
if not reranker.available:
print("ERROR: reranker model failed to load", file=sys.stderr)
sys.exit(2)
# Warmup so JIT / weight loading does not skew p50.
if warmup > 0:
warm = synthetic_passages(candidates_per_query, passage_chars)
for q in queries[:warmup] or [queries[0]]:
reranker.score(q, warm)
latencies_ms: list[float] = []
pair_count = 0
t0 = time.perf_counter()
for q in queries:
if source == "opensearch":
passages = passages_from_opensearch(q, candidates_per_query)
if len(passages) < candidates_per_query:
passages += synthetic_passages(candidates_per_query - len(passages), passage_chars)
else:
passages = synthetic_passages(candidates_per_query, passage_chars)
start = time.perf_counter()
reranker.score(q, passages)
latencies_ms.append((time.perf_counter() - start) * 1000.0)
pair_count += len(passages)
wall = time.perf_counter() - t0
return BenchResult(
model=reranker.model_name,
device=reranker.device,
queries=len(queries),
candidates_per_query=candidates_per_query,
passage_chars=passage_chars,
warmup=warmup,
p50_ms=percentile(latencies_ms, 50),
p95_ms=percentile(latencies_ms, 95),
p99_ms=percentile(latencies_ms, 99),
mean_ms=statistics.fmean(latencies_ms),
pairs_per_sec=pair_count / wall if wall > 0 else 0.0,
wall_seconds=wall,
)
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--queries", type=int, default=32)
parser.add_argument("--candidates", type=int, default=settings.rerank_candidates)
parser.add_argument("--passage-length", type=int, default=700,
help="Synthetic passage character length")
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--source", choices=["synthetic", "opensearch"], default="synthetic")
parser.add_argument("--query", type=str, default=None,
help="Single query to use against OpenSearch (with --source opensearch)")
parser.add_argument("--json-only", action="store_true")
args = parser.parse_args()
if args.source == "opensearch" and args.query:
queries = [args.query] * args.queries
else:
queries = synthetic_queries(args.queries)
result = run(
queries=queries,
candidates_per_query=args.candidates,
passage_chars=args.passage_length,
warmup=args.warmup,
source=args.source,
)
payload = asdict(result)
print(json.dumps(payload, indent=2))
if not args.json_only:
print()
print("| Metric | Value |")
print("|---------------------|-----------------|")
print(f"| Model | {result.model} |")
print(f"| Device | {result.device} |")
print(f"| Queries | {result.queries} |")
print(f"| Candidates / query | {result.candidates_per_query} |")
print(f"| Passage chars | {result.passage_chars} |")
print(f"| p50 latency | {result.p50_ms:.1f} ms |")
print(f"| p95 latency | {result.p95_ms:.1f} ms |")
print(f"| p99 latency | {result.p99_ms:.1f} ms |")
print(f"| mean latency | {result.mean_ms:.1f} ms |")
print(f"| Throughput | {result.pairs_per_sec:.1f} pairs/s |")
print(f"| Wall time | {result.wall_seconds:.2f} s |")
return 0
if __name__ == "__main__":
sys.exit(main())