chore: bootstrap repository with governance docs
Initialize git, add Apache-2.0 LICENSE, .gitattributes (LF line endings), AGENTS.md (entry points, stack, discovery order, baseline checks), RUNBOOK.md (dev boot, prod deploy with overlay, ingestion, failures, rollback, scaling notes), .env.prod.example with rotated credential placeholders, and dev-only warnings on .env.example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
0
app/indexing/__init__.py
Normal file
0
app/indexing/__init__.py
Normal file
90
app/indexing/embeddings.py
Normal file
90
app/indexing/embeddings.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""BGE-M3 dense embedder with batching and CPU/GPU support.
|
||||
|
||||
We prefer FlagEmbedding's ``BGEM3FlagModel`` because it is the canonical
|
||||
implementation and supports dense + sparse output. We fall back to
|
||||
``sentence-transformers`` for portability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.config import settings
|
||||
from app.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Embedder:
|
||||
def __init__(self, model_name: str, device: str, normalize: bool, batch_size: int) -> None:
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.normalize = normalize
|
||||
self.batch_size = batch_size
|
||||
self._impl = "flagembedding"
|
||||
self._model = None
|
||||
self._st_model = None
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
from FlagEmbedding import BGEM3FlagModel # type: ignore
|
||||
use_fp16 = self.device != "cpu"
|
||||
self._model = BGEM3FlagModel(self.model_name, use_fp16=use_fp16, devices=self.device)
|
||||
self._impl = "flagembedding"
|
||||
logger.info("embedder.loaded", impl="flagembedding", model=self.model_name, device=self.device)
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("embedder.flagembedding_failed", error=str(exc))
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self._st_model = SentenceTransformer(self.model_name, device=self.device)
|
||||
self._impl = "sentence-transformers"
|
||||
logger.info("embedder.loaded", impl="sentence-transformers", model=self.model_name, device=self.device)
|
||||
|
||||
def encode(self, texts: Sequence[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
if self._impl == "flagembedding":
|
||||
out = self._model.encode( # type: ignore[union-attr]
|
||||
list(texts),
|
||||
batch_size=self.batch_size,
|
||||
max_length=8192,
|
||||
return_dense=True,
|
||||
return_sparse=False,
|
||||
return_colbert_vecs=False,
|
||||
)
|
||||
dense = out["dense_vecs"] if isinstance(out, dict) else out
|
||||
arr = np.asarray(dense, dtype=np.float32)
|
||||
else:
|
||||
arr = self._st_model.encode( # type: ignore[union-attr]
|
||||
list(texts),
|
||||
batch_size=self.batch_size,
|
||||
normalize_embeddings=self.normalize,
|
||||
convert_to_numpy=True,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
arr = arr.astype(np.float32)
|
||||
|
||||
if self.normalize and self._impl == "flagembedding":
|
||||
norms = np.linalg.norm(arr, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1.0
|
||||
arr = arr / norms
|
||||
|
||||
return arr.tolist()
|
||||
|
||||
def encode_one(self, text: str) -> list[float]:
|
||||
return self.encode([text])[0]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_embedder() -> Embedder:
|
||||
return Embedder(
|
||||
model_name=settings.embedding_model,
|
||||
device=settings.embedding_device,
|
||||
normalize=settings.embedding_normalize,
|
||||
batch_size=settings.embedding_batch_size,
|
||||
)
|
||||
327
app/indexing/hybrid_search.py
Normal file
327
app/indexing/hybrid_search.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Hybrid search: lexical (OpenSearch BM25) + semantic (Qdrant) + RRF + reranker.
|
||||
|
||||
Always returns ``SearchResponse`` (never throws on missing index/collection -
|
||||
empty results are valid).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client.http import models as qm
|
||||
|
||||
from app.api.schemas import (
|
||||
Citation,
|
||||
SearchFilters,
|
||||
SearchHit,
|
||||
SearchMode,
|
||||
SearchRequest,
|
||||
SearchResponse,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.indexing.embeddings import get_embedder
|
||||
from app.indexing.opensearch_client import get_opensearch
|
||||
from app.indexing.qdrant_client import DENSE_VECTOR_NAME, get_qdrant
|
||||
from app.indexing.reranker import get_reranker
|
||||
from app.logging_config import get_logger
|
||||
from app.utils.text_cleaning import normalize_for_search
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Candidate:
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
page_number: int
|
||||
block_type: str
|
||||
block_id: str | None
|
||||
text: str
|
||||
source_path: str
|
||||
original_file_name: str
|
||||
quality_flags: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
bm25_score: float | None = None
|
||||
bm25_rank: int | None = None
|
||||
dense_score: float | None = None
|
||||
dense_rank: int | None = None
|
||||
|
||||
|
||||
def run_search(req: SearchRequest) -> SearchResponse:
|
||||
mode: SearchMode = req.search_mode
|
||||
filters = req.filters
|
||||
|
||||
lexical: list[_Candidate] = []
|
||||
semantic: list[_Candidate] = []
|
||||
|
||||
if mode in ("lexical", "hybrid"):
|
||||
try:
|
||||
lexical = _lexical_search(req.query, filters, settings.hybrid_opensearch_top_k)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("search.lexical_failed", error=str(exc))
|
||||
|
||||
if mode in ("semantic", "hybrid"):
|
||||
try:
|
||||
semantic = _semantic_search(req.query, filters, settings.hybrid_qdrant_top_k)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("search.semantic_failed", error=str(exc))
|
||||
|
||||
merged = _merge(lexical, semantic, mode)
|
||||
candidates = merged[: settings.rerank_candidates]
|
||||
|
||||
reranker = get_reranker()
|
||||
reranked_flag = False
|
||||
if settings.reranker_enabled and reranker.available and candidates:
|
||||
scores = reranker.score(req.query, [c.text for c in candidates])
|
||||
for c, s in zip(candidates, scores, strict=True):
|
||||
c.dense_score = s
|
||||
candidates.sort(key=lambda c: (c.dense_score or 0.0), reverse=True)
|
||||
reranked_flag = True
|
||||
|
||||
final = candidates[: req.limit]
|
||||
|
||||
hits: list[SearchHit] = []
|
||||
for rank, c in enumerate(final, start=1):
|
||||
score = (
|
||||
c.dense_score
|
||||
if reranked_flag
|
||||
else (c.dense_score if mode == "semantic" else c.bm25_score) or 0.0
|
||||
)
|
||||
hits.append(
|
||||
SearchHit(
|
||||
rank=rank,
|
||||
score=float(score),
|
||||
document_id=uuid.UUID(c.document_id),
|
||||
chunk_id=uuid.UUID(c.chunk_id),
|
||||
original_file_name=c.original_file_name,
|
||||
source_path=c.source_path,
|
||||
page_number=c.page_number,
|
||||
block_type=c.block_type,
|
||||
text=c.text,
|
||||
citation=Citation(
|
||||
pdf=c.original_file_name,
|
||||
page=c.page_number,
|
||||
block_id=c.block_id,
|
||||
table_id=str(c.metadata.get("table_index")) if c.metadata.get("table_index") is not None else None,
|
||||
figure_id=str(c.metadata.get("figure_index")) if c.metadata.get("figure_index") is not None else None,
|
||||
),
|
||||
quality_flags=c.quality_flags,
|
||||
metadata=c.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
query=req.query,
|
||||
mode=mode,
|
||||
total_candidates=len(merged),
|
||||
reranked=reranked_flag,
|
||||
results=hits,
|
||||
)
|
||||
|
||||
|
||||
# ---------------- lexical ----------------
|
||||
|
||||
def _lexical_search(query: str, filters: SearchFilters, top_k: int) -> list[_Candidate]:
|
||||
client = get_opensearch()
|
||||
if not client.indices.exists(index=settings.opensearch_index_chunks):
|
||||
return []
|
||||
|
||||
must = [
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query,
|
||||
"fields": ["text^1.0", "text.ru^1.5", "text.en^1.5", "normalized_text^0.7"],
|
||||
"type": "best_fields",
|
||||
"operator": "or",
|
||||
}
|
||||
}
|
||||
]
|
||||
norm = normalize_for_search(query)
|
||||
if norm and norm != query.lower():
|
||||
must.append({"match": {"normalized_text": {"query": norm, "boost": 0.5}}})
|
||||
|
||||
filter_clauses = _opensearch_filters(filters)
|
||||
body = {
|
||||
"size": top_k,
|
||||
"query": {"bool": {"must": must, "filter": filter_clauses}},
|
||||
"_source": [
|
||||
"chunk_id",
|
||||
"document_id",
|
||||
"source_path",
|
||||
"original_file_name",
|
||||
"page_number",
|
||||
"block_type",
|
||||
"block_id",
|
||||
"text",
|
||||
"quality_flags",
|
||||
"metadata",
|
||||
],
|
||||
}
|
||||
res = client.search(index=settings.opensearch_index_chunks, body=body, request_timeout=30)
|
||||
out: list[_Candidate] = []
|
||||
for rank, hit in enumerate(res.get("hits", {}).get("hits", []), start=1):
|
||||
s = hit.get("_source", {})
|
||||
out.append(
|
||||
_Candidate(
|
||||
chunk_id=s["chunk_id"],
|
||||
document_id=s["document_id"],
|
||||
page_number=int(s.get("page_number", 0)),
|
||||
block_type=s.get("block_type", "paragraph"),
|
||||
block_id=s.get("block_id"),
|
||||
text=s.get("text", ""),
|
||||
source_path=s.get("source_path", ""),
|
||||
original_file_name=s.get("original_file_name", ""),
|
||||
quality_flags=s.get("quality_flags") or {},
|
||||
metadata=s.get("metadata") or {},
|
||||
bm25_score=float(hit.get("_score") or 0.0),
|
||||
bm25_rank=rank,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _opensearch_filters(filters: SearchFilters) -> list[dict[str, Any]]:
|
||||
clauses: list[dict[str, Any]] = []
|
||||
if filters.document_id:
|
||||
clauses.append({"term": {"document_id": str(filters.document_id)}})
|
||||
if filters.source_path:
|
||||
clauses.append({"term": {"source_path": filters.source_path}})
|
||||
if filters.block_type:
|
||||
clauses.append({"term": {"block_type": filters.block_type}})
|
||||
if filters.min_ocr_confidence is not None:
|
||||
clauses.append({"range": {"ocr_confidence": {"gte": filters.min_ocr_confidence}}})
|
||||
return clauses
|
||||
|
||||
|
||||
# ---------------- semantic ----------------
|
||||
|
||||
def _semantic_search(query: str, filters: SearchFilters, top_k: int) -> list[_Candidate]:
|
||||
embedder = get_embedder()
|
||||
vector = embedder.encode_one(query)
|
||||
qf = _qdrant_filter(filters)
|
||||
|
||||
client = get_qdrant()
|
||||
try:
|
||||
results = client.query_points(
|
||||
collection_name=settings.qdrant_collection_chunks,
|
||||
query=vector,
|
||||
using=DENSE_VECTOR_NAME,
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
query_filter=qf,
|
||||
).points
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("qdrant.query_points_fallback", error=str(exc))
|
||||
results = client.search(
|
||||
collection_name=settings.qdrant_collection_chunks,
|
||||
query_vector=(DENSE_VECTOR_NAME, vector),
|
||||
query_filter=qf,
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
out: list[_Candidate] = []
|
||||
for rank, p in enumerate(results, start=1):
|
||||
payload = p.payload or {}
|
||||
chunk_id = payload.get("chunk_id") or str(p.id)
|
||||
out.append(
|
||||
_Candidate(
|
||||
chunk_id=str(chunk_id),
|
||||
document_id=str(payload.get("document_id", "")),
|
||||
page_number=int(payload.get("page_number") or 0),
|
||||
block_type=payload.get("block_type", "paragraph"),
|
||||
block_id=payload.get("block_id"),
|
||||
text=payload.get("text_preview", ""),
|
||||
source_path=payload.get("source_path", ""),
|
||||
original_file_name=payload.get("original_file_name", ""),
|
||||
quality_flags=payload.get("quality_flags") or {},
|
||||
metadata=payload.get("metadata") or {},
|
||||
dense_score=float(p.score or 0.0),
|
||||
dense_rank=rank,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _qdrant_filter(filters: SearchFilters) -> qm.Filter | None:
|
||||
must: list[qm.FieldCondition | qm.Range] = []
|
||||
if filters.document_id:
|
||||
must.append(qm.FieldCondition(key="document_id", match=qm.MatchValue(value=str(filters.document_id))))
|
||||
if filters.source_path:
|
||||
must.append(qm.FieldCondition(key="source_path", match=qm.MatchValue(value=filters.source_path)))
|
||||
if filters.block_type:
|
||||
must.append(qm.FieldCondition(key="block_type", match=qm.MatchValue(value=filters.block_type)))
|
||||
if filters.min_ocr_confidence is not None:
|
||||
must.append(qm.FieldCondition(key="ocr_confidence", range=qm.Range(gte=filters.min_ocr_confidence)))
|
||||
if not must:
|
||||
return None
|
||||
return qm.Filter(must=must)
|
||||
|
||||
|
||||
# ---------------- merge ----------------
|
||||
|
||||
def _merge(lexical: list[_Candidate], semantic: list[_Candidate], mode: SearchMode) -> list[_Candidate]:
|
||||
if mode == "lexical":
|
||||
return lexical
|
||||
if mode == "semantic":
|
||||
return _hydrate_semantic_text(semantic)
|
||||
|
||||
by_id: dict[str, _Candidate] = {}
|
||||
for c in lexical:
|
||||
by_id[c.chunk_id] = c
|
||||
for c in semantic:
|
||||
if c.chunk_id in by_id:
|
||||
by_id[c.chunk_id].dense_score = c.dense_score
|
||||
by_id[c.chunk_id].dense_rank = c.dense_rank
|
||||
if not by_id[c.chunk_id].text:
|
||||
by_id[c.chunk_id].text = c.text
|
||||
else:
|
||||
by_id[c.chunk_id] = c
|
||||
|
||||
rrf: dict[str, float] = defaultdict(float)
|
||||
k = settings.hybrid_rrf_k
|
||||
for c in lexical:
|
||||
if c.bm25_rank is not None:
|
||||
rrf[c.chunk_id] += 1.0 / (k + c.bm25_rank)
|
||||
for c in semantic:
|
||||
if c.dense_rank is not None:
|
||||
rrf[c.chunk_id] += 1.0 / (k + c.dense_rank)
|
||||
|
||||
items = sorted(by_id.values(), key=lambda c: rrf.get(c.chunk_id, 0.0), reverse=True)
|
||||
return _hydrate_full_text(items)
|
||||
|
||||
|
||||
def _hydrate_full_text(candidates: list[_Candidate]) -> list[_Candidate]:
|
||||
"""For candidates whose text came only from Qdrant payload (preview), pull
|
||||
the full chunk text from OpenSearch by id so the reranker sees full content.
|
||||
"""
|
||||
missing = [c for c in candidates if len(c.text) <= 512]
|
||||
if not missing:
|
||||
return candidates
|
||||
client = get_opensearch()
|
||||
ids = [c.chunk_id for c in missing]
|
||||
try:
|
||||
res = client.mget(index=settings.opensearch_index_chunks, body={"ids": ids})
|
||||
except Exception:
|
||||
return candidates
|
||||
by_id = {d["_id"]: d.get("_source", {}) for d in res.get("docs", []) if d.get("found")}
|
||||
for c in missing:
|
||||
s = by_id.get(c.chunk_id)
|
||||
if s and s.get("text"):
|
||||
c.text = s["text"]
|
||||
if not c.original_file_name:
|
||||
c.original_file_name = s.get("original_file_name", "")
|
||||
if not c.source_path:
|
||||
c.source_path = s.get("source_path", "")
|
||||
if not c.metadata:
|
||||
c.metadata = s.get("metadata") or {}
|
||||
if not c.quality_flags:
|
||||
c.quality_flags = s.get("quality_flags") or {}
|
||||
return candidates
|
||||
|
||||
|
||||
def _hydrate_semantic_text(candidates: list[_Candidate]) -> list[_Candidate]:
|
||||
return _hydrate_full_text(candidates)
|
||||
142
app/indexing/opensearch_client.py
Normal file
142
app/indexing/opensearch_client.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""OpenSearch client + index bootstrap + chunk indexing helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Any, Iterable
|
||||
|
||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||
from opensearchpy.helpers import bulk
|
||||
|
||||
from app.config import settings
|
||||
from app.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Index settings: 3 analyzers (russian, english, standard).
|
||||
# We index ``text`` with multi-fields (.ru, .en, .raw) so we can boost per language at query time.
|
||||
INDEX_SETTINGS: dict[str, Any] = {
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
"analysis": {
|
||||
"filter": {
|
||||
"ru_stop": {"type": "stop", "stopwords": "_russian_"},
|
||||
"ru_stemmer": {"type": "stemmer", "language": "russian"},
|
||||
"en_stop": {"type": "stop", "stopwords": "_english_"},
|
||||
"en_stemmer": {"type": "stemmer", "language": "english"},
|
||||
},
|
||||
"analyzer": {
|
||||
"ru_analyzer": {
|
||||
"type": "custom",
|
||||
"tokenizer": "standard",
|
||||
"filter": ["lowercase", "ru_stop", "ru_stemmer"],
|
||||
},
|
||||
"en_analyzer": {
|
||||
"type": "custom",
|
||||
"tokenizer": "standard",
|
||||
"filter": ["lowercase", "en_stop", "en_stemmer"],
|
||||
},
|
||||
"code_analyzer": {
|
||||
"type": "custom",
|
||||
"tokenizer": "standard",
|
||||
"filter": ["lowercase"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"mappings": {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"chunk_id": {"type": "keyword"},
|
||||
"document_id": {"type": "keyword"},
|
||||
"source_path": {"type": "keyword"},
|
||||
"original_file_name": {
|
||||
"type": "text",
|
||||
"fields": {"keyword": {"type": "keyword", "ignore_above": 512}},
|
||||
},
|
||||
"page_number": {"type": "integer"},
|
||||
"block_type": {"type": "keyword"},
|
||||
"block_id": {"type": "keyword"},
|
||||
"text": {
|
||||
"type": "text",
|
||||
"analyzer": "code_analyzer",
|
||||
"fields": {
|
||||
"ru": {"type": "text", "analyzer": "ru_analyzer"},
|
||||
"en": {"type": "text", "analyzer": "en_analyzer"},
|
||||
},
|
||||
},
|
||||
"normalized_text": {
|
||||
"type": "text",
|
||||
"analyzer": "code_analyzer",
|
||||
},
|
||||
"ocr_confidence": {"type": "float"},
|
||||
"language_hint": {"type": "keyword"},
|
||||
"metadata": {"type": "object", "enabled": True},
|
||||
"quality_flags": {"type": "object", "enabled": True},
|
||||
"created_at": {"type": "date"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_opensearch() -> OpenSearch:
|
||||
auth = None
|
||||
if settings.opensearch_user and settings.opensearch_password:
|
||||
auth = (settings.opensearch_user, settings.opensearch_password)
|
||||
return OpenSearch(
|
||||
hosts=[{"host": settings.opensearch_host, "port": settings.opensearch_port}],
|
||||
http_auth=auth,
|
||||
use_ssl=settings.opensearch_use_ssl,
|
||||
verify_certs=settings.opensearch_verify_certs,
|
||||
ssl_show_warn=False,
|
||||
connection_class=RequestsHttpConnection,
|
||||
timeout=30,
|
||||
max_retries=3,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
|
||||
|
||||
def ensure_index(index: str | None = None) -> None:
|
||||
name = index or settings.opensearch_index_chunks
|
||||
client = get_opensearch()
|
||||
if client.indices.exists(index=name):
|
||||
logger.debug("opensearch.index.exists", index=name)
|
||||
return
|
||||
logger.info("opensearch.index.create", index=name)
|
||||
client.indices.create(index=name, body=INDEX_SETTINGS)
|
||||
|
||||
|
||||
def index_chunks(docs: Iterable[dict[str, Any]], index: str | None = None) -> tuple[int, int]:
|
||||
"""Bulk-upsert chunks. Returns (success, errors)."""
|
||||
name = index or settings.opensearch_index_chunks
|
||||
actions: list[dict[str, Any]] = []
|
||||
for d in docs:
|
||||
actions.append(
|
||||
{
|
||||
"_op_type": "index",
|
||||
"_index": name,
|
||||
"_id": d["chunk_id"],
|
||||
"_source": d,
|
||||
}
|
||||
)
|
||||
if not actions:
|
||||
return 0, 0
|
||||
success, errors = bulk(get_opensearch(), actions, raise_on_error=False, request_timeout=120)
|
||||
if errors:
|
||||
logger.warning("opensearch.bulk.errors", count=len(errors))
|
||||
return success, len(errors) if isinstance(errors, list) else 0
|
||||
|
||||
|
||||
def delete_by_document(document_id: str, index: str | None = None) -> int:
|
||||
name = index or settings.opensearch_index_chunks
|
||||
client = get_opensearch()
|
||||
if not client.indices.exists(index=name):
|
||||
return 0
|
||||
res = client.delete_by_query(
|
||||
index=name,
|
||||
body={"query": {"term": {"document_id": document_id}}},
|
||||
refresh=True,
|
||||
)
|
||||
return int(res.get("deleted", 0))
|
||||
103
app/indexing/qdrant_client.py
Normal file
103
app/indexing/qdrant_client.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Qdrant client + collection bootstrap + chunk upsert."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Any, Sequence
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as qm
|
||||
|
||||
from app.config import settings
|
||||
from app.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DENSE_VECTOR_NAME = "dense"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_qdrant() -> QdrantClient:
|
||||
return QdrantClient(
|
||||
host=settings.qdrant_host,
|
||||
port=settings.qdrant_port,
|
||||
api_key=settings.qdrant_api_key or None,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
|
||||
def ensure_collection(collection: str | None = None, dim: int | None = None) -> None:
|
||||
name = collection or settings.qdrant_collection_chunks
|
||||
vector_size = dim or settings.embedding_dim
|
||||
client = get_qdrant()
|
||||
existing = {c.name for c in client.get_collections().collections}
|
||||
if name in existing:
|
||||
logger.debug("qdrant.collection.exists", collection=name)
|
||||
return
|
||||
logger.info("qdrant.collection.create", collection=name, dim=vector_size)
|
||||
client.create_collection(
|
||||
collection_name=name,
|
||||
vectors_config={
|
||||
DENSE_VECTOR_NAME: qm.VectorParams(
|
||||
size=vector_size,
|
||||
distance=qm.Distance.COSINE,
|
||||
)
|
||||
},
|
||||
optimizers_config=qm.OptimizersConfigDiff(default_segment_number=2),
|
||||
)
|
||||
# Payload indexes for filtering.
|
||||
for field in ("document_id", "source_path", "block_type"):
|
||||
client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name=field,
|
||||
field_schema=qm.PayloadSchemaType.KEYWORD,
|
||||
)
|
||||
client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name="page_number",
|
||||
field_schema=qm.PayloadSchemaType.INTEGER,
|
||||
)
|
||||
client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name="ocr_confidence",
|
||||
field_schema=qm.PayloadSchemaType.FLOAT,
|
||||
)
|
||||
|
||||
|
||||
def upsert_chunks(
|
||||
points: Sequence[tuple[str, list[float], dict[str, Any]]],
|
||||
collection: str | None = None,
|
||||
) -> int:
|
||||
"""Upsert (chunk_id, vector, payload) triples. Returns count upserted."""
|
||||
name = collection or settings.qdrant_collection_chunks
|
||||
if not points:
|
||||
return 0
|
||||
qpoints = [
|
||||
qm.PointStruct(
|
||||
id=_qid(chunk_id),
|
||||
vector={DENSE_VECTOR_NAME: vector},
|
||||
payload={**payload, "chunk_id": chunk_id},
|
||||
)
|
||||
for chunk_id, vector, payload in points
|
||||
]
|
||||
get_qdrant().upsert(collection_name=name, points=qpoints, wait=False)
|
||||
return len(qpoints)
|
||||
|
||||
|
||||
def delete_by_document(document_id: str, collection: str | None = None) -> int:
|
||||
name = collection or settings.qdrant_collection_chunks
|
||||
client = get_qdrant()
|
||||
client.delete(
|
||||
collection_name=name,
|
||||
points_selector=qm.FilterSelector(
|
||||
filter=qm.Filter(
|
||||
must=[qm.FieldCondition(key="document_id", match=qm.MatchValue(value=document_id))]
|
||||
)
|
||||
),
|
||||
)
|
||||
return 1
|
||||
|
||||
|
||||
def _qid(chunk_id: str) -> str:
|
||||
"""Qdrant accepts UUID strings or unsigned ints. Chunks are UUIDs already."""
|
||||
return chunk_id
|
||||
75
app/indexing/reranker.py
Normal file
75
app/indexing/reranker.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""BGE reranker - cross-encoder style scoring of (query, passage) pairs.
|
||||
|
||||
Designed to degrade gracefully:
|
||||
- If the model fails to load, ``rerank`` returns inputs unchanged with the
|
||||
``reranked`` flag set to False so the API can report the truth to clients.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Sequence
|
||||
|
||||
from app.config import settings
|
||||
from app.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Reranker:
|
||||
def __init__(self, model_name: str, device: str, batch_size: int) -> None:
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.batch_size = batch_size
|
||||
self._impl: str | None = None
|
||||
self._model = None
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
from FlagEmbedding import FlagReranker # type: ignore
|
||||
use_fp16 = self.device != "cpu"
|
||||
self._model = FlagReranker(self.model_name, use_fp16=use_fp16, devices=self.device)
|
||||
self._impl = "flagembedding"
|
||||
logger.info("reranker.loaded", impl="flagembedding", model=self.model_name, device=self.device)
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("reranker.flagembedding_failed", error=str(exc))
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder
|
||||
self._model = CrossEncoder(self.model_name, device=self.device)
|
||||
self._impl = "sentence-transformers"
|
||||
logger.info("reranker.loaded", impl="sentence-transformers", model=self.model_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("reranker.disabled", error=str(exc))
|
||||
self._impl = None
|
||||
self._model = None
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return self._impl is not None and self._model is not None
|
||||
|
||||
def score(self, query: str, passages: Sequence[str]) -> list[float]:
|
||||
if not self.available or not passages:
|
||||
return [0.0] * len(passages)
|
||||
pairs = [(query, p) for p in passages]
|
||||
if self._impl == "flagembedding":
|
||||
scores = self._model.compute_score(pairs, batch_size=self.batch_size, normalize=True) # type: ignore[union-attr]
|
||||
else:
|
||||
scores = self._model.predict(pairs, batch_size=self.batch_size) # type: ignore[union-attr]
|
||||
if not isinstance(scores, list):
|
||||
try:
|
||||
scores = list(scores)
|
||||
except TypeError:
|
||||
scores = [float(scores)]
|
||||
return [float(s) for s in scores]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_reranker() -> Reranker:
|
||||
return Reranker(
|
||||
model_name=settings.reranker_model,
|
||||
device=settings.reranker_device,
|
||||
batch_size=settings.reranker_batch_size,
|
||||
)
|
||||
Reference in New Issue
Block a user