"""Reranker latency / throughput benchmark. Measures BGE-reranker-v2-m3 (or whatever ``RERANKER_MODEL`` resolves to) against synthetic or live corpus passages and prints the standard set of percentiles plus throughput. Use this on staging hardware to verify whether the configured device meets the latency budget before committing to a target top-K. Usage: # 1) synthetic warm-up (no DB / OpenSearch needed) python scripts/benchmark_reranker.py --queries 32 --candidates 40 \ --passage-length 700 --warmup 4 # 2) live corpus pull (samples real chunks from OpenSearch) python scripts/benchmark_reranker.py --source opensearch \ --query "ГОСТ 21.501-93" \ --candidates 40 Outputs JSON to stdout and a markdown summary table. """ from __future__ import annotations import argparse import json import statistics import sys import time from dataclasses import asdict, dataclass from app.config import settings from app.indexing.reranker import get_reranker from app.logging_config import configure_logging, get_logger configure_logging() logger = get_logger(__name__) @dataclass class BenchResult: model: str device: str queries: int candidates_per_query: int passage_chars: int warmup: int p50_ms: float p95_ms: float p99_ms: float mean_ms: float pairs_per_sec: float wall_seconds: float def percentile(values: list[float], q: float) -> float: if not values: return 0.0 s = sorted(values) idx = max(0, min(len(s) - 1, int(round((q / 100.0) * (len(s) - 1))))) return s[idx] def synthetic_passages(n: int, chars: int) -> list[str]: seed = "ГОСТ 21.501-93 определяет правила выполнения архитектурно-строительных рабочих чертежей. " base = (seed * ((chars // len(seed)) + 2))[:chars] return [f"[{i}] {base}" for i in range(n)] def synthetic_queries(n: int) -> list[str]: samples = [ "ГОСТ 21.501-93 рабочие чертежи", "класс бетона B25", "журнал ремонтов узлов", "правила производства земляных работ", "схема электропитания корпус 3", "контроль качества сварных соединений", "регламент технического обслуживания", ] return [samples[i % len(samples)] for i in range(n)] def passages_from_opensearch(query: str, top_k: int) -> list[str]: from app.indexing.opensearch_client import get_opensearch res = get_opensearch().search( index=settings.opensearch_index_chunks, body={ "size": top_k, "query": {"multi_match": {"query": query, "fields": ["text", "text.ru", "text.en"]}}, "_source": ["text"], }, request_timeout=30, ) return [h["_source"]["text"] for h in res["hits"]["hits"] if h["_source"].get("text")] def run( queries: list[str], candidates_per_query: int, passage_chars: int, warmup: int, source: str, ) -> BenchResult: reranker = get_reranker() if not reranker.available: print("ERROR: reranker model failed to load", file=sys.stderr) sys.exit(2) # Warmup so JIT / weight loading does not skew p50. if warmup > 0: warm = synthetic_passages(candidates_per_query, passage_chars) for q in queries[:warmup] or [queries[0]]: reranker.score(q, warm) latencies_ms: list[float] = [] pair_count = 0 t0 = time.perf_counter() for q in queries: if source == "opensearch": passages = passages_from_opensearch(q, candidates_per_query) if len(passages) < candidates_per_query: passages += synthetic_passages(candidates_per_query - len(passages), passage_chars) else: passages = synthetic_passages(candidates_per_query, passage_chars) start = time.perf_counter() reranker.score(q, passages) latencies_ms.append((time.perf_counter() - start) * 1000.0) pair_count += len(passages) wall = time.perf_counter() - t0 return BenchResult( model=reranker.model_name, device=reranker.device, queries=len(queries), candidates_per_query=candidates_per_query, passage_chars=passage_chars, warmup=warmup, p50_ms=percentile(latencies_ms, 50), p95_ms=percentile(latencies_ms, 95), p99_ms=percentile(latencies_ms, 99), mean_ms=statistics.fmean(latencies_ms), pairs_per_sec=pair_count / wall if wall > 0 else 0.0, wall_seconds=wall, ) def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--queries", type=int, default=32) parser.add_argument("--candidates", type=int, default=settings.rerank_candidates) parser.add_argument("--passage-length", type=int, default=700, help="Synthetic passage character length") parser.add_argument("--warmup", type=int, default=2) parser.add_argument("--source", choices=["synthetic", "opensearch"], default="synthetic") parser.add_argument("--query", type=str, default=None, help="Single query to use against OpenSearch (with --source opensearch)") parser.add_argument("--json-only", action="store_true") args = parser.parse_args() if args.source == "opensearch" and args.query: queries = [args.query] * args.queries else: queries = synthetic_queries(args.queries) result = run( queries=queries, candidates_per_query=args.candidates, passage_chars=args.passage_length, warmup=args.warmup, source=args.source, ) payload = asdict(result) print(json.dumps(payload, indent=2)) if not args.json_only: print() print("| Metric | Value |") print("|---------------------|-----------------|") print(f"| Model | {result.model} |") print(f"| Device | {result.device} |") print(f"| Queries | {result.queries} |") print(f"| Candidates / query | {result.candidates_per_query} |") print(f"| Passage chars | {result.passage_chars} |") print(f"| p50 latency | {result.p50_ms:.1f} ms |") print(f"| p95 latency | {result.p95_ms:.1f} ms |") print(f"| p99 latency | {result.p99_ms:.1f} ms |") print(f"| mean latency | {result.mean_ms:.1f} ms |") print(f"| Throughput | {result.pairs_per_sec:.1f} pairs/s |") print(f"| Wall time | {result.wall_seconds:.2f} s |") return 0 if __name__ == "__main__": sys.exit(main())