perf(reranker): add benchmark harness and passage clipping
- scripts/benchmark_reranker.py exercises the configured reranker with synthetic queries or live OpenSearch samples and prints p50/p95/p99 latency, mean latency, and pairs/sec throughput. Supports --warmup, --candidates, --passage-length, --source, and a --json-only mode for CI. - app/indexing/reranker.py clips passages to 2048 characters before scoring so a runaway chunk cannot starve the cross-encoder beyond bge-reranker-v2-m3's training window. - RUNBOOK.md gains a Reranker benchmark section with CPU/GPU SLO targets and a remediation ladder (lower top-K, raise batch size, switch device, disable reranker) when measured p95 exceeds budget. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
195
scripts/benchmark_reranker.py
Normal file
195
scripts/benchmark_reranker.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Reranker latency / throughput benchmark.
|
||||
|
||||
Measures BGE-reranker-v2-m3 (or whatever ``RERANKER_MODEL`` resolves to)
|
||||
against synthetic or live corpus passages and prints the standard set of
|
||||
percentiles plus throughput. Use this on staging hardware to verify whether
|
||||
the configured device meets the latency budget before committing to a target
|
||||
top-K.
|
||||
|
||||
Usage:
|
||||
|
||||
# 1) synthetic warm-up (no DB / OpenSearch needed)
|
||||
python scripts/benchmark_reranker.py --queries 32 --candidates 40 \
|
||||
--passage-length 700 --warmup 4
|
||||
|
||||
# 2) live corpus pull (samples real chunks from OpenSearch)
|
||||
python scripts/benchmark_reranker.py --source opensearch \
|
||||
--query "ГОСТ 21.501-93" \
|
||||
--candidates 40
|
||||
|
||||
Outputs JSON to stdout and a markdown summary table.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
from app.config import settings
|
||||
from app.indexing.reranker import get_reranker
|
||||
from app.logging_config import configure_logging, get_logger
|
||||
|
||||
configure_logging()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchResult:
|
||||
model: str
|
||||
device: str
|
||||
queries: int
|
||||
candidates_per_query: int
|
||||
passage_chars: int
|
||||
warmup: int
|
||||
p50_ms: float
|
||||
p95_ms: float
|
||||
p99_ms: float
|
||||
mean_ms: float
|
||||
pairs_per_sec: float
|
||||
wall_seconds: float
|
||||
|
||||
|
||||
def percentile(values: list[float], q: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
s = sorted(values)
|
||||
idx = max(0, min(len(s) - 1, int(round((q / 100.0) * (len(s) - 1)))))
|
||||
return s[idx]
|
||||
|
||||
|
||||
def synthetic_passages(n: int, chars: int) -> list[str]:
|
||||
seed = "ГОСТ 21.501-93 определяет правила выполнения архитектурно-строительных рабочих чертежей. "
|
||||
base = (seed * ((chars // len(seed)) + 2))[:chars]
|
||||
return [f"[{i}] {base}" for i in range(n)]
|
||||
|
||||
|
||||
def synthetic_queries(n: int) -> list[str]:
|
||||
samples = [
|
||||
"ГОСТ 21.501-93 рабочие чертежи",
|
||||
"класс бетона B25",
|
||||
"журнал ремонтов узлов",
|
||||
"правила производства земляных работ",
|
||||
"схема электропитания корпус 3",
|
||||
"контроль качества сварных соединений",
|
||||
"регламент технического обслуживания",
|
||||
]
|
||||
return [samples[i % len(samples)] for i in range(n)]
|
||||
|
||||
|
||||
def passages_from_opensearch(query: str, top_k: int) -> list[str]:
|
||||
from app.indexing.opensearch_client import get_opensearch
|
||||
res = get_opensearch().search(
|
||||
index=settings.opensearch_index_chunks,
|
||||
body={
|
||||
"size": top_k,
|
||||
"query": {"multi_match": {"query": query, "fields": ["text", "text.ru", "text.en"]}},
|
||||
"_source": ["text"],
|
||||
},
|
||||
request_timeout=30,
|
||||
)
|
||||
return [h["_source"]["text"] for h in res["hits"]["hits"] if h["_source"].get("text")]
|
||||
|
||||
|
||||
def run(
|
||||
queries: list[str],
|
||||
candidates_per_query: int,
|
||||
passage_chars: int,
|
||||
warmup: int,
|
||||
source: str,
|
||||
) -> BenchResult:
|
||||
reranker = get_reranker()
|
||||
if not reranker.available:
|
||||
print("ERROR: reranker model failed to load", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
# Warmup so JIT / weight loading does not skew p50.
|
||||
if warmup > 0:
|
||||
warm = synthetic_passages(candidates_per_query, passage_chars)
|
||||
for q in queries[:warmup] or [queries[0]]:
|
||||
reranker.score(q, warm)
|
||||
|
||||
latencies_ms: list[float] = []
|
||||
pair_count = 0
|
||||
t0 = time.perf_counter()
|
||||
for q in queries:
|
||||
if source == "opensearch":
|
||||
passages = passages_from_opensearch(q, candidates_per_query)
|
||||
if len(passages) < candidates_per_query:
|
||||
passages += synthetic_passages(candidates_per_query - len(passages), passage_chars)
|
||||
else:
|
||||
passages = synthetic_passages(candidates_per_query, passage_chars)
|
||||
start = time.perf_counter()
|
||||
reranker.score(q, passages)
|
||||
latencies_ms.append((time.perf_counter() - start) * 1000.0)
|
||||
pair_count += len(passages)
|
||||
wall = time.perf_counter() - t0
|
||||
|
||||
return BenchResult(
|
||||
model=reranker.model_name,
|
||||
device=reranker.device,
|
||||
queries=len(queries),
|
||||
candidates_per_query=candidates_per_query,
|
||||
passage_chars=passage_chars,
|
||||
warmup=warmup,
|
||||
p50_ms=percentile(latencies_ms, 50),
|
||||
p95_ms=percentile(latencies_ms, 95),
|
||||
p99_ms=percentile(latencies_ms, 99),
|
||||
mean_ms=statistics.fmean(latencies_ms),
|
||||
pairs_per_sec=pair_count / wall if wall > 0 else 0.0,
|
||||
wall_seconds=wall,
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--queries", type=int, default=32)
|
||||
parser.add_argument("--candidates", type=int, default=settings.rerank_candidates)
|
||||
parser.add_argument("--passage-length", type=int, default=700,
|
||||
help="Synthetic passage character length")
|
||||
parser.add_argument("--warmup", type=int, default=2)
|
||||
parser.add_argument("--source", choices=["synthetic", "opensearch"], default="synthetic")
|
||||
parser.add_argument("--query", type=str, default=None,
|
||||
help="Single query to use against OpenSearch (with --source opensearch)")
|
||||
parser.add_argument("--json-only", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.source == "opensearch" and args.query:
|
||||
queries = [args.query] * args.queries
|
||||
else:
|
||||
queries = synthetic_queries(args.queries)
|
||||
|
||||
result = run(
|
||||
queries=queries,
|
||||
candidates_per_query=args.candidates,
|
||||
passage_chars=args.passage_length,
|
||||
warmup=args.warmup,
|
||||
source=args.source,
|
||||
)
|
||||
|
||||
payload = asdict(result)
|
||||
print(json.dumps(payload, indent=2))
|
||||
|
||||
if not args.json_only:
|
||||
print()
|
||||
print("| Metric | Value |")
|
||||
print("|---------------------|-----------------|")
|
||||
print(f"| Model | {result.model} |")
|
||||
print(f"| Device | {result.device} |")
|
||||
print(f"| Queries | {result.queries} |")
|
||||
print(f"| Candidates / query | {result.candidates_per_query} |")
|
||||
print(f"| Passage chars | {result.passage_chars} |")
|
||||
print(f"| p50 latency | {result.p50_ms:.1f} ms |")
|
||||
print(f"| p95 latency | {result.p95_ms:.1f} ms |")
|
||||
print(f"| p99 latency | {result.p99_ms:.1f} ms |")
|
||||
print(f"| mean latency | {result.mean_ms:.1f} ms |")
|
||||
print(f"| Throughput | {result.pairs_per_sec:.1f} pairs/s |")
|
||||
print(f"| Wall time | {result.wall_seconds:.2f} s |")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user