""" Unified BGE dataset pipeline with comprehensive format support - Triplets: Query-triplets format (query, pos, neg) for embedding training - Pairs: Flat pairs format (query, passage, label) for reranker training - Candidates: Unified format (query, candidates) with threshold-based label assignment - Legacy: Supports nested formats with automatic conversion to flat pairs - Enhanced: Score-weighted sampling, hard negative mining, multiple positives """ import json import random from typing import Dict, List, Optional, Union, Tuple import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizer import logging from tqdm import tqdm import numpy as np logger = logging.getLogger(__name__) class BGEDataset(Dataset): """ Unified BGE dataset with automatic format detection and enhanced features Supports: triplets, pairs, candidates, and legacy nested formats """ def __init__( self, data_path: str, tokenizer: PreTrainedTokenizer, query_max_length: int = 64, passage_max_length: int = 512, is_train: bool = True, cache_dir: Optional[str] = None, legacy_support: bool = True, format_type: str = "auto", # "triplets", "pairs", "candidates", "legacy_nested", "auto" candidates_threshold: float = 0.5, # Threshold for candidates format label assignment cache_in_memory: bool = False, # Cache tokenized data in memory for small datasets max_cache_size: int = 100000 # Maximum number of samples to cache in memory ): self.tokenizer = tokenizer self.query_max_length = query_max_length self.passage_max_length = passage_max_length self.is_train = is_train self.legacy_support = legacy_support self.format_type = format_type self.candidates_threshold = candidates_threshold self.cache_in_memory = cache_in_memory self.max_cache_size = max_cache_size self._tokenized_cache = {} logger.info(f"Loading data from {data_path}") self.data, self.detected_format = self._load_and_detect_format(data_path) logger.info(f"Loaded {len(self.data)} examples in {self.detected_format} format") # Pre-tokenize data if caching is enabled and dataset is small enough if self.cache_in_memory and len(self.data) <= self.max_cache_size: logger.info(f"Pre-tokenizing {len(self.data)} samples for in-memory cache...") self._pretokenize_all() logger.info("Pre-tokenization complete") def _load_and_detect_format(self, data_path: str) -> Tuple[List[Dict], str]: """Load data and automatically detect format with improved detection logic""" data = [] format_type = "unknown" total_lines = 0 with open(data_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(tqdm(f, desc="Loading and detecting format"), 1): if line.strip(): total_lines += 1 try: item = json.loads(line) # Detect format on first valid item if format_type == "unknown": format_type = self._detect_schema_type(item) logger.info(f"Detected schema: {format_type}") # Validate and process item if self._validate_item(item, format_type): processed_items = self._process_item(item, format_type) # Process_item may return multiple items (for candidates format) if isinstance(processed_items, list): data.extend(processed_items) else: data.append(processed_items) else: logger.warning(f"Invalid item at line {line_num}: {item}") except json.JSONDecodeError as e: logger.warning(f"Invalid JSON at line {line_num}: {e}") continue if not data: raise ValueError(f"No valid data found in {data_path}") # Validation check if total_lines > 10 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning(f"Few valid data lines loaded from {data_path}: {len(data)} out of {total_lines}") return data, format_type def _detect_schema_type(self, item: Dict) -> str: """ Enhanced format detection based on key presence and structure: - triplets: has 'pos' and 'neg' keys (for embedding training) - pairs: has 'passage' and 'label' keys (for reranker training) - candidates: has 'candidates' key with list of candidate objects - legacy_nested: has 'pos' but no 'label' (backward compatibility) """ try: # Use explicitly set format if provided if self.format_type != "auto": return self.format_type # Ensure item is a dictionary if not isinstance(item, dict): raise ValueError(f"Expected dict but got {type(item)}") # Extract available keys for debugging available_keys = set(item.keys()) # Check for candidates format first (most specific) if 'candidates' in item: # Validate candidates structure candidates = item.get('candidates', []) if isinstance(candidates, list) and len(candidates) > 0: # Check if first candidate has expected structure first_candidate = candidates[0] if isinstance(first_candidate, dict) and 'text' in first_candidate: return "candidates" # Also support simple string list as candidates elif isinstance(first_candidate, str): logger.info("Detected candidates format with string list - will convert to dict format") return "candidates" # Check for pairs format (reranker) if 'passage' in item and 'label' in item: # Additional check for query field if 'query' in item: return "pairs" else: logger.warning(f"Found 'passage' and 'label' but missing 'query'. Available keys: {available_keys}") # Check for triplets format (embedding) if 'pos' in item: # Standard triplets format has both pos and neg if 'neg' in item: return "triplets" # Might be triplets with only positives elif 'query' in item and not ('label' in item or 'passage' in item): logger.info("Detected triplets format with only positive samples") return "triplets" # Legacy nested format elif self.legacy_support and 'query' in item: logger.warning("Detected legacy nested format. Converting to flat pairs for reranker training.") return "legacy_nested" # Try to provide helpful error message common_keys = available_keys & {'query', 'pos', 'neg', 'passage', 'label', 'candidates'} error_msg = f"Unknown data format. Available keys: {available_keys}" if common_keys: error_msg += f", Common keys found: {common_keys}" else: error_msg += ". No common keys found. Expected 'query' with 'pos'/'neg' (triplets), 'passage'/'label' (pairs), or 'candidates' (candidates format)" raise ValueError(error_msg) except Exception as e: logger.error(f"Format detection failed: {e}") raise def _validate_item(self, item: Dict, format_type: str) -> bool: """ Validate item based on detected format with comprehensive checks Returns: bool: True if item is valid, False otherwise """ try: # Basic type check if not isinstance(item, dict): return False # Format-specific validation if format_type == "triplets": # Check required fields if not all(field in item for field in ['query', 'pos']): return False # Validate query if not isinstance(item['query'], str) or not item['query'].strip(): return False # Validate positives if not isinstance(item.get('pos', []), list) or len(item['pos']) == 0: return False # Check each positive is a non-empty string for pos in item['pos']: if not isinstance(pos, str) or not pos.strip(): return False # Validate negatives if present if 'neg' in item: if not isinstance(item['neg'], list): return False for neg in item['neg']: if not isinstance(neg, str) or not neg.strip(): return False return True elif format_type == "pairs": # Check required fields if not all(field in item for field in ['query', 'passage', 'label']): return False # Validate query and passage if not isinstance(item['query'], str) or not item['query'].strip(): return False if not isinstance(item['passage'], str) or not item['passage'].strip(): return False # Validate label (should be numeric 0 or 1, or float between 0 and 1) label = item['label'] if isinstance(label, (int, float)): if not (0 <= label <= 1): return False else: return False return True elif format_type == "candidates": # Check required fields if not ('query' in item and 'candidates' in item): return False # Validate query if not isinstance(item['query'], str) or not item['query'].strip(): return False # Validate candidates if not isinstance(item['candidates'], list) or len(item['candidates']) == 0: return False # Check each candidate for i, candidate in enumerate(item['candidates']): if not isinstance(candidate, dict): return False if 'text' not in candidate: return False if not isinstance(candidate['text'], str) or not candidate['text'].strip(): return False # Validate score if present if 'score' in candidate: score = candidate['score'] if not isinstance(score, (int, float)): return False return True elif format_type == "legacy_nested": # Basic validation for legacy format if not ('query' in item and 'pos' in item): return False if not isinstance(item['query'], str) or not item['query'].strip(): return False return True else: return False except Exception as e: logger.debug(f"Validation error: {e}") return False def _process_item(self, item: Dict, format_type: str) -> Union[Dict, List[Dict]]: """Process item based on format type - may return multiple items for candidates""" if format_type == "triplets": return self._process_triplets_item(item) elif format_type == "pairs": return self._process_pairs_item(item) elif format_type == "candidates": return self._process_candidates_item(item) elif format_type == "legacy_nested": return self._process_legacy_nested_item(item) else: raise ValueError(f"Unknown format type: {format_type}") def _process_triplets_item(self, item: Dict) -> Dict: """Process triplets format item for embedding training""" processed = { 'query': item['query'], 'pos': item.get('pos', []), 'neg': item.get('neg', []), 'pos_scores': item.get('pos_scores', [1.0] * len(item.get('pos', []))), 'neg_scores': item.get('neg_scores', [0.0] * len(item.get('neg', []))), 'prompt': item.get('prompt', ''), 'type': item.get('type', 'normal'), 'format': 'triplets' } return processed def _process_pairs_item(self, item: Dict) -> Dict: """Process pairs format item for reranker training""" processed = { 'query': item['query'], 'passage': item['passage'], 'label': item['label'], 'score': item.get('score', 1.0 if item['label'] > 0 else 0.0), 'qid': item.get('qid', ''), 'pid': item.get('pid', ''), 'format': 'pairs' } return processed def _process_candidates_item(self, item: Dict) -> List[Dict]: """ Process candidates format item - explode into multiple pairs Each candidate becomes a separate (query, passage, label) pair Label is assigned based on score threshold """ query = item['query'] candidates = item['candidates'] qid = item.get('qid', '') pairs = [] for i, candidate in enumerate(candidates): if isinstance(candidate, dict): text = candidate.get('text', '') score = candidate.get('score', 0.0) pid = candidate.get('pid', f'cand_{i}') else: # Handle simple string candidates text = str(candidate) score = 1.0 # Default to positive if no score pid = f'cand_{i}' # Assign label based on threshold label = 1 if score >= self.candidates_threshold else 0 pair = { 'query': query, 'passage': text, 'label': label, 'score': score, 'qid': qid, 'pid': pid, 'format': 'pairs', # Convert to pairs format 'source': 'candidates' # Track original source } pairs.append(pair) return pairs def _process_legacy_nested_item(self, item: Dict) -> List[Dict]: """ Process legacy nested format - convert to flat pairs for reranker training Each pos[i] becomes (query, passage, label=1) Each neg[j] becomes (query, passage, label=0) """ query = item['query'] positives = item.get('pos', []) negatives = item.get('neg', []) pos_scores = item.get('pos_scores', [1.0] * len(positives)) neg_scores = item.get('neg_scores', [0.0] * len(negatives)) qid = item.get('qid', '') pairs = [] # Convert positives for i, passage in enumerate(positives): score = pos_scores[i] if i < len(pos_scores) else 1.0 pair = { 'query': query, 'passage': passage, 'label': 1, 'score': score, 'qid': qid, 'pid': f'pos_{i}', 'format': 'pairs', # Convert to pairs format 'source': 'legacy_nested' # Track original source } pairs.append(pair) # Convert negatives for i, passage in enumerate(negatives): score = neg_scores[i] if i < len(neg_scores) else 0.0 pair = { 'query': query, 'passage': passage, 'label': 0, 'score': score, 'qid': qid, 'pid': f'neg_{i}', 'format': 'pairs', # Convert to pairs format 'source': 'legacy_nested' # Track original source } pairs.append(pair) return pairs def _pretokenize_all(self): """Pre-tokenize all samples for in-memory caching""" from tqdm import tqdm for idx in tqdm(range(len(self.data)), desc="Pre-tokenizing"): # Tokenize based on format item = self.data[idx] if item['format'] == 'triplets': # Pre-tokenize query and passages for triplets query = item['query'] query_encoding = self.tokenizer( query, max_length=self.query_max_length, padding='max_length', truncation=True, return_tensors='pt' ) self._tokenized_cache[f"{idx}_query"] = { 'input_ids': query_encoding.input_ids.squeeze(0), 'attention_mask': query_encoding.attention_mask.squeeze(0) } elif item['format'] == 'pairs': # Pre-tokenize query-passage pair query = item['query'] passage = item['passage'] pair_text = f"{query} [SEP] {passage}" encoding = self.tokenizer( pair_text, max_length=self.query_max_length + self.passage_max_length, padding='max_length', truncation=True, return_tensors='pt' ) self._tokenized_cache[idx] = { 'input_ids': encoding.input_ids.squeeze(0), 'attention_mask': encoding.attention_mask.squeeze(0) } def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict: """Get item based on detected format""" item = self.data[idx] if item['format'] == 'triplets': return self._get_triplets_item(item, idx) elif item['format'] == 'pairs': return self._get_pairs_item(item, idx) else: raise ValueError(f"Unknown format: {item['format']}") def _get_triplets_item(self, item: Dict, idx: int) -> Dict: """Get triplets training example for embedding model (contrastive learning)""" query = item['query'] positives = item['pos'] negatives = item['neg'] pos_scores = item['pos_scores'] neg_scores = item['neg_scores'] prompt = item['prompt'] # Apply prompt if provided if prompt and self.is_train: query = f"{prompt}{query}" # Check if we have cached tokenized query if f"{idx}_query" in self._tokenized_cache: query_encodings = self._tokenized_cache[f"{idx}_query"] else: # Tokenize query query_encodings = self.tokenizer( query, max_length=self.query_max_length, padding='max_length', truncation=True, return_tensors='pt' ) query_encodings = { 'input_ids': query_encodings.input_ids.squeeze(0), 'attention_mask': query_encodings.attention_mask.squeeze(0) } # Enhanced sampling for contrastive learning if self.is_train: # Sample multiple positives for richer contrastive signals num_positives = min(2, len(positives)) sampled_positives = random.sample(positives, num_positives) if positives else [] # Hard negative mining with score-based sampling num_negatives = min(6, len(negatives)) if negatives and neg_scores: # Weight by scores for hard negative selection neg_weights = np.array(neg_scores) neg_weights = neg_weights / neg_weights.sum() if neg_weights.sum() > 0 else np.ones_like(neg_weights) sampled_neg_indices = np.random.choice( len(negatives), size=min(num_negatives, len(negatives)), replace=False, p=neg_weights ) sampled_negatives = [negatives[i] for i in sampled_neg_indices] sampled_neg_scores = [neg_scores[i] for i in sampled_neg_indices] else: sampled_negatives = random.sample(negatives, num_negatives) if negatives else [] sampled_neg_scores = [0.0] * len(sampled_negatives) passages = sampled_positives + sampled_negatives labels = [1] * len(sampled_positives) + [0] * len(sampled_negatives) scores = pos_scores[:len(sampled_positives)] + sampled_neg_scores else: # Use all for evaluation passages = positives + negatives labels = [1] * len(positives) + [0] * len(negatives) scores = pos_scores + neg_scores # Tokenize passages passage_encodings = [] for passage in passages: encoding = self.tokenizer( passage, max_length=self.passage_max_length, padding='max_length', truncation=True, return_tensors='pt' ) passage_encodings.append({ 'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0), 'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0) }) # Pad to consistent size (max 8 passages for triplets) max_passages = 8 while len(passage_encodings) < max_passages: passage_encodings.append({ 'input_ids': torch.zeros(self.passage_max_length, dtype=torch.long), 'attention_mask': torch.zeros(self.passage_max_length, dtype=torch.long) }) labels.append(-1) scores.append(0.0) # Stack tensors passage_input_ids = torch.stack([p['input_ids'] for p in passage_encodings]) passage_attention_mask = torch.stack([p['attention_mask'] for p in passage_encodings]) return { 'query_input_ids': query_encodings['input_ids'], 'query_attention_mask': query_encodings['attention_mask'], 'passage_input_ids': passage_input_ids, 'passage_attention_mask': passage_attention_mask, 'labels': torch.tensor(labels, dtype=torch.long), 'scores': torch.tensor(scores, dtype=torch.float), 'query_text': query, 'passage_texts': passages + [''] * (max_passages - len(passages)), 'item_idx': idx, 'format': 'triplets', # Backward compatibility 'positives': positives, 'negatives': negatives } def _get_pairs_item(self, item: Dict, idx: int) -> Dict: """Get pairs training example for reranker (cross-encoder training)""" query = item['query'] passage = item['passage'] label = item['label'] score = item['score'] # Check if we have cached tokenized data if idx in self._tokenized_cache: encoding = self._tokenized_cache[idx] return { 'input_ids': encoding['input_ids'], 'attention_mask': encoding['attention_mask'], 'labels': torch.tensor(label, dtype=torch.long), 'scores': torch.tensor(score, dtype=torch.float), 'query_text': query, 'passage_text': passage, 'qid': item.get('qid', ''), 'pid': item.get('pid', ''), 'item_idx': idx, 'format': 'pairs', 'source': item.get('source', 'original') } # Tokenize query-passage pair if not cached encoding = self.tokenizer( query, passage, max_length=self.passage_max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0), 'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0), 'labels': torch.tensor(label, dtype=torch.long), 'scores': torch.tensor(score, dtype=torch.float), 'query_text': query, 'passage_text': passage, 'qid': item['qid'], 'pid': item['pid'], 'item_idx': idx, 'format': 'pairs', 'source': item.get('source', 'original') } # Specialized dataset classes for backward compatibility class BGEM3Dataset(BGEDataset): """BGE-M3 specialized dataset (embedding model) with enhanced contrastive learning""" def __init__( self, data_path: str, tokenizer: PreTrainedTokenizer, query_max_length: int = 64, passage_max_length: int = 512, is_train: bool = True, train_group_size: int = 8, cache_dir: Optional[str] = None ): # Force embedding format super().__init__( data_path=data_path, tokenizer=tokenizer, query_max_length=query_max_length, passage_max_length=passage_max_length, is_train=is_train, cache_dir=cache_dir, format_type="triplets" ) self.train_group_size = train_group_size # Warn if wrong format detected if self.detected_format not in ['triplets']: logger.warning(f"Expected triplets format for BGEM3Dataset, got {self.detected_format}") class BGERerankerDataset(BGEDataset): """BGE-Reranker specialized dataset (cross-encoder model) with flat pairs support""" def __init__( self, data_path: str, tokenizer: PreTrainedTokenizer, query_max_length: int = 64, passage_max_length: int = 448, is_train: bool = True, train_group_size: int = 16, cache_dir: Optional[str] = None, legacy_support: bool = True ): # Force reranker format (or allow legacy) format_type = "auto" if legacy_support else "pairs" super().__init__( data_path=data_path, tokenizer=tokenizer, query_max_length=query_max_length, passage_max_length=passage_max_length, is_train=is_train, cache_dir=cache_dir, legacy_support=legacy_support, format_type=format_type ) self.train_group_size = train_group_size # Warn if wrong format detected if self.detected_format not in ['pairs', 'legacy_nested']: logger.warning(f"Expected pairs format for BGERerankerDataset, got {self.detected_format}") class JointDataset(Dataset): """Dataset for joint training of retriever and reranker""" def __init__( self, data_path: str, tokenizer: PreTrainedTokenizer, retriever_config: Dict, reranker_config: Dict, is_train: bool = True, cache_dir: Optional[str] = None ): self.tokenizer = tokenizer self.is_train = is_train # Create both datasets self.retriever_dataset = BGEM3Dataset( data_path=data_path, tokenizer=tokenizer, is_train=is_train, cache_dir=cache_dir, **retriever_config ) self.reranker_dataset = BGERerankerDataset( data_path=data_path, tokenizer=tokenizer, is_train=is_train, cache_dir=cache_dir, **reranker_config ) def __len__(self) -> int: return len(self.retriever_dataset) def __getitem__(self, idx: int) -> Dict: """Get example for both retriever and reranker""" retriever_item = self.retriever_dataset[idx] reranker_item = self.reranker_dataset[idx] return { 'retriever': retriever_item, 'reranker': reranker_item, 'item_idx': idx } # Factory functions for easy dataset creation def create_embedding_dataset( data_path: str, tokenizer: PreTrainedTokenizer, query_max_length: int = 64, passage_max_length: int = 512, is_train: bool = True ) -> BGEDataset: """Create dataset specifically for embedding model training""" dataset = BGEDataset( data_path=data_path, tokenizer=tokenizer, query_max_length=query_max_length, passage_max_length=passage_max_length, is_train=is_train, format_type="triplets" ) if dataset.detected_format not in ['triplets']: logger.warning(f"Expected triplets format, got {dataset.detected_format}") return dataset def create_reranker_dataset( data_path: str, tokenizer: PreTrainedTokenizer, query_max_length: int = 64, passage_max_length: int = 448, is_train: bool = True, legacy_support: bool = True ) -> BGEDataset: """Create dataset specifically for reranker model training""" dataset = BGEDataset( data_path=data_path, tokenizer=tokenizer, query_max_length=query_max_length, passage_max_length=passage_max_length, is_train=is_train, legacy_support=legacy_support, format_type="auto" if legacy_support else "pairs" ) if dataset.detected_format not in ['pairs', 'legacy_nested']: logger.warning(f"Expected pairs format, got {dataset.detected_format}") return dataset def migrate_nested_to_flat(input_file: str, output_file: str) -> str: """ Migrate legacy nested reranker JSONL to flat pairs format Explodes each (pos[i], label=1) and (neg[j], label=0) into separate lines """ logger.info(f"Migrating nested reranker format: {input_file} -> {output_file}") flat_pairs = [] with open(input_file, 'r', encoding='utf-8') as f: for line_no, line in enumerate(tqdm(f, desc="Converting to flat pairs"), 1): if line.strip(): try: item = json.loads(line) query = item['query'] positives = item.get('pos', []) negatives = item.get('neg', []) pos_scores = item.get('pos_scores', [1.0] * len(positives)) neg_scores = item.get('neg_scores', [0.0] * len(negatives)) # Create flat pairs for positives for i, passage in enumerate(positives): flat_pairs.append({ 'query': query, 'passage': passage, 'label': 1, 'score': pos_scores[i] if i < len(pos_scores) else 1.0, 'qid': f"Q{line_no}", 'pid': f"P{line_no}_{i}" }) # Create flat pairs for negatives for i, passage in enumerate(negatives): flat_pairs.append({ 'query': query, 'passage': passage, 'label': 0, 'score': neg_scores[i] if i < len(neg_scores) else 0.0, 'qid': f"Q{line_no}", 'pid': f"N{line_no}_{i}" }) except json.JSONDecodeError as e: logger.warning(f"Skipping invalid JSON on line {line_no}: {e}") continue # Save flat pairs with open(output_file, 'w', encoding='utf-8') as f: for pair in flat_pairs: f.write(json.dumps(pair, ensure_ascii=False) + '\n') logger.info(f"✅ Migrated {len(flat_pairs)} flat pairs to {output_file}") return output_file # Collate functions for different formats def collate_embedding_batch(batch: List[Dict]) -> Dict: """Collate function for embedding model batches""" return { 'query_input_ids': torch.stack([item['query_input_ids'] for item in batch]), 'query_attention_mask': torch.stack([item['query_attention_mask'] for item in batch]), 'passage_input_ids': torch.stack([item['passage_input_ids'] for item in batch]), 'passage_attention_mask': torch.stack([item['passage_attention_mask'] for item in batch]), 'labels': torch.stack([item['labels'] for item in batch]), 'scores': torch.stack([item['scores'] for item in batch]), 'query_text': [item['query_text'] for item in batch], 'passage_texts': [item['passage_texts'] for item in batch], 'item_idx': [item['item_idx'] for item in batch] } def collate_reranker_batch(batch: List[Dict]) -> Dict: """Collate function for reranker model batches""" if batch[0]['format'] == 'pairs': # Flat pairs format return { 'input_ids': torch.stack([item['input_ids'] for item in batch]), 'attention_mask': torch.stack([item['attention_mask'] for item in batch]), 'labels': torch.stack([item['labels'] for item in batch]), 'scores': torch.stack([item['scores'] for item in batch]), 'query_text': [item['query_text'] for item in batch], 'passage_text': [item['passage_text'] for item in batch], 'qid': [item['qid'] for item in batch], 'pid': [item['pid'] for item in batch], 'item_idx': [item['item_idx'] for item in batch] } else: # Legacy nested format return { 'input_ids': torch.stack([item['input_ids'] for item in batch]), 'attention_mask': torch.stack([item['attention_mask'] for item in batch]), 'labels': torch.stack([item['labels'] for item in batch]), 'query_text': [item['query_text'] for item in batch], 'passage_texts': [item['passage_texts'] for item in batch], 'item_idx': [item['item_idx'] for item in batch] }