bge_finetune/data/hard_negative_mining.py
2025-07-22 16:55:25 +08:00

535 lines
21 KiB
Python

"""
ANCE-style hard negative mining for BGE models
"""
import os
import json
import logging
import threading
import queue
from typing import List, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import faiss
# Optional imports for text processing
try:
import jieba
JIEBA_AVAILABLE = True
except ImportError:
JIEBA_AVAILABLE = False
jieba = None
try:
from rank_bm25 import BM25Okapi
BM25_AVAILABLE = True
except ImportError:
BM25_AVAILABLE = False
BM25Okapi = None
from collections import defaultdict
logger = logging.getLogger(__name__)
class ANCEHardNegativeMiner:
"""
Approximate Nearest Neighbor Negative Contrastive Learning
Dynamically mines hard negatives using evolving model
Notes:
- All parameters (embedding_dim, num_hard_negatives, negative_range, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
embedding_dim: int = 768,
num_hard_negatives: int = 15,
negative_range: Tuple[int, int] = (10, 100),
margin: float = 0.1,
index_refresh_interval: int = 1000,
index_type: str = "IVF",
nlist: int = 100,
use_gpu: bool = True,
max_index_size: int = 1000000
):
self.embedding_dim = embedding_dim
self.num_hard_negatives = num_hard_negatives
self.negative_range = negative_range
self.margin = margin
self.index_refresh_interval = index_refresh_interval
self.use_gpu = use_gpu and torch.cuda.is_available()
self.max_index_size = max_index_size
# Initialize FAISS index
self.index = self._create_index(index_type, nlist)
self.passage_map: Dict[int, str] = {} # Maps index ID to passage text
self.query_map: Dict[str, List[str]] = defaultdict(list) # Maps query to positive passages
# For asynchronous index updates
self.update_queue = queue.Queue()
self.update_thread = None
self.stop_update_thread = threading.Event()
self.current_step = 0
# Statistics
self.stats = {
'total_queries': 0,
'total_negatives_mined': 0,
'index_updates': 0,
'cache_hits': 0,
'cache_misses': 0
}
def _create_index(self, index_type: str, nlist: int) -> faiss.Index:
"""Create FAISS index, warn if falling back to CPU."""
if index_type == "IVF":
quantizer = faiss.IndexFlatIP(self.embedding_dim)
index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT)
elif index_type == "HNSW":
index = faiss.IndexHNSWFlat(self.embedding_dim, 32, faiss.METRIC_INNER_PRODUCT)
else:
index = faiss.IndexFlatIP(self.embedding_dim)
if self.use_gpu:
try:
res = faiss.StandardGpuResources() # type: ignore[attr-defined]
index = faiss.index_cpu_to_gpu(res, 0, index) # type: ignore[attr-defined]
logger.info("Using GPU for FAISS index")
except Exception as e:
logger.warning(f"Failed to use GPU for FAISS: {e}. Using CPU.")
self.use_gpu = False
else:
logger.warning("FAISS is running on CPU. This may be slower than GPU.")
return index
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
return True
return False
def start_async_updates(self):
"""Start background thread for asynchronous index updates"""
if self.update_thread is None or not self.update_thread.is_alive():
self.stop_update_thread.clear()
self.update_thread = threading.Thread(target=self._update_worker)
self.update_thread.daemon = True
self.update_thread.start()
logger.info("Started asynchronous index update thread")
def stop_async_updates(self):
"""Stop background update thread and ensure clean shutdown."""
if self.update_thread and self.update_thread.is_alive():
self.stop_update_thread.set()
self.update_thread.join(timeout=5)
logger.info("Stopped asynchronous index update thread")
def _update_worker(self):
"""Background worker for index updates"""
while not self.stop_update_thread.is_set():
try:
# Get update task with timeout
task = self.update_queue.get(timeout=1)
if task['type'] == 'add_embeddings':
self._add_to_index(task['embeddings'], task['passages'])
elif task['type'] == 'rebuild_index':
self._rebuild_index(task['model'])
self.stats['index_updates'] += 1
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error in update worker: {e}")
def _add_to_index(self, embeddings: np.ndarray, passages: List[str]):
"""Add embeddings to index"""
try:
# Normalize embeddings for inner product search
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Add to index
start_idx = len(self.passage_map)
# Debug log for embeddings
logger.debug(f"embeddings shape: {getattr(embeddings, 'shape', None)}, type: {type(embeddings)}")
if hasattr(self.index, 'is_trained') and not self.index.is_trained:
if embeddings is not None and len(embeddings) > 0:
if hasattr(self.index, 'nlist') and len(embeddings) >= getattr(self.index, 'nlist', 0):
# Safe to access self.index.nlist and call train
logger.info("Training FAISS index...")
self.index.train(embeddings) # type: ignore[call-arg]
elif hasattr(self.index, 'nlist'):
logger.warning(f"Not enough data to train index. Need {getattr(self.index, 'nlist', 0)}, got {len(embeddings)}")
return
else:
# For index types without nlist/train, skip training
pass
else:
logger.warning("Embeddings is None or empty, skipping training and addition to index.")
return
if embeddings is not None and len(embeddings) > 0:
self.index.add(embeddings) # type: ignore[call-arg]
else:
logger.warning("Embeddings is None or empty, skipping addition to index.")
# Update passage map
for i, passage in enumerate(passages):
self.passage_map[start_idx + i] = passage
logger.info(f"Added {len(passages)} passages to index. Total: {self.index.ntotal}")
except Exception as e:
logger.error(f"Error adding embeddings to index: {e}")
def _rebuild_index(self, model):
"""Rebuild entire index with current model"""
try:
logger.info("Rebuilding entire index...")
# Clear current index
self.index.reset()
old_passage_map = self.passage_map.copy()
self.passage_map.clear()
# Re-encode all passages
all_passages = list(set(sum(self.query_map.values(), [])))
if not all_passages:
logger.warning("No passages to rebuild index with")
return
embeddings = []
batch_size = 32
for i in tqdm(range(0, len(all_passages), batch_size), desc="Encoding passages"):
batch = all_passages[i:i + batch_size]
try:
with torch.no_grad():
batch_embeddings = model.encode(batch, convert_to_numpy=True, batch_size=len(batch))
embeddings.append(batch_embeddings)
except Exception as e:
logger.error(f"Error encoding batch {i // batch_size}: {e}")
continue
if embeddings:
embeddings = np.vstack(embeddings)
self._add_to_index(embeddings, all_passages)
logger.info(f"Rebuilt index with {len(all_passages)} passages")
else:
# Restore old passage map if rebuild failed
self.passage_map = old_passage_map
logger.error("Failed to rebuild index - no embeddings generated")
except Exception as e:
logger.error(f"Error rebuilding index: {e}")
def mine_hard_negatives(
self,
queries: List[str],
query_embeddings: np.ndarray,
positive_passages: List[List[str]],
corpus_embeddings: Optional[np.ndarray] = None,
corpus_passages: Optional[List[str]] = None
) -> List[List[str]]:
"""
Mine hard negatives for queries using ANCE approach
Args:
queries: List of query texts
query_embeddings: Query embeddings [num_queries, embedding_dim]
positive_passages: List of positive passages for each query
corpus_embeddings: Optional corpus embeddings to add to index
corpus_passages: Optional corpus passages
Returns:
List of hard negative passages for each query
"""
# Add corpus to index if provided
if corpus_embeddings is not None and corpus_passages is not None:
self.update_queue.put({
'type': 'add_embeddings',
'embeddings': corpus_embeddings,
'passages': corpus_passages
})
# Update query map
for query, positives in zip(queries, positive_passages):
self.query_map[query].extend(positives)
# Normalize query embeddings
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
# Search for hard negatives
hard_negatives = []
k = min(self.negative_range[1], self.index.ntotal)
if k == 0:
logger.warning("Index is empty, returning empty hard negatives")
return [[] for _ in queries]
similarities, indices = self.index.search(query_embeddings, k) # type: ignore[call-arg]
for i, (query, positives) in enumerate(zip(queries, positive_passages)):
query_hard_negs = []
candidates = []
# Collect candidates within negative range
for j in range(self.negative_range[0], min(k, self.negative_range[1])):
if indices[i, j] >= 0: # Valid index
passage_idx = indices[i, j]
if passage_idx in self.passage_map:
passage = self.passage_map[passage_idx]
score = similarities[i, j]
# Skip if it's a positive passage
if passage not in positives:
candidates.append((passage, score))
# Select hard negatives with margin
if candidates:
# Sort by score (descending)
candidates.sort(key=lambda x: x[1], reverse=True)
# Select with margin constraint
selected = []
last_score = float('inf')
for passage, score in candidates:
if last_score - score >= self.margin or len(selected) == 0:
selected.append(passage)
last_score = score
if len(selected) >= self.num_hard_negatives:
break
query_hard_negs = selected
# Pad with random negatives if needed
if len(query_hard_negs) < self.num_hard_negatives:
available = [p for p in self.passage_map.values()
if p not in positives and p not in query_hard_negs]
if available:
additional_count = min(len(available), self.num_hard_negatives - len(query_hard_negs))
additional = np.random.choice(available, size=additional_count, replace=False).tolist()
query_hard_negs.extend(additional)
hard_negatives.append(query_hard_negs)
self.stats['total_queries'] += len(queries)
self.stats['total_negatives_mined'] += sum(len(negs) for negs in hard_negatives)
return hard_negatives
def update_index_if_needed(self, step: int, model=None):
"""Check if index needs update and trigger if necessary"""
self.current_step = step
if step > 0 and step % self.index_refresh_interval == 0 and model is not None:
logger.info(f"Triggering index rebuild at step {step}")
self.update_queue.put({
'type': 'rebuild_index',
'model': model
})
def save_mined_data(self, output_path: str, original_data_path: str):
"""Save data with mined hard negatives"""
logger.info(f"Saving mined data to {output_path}")
# Load original data
original_data = []
with open(original_data_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
original_data.append(json.loads(line))
# This is a simplified implementation - in practice you'd batch mine negatives
with open(output_path, 'w', encoding='utf-8') as f:
for item in tqdm(original_data, desc="Saving mined data"):
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logger.info(f"Saved {len(original_data)} examples")
def get_statistics(self) -> Dict[str, int]:
"""Get mining statistics"""
self.stats['index_size'] = self.index.ntotal
self.stats['unique_passages'] = len(self.passage_map)
return self.stats.copy()
class HardNegativeDataAugmenter:
"""
Augment training data with various hard negative strategies
"""
def __init__(self, config: Dict):
self.config = config
self.strategies = config.get('strategies', ['ance', 'bm25', 'semantic'])
self.bm25_index = None
self.corpus_passages = None
def build_bm25_index(self, corpus: List[str]):
"""Build BM25 index for hard negative mining"""
try:
logger.info("Building BM25 index...")
self.corpus_passages = corpus
# Tokenize corpus
tokenized_corpus = []
for doc in tqdm(corpus, desc="Tokenizing corpus"):
if self._is_chinese(doc) and JIEBA_AVAILABLE:
tokens = list(jieba.cut(doc))
else:
tokens = doc.lower().split()
tokenized_corpus.append(tokens)
if BM25_AVAILABLE:
self.bm25_index = BM25Okapi(tokenized_corpus)
logger.info(f"Built BM25 index with {len(corpus)} documents")
else:
logger.warning("BM25Okapi not available, skipping BM25 index")
self.bm25_index = None
except Exception as e:
logger.error(f"Error building BM25 index: {e}")
self.bm25_index = None
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
return True
return False
def augment_data(
self,
data: List[Dict],
model=None,
corpus: Optional[List[str]] = None
) -> List[Dict]:
"""Apply multiple hard negative augmentation strategies"""
# Build BM25 index if corpus provided
if corpus and 'bm25' in self.strategies:
self.build_bm25_index(corpus)
augmented_data = []
for item in tqdm(data, desc="Augmenting data"):
query = item['query']
positives = item.get('pos', [])
negatives = item.get('neg', [])
# Apply different strategies
if 'bm25' in self.strategies and self.bm25_index:
bm25_negs = self._mine_bm25_negatives(query, positives)
negatives.extend(bm25_negs[:5]) # Add top 5 BM25 negatives
if 'semantic' in self.strategies and model:
semantic_negs = self._mine_semantic_negatives(query, positives, model)
negatives.extend(semantic_negs[:5])
if 'random_swap' in self.strategies:
swap_negs = self._create_swap_negatives(positives)
negatives.extend(swap_negs[:3])
# Remove duplicates while preserving order
seen = set()
unique_negatives = []
for neg in negatives:
if neg not in seen and neg not in positives:
seen.add(neg)
unique_negatives.append(neg)
augmented_item = item.copy()
augmented_item['neg'] = unique_negatives[:self.config.get('max_negatives', 30)]
augmented_data.append(augmented_item)
return augmented_data
def _mine_bm25_negatives(self, query: str, positives: List[str]) -> List[str]:
"""Mine hard negatives using BM25"""
if not self.bm25_index or not self.corpus_passages or not BM25_AVAILABLE:
return []
# Tokenize query
if self._is_chinese(query) and JIEBA_AVAILABLE:
query_tokens = list(jieba.cut(query))
else:
query_tokens = query.lower().split()
# Get BM25 scores
scores = self.bm25_index.get_scores(query_tokens)
# Get top passages excluding positives
top_indices = np.argsort(scores)[::-1]
negatives = []
for idx in top_indices:
if len(negatives) >= 10:
break
passage = self.corpus_passages[idx]
if passage not in positives:
negatives.append(passage)
return negatives
def _mine_semantic_negatives(self, query: str, positives: List[str], model) -> List[str]:
"""
Mine semantically similar but irrelevant passages using the model's encode method.
Returns top-k most similar passages (excluding positives).
"""
try:
if not self.corpus_passages or not model:
logger.warning("No corpus or model provided for semantic negative mining.")
return []
batch_size = self.config.get('semantic_batch_size', 64)
num_negatives = self.config.get('semantic_num_negatives', 10)
device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
# Encode query
query_embedding = model.encode([query], convert_to_numpy=True, device=device)
# Encode all corpus passages in batches
passage_embeddings = []
for i in range(0, len(self.corpus_passages), batch_size):
batch = self.corpus_passages[i:i+batch_size]
batch_emb = model.encode(batch, convert_to_numpy=True, device=device)
passage_embeddings.append(batch_emb)
passage_embeddings = np.vstack(passage_embeddings)
# Compute similarity
scores = np.dot(passage_embeddings, query_embedding[0])
# Exclude positives
candidates = [(idx, score) for idx, (score, passage) in enumerate(zip(scores, self.corpus_passages)) if passage not in positives]
# Sort and select top-k
candidates.sort(key=lambda x: x[1], reverse=True)
top_indices = [idx for idx, _ in candidates[:num_negatives]]
negatives = [self.corpus_passages[idx] for idx in top_indices]
logger.info(f"Mined {len(negatives)} semantic negatives for query: {query}")
return negatives
except Exception as e:
logger.error(f"Error mining semantic negatives: {e}")
return []
def _create_swap_negatives(self, positives: List[str]) -> List[str]:
"""Create negatives by swapping words in positive passages"""
negatives = []
for pos in positives[:2]: # Only use first 2 positives
words = pos.split()
if len(words) > 5:
# Swap random words
import random
swapped = words.copy()
for _ in range(min(3, len(words) // 3)):
i, j = random.sample(range(len(words)), 2)
swapped[i], swapped[j] = swapped[j], swapped[i]
negatives.append(' '.join(swapped))
return negatives