Initialize git, add Apache-2.0 LICENSE, .gitattributes (LF line endings), AGENTS.md (entry points, stack, discovery order, baseline checks), RUNBOOK.md (dev boot, prod deploy with overlay, ingestion, failures, rollback, scaling notes), .env.prod.example with rotated credential placeholders, and dev-only warnings on .env.example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
"""BGE-M3 dense embedder with batching and CPU/GPU support.
|
|
|
|
We prefer FlagEmbedding's ``BGEM3FlagModel`` because it is the canonical
|
|
implementation and supports dense + sparse output. We fall back to
|
|
``sentence-transformers`` for portability.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
from typing import Sequence
|
|
|
|
import numpy as np
|
|
|
|
from app.config import settings
|
|
from app.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class Embedder:
|
|
def __init__(self, model_name: str, device: str, normalize: bool, batch_size: int) -> None:
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self.normalize = normalize
|
|
self.batch_size = batch_size
|
|
self._impl = "flagembedding"
|
|
self._model = None
|
|
self._st_model = None
|
|
self._load()
|
|
|
|
def _load(self) -> None:
|
|
try:
|
|
from FlagEmbedding import BGEM3FlagModel # type: ignore
|
|
use_fp16 = self.device != "cpu"
|
|
self._model = BGEM3FlagModel(self.model_name, use_fp16=use_fp16, devices=self.device)
|
|
self._impl = "flagembedding"
|
|
logger.info("embedder.loaded", impl="flagembedding", model=self.model_name, device=self.device)
|
|
return
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("embedder.flagembedding_failed", error=str(exc))
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
self._st_model = SentenceTransformer(self.model_name, device=self.device)
|
|
self._impl = "sentence-transformers"
|
|
logger.info("embedder.loaded", impl="sentence-transformers", model=self.model_name, device=self.device)
|
|
|
|
def encode(self, texts: Sequence[str]) -> list[list[float]]:
|
|
if not texts:
|
|
return []
|
|
if self._impl == "flagembedding":
|
|
out = self._model.encode( # type: ignore[union-attr]
|
|
list(texts),
|
|
batch_size=self.batch_size,
|
|
max_length=8192,
|
|
return_dense=True,
|
|
return_sparse=False,
|
|
return_colbert_vecs=False,
|
|
)
|
|
dense = out["dense_vecs"] if isinstance(out, dict) else out
|
|
arr = np.asarray(dense, dtype=np.float32)
|
|
else:
|
|
arr = self._st_model.encode( # type: ignore[union-attr]
|
|
list(texts),
|
|
batch_size=self.batch_size,
|
|
normalize_embeddings=self.normalize,
|
|
convert_to_numpy=True,
|
|
show_progress_bar=False,
|
|
)
|
|
arr = arr.astype(np.float32)
|
|
|
|
if self.normalize and self._impl == "flagembedding":
|
|
norms = np.linalg.norm(arr, axis=1, keepdims=True)
|
|
norms[norms == 0] = 1.0
|
|
arr = arr / norms
|
|
|
|
return arr.tolist()
|
|
|
|
def encode_one(self, text: str) -> list[float]:
|
|
return self.encode([text])[0]
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_embedder() -> Embedder:
|
|
return Embedder(
|
|
model_name=settings.embedding_model,
|
|
device=settings.embedding_device,
|
|
normalize=settings.embedding_normalize,
|
|
batch_size=settings.embedding_batch_size,
|
|
)
|