init commit
This commit is contained in:
534
data/hard_negative_mining.py
Normal file
534
data/hard_negative_mining.py
Normal file
@@ -0,0 +1,534 @@
|
||||
"""
|
||||
ANCE-style hard negative mining for BGE models
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import queue
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
import faiss
|
||||
|
||||
# Optional imports for text processing
|
||||
try:
|
||||
import jieba
|
||||
JIEBA_AVAILABLE = True
|
||||
except ImportError:
|
||||
JIEBA_AVAILABLE = False
|
||||
jieba = None
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
BM25_AVAILABLE = True
|
||||
except ImportError:
|
||||
BM25_AVAILABLE = False
|
||||
BM25Okapi = None
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ANCEHardNegativeMiner:
|
||||
"""
|
||||
Approximate Nearest Neighbor Negative Contrastive Learning
|
||||
Dynamically mines hard negatives using evolving model
|
||||
|
||||
Notes:
|
||||
- All parameters (embedding_dim, num_hard_negatives, negative_range, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
num_hard_negatives: int = 15,
|
||||
negative_range: Tuple[int, int] = (10, 100),
|
||||
margin: float = 0.1,
|
||||
index_refresh_interval: int = 1000,
|
||||
index_type: str = "IVF",
|
||||
nlist: int = 100,
|
||||
use_gpu: bool = True,
|
||||
max_index_size: int = 1000000
|
||||
):
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_hard_negatives = num_hard_negatives
|
||||
self.negative_range = negative_range
|
||||
self.margin = margin
|
||||
self.index_refresh_interval = index_refresh_interval
|
||||
self.use_gpu = use_gpu and torch.cuda.is_available()
|
||||
self.max_index_size = max_index_size
|
||||
|
||||
# Initialize FAISS index
|
||||
self.index = self._create_index(index_type, nlist)
|
||||
self.passage_map: Dict[int, str] = {} # Maps index ID to passage text
|
||||
self.query_map: Dict[str, List[str]] = defaultdict(list) # Maps query to positive passages
|
||||
|
||||
# For asynchronous index updates
|
||||
self.update_queue = queue.Queue()
|
||||
self.update_thread = None
|
||||
self.stop_update_thread = threading.Event()
|
||||
self.current_step = 0
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'total_queries': 0,
|
||||
'total_negatives_mined': 0,
|
||||
'index_updates': 0,
|
||||
'cache_hits': 0,
|
||||
'cache_misses': 0
|
||||
}
|
||||
|
||||
def _create_index(self, index_type: str, nlist: int) -> faiss.Index:
|
||||
"""Create FAISS index, warn if falling back to CPU."""
|
||||
if index_type == "IVF":
|
||||
quantizer = faiss.IndexFlatIP(self.embedding_dim)
|
||||
index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT)
|
||||
elif index_type == "HNSW":
|
||||
index = faiss.IndexHNSWFlat(self.embedding_dim, 32, faiss.METRIC_INNER_PRODUCT)
|
||||
else:
|
||||
index = faiss.IndexFlatIP(self.embedding_dim)
|
||||
if self.use_gpu:
|
||||
try:
|
||||
res = faiss.StandardGpuResources() # type: ignore[attr-defined]
|
||||
index = faiss.index_cpu_to_gpu(res, 0, index) # type: ignore[attr-defined]
|
||||
logger.info("Using GPU for FAISS index")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to use GPU for FAISS: {e}. Using CPU.")
|
||||
self.use_gpu = False
|
||||
else:
|
||||
logger.warning("FAISS is running on CPU. This may be slower than GPU.")
|
||||
return index
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters"""
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
def start_async_updates(self):
|
||||
"""Start background thread for asynchronous index updates"""
|
||||
if self.update_thread is None or not self.update_thread.is_alive():
|
||||
self.stop_update_thread.clear()
|
||||
self.update_thread = threading.Thread(target=self._update_worker)
|
||||
self.update_thread.daemon = True
|
||||
self.update_thread.start()
|
||||
logger.info("Started asynchronous index update thread")
|
||||
|
||||
def stop_async_updates(self):
|
||||
"""Stop background update thread and ensure clean shutdown."""
|
||||
if self.update_thread and self.update_thread.is_alive():
|
||||
self.stop_update_thread.set()
|
||||
self.update_thread.join(timeout=5)
|
||||
logger.info("Stopped asynchronous index update thread")
|
||||
|
||||
def _update_worker(self):
|
||||
"""Background worker for index updates"""
|
||||
while not self.stop_update_thread.is_set():
|
||||
try:
|
||||
# Get update task with timeout
|
||||
task = self.update_queue.get(timeout=1)
|
||||
|
||||
if task['type'] == 'add_embeddings':
|
||||
self._add_to_index(task['embeddings'], task['passages'])
|
||||
elif task['type'] == 'rebuild_index':
|
||||
self._rebuild_index(task['model'])
|
||||
|
||||
self.stats['index_updates'] += 1
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in update worker: {e}")
|
||||
|
||||
def _add_to_index(self, embeddings: np.ndarray, passages: List[str]):
|
||||
"""Add embeddings to index"""
|
||||
try:
|
||||
# Normalize embeddings for inner product search
|
||||
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
|
||||
# Add to index
|
||||
start_idx = len(self.passage_map)
|
||||
|
||||
# Debug log for embeddings
|
||||
logger.debug(f"embeddings shape: {getattr(embeddings, 'shape', None)}, type: {type(embeddings)}")
|
||||
|
||||
if hasattr(self.index, 'is_trained') and not self.index.is_trained:
|
||||
if embeddings is not None and len(embeddings) > 0:
|
||||
if hasattr(self.index, 'nlist') and len(embeddings) >= getattr(self.index, 'nlist', 0):
|
||||
# Safe to access self.index.nlist and call train
|
||||
logger.info("Training FAISS index...")
|
||||
self.index.train(embeddings) # type: ignore[call-arg]
|
||||
elif hasattr(self.index, 'nlist'):
|
||||
logger.warning(f"Not enough data to train index. Need {getattr(self.index, 'nlist', 0)}, got {len(embeddings)}")
|
||||
return
|
||||
else:
|
||||
# For index types without nlist/train, skip training
|
||||
pass
|
||||
else:
|
||||
logger.warning("Embeddings is None or empty, skipping training and addition to index.")
|
||||
return
|
||||
|
||||
if embeddings is not None and len(embeddings) > 0:
|
||||
self.index.add(embeddings) # type: ignore[call-arg]
|
||||
else:
|
||||
logger.warning("Embeddings is None or empty, skipping addition to index.")
|
||||
|
||||
# Update passage map
|
||||
for i, passage in enumerate(passages):
|
||||
self.passage_map[start_idx + i] = passage
|
||||
|
||||
logger.info(f"Added {len(passages)} passages to index. Total: {self.index.ntotal}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding embeddings to index: {e}")
|
||||
|
||||
def _rebuild_index(self, model):
|
||||
"""Rebuild entire index with current model"""
|
||||
try:
|
||||
logger.info("Rebuilding entire index...")
|
||||
|
||||
# Clear current index
|
||||
self.index.reset()
|
||||
old_passage_map = self.passage_map.copy()
|
||||
self.passage_map.clear()
|
||||
|
||||
# Re-encode all passages
|
||||
all_passages = list(set(sum(self.query_map.values(), [])))
|
||||
if not all_passages:
|
||||
logger.warning("No passages to rebuild index with")
|
||||
return
|
||||
|
||||
embeddings = []
|
||||
batch_size = 32
|
||||
|
||||
for i in tqdm(range(0, len(all_passages), batch_size), desc="Encoding passages"):
|
||||
batch = all_passages[i:i + batch_size]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
batch_embeddings = model.encode(batch, convert_to_numpy=True, batch_size=len(batch))
|
||||
embeddings.append(batch_embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding batch {i // batch_size}: {e}")
|
||||
continue
|
||||
|
||||
if embeddings:
|
||||
embeddings = np.vstack(embeddings)
|
||||
self._add_to_index(embeddings, all_passages)
|
||||
logger.info(f"Rebuilt index with {len(all_passages)} passages")
|
||||
else:
|
||||
# Restore old passage map if rebuild failed
|
||||
self.passage_map = old_passage_map
|
||||
logger.error("Failed to rebuild index - no embeddings generated")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rebuilding index: {e}")
|
||||
|
||||
def mine_hard_negatives(
|
||||
self,
|
||||
queries: List[str],
|
||||
query_embeddings: np.ndarray,
|
||||
positive_passages: List[List[str]],
|
||||
corpus_embeddings: Optional[np.ndarray] = None,
|
||||
corpus_passages: Optional[List[str]] = None
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Mine hard negatives for queries using ANCE approach
|
||||
|
||||
Args:
|
||||
queries: List of query texts
|
||||
query_embeddings: Query embeddings [num_queries, embedding_dim]
|
||||
positive_passages: List of positive passages for each query
|
||||
corpus_embeddings: Optional corpus embeddings to add to index
|
||||
corpus_passages: Optional corpus passages
|
||||
|
||||
Returns:
|
||||
List of hard negative passages for each query
|
||||
"""
|
||||
# Add corpus to index if provided
|
||||
if corpus_embeddings is not None and corpus_passages is not None:
|
||||
self.update_queue.put({
|
||||
'type': 'add_embeddings',
|
||||
'embeddings': corpus_embeddings,
|
||||
'passages': corpus_passages
|
||||
})
|
||||
|
||||
# Update query map
|
||||
for query, positives in zip(queries, positive_passages):
|
||||
self.query_map[query].extend(positives)
|
||||
|
||||
# Normalize query embeddings
|
||||
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
|
||||
|
||||
# Search for hard negatives
|
||||
hard_negatives = []
|
||||
k = min(self.negative_range[1], self.index.ntotal)
|
||||
|
||||
if k == 0:
|
||||
logger.warning("Index is empty, returning empty hard negatives")
|
||||
return [[] for _ in queries]
|
||||
|
||||
similarities, indices = self.index.search(query_embeddings, k) # type: ignore[call-arg]
|
||||
|
||||
for i, (query, positives) in enumerate(zip(queries, positive_passages)):
|
||||
query_hard_negs = []
|
||||
candidates = []
|
||||
|
||||
# Collect candidates within negative range
|
||||
for j in range(self.negative_range[0], min(k, self.negative_range[1])):
|
||||
if indices[i, j] >= 0: # Valid index
|
||||
passage_idx = indices[i, j]
|
||||
if passage_idx in self.passage_map:
|
||||
passage = self.passage_map[passage_idx]
|
||||
score = similarities[i, j]
|
||||
|
||||
# Skip if it's a positive passage
|
||||
if passage not in positives:
|
||||
candidates.append((passage, score))
|
||||
|
||||
# Select hard negatives with margin
|
||||
if candidates:
|
||||
# Sort by score (descending)
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Select with margin constraint
|
||||
selected = []
|
||||
last_score = float('inf')
|
||||
for passage, score in candidates:
|
||||
if last_score - score >= self.margin or len(selected) == 0:
|
||||
selected.append(passage)
|
||||
last_score = score
|
||||
if len(selected) >= self.num_hard_negatives:
|
||||
break
|
||||
|
||||
query_hard_negs = selected
|
||||
|
||||
# Pad with random negatives if needed
|
||||
if len(query_hard_negs) < self.num_hard_negatives:
|
||||
available = [p for p in self.passage_map.values()
|
||||
if p not in positives and p not in query_hard_negs]
|
||||
if available:
|
||||
additional_count = min(len(available), self.num_hard_negatives - len(query_hard_negs))
|
||||
additional = np.random.choice(available, size=additional_count, replace=False).tolist()
|
||||
query_hard_negs.extend(additional)
|
||||
|
||||
hard_negatives.append(query_hard_negs)
|
||||
|
||||
self.stats['total_queries'] += len(queries)
|
||||
self.stats['total_negatives_mined'] += sum(len(negs) for negs in hard_negatives)
|
||||
|
||||
return hard_negatives
|
||||
|
||||
def update_index_if_needed(self, step: int, model=None):
|
||||
"""Check if index needs update and trigger if necessary"""
|
||||
self.current_step = step
|
||||
|
||||
if step > 0 and step % self.index_refresh_interval == 0 and model is not None:
|
||||
logger.info(f"Triggering index rebuild at step {step}")
|
||||
self.update_queue.put({
|
||||
'type': 'rebuild_index',
|
||||
'model': model
|
||||
})
|
||||
|
||||
def save_mined_data(self, output_path: str, original_data_path: str):
|
||||
"""Save data with mined hard negatives"""
|
||||
logger.info(f"Saving mined data to {output_path}")
|
||||
|
||||
# Load original data
|
||||
original_data = []
|
||||
with open(original_data_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
original_data.append(json.loads(line))
|
||||
|
||||
# This is a simplified implementation - in practice you'd batch mine negatives
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for item in tqdm(original_data, desc="Saving mined data"):
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
logger.info(f"Saved {len(original_data)} examples")
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
"""Get mining statistics"""
|
||||
self.stats['index_size'] = self.index.ntotal
|
||||
self.stats['unique_passages'] = len(self.passage_map)
|
||||
return self.stats.copy()
|
||||
|
||||
|
||||
class HardNegativeDataAugmenter:
|
||||
"""
|
||||
Augment training data with various hard negative strategies
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.strategies = config.get('strategies', ['ance', 'bm25', 'semantic'])
|
||||
self.bm25_index = None
|
||||
self.corpus_passages = None
|
||||
|
||||
def build_bm25_index(self, corpus: List[str]):
|
||||
"""Build BM25 index for hard negative mining"""
|
||||
try:
|
||||
logger.info("Building BM25 index...")
|
||||
self.corpus_passages = corpus
|
||||
|
||||
# Tokenize corpus
|
||||
tokenized_corpus = []
|
||||
for doc in tqdm(corpus, desc="Tokenizing corpus"):
|
||||
if self._is_chinese(doc) and JIEBA_AVAILABLE:
|
||||
tokens = list(jieba.cut(doc))
|
||||
else:
|
||||
tokens = doc.lower().split()
|
||||
tokenized_corpus.append(tokens)
|
||||
|
||||
if BM25_AVAILABLE:
|
||||
self.bm25_index = BM25Okapi(tokenized_corpus)
|
||||
logger.info(f"Built BM25 index with {len(corpus)} documents")
|
||||
else:
|
||||
logger.warning("BM25Okapi not available, skipping BM25 index")
|
||||
self.bm25_index = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building BM25 index: {e}")
|
||||
self.bm25_index = None
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters"""
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
def augment_data(
|
||||
self,
|
||||
data: List[Dict],
|
||||
model=None,
|
||||
corpus: Optional[List[str]] = None
|
||||
) -> List[Dict]:
|
||||
"""Apply multiple hard negative augmentation strategies"""
|
||||
|
||||
# Build BM25 index if corpus provided
|
||||
if corpus and 'bm25' in self.strategies:
|
||||
self.build_bm25_index(corpus)
|
||||
|
||||
augmented_data = []
|
||||
|
||||
for item in tqdm(data, desc="Augmenting data"):
|
||||
query = item['query']
|
||||
positives = item.get('pos', [])
|
||||
negatives = item.get('neg', [])
|
||||
|
||||
# Apply different strategies
|
||||
if 'bm25' in self.strategies and self.bm25_index:
|
||||
bm25_negs = self._mine_bm25_negatives(query, positives)
|
||||
negatives.extend(bm25_negs[:5]) # Add top 5 BM25 negatives
|
||||
|
||||
if 'semantic' in self.strategies and model:
|
||||
semantic_negs = self._mine_semantic_negatives(query, positives, model)
|
||||
negatives.extend(semantic_negs[:5])
|
||||
|
||||
if 'random_swap' in self.strategies:
|
||||
swap_negs = self._create_swap_negatives(positives)
|
||||
negatives.extend(swap_negs[:3])
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_negatives = []
|
||||
for neg in negatives:
|
||||
if neg not in seen and neg not in positives:
|
||||
seen.add(neg)
|
||||
unique_negatives.append(neg)
|
||||
|
||||
augmented_item = item.copy()
|
||||
augmented_item['neg'] = unique_negatives[:self.config.get('max_negatives', 30)]
|
||||
augmented_data.append(augmented_item)
|
||||
|
||||
return augmented_data
|
||||
|
||||
def _mine_bm25_negatives(self, query: str, positives: List[str]) -> List[str]:
|
||||
"""Mine hard negatives using BM25"""
|
||||
if not self.bm25_index or not self.corpus_passages or not BM25_AVAILABLE:
|
||||
return []
|
||||
|
||||
# Tokenize query
|
||||
if self._is_chinese(query) and JIEBA_AVAILABLE:
|
||||
query_tokens = list(jieba.cut(query))
|
||||
else:
|
||||
query_tokens = query.lower().split()
|
||||
|
||||
# Get BM25 scores
|
||||
scores = self.bm25_index.get_scores(query_tokens)
|
||||
|
||||
# Get top passages excluding positives
|
||||
top_indices = np.argsort(scores)[::-1]
|
||||
negatives = []
|
||||
|
||||
for idx in top_indices:
|
||||
if len(negatives) >= 10:
|
||||
break
|
||||
passage = self.corpus_passages[idx]
|
||||
if passage not in positives:
|
||||
negatives.append(passage)
|
||||
|
||||
return negatives
|
||||
|
||||
def _mine_semantic_negatives(self, query: str, positives: List[str], model) -> List[str]:
|
||||
"""
|
||||
Mine semantically similar but irrelevant passages using the model's encode method.
|
||||
Returns top-k most similar passages (excluding positives).
|
||||
"""
|
||||
try:
|
||||
if not self.corpus_passages or not model:
|
||||
logger.warning("No corpus or model provided for semantic negative mining.")
|
||||
return []
|
||||
|
||||
batch_size = self.config.get('semantic_batch_size', 64)
|
||||
num_negatives = self.config.get('semantic_num_negatives', 10)
|
||||
device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Encode query
|
||||
query_embedding = model.encode([query], convert_to_numpy=True, device=device)
|
||||
# Encode all corpus passages in batches
|
||||
passage_embeddings = []
|
||||
for i in range(0, len(self.corpus_passages), batch_size):
|
||||
batch = self.corpus_passages[i:i+batch_size]
|
||||
batch_emb = model.encode(batch, convert_to_numpy=True, device=device)
|
||||
passage_embeddings.append(batch_emb)
|
||||
passage_embeddings = np.vstack(passage_embeddings)
|
||||
|
||||
# Compute similarity
|
||||
scores = np.dot(passage_embeddings, query_embedding[0])
|
||||
# Exclude positives
|
||||
candidates = [(idx, score) for idx, (score, passage) in enumerate(zip(scores, self.corpus_passages)) if passage not in positives]
|
||||
# Sort and select top-k
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
top_indices = [idx for idx, _ in candidates[:num_negatives]]
|
||||
negatives = [self.corpus_passages[idx] for idx in top_indices]
|
||||
logger.info(f"Mined {len(negatives)} semantic negatives for query: {query}")
|
||||
return negatives
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error mining semantic negatives: {e}")
|
||||
return []
|
||||
|
||||
def _create_swap_negatives(self, positives: List[str]) -> List[str]:
|
||||
"""Create negatives by swapping words in positive passages"""
|
||||
negatives = []
|
||||
|
||||
for pos in positives[:2]: # Only use first 2 positives
|
||||
words = pos.split()
|
||||
if len(words) > 5:
|
||||
# Swap random words
|
||||
import random
|
||||
swapped = words.copy()
|
||||
for _ in range(min(3, len(words) // 3)):
|
||||
i, j = random.sample(range(len(words)), 2)
|
||||
swapped[i], swapped[j] = swapped[j], swapped[i]
|
||||
negatives.append(' '.join(swapped))
|
||||
|
||||
return negatives
|
||||
Reference in New Issue
Block a user