535 lines
21 KiB
Python
535 lines
21 KiB
Python
"""
|
|
ANCE-style hard negative mining for BGE models
|
|
"""
|
|
import os
|
|
import json
|
|
import logging
|
|
import threading
|
|
import queue
|
|
from typing import List, Dict, Optional, Tuple, Union
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from tqdm import tqdm
|
|
import faiss
|
|
|
|
# Optional imports for text processing
|
|
try:
|
|
import jieba
|
|
JIEBA_AVAILABLE = True
|
|
except ImportError:
|
|
JIEBA_AVAILABLE = False
|
|
jieba = None
|
|
|
|
try:
|
|
from rank_bm25 import BM25Okapi
|
|
BM25_AVAILABLE = True
|
|
except ImportError:
|
|
BM25_AVAILABLE = False
|
|
BM25Okapi = None
|
|
|
|
from collections import defaultdict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ANCEHardNegativeMiner:
|
|
"""
|
|
Approximate Nearest Neighbor Negative Contrastive Learning
|
|
Dynamically mines hard negatives using evolving model
|
|
|
|
Notes:
|
|
- All parameters (embedding_dim, num_hard_negatives, negative_range, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int = 768,
|
|
num_hard_negatives: int = 15,
|
|
negative_range: Tuple[int, int] = (10, 100),
|
|
margin: float = 0.1,
|
|
index_refresh_interval: int = 1000,
|
|
index_type: str = "IVF",
|
|
nlist: int = 100,
|
|
use_gpu: bool = True,
|
|
max_index_size: int = 1000000
|
|
):
|
|
self.embedding_dim = embedding_dim
|
|
self.num_hard_negatives = num_hard_negatives
|
|
self.negative_range = negative_range
|
|
self.margin = margin
|
|
self.index_refresh_interval = index_refresh_interval
|
|
self.use_gpu = use_gpu and torch.cuda.is_available()
|
|
self.max_index_size = max_index_size
|
|
|
|
# Initialize FAISS index
|
|
self.index = self._create_index(index_type, nlist)
|
|
self.passage_map: Dict[int, str] = {} # Maps index ID to passage text
|
|
self.query_map: Dict[str, List[str]] = defaultdict(list) # Maps query to positive passages
|
|
|
|
# For asynchronous index updates
|
|
self.update_queue = queue.Queue()
|
|
self.update_thread = None
|
|
self.stop_update_thread = threading.Event()
|
|
self.current_step = 0
|
|
|
|
# Statistics
|
|
self.stats = {
|
|
'total_queries': 0,
|
|
'total_negatives_mined': 0,
|
|
'index_updates': 0,
|
|
'cache_hits': 0,
|
|
'cache_misses': 0
|
|
}
|
|
|
|
def _create_index(self, index_type: str, nlist: int) -> faiss.Index:
|
|
"""Create FAISS index, warn if falling back to CPU."""
|
|
if index_type == "IVF":
|
|
quantizer = faiss.IndexFlatIP(self.embedding_dim)
|
|
index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT)
|
|
elif index_type == "HNSW":
|
|
index = faiss.IndexHNSWFlat(self.embedding_dim, 32, faiss.METRIC_INNER_PRODUCT)
|
|
else:
|
|
index = faiss.IndexFlatIP(self.embedding_dim)
|
|
if self.use_gpu:
|
|
try:
|
|
res = faiss.StandardGpuResources() # type: ignore[attr-defined]
|
|
index = faiss.index_cpu_to_gpu(res, 0, index) # type: ignore[attr-defined]
|
|
logger.info("Using GPU for FAISS index")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to use GPU for FAISS: {e}. Using CPU.")
|
|
self.use_gpu = False
|
|
else:
|
|
logger.warning("FAISS is running on CPU. This may be slower than GPU.")
|
|
return index
|
|
|
|
def _is_chinese(self, text: str) -> bool:
|
|
"""Check if text contains Chinese characters"""
|
|
for char in text:
|
|
if '\u4e00' <= char <= '\u9fff':
|
|
return True
|
|
return False
|
|
|
|
def start_async_updates(self):
|
|
"""Start background thread for asynchronous index updates"""
|
|
if self.update_thread is None or not self.update_thread.is_alive():
|
|
self.stop_update_thread.clear()
|
|
self.update_thread = threading.Thread(target=self._update_worker)
|
|
self.update_thread.daemon = True
|
|
self.update_thread.start()
|
|
logger.info("Started asynchronous index update thread")
|
|
|
|
def stop_async_updates(self):
|
|
"""Stop background update thread and ensure clean shutdown."""
|
|
if self.update_thread and self.update_thread.is_alive():
|
|
self.stop_update_thread.set()
|
|
self.update_thread.join(timeout=5)
|
|
logger.info("Stopped asynchronous index update thread")
|
|
|
|
def _update_worker(self):
|
|
"""Background worker for index updates"""
|
|
while not self.stop_update_thread.is_set():
|
|
try:
|
|
# Get update task with timeout
|
|
task = self.update_queue.get(timeout=1)
|
|
|
|
if task['type'] == 'add_embeddings':
|
|
self._add_to_index(task['embeddings'], task['passages'])
|
|
elif task['type'] == 'rebuild_index':
|
|
self._rebuild_index(task['model'])
|
|
|
|
self.stats['index_updates'] += 1
|
|
|
|
except queue.Empty:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in update worker: {e}")
|
|
|
|
def _add_to_index(self, embeddings: np.ndarray, passages: List[str]):
|
|
"""Add embeddings to index"""
|
|
try:
|
|
# Normalize embeddings for inner product search
|
|
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
|
|
# Add to index
|
|
start_idx = len(self.passage_map)
|
|
|
|
# Debug log for embeddings
|
|
logger.debug(f"embeddings shape: {getattr(embeddings, 'shape', None)}, type: {type(embeddings)}")
|
|
|
|
if hasattr(self.index, 'is_trained') and not self.index.is_trained:
|
|
if embeddings is not None and len(embeddings) > 0:
|
|
if hasattr(self.index, 'nlist') and len(embeddings) >= getattr(self.index, 'nlist', 0):
|
|
# Safe to access self.index.nlist and call train
|
|
logger.info("Training FAISS index...")
|
|
self.index.train(embeddings) # type: ignore[call-arg]
|
|
elif hasattr(self.index, 'nlist'):
|
|
logger.warning(f"Not enough data to train index. Need {getattr(self.index, 'nlist', 0)}, got {len(embeddings)}")
|
|
return
|
|
else:
|
|
# For index types without nlist/train, skip training
|
|
pass
|
|
else:
|
|
logger.warning("Embeddings is None or empty, skipping training and addition to index.")
|
|
return
|
|
|
|
if embeddings is not None and len(embeddings) > 0:
|
|
self.index.add(embeddings) # type: ignore[call-arg]
|
|
else:
|
|
logger.warning("Embeddings is None or empty, skipping addition to index.")
|
|
|
|
# Update passage map
|
|
for i, passage in enumerate(passages):
|
|
self.passage_map[start_idx + i] = passage
|
|
|
|
logger.info(f"Added {len(passages)} passages to index. Total: {self.index.ntotal}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding embeddings to index: {e}")
|
|
|
|
def _rebuild_index(self, model):
|
|
"""Rebuild entire index with current model"""
|
|
try:
|
|
logger.info("Rebuilding entire index...")
|
|
|
|
# Clear current index
|
|
self.index.reset()
|
|
old_passage_map = self.passage_map.copy()
|
|
self.passage_map.clear()
|
|
|
|
# Re-encode all passages
|
|
all_passages = list(set(sum(self.query_map.values(), [])))
|
|
if not all_passages:
|
|
logger.warning("No passages to rebuild index with")
|
|
return
|
|
|
|
embeddings = []
|
|
batch_size = 32
|
|
|
|
for i in tqdm(range(0, len(all_passages), batch_size), desc="Encoding passages"):
|
|
batch = all_passages[i:i + batch_size]
|
|
try:
|
|
with torch.no_grad():
|
|
batch_embeddings = model.encode(batch, convert_to_numpy=True, batch_size=len(batch))
|
|
embeddings.append(batch_embeddings)
|
|
except Exception as e:
|
|
logger.error(f"Error encoding batch {i // batch_size}: {e}")
|
|
continue
|
|
|
|
if embeddings:
|
|
embeddings = np.vstack(embeddings)
|
|
self._add_to_index(embeddings, all_passages)
|
|
logger.info(f"Rebuilt index with {len(all_passages)} passages")
|
|
else:
|
|
# Restore old passage map if rebuild failed
|
|
self.passage_map = old_passage_map
|
|
logger.error("Failed to rebuild index - no embeddings generated")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error rebuilding index: {e}")
|
|
|
|
def mine_hard_negatives(
|
|
self,
|
|
queries: List[str],
|
|
query_embeddings: np.ndarray,
|
|
positive_passages: List[List[str]],
|
|
corpus_embeddings: Optional[np.ndarray] = None,
|
|
corpus_passages: Optional[List[str]] = None
|
|
) -> List[List[str]]:
|
|
"""
|
|
Mine hard negatives for queries using ANCE approach
|
|
|
|
Args:
|
|
queries: List of query texts
|
|
query_embeddings: Query embeddings [num_queries, embedding_dim]
|
|
positive_passages: List of positive passages for each query
|
|
corpus_embeddings: Optional corpus embeddings to add to index
|
|
corpus_passages: Optional corpus passages
|
|
|
|
Returns:
|
|
List of hard negative passages for each query
|
|
"""
|
|
# Add corpus to index if provided
|
|
if corpus_embeddings is not None and corpus_passages is not None:
|
|
self.update_queue.put({
|
|
'type': 'add_embeddings',
|
|
'embeddings': corpus_embeddings,
|
|
'passages': corpus_passages
|
|
})
|
|
|
|
# Update query map
|
|
for query, positives in zip(queries, positive_passages):
|
|
self.query_map[query].extend(positives)
|
|
|
|
# Normalize query embeddings
|
|
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
|
|
|
|
# Search for hard negatives
|
|
hard_negatives = []
|
|
k = min(self.negative_range[1], self.index.ntotal)
|
|
|
|
if k == 0:
|
|
logger.warning("Index is empty, returning empty hard negatives")
|
|
return [[] for _ in queries]
|
|
|
|
similarities, indices = self.index.search(query_embeddings, k) # type: ignore[call-arg]
|
|
|
|
for i, (query, positives) in enumerate(zip(queries, positive_passages)):
|
|
query_hard_negs = []
|
|
candidates = []
|
|
|
|
# Collect candidates within negative range
|
|
for j in range(self.negative_range[0], min(k, self.negative_range[1])):
|
|
if indices[i, j] >= 0: # Valid index
|
|
passage_idx = indices[i, j]
|
|
if passage_idx in self.passage_map:
|
|
passage = self.passage_map[passage_idx]
|
|
score = similarities[i, j]
|
|
|
|
# Skip if it's a positive passage
|
|
if passage not in positives:
|
|
candidates.append((passage, score))
|
|
|
|
# Select hard negatives with margin
|
|
if candidates:
|
|
# Sort by score (descending)
|
|
candidates.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
# Select with margin constraint
|
|
selected = []
|
|
last_score = float('inf')
|
|
for passage, score in candidates:
|
|
if last_score - score >= self.margin or len(selected) == 0:
|
|
selected.append(passage)
|
|
last_score = score
|
|
if len(selected) >= self.num_hard_negatives:
|
|
break
|
|
|
|
query_hard_negs = selected
|
|
|
|
# Pad with random negatives if needed
|
|
if len(query_hard_negs) < self.num_hard_negatives:
|
|
available = [p for p in self.passage_map.values()
|
|
if p not in positives and p not in query_hard_negs]
|
|
if available:
|
|
additional_count = min(len(available), self.num_hard_negatives - len(query_hard_negs))
|
|
additional = np.random.choice(available, size=additional_count, replace=False).tolist()
|
|
query_hard_negs.extend(additional)
|
|
|
|
hard_negatives.append(query_hard_negs)
|
|
|
|
self.stats['total_queries'] += len(queries)
|
|
self.stats['total_negatives_mined'] += sum(len(negs) for negs in hard_negatives)
|
|
|
|
return hard_negatives
|
|
|
|
def update_index_if_needed(self, step: int, model=None):
|
|
"""Check if index needs update and trigger if necessary"""
|
|
self.current_step = step
|
|
|
|
if step > 0 and step % self.index_refresh_interval == 0 and model is not None:
|
|
logger.info(f"Triggering index rebuild at step {step}")
|
|
self.update_queue.put({
|
|
'type': 'rebuild_index',
|
|
'model': model
|
|
})
|
|
|
|
def save_mined_data(self, output_path: str, original_data_path: str):
|
|
"""Save data with mined hard negatives"""
|
|
logger.info(f"Saving mined data to {output_path}")
|
|
|
|
# Load original data
|
|
original_data = []
|
|
with open(original_data_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
if line.strip():
|
|
original_data.append(json.loads(line))
|
|
|
|
# This is a simplified implementation - in practice you'd batch mine negatives
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
for item in tqdm(original_data, desc="Saving mined data"):
|
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
|
|
logger.info(f"Saved {len(original_data)} examples")
|
|
|
|
def get_statistics(self) -> Dict[str, int]:
|
|
"""Get mining statistics"""
|
|
self.stats['index_size'] = self.index.ntotal
|
|
self.stats['unique_passages'] = len(self.passage_map)
|
|
return self.stats.copy()
|
|
|
|
|
|
class HardNegativeDataAugmenter:
|
|
"""
|
|
Augment training data with various hard negative strategies
|
|
"""
|
|
|
|
def __init__(self, config: Dict):
|
|
self.config = config
|
|
self.strategies = config.get('strategies', ['ance', 'bm25', 'semantic'])
|
|
self.bm25_index = None
|
|
self.corpus_passages = None
|
|
|
|
def build_bm25_index(self, corpus: List[str]):
|
|
"""Build BM25 index for hard negative mining"""
|
|
try:
|
|
logger.info("Building BM25 index...")
|
|
self.corpus_passages = corpus
|
|
|
|
# Tokenize corpus
|
|
tokenized_corpus = []
|
|
for doc in tqdm(corpus, desc="Tokenizing corpus"):
|
|
if self._is_chinese(doc) and JIEBA_AVAILABLE:
|
|
tokens = list(jieba.cut(doc))
|
|
else:
|
|
tokens = doc.lower().split()
|
|
tokenized_corpus.append(tokens)
|
|
|
|
if BM25_AVAILABLE:
|
|
self.bm25_index = BM25Okapi(tokenized_corpus)
|
|
logger.info(f"Built BM25 index with {len(corpus)} documents")
|
|
else:
|
|
logger.warning("BM25Okapi not available, skipping BM25 index")
|
|
self.bm25_index = None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building BM25 index: {e}")
|
|
self.bm25_index = None
|
|
|
|
def _is_chinese(self, text: str) -> bool:
|
|
"""Check if text contains Chinese characters"""
|
|
for char in text:
|
|
if '\u4e00' <= char <= '\u9fff':
|
|
return True
|
|
return False
|
|
|
|
def augment_data(
|
|
self,
|
|
data: List[Dict],
|
|
model=None,
|
|
corpus: Optional[List[str]] = None
|
|
) -> List[Dict]:
|
|
"""Apply multiple hard negative augmentation strategies"""
|
|
|
|
# Build BM25 index if corpus provided
|
|
if corpus and 'bm25' in self.strategies:
|
|
self.build_bm25_index(corpus)
|
|
|
|
augmented_data = []
|
|
|
|
for item in tqdm(data, desc="Augmenting data"):
|
|
query = item['query']
|
|
positives = item.get('pos', [])
|
|
negatives = item.get('neg', [])
|
|
|
|
# Apply different strategies
|
|
if 'bm25' in self.strategies and self.bm25_index:
|
|
bm25_negs = self._mine_bm25_negatives(query, positives)
|
|
negatives.extend(bm25_negs[:5]) # Add top 5 BM25 negatives
|
|
|
|
if 'semantic' in self.strategies and model:
|
|
semantic_negs = self._mine_semantic_negatives(query, positives, model)
|
|
negatives.extend(semantic_negs[:5])
|
|
|
|
if 'random_swap' in self.strategies:
|
|
swap_negs = self._create_swap_negatives(positives)
|
|
negatives.extend(swap_negs[:3])
|
|
|
|
# Remove duplicates while preserving order
|
|
seen = set()
|
|
unique_negatives = []
|
|
for neg in negatives:
|
|
if neg not in seen and neg not in positives:
|
|
seen.add(neg)
|
|
unique_negatives.append(neg)
|
|
|
|
augmented_item = item.copy()
|
|
augmented_item['neg'] = unique_negatives[:self.config.get('max_negatives', 30)]
|
|
augmented_data.append(augmented_item)
|
|
|
|
return augmented_data
|
|
|
|
def _mine_bm25_negatives(self, query: str, positives: List[str]) -> List[str]:
|
|
"""Mine hard negatives using BM25"""
|
|
if not self.bm25_index or not self.corpus_passages or not BM25_AVAILABLE:
|
|
return []
|
|
|
|
# Tokenize query
|
|
if self._is_chinese(query) and JIEBA_AVAILABLE:
|
|
query_tokens = list(jieba.cut(query))
|
|
else:
|
|
query_tokens = query.lower().split()
|
|
|
|
# Get BM25 scores
|
|
scores = self.bm25_index.get_scores(query_tokens)
|
|
|
|
# Get top passages excluding positives
|
|
top_indices = np.argsort(scores)[::-1]
|
|
negatives = []
|
|
|
|
for idx in top_indices:
|
|
if len(negatives) >= 10:
|
|
break
|
|
passage = self.corpus_passages[idx]
|
|
if passage not in positives:
|
|
negatives.append(passage)
|
|
|
|
return negatives
|
|
|
|
def _mine_semantic_negatives(self, query: str, positives: List[str], model) -> List[str]:
|
|
"""
|
|
Mine semantically similar but irrelevant passages using the model's encode method.
|
|
Returns top-k most similar passages (excluding positives).
|
|
"""
|
|
try:
|
|
if not self.corpus_passages or not model:
|
|
logger.warning("No corpus or model provided for semantic negative mining.")
|
|
return []
|
|
|
|
batch_size = self.config.get('semantic_batch_size', 64)
|
|
num_negatives = self.config.get('semantic_num_negatives', 10)
|
|
device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Encode query
|
|
query_embedding = model.encode([query], convert_to_numpy=True, device=device)
|
|
# Encode all corpus passages in batches
|
|
passage_embeddings = []
|
|
for i in range(0, len(self.corpus_passages), batch_size):
|
|
batch = self.corpus_passages[i:i+batch_size]
|
|
batch_emb = model.encode(batch, convert_to_numpy=True, device=device)
|
|
passage_embeddings.append(batch_emb)
|
|
passage_embeddings = np.vstack(passage_embeddings)
|
|
|
|
# Compute similarity
|
|
scores = np.dot(passage_embeddings, query_embedding[0])
|
|
# Exclude positives
|
|
candidates = [(idx, score) for idx, (score, passage) in enumerate(zip(scores, self.corpus_passages)) if passage not in positives]
|
|
# Sort and select top-k
|
|
candidates.sort(key=lambda x: x[1], reverse=True)
|
|
top_indices = [idx for idx, _ in candidates[:num_negatives]]
|
|
negatives = [self.corpus_passages[idx] for idx in top_indices]
|
|
logger.info(f"Mined {len(negatives)} semantic negatives for query: {query}")
|
|
return negatives
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error mining semantic negatives: {e}")
|
|
return []
|
|
|
|
def _create_swap_negatives(self, positives: List[str]) -> List[str]:
|
|
"""Create negatives by swapping words in positive passages"""
|
|
negatives = []
|
|
|
|
for pos in positives[:2]: # Only use first 2 positives
|
|
words = pos.split()
|
|
if len(words) > 5:
|
|
# Swap random words
|
|
import random
|
|
swapped = words.copy()
|
|
for _ in range(min(3, len(words) // 3)):
|
|
i, j = random.sample(range(len(words)), 2)
|
|
swapped[i], swapped[j] = swapped[j], swapped[i]
|
|
negatives.append(' '.join(swapped))
|
|
|
|
return negatives
|