"""BGE-M3 dense embedder with batching and CPU/GPU support. We prefer FlagEmbedding's ``BGEM3FlagModel`` because it is the canonical implementation and supports dense + sparse output. We fall back to ``sentence-transformers`` for portability. """ from __future__ import annotations from functools import lru_cache from typing import Sequence import numpy as np from app.config import settings from app.logging_config import get_logger logger = get_logger(__name__) class Embedder: def __init__(self, model_name: str, device: str, normalize: bool, batch_size: int) -> None: self.model_name = model_name self.device = device self.normalize = normalize self.batch_size = batch_size self._impl = "flagembedding" self._model = None self._st_model = None self._load() def _load(self) -> None: try: from FlagEmbedding import BGEM3FlagModel # type: ignore use_fp16 = self.device != "cpu" self._model = BGEM3FlagModel(self.model_name, use_fp16=use_fp16, devices=self.device) self._impl = "flagembedding" logger.info("embedder.loaded", impl="flagembedding", model=self.model_name, device=self.device) return except Exception as exc: # noqa: BLE001 logger.warning("embedder.flagembedding_failed", error=str(exc)) from sentence_transformers import SentenceTransformer self._st_model = SentenceTransformer(self.model_name, device=self.device) self._impl = "sentence-transformers" logger.info("embedder.loaded", impl="sentence-transformers", model=self.model_name, device=self.device) def encode(self, texts: Sequence[str]) -> list[list[float]]: if not texts: return [] if self._impl == "flagembedding": out = self._model.encode( # type: ignore[union-attr] list(texts), batch_size=self.batch_size, max_length=8192, return_dense=True, return_sparse=False, return_colbert_vecs=False, ) dense = out["dense_vecs"] if isinstance(out, dict) else out arr = np.asarray(dense, dtype=np.float32) else: arr = self._st_model.encode( # type: ignore[union-attr] list(texts), batch_size=self.batch_size, normalize_embeddings=self.normalize, convert_to_numpy=True, show_progress_bar=False, ) arr = arr.astype(np.float32) if self.normalize and self._impl == "flagembedding": norms = np.linalg.norm(arr, axis=1, keepdims=True) norms[norms == 0] = 1.0 arr = arr / norms return arr.tolist() def encode_one(self, text: str) -> list[float]: return self.encode([text])[0] @lru_cache(maxsize=1) def get_embedder() -> Embedder: return Embedder( model_name=settings.embedding_model, device=settings.embedding_device, normalize=settings.embedding_normalize, batch_size=settings.embedding_batch_size, )