858 lines
34 KiB
Python
858 lines
34 KiB
Python
"""
|
|
Unified BGE dataset pipeline with comprehensive format support
|
|
- Triplets: Query-triplets format (query, pos, neg) for embedding training
|
|
- Pairs: Flat pairs format (query, passage, label) for reranker training
|
|
- Candidates: Unified format (query, candidates) with threshold-based label assignment
|
|
- Legacy: Supports nested formats with automatic conversion to flat pairs
|
|
- Enhanced: Score-weighted sampling, hard negative mining, multiple positives
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
from typing import Dict, List, Optional, Union, Tuple
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from transformers import PreTrainedTokenizer
|
|
import logging
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BGEDataset(Dataset):
|
|
"""
|
|
Unified BGE dataset with automatic format detection and enhanced features
|
|
Supports: triplets, pairs, candidates, and legacy nested formats
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_path: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
query_max_length: int = 64,
|
|
passage_max_length: int = 512,
|
|
is_train: bool = True,
|
|
cache_dir: Optional[str] = None,
|
|
legacy_support: bool = True,
|
|
format_type: str = "auto", # "triplets", "pairs", "candidates", "legacy_nested", "auto"
|
|
candidates_threshold: float = 0.5, # Threshold for candidates format label assignment
|
|
cache_in_memory: bool = False, # Cache tokenized data in memory for small datasets
|
|
max_cache_size: int = 100000 # Maximum number of samples to cache in memory
|
|
):
|
|
self.tokenizer = tokenizer
|
|
self.query_max_length = query_max_length
|
|
self.passage_max_length = passage_max_length
|
|
self.is_train = is_train
|
|
self.legacy_support = legacy_support
|
|
self.format_type = format_type
|
|
self.candidates_threshold = candidates_threshold
|
|
self.cache_in_memory = cache_in_memory
|
|
self.max_cache_size = max_cache_size
|
|
self._tokenized_cache = {}
|
|
|
|
logger.info(f"Loading data from {data_path}")
|
|
self.data, self.detected_format = self._load_and_detect_format(data_path)
|
|
logger.info(f"Loaded {len(self.data)} examples in {self.detected_format} format")
|
|
|
|
# Pre-tokenize data if caching is enabled and dataset is small enough
|
|
if self.cache_in_memory and len(self.data) <= self.max_cache_size:
|
|
logger.info(f"Pre-tokenizing {len(self.data)} samples for in-memory cache...")
|
|
self._pretokenize_all()
|
|
logger.info("Pre-tokenization complete")
|
|
|
|
def _load_and_detect_format(self, data_path: str) -> Tuple[List[Dict], str]:
|
|
"""Load data and automatically detect format with improved detection logic"""
|
|
data = []
|
|
format_type = "unknown"
|
|
total_lines = 0
|
|
|
|
with open(data_path, 'r', encoding='utf-8') as f:
|
|
for line_num, line in enumerate(tqdm(f, desc="Loading and detecting format"), 1):
|
|
if line.strip():
|
|
total_lines += 1
|
|
try:
|
|
item = json.loads(line)
|
|
|
|
# Detect format on first valid item
|
|
if format_type == "unknown":
|
|
format_type = self._detect_schema_type(item)
|
|
logger.info(f"Detected schema: {format_type}")
|
|
|
|
# Validate and process item
|
|
if self._validate_item(item, format_type):
|
|
processed_items = self._process_item(item, format_type)
|
|
# Process_item may return multiple items (for candidates format)
|
|
if isinstance(processed_items, list):
|
|
data.extend(processed_items)
|
|
else:
|
|
data.append(processed_items)
|
|
else:
|
|
logger.warning(f"Invalid item at line {line_num}: {item}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Invalid JSON at line {line_num}: {e}")
|
|
continue
|
|
|
|
if not data:
|
|
raise ValueError(f"No valid data found in {data_path}")
|
|
|
|
# Validation check
|
|
if total_lines > 10 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(f"Few valid data lines loaded from {data_path}: {len(data)} out of {total_lines}")
|
|
|
|
return data, format_type
|
|
|
|
def _detect_schema_type(self, item: Dict) -> str:
|
|
"""
|
|
Enhanced format detection based on key presence and structure:
|
|
- triplets: has 'pos' and 'neg' keys (for embedding training)
|
|
- pairs: has 'passage' and 'label' keys (for reranker training)
|
|
- candidates: has 'candidates' key with list of candidate objects
|
|
- legacy_nested: has 'pos' but no 'label' (backward compatibility)
|
|
"""
|
|
try:
|
|
# Use explicitly set format if provided
|
|
if self.format_type != "auto":
|
|
return self.format_type
|
|
|
|
# Ensure item is a dictionary
|
|
if not isinstance(item, dict):
|
|
raise ValueError(f"Expected dict but got {type(item)}")
|
|
|
|
# Extract available keys for debugging
|
|
available_keys = set(item.keys())
|
|
|
|
# Check for candidates format first (most specific)
|
|
if 'candidates' in item:
|
|
# Validate candidates structure
|
|
candidates = item.get('candidates', [])
|
|
if isinstance(candidates, list) and len(candidates) > 0:
|
|
# Check if first candidate has expected structure
|
|
first_candidate = candidates[0]
|
|
if isinstance(first_candidate, dict) and 'text' in first_candidate:
|
|
return "candidates"
|
|
# Also support simple string list as candidates
|
|
elif isinstance(first_candidate, str):
|
|
logger.info("Detected candidates format with string list - will convert to dict format")
|
|
return "candidates"
|
|
|
|
# Check for pairs format (reranker)
|
|
if 'passage' in item and 'label' in item:
|
|
# Additional check for query field
|
|
if 'query' in item:
|
|
return "pairs"
|
|
else:
|
|
logger.warning(f"Found 'passage' and 'label' but missing 'query'. Available keys: {available_keys}")
|
|
|
|
# Check for triplets format (embedding)
|
|
if 'pos' in item:
|
|
# Standard triplets format has both pos and neg
|
|
if 'neg' in item:
|
|
return "triplets"
|
|
# Might be triplets with only positives
|
|
elif 'query' in item and not ('label' in item or 'passage' in item):
|
|
logger.info("Detected triplets format with only positive samples")
|
|
return "triplets"
|
|
# Legacy nested format
|
|
elif self.legacy_support and 'query' in item:
|
|
logger.warning("Detected legacy nested format. Converting to flat pairs for reranker training.")
|
|
return "legacy_nested"
|
|
|
|
# Try to provide helpful error message
|
|
common_keys = available_keys & {'query', 'pos', 'neg', 'passage', 'label', 'candidates'}
|
|
error_msg = f"Unknown data format. Available keys: {available_keys}"
|
|
if common_keys:
|
|
error_msg += f", Common keys found: {common_keys}"
|
|
else:
|
|
error_msg += ". No common keys found. Expected 'query' with 'pos'/'neg' (triplets), 'passage'/'label' (pairs), or 'candidates' (candidates format)"
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Format detection failed: {e}")
|
|
raise
|
|
|
|
def _validate_item(self, item: Dict, format_type: str) -> bool:
|
|
"""
|
|
Validate item based on detected format with comprehensive checks
|
|
|
|
Returns:
|
|
bool: True if item is valid, False otherwise
|
|
"""
|
|
try:
|
|
# Basic type check
|
|
if not isinstance(item, dict):
|
|
return False
|
|
|
|
# Format-specific validation
|
|
if format_type == "triplets":
|
|
# Check required fields
|
|
if not all(field in item for field in ['query', 'pos']):
|
|
return False
|
|
# Validate query
|
|
if not isinstance(item['query'], str) or not item['query'].strip():
|
|
return False
|
|
# Validate positives
|
|
if not isinstance(item.get('pos', []), list) or len(item['pos']) == 0:
|
|
return False
|
|
# Check each positive is a non-empty string
|
|
for pos in item['pos']:
|
|
if not isinstance(pos, str) or not pos.strip():
|
|
return False
|
|
# Validate negatives if present
|
|
if 'neg' in item:
|
|
if not isinstance(item['neg'], list):
|
|
return False
|
|
for neg in item['neg']:
|
|
if not isinstance(neg, str) or not neg.strip():
|
|
return False
|
|
return True
|
|
|
|
elif format_type == "pairs":
|
|
# Check required fields
|
|
if not all(field in item for field in ['query', 'passage', 'label']):
|
|
return False
|
|
# Validate query and passage
|
|
if not isinstance(item['query'], str) or not item['query'].strip():
|
|
return False
|
|
if not isinstance(item['passage'], str) or not item['passage'].strip():
|
|
return False
|
|
# Validate label (should be numeric 0 or 1, or float between 0 and 1)
|
|
label = item['label']
|
|
if isinstance(label, (int, float)):
|
|
if not (0 <= label <= 1):
|
|
return False
|
|
else:
|
|
return False
|
|
return True
|
|
|
|
elif format_type == "candidates":
|
|
# Check required fields
|
|
if not ('query' in item and 'candidates' in item):
|
|
return False
|
|
# Validate query
|
|
if not isinstance(item['query'], str) or not item['query'].strip():
|
|
return False
|
|
# Validate candidates
|
|
if not isinstance(item['candidates'], list) or len(item['candidates']) == 0:
|
|
return False
|
|
# Check each candidate
|
|
for i, candidate in enumerate(item['candidates']):
|
|
if not isinstance(candidate, dict):
|
|
return False
|
|
if 'text' not in candidate:
|
|
return False
|
|
if not isinstance(candidate['text'], str) or not candidate['text'].strip():
|
|
return False
|
|
# Validate score if present
|
|
if 'score' in candidate:
|
|
score = candidate['score']
|
|
if not isinstance(score, (int, float)):
|
|
return False
|
|
return True
|
|
|
|
elif format_type == "legacy_nested":
|
|
# Basic validation for legacy format
|
|
if not ('query' in item and 'pos' in item):
|
|
return False
|
|
if not isinstance(item['query'], str) or not item['query'].strip():
|
|
return False
|
|
return True
|
|
|
|
else:
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Validation error: {e}")
|
|
return False
|
|
|
|
def _process_item(self, item: Dict, format_type: str) -> Union[Dict, List[Dict]]:
|
|
"""Process item based on format type - may return multiple items for candidates"""
|
|
if format_type == "triplets":
|
|
return self._process_triplets_item(item)
|
|
elif format_type == "pairs":
|
|
return self._process_pairs_item(item)
|
|
elif format_type == "candidates":
|
|
return self._process_candidates_item(item)
|
|
elif format_type == "legacy_nested":
|
|
return self._process_legacy_nested_item(item)
|
|
else:
|
|
raise ValueError(f"Unknown format type: {format_type}")
|
|
|
|
def _process_triplets_item(self, item: Dict) -> Dict:
|
|
"""Process triplets format item for embedding training"""
|
|
processed = {
|
|
'query': item['query'],
|
|
'pos': item.get('pos', []),
|
|
'neg': item.get('neg', []),
|
|
'pos_scores': item.get('pos_scores', [1.0] * len(item.get('pos', []))),
|
|
'neg_scores': item.get('neg_scores', [0.0] * len(item.get('neg', []))),
|
|
'prompt': item.get('prompt', ''),
|
|
'type': item.get('type', 'normal'),
|
|
'format': 'triplets'
|
|
}
|
|
return processed
|
|
|
|
def _process_pairs_item(self, item: Dict) -> Dict:
|
|
"""Process pairs format item for reranker training"""
|
|
processed = {
|
|
'query': item['query'],
|
|
'passage': item['passage'],
|
|
'label': item['label'],
|
|
'score': item.get('score', 1.0 if item['label'] > 0 else 0.0),
|
|
'qid': item.get('qid', ''),
|
|
'pid': item.get('pid', ''),
|
|
'format': 'pairs'
|
|
}
|
|
return processed
|
|
|
|
def _process_candidates_item(self, item: Dict) -> List[Dict]:
|
|
"""
|
|
Process candidates format item - explode into multiple pairs
|
|
Each candidate becomes a separate (query, passage, label) pair
|
|
Label is assigned based on score threshold
|
|
"""
|
|
query = item['query']
|
|
candidates = item['candidates']
|
|
qid = item.get('qid', '')
|
|
|
|
pairs = []
|
|
for i, candidate in enumerate(candidates):
|
|
if isinstance(candidate, dict):
|
|
text = candidate.get('text', '')
|
|
score = candidate.get('score', 0.0)
|
|
pid = candidate.get('pid', f'cand_{i}')
|
|
else:
|
|
# Handle simple string candidates
|
|
text = str(candidate)
|
|
score = 1.0 # Default to positive if no score
|
|
pid = f'cand_{i}'
|
|
|
|
# Assign label based on threshold
|
|
label = 1 if score >= self.candidates_threshold else 0
|
|
|
|
pair = {
|
|
'query': query,
|
|
'passage': text,
|
|
'label': label,
|
|
'score': score,
|
|
'qid': qid,
|
|
'pid': pid,
|
|
'format': 'pairs', # Convert to pairs format
|
|
'source': 'candidates' # Track original source
|
|
}
|
|
pairs.append(pair)
|
|
|
|
return pairs
|
|
|
|
def _process_legacy_nested_item(self, item: Dict) -> List[Dict]:
|
|
"""
|
|
Process legacy nested format - convert to flat pairs for reranker training
|
|
Each pos[i] becomes (query, passage, label=1)
|
|
Each neg[j] becomes (query, passage, label=0)
|
|
"""
|
|
query = item['query']
|
|
positives = item.get('pos', [])
|
|
negatives = item.get('neg', [])
|
|
pos_scores = item.get('pos_scores', [1.0] * len(positives))
|
|
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
|
|
qid = item.get('qid', '')
|
|
|
|
pairs = []
|
|
|
|
# Convert positives
|
|
for i, passage in enumerate(positives):
|
|
score = pos_scores[i] if i < len(pos_scores) else 1.0
|
|
pair = {
|
|
'query': query,
|
|
'passage': passage,
|
|
'label': 1,
|
|
'score': score,
|
|
'qid': qid,
|
|
'pid': f'pos_{i}',
|
|
'format': 'pairs', # Convert to pairs format
|
|
'source': 'legacy_nested' # Track original source
|
|
}
|
|
pairs.append(pair)
|
|
|
|
# Convert negatives
|
|
for i, passage in enumerate(negatives):
|
|
score = neg_scores[i] if i < len(neg_scores) else 0.0
|
|
pair = {
|
|
'query': query,
|
|
'passage': passage,
|
|
'label': 0,
|
|
'score': score,
|
|
'qid': qid,
|
|
'pid': f'neg_{i}',
|
|
'format': 'pairs', # Convert to pairs format
|
|
'source': 'legacy_nested' # Track original source
|
|
}
|
|
pairs.append(pair)
|
|
|
|
return pairs
|
|
|
|
def _pretokenize_all(self):
|
|
"""Pre-tokenize all samples for in-memory caching"""
|
|
from tqdm import tqdm
|
|
for idx in tqdm(range(len(self.data)), desc="Pre-tokenizing"):
|
|
# Tokenize based on format
|
|
item = self.data[idx]
|
|
if item['format'] == 'triplets':
|
|
# Pre-tokenize query and passages for triplets
|
|
query = item['query']
|
|
query_encoding = self.tokenizer(
|
|
query,
|
|
max_length=self.query_max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
)
|
|
self._tokenized_cache[f"{idx}_query"] = {
|
|
'input_ids': query_encoding.input_ids.squeeze(0),
|
|
'attention_mask': query_encoding.attention_mask.squeeze(0)
|
|
}
|
|
elif item['format'] == 'pairs':
|
|
# Pre-tokenize query-passage pair
|
|
query = item['query']
|
|
passage = item['passage']
|
|
pair_text = f"{query} [SEP] {passage}"
|
|
encoding = self.tokenizer(
|
|
pair_text,
|
|
max_length=self.query_max_length + self.passage_max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
)
|
|
self._tokenized_cache[idx] = {
|
|
'input_ids': encoding.input_ids.squeeze(0),
|
|
'attention_mask': encoding.attention_mask.squeeze(0)
|
|
}
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx: int) -> Dict:
|
|
"""Get item based on detected format"""
|
|
item = self.data[idx]
|
|
|
|
if item['format'] == 'triplets':
|
|
return self._get_triplets_item(item, idx)
|
|
elif item['format'] == 'pairs':
|
|
return self._get_pairs_item(item, idx)
|
|
else:
|
|
raise ValueError(f"Unknown format: {item['format']}")
|
|
|
|
def _get_triplets_item(self, item: Dict, idx: int) -> Dict:
|
|
"""Get triplets training example for embedding model (contrastive learning)"""
|
|
query = item['query']
|
|
positives = item['pos']
|
|
negatives = item['neg']
|
|
pos_scores = item['pos_scores']
|
|
neg_scores = item['neg_scores']
|
|
prompt = item['prompt']
|
|
|
|
# Apply prompt if provided
|
|
if prompt and self.is_train:
|
|
query = f"{prompt}{query}"
|
|
|
|
# Check if we have cached tokenized query
|
|
if f"{idx}_query" in self._tokenized_cache:
|
|
query_encodings = self._tokenized_cache[f"{idx}_query"]
|
|
else:
|
|
# Tokenize query
|
|
query_encodings = self.tokenizer(
|
|
query,
|
|
max_length=self.query_max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
)
|
|
query_encodings = {
|
|
'input_ids': query_encodings.input_ids.squeeze(0),
|
|
'attention_mask': query_encodings.attention_mask.squeeze(0)
|
|
}
|
|
|
|
# Enhanced sampling for contrastive learning
|
|
if self.is_train:
|
|
# Sample multiple positives for richer contrastive signals
|
|
num_positives = min(2, len(positives))
|
|
sampled_positives = random.sample(positives, num_positives) if positives else []
|
|
|
|
# Hard negative mining with score-based sampling
|
|
num_negatives = min(6, len(negatives))
|
|
if negatives and neg_scores:
|
|
# Weight by scores for hard negative selection
|
|
neg_weights = np.array(neg_scores)
|
|
neg_weights = neg_weights / neg_weights.sum() if neg_weights.sum() > 0 else np.ones_like(neg_weights)
|
|
sampled_neg_indices = np.random.choice(
|
|
len(negatives), size=min(num_negatives, len(negatives)),
|
|
replace=False, p=neg_weights
|
|
)
|
|
sampled_negatives = [negatives[i] for i in sampled_neg_indices]
|
|
sampled_neg_scores = [neg_scores[i] for i in sampled_neg_indices]
|
|
else:
|
|
sampled_negatives = random.sample(negatives, num_negatives) if negatives else []
|
|
sampled_neg_scores = [0.0] * len(sampled_negatives)
|
|
|
|
passages = sampled_positives + sampled_negatives
|
|
labels = [1] * len(sampled_positives) + [0] * len(sampled_negatives)
|
|
scores = pos_scores[:len(sampled_positives)] + sampled_neg_scores
|
|
else:
|
|
# Use all for evaluation
|
|
passages = positives + negatives
|
|
labels = [1] * len(positives) + [0] * len(negatives)
|
|
scores = pos_scores + neg_scores
|
|
|
|
# Tokenize passages
|
|
passage_encodings = []
|
|
for passage in passages:
|
|
encoding = self.tokenizer(
|
|
passage,
|
|
max_length=self.passage_max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
)
|
|
passage_encodings.append({
|
|
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
|
|
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0)
|
|
})
|
|
|
|
# Pad to consistent size (max 8 passages for triplets)
|
|
max_passages = 8
|
|
while len(passage_encodings) < max_passages:
|
|
passage_encodings.append({
|
|
'input_ids': torch.zeros(self.passage_max_length, dtype=torch.long),
|
|
'attention_mask': torch.zeros(self.passage_max_length, dtype=torch.long)
|
|
})
|
|
labels.append(-1)
|
|
scores.append(0.0)
|
|
|
|
# Stack tensors
|
|
passage_input_ids = torch.stack([p['input_ids'] for p in passage_encodings])
|
|
passage_attention_mask = torch.stack([p['attention_mask'] for p in passage_encodings])
|
|
|
|
return {
|
|
'query_input_ids': query_encodings['input_ids'],
|
|
'query_attention_mask': query_encodings['attention_mask'],
|
|
'passage_input_ids': passage_input_ids,
|
|
'passage_attention_mask': passage_attention_mask,
|
|
'labels': torch.tensor(labels, dtype=torch.long),
|
|
'scores': torch.tensor(scores, dtype=torch.float),
|
|
'query_text': query,
|
|
'passage_texts': passages + [''] * (max_passages - len(passages)),
|
|
'item_idx': idx,
|
|
'format': 'triplets',
|
|
# Backward compatibility
|
|
'positives': positives,
|
|
'negatives': negatives
|
|
}
|
|
|
|
def _get_pairs_item(self, item: Dict, idx: int) -> Dict:
|
|
"""Get pairs training example for reranker (cross-encoder training)"""
|
|
query = item['query']
|
|
passage = item['passage']
|
|
label = item['label']
|
|
score = item['score']
|
|
|
|
# Check if we have cached tokenized data
|
|
if idx in self._tokenized_cache:
|
|
encoding = self._tokenized_cache[idx]
|
|
return {
|
|
'input_ids': encoding['input_ids'],
|
|
'attention_mask': encoding['attention_mask'],
|
|
'labels': torch.tensor(label, dtype=torch.long),
|
|
'scores': torch.tensor(score, dtype=torch.float),
|
|
'query_text': query,
|
|
'passage_text': passage,
|
|
'qid': item.get('qid', ''),
|
|
'pid': item.get('pid', ''),
|
|
'item_idx': idx,
|
|
'format': 'pairs',
|
|
'source': item.get('source', 'original')
|
|
}
|
|
|
|
# Tokenize query-passage pair if not cached
|
|
encoding = self.tokenizer(
|
|
query,
|
|
passage,
|
|
max_length=self.passage_max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
)
|
|
|
|
return {
|
|
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
|
|
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0),
|
|
'labels': torch.tensor(label, dtype=torch.long),
|
|
'scores': torch.tensor(score, dtype=torch.float),
|
|
'query_text': query,
|
|
'passage_text': passage,
|
|
'qid': item['qid'],
|
|
'pid': item['pid'],
|
|
'item_idx': idx,
|
|
'format': 'pairs',
|
|
'source': item.get('source', 'original')
|
|
}
|
|
|
|
|
|
# Specialized dataset classes for backward compatibility
|
|
class BGEM3Dataset(BGEDataset):
|
|
"""BGE-M3 specialized dataset (embedding model) with enhanced contrastive learning"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_path: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
query_max_length: int = 64,
|
|
passage_max_length: int = 512,
|
|
is_train: bool = True,
|
|
train_group_size: int = 8,
|
|
cache_dir: Optional[str] = None
|
|
):
|
|
# Force embedding format
|
|
super().__init__(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
query_max_length=query_max_length,
|
|
passage_max_length=passage_max_length,
|
|
is_train=is_train,
|
|
cache_dir=cache_dir,
|
|
format_type="triplets"
|
|
)
|
|
self.train_group_size = train_group_size
|
|
|
|
# Warn if wrong format detected
|
|
if self.detected_format not in ['triplets']:
|
|
logger.warning(f"Expected triplets format for BGEM3Dataset, got {self.detected_format}")
|
|
|
|
|
|
class BGERerankerDataset(BGEDataset):
|
|
"""BGE-Reranker specialized dataset (cross-encoder model) with flat pairs support"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_path: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
query_max_length: int = 64,
|
|
passage_max_length: int = 448,
|
|
is_train: bool = True,
|
|
train_group_size: int = 16,
|
|
cache_dir: Optional[str] = None,
|
|
legacy_support: bool = True
|
|
):
|
|
# Force reranker format (or allow legacy)
|
|
format_type = "auto" if legacy_support else "pairs"
|
|
super().__init__(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
query_max_length=query_max_length,
|
|
passage_max_length=passage_max_length,
|
|
is_train=is_train,
|
|
cache_dir=cache_dir,
|
|
legacy_support=legacy_support,
|
|
format_type=format_type
|
|
)
|
|
self.train_group_size = train_group_size
|
|
|
|
# Warn if wrong format detected
|
|
if self.detected_format not in ['pairs', 'legacy_nested']:
|
|
logger.warning(f"Expected pairs format for BGERerankerDataset, got {self.detected_format}")
|
|
|
|
|
|
class JointDataset(Dataset):
|
|
"""Dataset for joint training of retriever and reranker"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_path: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
retriever_config: Dict,
|
|
reranker_config: Dict,
|
|
is_train: bool = True,
|
|
cache_dir: Optional[str] = None
|
|
):
|
|
self.tokenizer = tokenizer
|
|
self.is_train = is_train
|
|
|
|
# Create both datasets
|
|
self.retriever_dataset = BGEM3Dataset(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
is_train=is_train,
|
|
cache_dir=cache_dir,
|
|
**retriever_config
|
|
)
|
|
|
|
self.reranker_dataset = BGERerankerDataset(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
is_train=is_train,
|
|
cache_dir=cache_dir,
|
|
**reranker_config
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.retriever_dataset)
|
|
|
|
def __getitem__(self, idx: int) -> Dict:
|
|
"""Get example for both retriever and reranker"""
|
|
retriever_item = self.retriever_dataset[idx]
|
|
reranker_item = self.reranker_dataset[idx]
|
|
|
|
return {
|
|
'retriever': retriever_item,
|
|
'reranker': reranker_item,
|
|
'item_idx': idx
|
|
}
|
|
|
|
|
|
# Factory functions for easy dataset creation
|
|
def create_embedding_dataset(
|
|
data_path: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
query_max_length: int = 64,
|
|
passage_max_length: int = 512,
|
|
is_train: bool = True
|
|
) -> BGEDataset:
|
|
"""Create dataset specifically for embedding model training"""
|
|
dataset = BGEDataset(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
query_max_length=query_max_length,
|
|
passage_max_length=passage_max_length,
|
|
is_train=is_train,
|
|
format_type="triplets"
|
|
)
|
|
|
|
if dataset.detected_format not in ['triplets']:
|
|
logger.warning(f"Expected triplets format, got {dataset.detected_format}")
|
|
|
|
return dataset
|
|
|
|
|
|
def create_reranker_dataset(
|
|
data_path: str,
|
|
tokenizer: PreTrainedTokenizer,
|
|
query_max_length: int = 64,
|
|
passage_max_length: int = 448,
|
|
is_train: bool = True,
|
|
legacy_support: bool = True
|
|
) -> BGEDataset:
|
|
"""Create dataset specifically for reranker model training"""
|
|
dataset = BGEDataset(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
query_max_length=query_max_length,
|
|
passage_max_length=passage_max_length,
|
|
is_train=is_train,
|
|
legacy_support=legacy_support,
|
|
format_type="auto" if legacy_support else "pairs"
|
|
)
|
|
|
|
if dataset.detected_format not in ['pairs', 'legacy_nested']:
|
|
logger.warning(f"Expected pairs format, got {dataset.detected_format}")
|
|
|
|
return dataset
|
|
|
|
|
|
def migrate_nested_to_flat(input_file: str, output_file: str) -> str:
|
|
"""
|
|
Migrate legacy nested reranker JSONL to flat pairs format
|
|
Explodes each (pos[i], label=1) and (neg[j], label=0) into separate lines
|
|
"""
|
|
logger.info(f"Migrating nested reranker format: {input_file} -> {output_file}")
|
|
|
|
flat_pairs = []
|
|
|
|
with open(input_file, 'r', encoding='utf-8') as f:
|
|
for line_no, line in enumerate(tqdm(f, desc="Converting to flat pairs"), 1):
|
|
if line.strip():
|
|
try:
|
|
item = json.loads(line)
|
|
query = item['query']
|
|
positives = item.get('pos', [])
|
|
negatives = item.get('neg', [])
|
|
pos_scores = item.get('pos_scores', [1.0] * len(positives))
|
|
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
|
|
|
|
# Create flat pairs for positives
|
|
for i, passage in enumerate(positives):
|
|
flat_pairs.append({
|
|
'query': query,
|
|
'passage': passage,
|
|
'label': 1,
|
|
'score': pos_scores[i] if i < len(pos_scores) else 1.0,
|
|
'qid': f"Q{line_no}",
|
|
'pid': f"P{line_no}_{i}"
|
|
})
|
|
|
|
# Create flat pairs for negatives
|
|
for i, passage in enumerate(negatives):
|
|
flat_pairs.append({
|
|
'query': query,
|
|
'passage': passage,
|
|
'label': 0,
|
|
'score': neg_scores[i] if i < len(neg_scores) else 0.0,
|
|
'qid': f"Q{line_no}",
|
|
'pid': f"N{line_no}_{i}"
|
|
})
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Skipping invalid JSON on line {line_no}: {e}")
|
|
continue
|
|
|
|
# Save flat pairs
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
for pair in flat_pairs:
|
|
f.write(json.dumps(pair, ensure_ascii=False) + '\n')
|
|
|
|
logger.info(f"✅ Migrated {len(flat_pairs)} flat pairs to {output_file}")
|
|
return output_file
|
|
|
|
|
|
# Collate functions for different formats
|
|
def collate_embedding_batch(batch: List[Dict]) -> Dict:
|
|
"""Collate function for embedding model batches"""
|
|
return {
|
|
'query_input_ids': torch.stack([item['query_input_ids'] for item in batch]),
|
|
'query_attention_mask': torch.stack([item['query_attention_mask'] for item in batch]),
|
|
'passage_input_ids': torch.stack([item['passage_input_ids'] for item in batch]),
|
|
'passage_attention_mask': torch.stack([item['passage_attention_mask'] for item in batch]),
|
|
'labels': torch.stack([item['labels'] for item in batch]),
|
|
'scores': torch.stack([item['scores'] for item in batch]),
|
|
'query_text': [item['query_text'] for item in batch],
|
|
'passage_texts': [item['passage_texts'] for item in batch],
|
|
'item_idx': [item['item_idx'] for item in batch]
|
|
}
|
|
|
|
|
|
def collate_reranker_batch(batch: List[Dict]) -> Dict:
|
|
"""Collate function for reranker model batches"""
|
|
if batch[0]['format'] == 'pairs':
|
|
# Flat pairs format
|
|
return {
|
|
'input_ids': torch.stack([item['input_ids'] for item in batch]),
|
|
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
|
|
'labels': torch.stack([item['labels'] for item in batch]),
|
|
'scores': torch.stack([item['scores'] for item in batch]),
|
|
'query_text': [item['query_text'] for item in batch],
|
|
'passage_text': [item['passage_text'] for item in batch],
|
|
'qid': [item['qid'] for item in batch],
|
|
'pid': [item['pid'] for item in batch],
|
|
'item_idx': [item['item_idx'] for item in batch]
|
|
}
|
|
else:
|
|
# Legacy nested format
|
|
return {
|
|
'input_ids': torch.stack([item['input_ids'] for item in batch]),
|
|
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
|
|
'labels': torch.stack([item['labels'] for item in batch]),
|
|
'query_text': [item['query_text'] for item in batch],
|
|
'passage_texts': [item['passage_texts'] for item in batch],
|
|
'item_idx': [item['item_idx'] for item in batch]
|
|
}
|