""" ANCE-style hard negative mining for BGE models """ import os import json import logging import threading import queue from typing import List, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm import faiss # Optional imports for text processing try: import jieba JIEBA_AVAILABLE = True except ImportError: JIEBA_AVAILABLE = False jieba = None try: from rank_bm25 import BM25Okapi BM25_AVAILABLE = True except ImportError: BM25_AVAILABLE = False BM25Okapi = None from collections import defaultdict logger = logging.getLogger(__name__) class ANCEHardNegativeMiner: """ Approximate Nearest Neighbor Negative Contrastive Learning Dynamically mines hard negatives using evolving model Notes: - All parameters (embedding_dim, num_hard_negatives, negative_range, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, embedding_dim: int = 768, num_hard_negatives: int = 15, negative_range: Tuple[int, int] = (10, 100), margin: float = 0.1, index_refresh_interval: int = 1000, index_type: str = "IVF", nlist: int = 100, use_gpu: bool = True, max_index_size: int = 1000000 ): self.embedding_dim = embedding_dim self.num_hard_negatives = num_hard_negatives self.negative_range = negative_range self.margin = margin self.index_refresh_interval = index_refresh_interval self.use_gpu = use_gpu and torch.cuda.is_available() self.max_index_size = max_index_size # Initialize FAISS index self.index = self._create_index(index_type, nlist) self.passage_map: Dict[int, str] = {} # Maps index ID to passage text self.query_map: Dict[str, List[str]] = defaultdict(list) # Maps query to positive passages # For asynchronous index updates self.update_queue = queue.Queue() self.update_thread = None self.stop_update_thread = threading.Event() self.current_step = 0 # Statistics self.stats = { 'total_queries': 0, 'total_negatives_mined': 0, 'index_updates': 0, 'cache_hits': 0, 'cache_misses': 0 } def _create_index(self, index_type: str, nlist: int) -> faiss.Index: """Create FAISS index, warn if falling back to CPU.""" if index_type == "IVF": quantizer = faiss.IndexFlatIP(self.embedding_dim) index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT) elif index_type == "HNSW": index = faiss.IndexHNSWFlat(self.embedding_dim, 32, faiss.METRIC_INNER_PRODUCT) else: index = faiss.IndexFlatIP(self.embedding_dim) if self.use_gpu: try: res = faiss.StandardGpuResources() # type: ignore[attr-defined] index = faiss.index_cpu_to_gpu(res, 0, index) # type: ignore[attr-defined] logger.info("Using GPU for FAISS index") except Exception as e: logger.warning(f"Failed to use GPU for FAISS: {e}. Using CPU.") self.use_gpu = False else: logger.warning("FAISS is running on CPU. This may be slower than GPU.") return index def _is_chinese(self, text: str) -> bool: """Check if text contains Chinese characters""" for char in text: if '\u4e00' <= char <= '\u9fff': return True return False def start_async_updates(self): """Start background thread for asynchronous index updates""" if self.update_thread is None or not self.update_thread.is_alive(): self.stop_update_thread.clear() self.update_thread = threading.Thread(target=self._update_worker) self.update_thread.daemon = True self.update_thread.start() logger.info("Started asynchronous index update thread") def stop_async_updates(self): """Stop background update thread and ensure clean shutdown.""" if self.update_thread and self.update_thread.is_alive(): self.stop_update_thread.set() self.update_thread.join(timeout=5) logger.info("Stopped asynchronous index update thread") def _update_worker(self): """Background worker for index updates""" while not self.stop_update_thread.is_set(): try: # Get update task with timeout task = self.update_queue.get(timeout=1) if task['type'] == 'add_embeddings': self._add_to_index(task['embeddings'], task['passages']) elif task['type'] == 'rebuild_index': self._rebuild_index(task['model']) self.stats['index_updates'] += 1 except queue.Empty: continue except Exception as e: logger.error(f"Error in update worker: {e}") def _add_to_index(self, embeddings: np.ndarray, passages: List[str]): """Add embeddings to index""" try: # Normalize embeddings for inner product search embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # Add to index start_idx = len(self.passage_map) # Debug log for embeddings logger.debug(f"embeddings shape: {getattr(embeddings, 'shape', None)}, type: {type(embeddings)}") if hasattr(self.index, 'is_trained') and not self.index.is_trained: if embeddings is not None and len(embeddings) > 0: if hasattr(self.index, 'nlist') and len(embeddings) >= getattr(self.index, 'nlist', 0): # Safe to access self.index.nlist and call train logger.info("Training FAISS index...") self.index.train(embeddings) # type: ignore[call-arg] elif hasattr(self.index, 'nlist'): logger.warning(f"Not enough data to train index. Need {getattr(self.index, 'nlist', 0)}, got {len(embeddings)}") return else: # For index types without nlist/train, skip training pass else: logger.warning("Embeddings is None or empty, skipping training and addition to index.") return if embeddings is not None and len(embeddings) > 0: self.index.add(embeddings) # type: ignore[call-arg] else: logger.warning("Embeddings is None or empty, skipping addition to index.") # Update passage map for i, passage in enumerate(passages): self.passage_map[start_idx + i] = passage logger.info(f"Added {len(passages)} passages to index. Total: {self.index.ntotal}") except Exception as e: logger.error(f"Error adding embeddings to index: {e}") def _rebuild_index(self, model): """Rebuild entire index with current model""" try: logger.info("Rebuilding entire index...") # Clear current index self.index.reset() old_passage_map = self.passage_map.copy() self.passage_map.clear() # Re-encode all passages all_passages = list(set(sum(self.query_map.values(), []))) if not all_passages: logger.warning("No passages to rebuild index with") return embeddings = [] batch_size = 32 for i in tqdm(range(0, len(all_passages), batch_size), desc="Encoding passages"): batch = all_passages[i:i + batch_size] try: with torch.no_grad(): batch_embeddings = model.encode(batch, convert_to_numpy=True, batch_size=len(batch)) embeddings.append(batch_embeddings) except Exception as e: logger.error(f"Error encoding batch {i // batch_size}: {e}") continue if embeddings: embeddings = np.vstack(embeddings) self._add_to_index(embeddings, all_passages) logger.info(f"Rebuilt index with {len(all_passages)} passages") else: # Restore old passage map if rebuild failed self.passage_map = old_passage_map logger.error("Failed to rebuild index - no embeddings generated") except Exception as e: logger.error(f"Error rebuilding index: {e}") def mine_hard_negatives( self, queries: List[str], query_embeddings: np.ndarray, positive_passages: List[List[str]], corpus_embeddings: Optional[np.ndarray] = None, corpus_passages: Optional[List[str]] = None ) -> List[List[str]]: """ Mine hard negatives for queries using ANCE approach Args: queries: List of query texts query_embeddings: Query embeddings [num_queries, embedding_dim] positive_passages: List of positive passages for each query corpus_embeddings: Optional corpus embeddings to add to index corpus_passages: Optional corpus passages Returns: List of hard negative passages for each query """ # Add corpus to index if provided if corpus_embeddings is not None and corpus_passages is not None: self.update_queue.put({ 'type': 'add_embeddings', 'embeddings': corpus_embeddings, 'passages': corpus_passages }) # Update query map for query, positives in zip(queries, positive_passages): self.query_map[query].extend(positives) # Normalize query embeddings query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True) # Search for hard negatives hard_negatives = [] k = min(self.negative_range[1], self.index.ntotal) if k == 0: logger.warning("Index is empty, returning empty hard negatives") return [[] for _ in queries] similarities, indices = self.index.search(query_embeddings, k) # type: ignore[call-arg] for i, (query, positives) in enumerate(zip(queries, positive_passages)): query_hard_negs = [] candidates = [] # Collect candidates within negative range for j in range(self.negative_range[0], min(k, self.negative_range[1])): if indices[i, j] >= 0: # Valid index passage_idx = indices[i, j] if passage_idx in self.passage_map: passage = self.passage_map[passage_idx] score = similarities[i, j] # Skip if it's a positive passage if passage not in positives: candidates.append((passage, score)) # Select hard negatives with margin if candidates: # Sort by score (descending) candidates.sort(key=lambda x: x[1], reverse=True) # Select with margin constraint selected = [] last_score = float('inf') for passage, score in candidates: if last_score - score >= self.margin or len(selected) == 0: selected.append(passage) last_score = score if len(selected) >= self.num_hard_negatives: break query_hard_negs = selected # Pad with random negatives if needed if len(query_hard_negs) < self.num_hard_negatives: available = [p for p in self.passage_map.values() if p not in positives and p not in query_hard_negs] if available: additional_count = min(len(available), self.num_hard_negatives - len(query_hard_negs)) additional = np.random.choice(available, size=additional_count, replace=False).tolist() query_hard_negs.extend(additional) hard_negatives.append(query_hard_negs) self.stats['total_queries'] += len(queries) self.stats['total_negatives_mined'] += sum(len(negs) for negs in hard_negatives) return hard_negatives def update_index_if_needed(self, step: int, model=None): """Check if index needs update and trigger if necessary""" self.current_step = step if step > 0 and step % self.index_refresh_interval == 0 and model is not None: logger.info(f"Triggering index rebuild at step {step}") self.update_queue.put({ 'type': 'rebuild_index', 'model': model }) def save_mined_data(self, output_path: str, original_data_path: str): """Save data with mined hard negatives""" logger.info(f"Saving mined data to {output_path}") # Load original data original_data = [] with open(original_data_path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): original_data.append(json.loads(line)) # This is a simplified implementation - in practice you'd batch mine negatives with open(output_path, 'w', encoding='utf-8') as f: for item in tqdm(original_data, desc="Saving mined data"): f.write(json.dumps(item, ensure_ascii=False) + '\n') logger.info(f"Saved {len(original_data)} examples") def get_statistics(self) -> Dict[str, int]: """Get mining statistics""" self.stats['index_size'] = self.index.ntotal self.stats['unique_passages'] = len(self.passage_map) return self.stats.copy() class HardNegativeDataAugmenter: """ Augment training data with various hard negative strategies """ def __init__(self, config: Dict): self.config = config self.strategies = config.get('strategies', ['ance', 'bm25', 'semantic']) self.bm25_index = None self.corpus_passages = None def build_bm25_index(self, corpus: List[str]): """Build BM25 index for hard negative mining""" try: logger.info("Building BM25 index...") self.corpus_passages = corpus # Tokenize corpus tokenized_corpus = [] for doc in tqdm(corpus, desc="Tokenizing corpus"): if self._is_chinese(doc) and JIEBA_AVAILABLE: tokens = list(jieba.cut(doc)) else: tokens = doc.lower().split() tokenized_corpus.append(tokens) if BM25_AVAILABLE: self.bm25_index = BM25Okapi(tokenized_corpus) logger.info(f"Built BM25 index with {len(corpus)} documents") else: logger.warning("BM25Okapi not available, skipping BM25 index") self.bm25_index = None except Exception as e: logger.error(f"Error building BM25 index: {e}") self.bm25_index = None def _is_chinese(self, text: str) -> bool: """Check if text contains Chinese characters""" for char in text: if '\u4e00' <= char <= '\u9fff': return True return False def augment_data( self, data: List[Dict], model=None, corpus: Optional[List[str]] = None ) -> List[Dict]: """Apply multiple hard negative augmentation strategies""" # Build BM25 index if corpus provided if corpus and 'bm25' in self.strategies: self.build_bm25_index(corpus) augmented_data = [] for item in tqdm(data, desc="Augmenting data"): query = item['query'] positives = item.get('pos', []) negatives = item.get('neg', []) # Apply different strategies if 'bm25' in self.strategies and self.bm25_index: bm25_negs = self._mine_bm25_negatives(query, positives) negatives.extend(bm25_negs[:5]) # Add top 5 BM25 negatives if 'semantic' in self.strategies and model: semantic_negs = self._mine_semantic_negatives(query, positives, model) negatives.extend(semantic_negs[:5]) if 'random_swap' in self.strategies: swap_negs = self._create_swap_negatives(positives) negatives.extend(swap_negs[:3]) # Remove duplicates while preserving order seen = set() unique_negatives = [] for neg in negatives: if neg not in seen and neg not in positives: seen.add(neg) unique_negatives.append(neg) augmented_item = item.copy() augmented_item['neg'] = unique_negatives[:self.config.get('max_negatives', 30)] augmented_data.append(augmented_item) return augmented_data def _mine_bm25_negatives(self, query: str, positives: List[str]) -> List[str]: """Mine hard negatives using BM25""" if not self.bm25_index or not self.corpus_passages or not BM25_AVAILABLE: return [] # Tokenize query if self._is_chinese(query) and JIEBA_AVAILABLE: query_tokens = list(jieba.cut(query)) else: query_tokens = query.lower().split() # Get BM25 scores scores = self.bm25_index.get_scores(query_tokens) # Get top passages excluding positives top_indices = np.argsort(scores)[::-1] negatives = [] for idx in top_indices: if len(negatives) >= 10: break passage = self.corpus_passages[idx] if passage not in positives: negatives.append(passage) return negatives def _mine_semantic_negatives(self, query: str, positives: List[str], model) -> List[str]: """ Mine semantically similar but irrelevant passages using the model's encode method. Returns top-k most similar passages (excluding positives). """ try: if not self.corpus_passages or not model: logger.warning("No corpus or model provided for semantic negative mining.") return [] batch_size = self.config.get('semantic_batch_size', 64) num_negatives = self.config.get('semantic_num_negatives', 10) device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') # Encode query query_embedding = model.encode([query], convert_to_numpy=True, device=device) # Encode all corpus passages in batches passage_embeddings = [] for i in range(0, len(self.corpus_passages), batch_size): batch = self.corpus_passages[i:i+batch_size] batch_emb = model.encode(batch, convert_to_numpy=True, device=device) passage_embeddings.append(batch_emb) passage_embeddings = np.vstack(passage_embeddings) # Compute similarity scores = np.dot(passage_embeddings, query_embedding[0]) # Exclude positives candidates = [(idx, score) for idx, (score, passage) in enumerate(zip(scores, self.corpus_passages)) if passage not in positives] # Sort and select top-k candidates.sort(key=lambda x: x[1], reverse=True) top_indices = [idx for idx, _ in candidates[:num_negatives]] negatives = [self.corpus_passages[idx] for idx in top_indices] logger.info(f"Mined {len(negatives)} semantic negatives for query: {query}") return negatives except Exception as e: logger.error(f"Error mining semantic negatives: {e}") return [] def _create_swap_negatives(self, positives: List[str]) -> List[str]: """Create negatives by swapping words in positive passages""" negatives = [] for pos in positives[:2]: # Only use first 2 positives words = pos.split() if len(words) > 5: # Swap random words import random swapped = words.copy() for _ in range(min(3, len(words) // 3)): i, j = random.sample(range(len(words)), 2) swapped[i], swapped[j] = swapped[j], swapped[i] negatives.append(' '.join(swapped)) return negatives