bge_finetune/data/dataset.py
2025-07-22 16:55:25 +08:00

858 lines
34 KiB
Python

"""
Unified BGE dataset pipeline with comprehensive format support
- Triplets: Query-triplets format (query, pos, neg) for embedding training
- Pairs: Flat pairs format (query, passage, label) for reranker training
- Candidates: Unified format (query, candidates) with threshold-based label assignment
- Legacy: Supports nested formats with automatic conversion to flat pairs
- Enhanced: Score-weighted sampling, hard negative mining, multiple positives
"""
import json
import random
from typing import Dict, List, Optional, Union, Tuple
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
import logging
from tqdm import tqdm
import numpy as np
logger = logging.getLogger(__name__)
class BGEDataset(Dataset):
"""
Unified BGE dataset with automatic format detection and enhanced features
Supports: triplets, pairs, candidates, and legacy nested formats
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 512,
is_train: bool = True,
cache_dir: Optional[str] = None,
legacy_support: bool = True,
format_type: str = "auto", # "triplets", "pairs", "candidates", "legacy_nested", "auto"
candidates_threshold: float = 0.5, # Threshold for candidates format label assignment
cache_in_memory: bool = False, # Cache tokenized data in memory for small datasets
max_cache_size: int = 100000 # Maximum number of samples to cache in memory
):
self.tokenizer = tokenizer
self.query_max_length = query_max_length
self.passage_max_length = passage_max_length
self.is_train = is_train
self.legacy_support = legacy_support
self.format_type = format_type
self.candidates_threshold = candidates_threshold
self.cache_in_memory = cache_in_memory
self.max_cache_size = max_cache_size
self._tokenized_cache = {}
logger.info(f"Loading data from {data_path}")
self.data, self.detected_format = self._load_and_detect_format(data_path)
logger.info(f"Loaded {len(self.data)} examples in {self.detected_format} format")
# Pre-tokenize data if caching is enabled and dataset is small enough
if self.cache_in_memory and len(self.data) <= self.max_cache_size:
logger.info(f"Pre-tokenizing {len(self.data)} samples for in-memory cache...")
self._pretokenize_all()
logger.info("Pre-tokenization complete")
def _load_and_detect_format(self, data_path: str) -> Tuple[List[Dict], str]:
"""Load data and automatically detect format with improved detection logic"""
data = []
format_type = "unknown"
total_lines = 0
with open(data_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading and detecting format"), 1):
if line.strip():
total_lines += 1
try:
item = json.loads(line)
# Detect format on first valid item
if format_type == "unknown":
format_type = self._detect_schema_type(item)
logger.info(f"Detected schema: {format_type}")
# Validate and process item
if self._validate_item(item, format_type):
processed_items = self._process_item(item, format_type)
# Process_item may return multiple items (for candidates format)
if isinstance(processed_items, list):
data.extend(processed_items)
else:
data.append(processed_items)
else:
logger.warning(f"Invalid item at line {line_num}: {item}")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
continue
if not data:
raise ValueError(f"No valid data found in {data_path}")
# Validation check
if total_lines > 10 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(f"Few valid data lines loaded from {data_path}: {len(data)} out of {total_lines}")
return data, format_type
def _detect_schema_type(self, item: Dict) -> str:
"""
Enhanced format detection based on key presence and structure:
- triplets: has 'pos' and 'neg' keys (for embedding training)
- pairs: has 'passage' and 'label' keys (for reranker training)
- candidates: has 'candidates' key with list of candidate objects
- legacy_nested: has 'pos' but no 'label' (backward compatibility)
"""
try:
# Use explicitly set format if provided
if self.format_type != "auto":
return self.format_type
# Ensure item is a dictionary
if not isinstance(item, dict):
raise ValueError(f"Expected dict but got {type(item)}")
# Extract available keys for debugging
available_keys = set(item.keys())
# Check for candidates format first (most specific)
if 'candidates' in item:
# Validate candidates structure
candidates = item.get('candidates', [])
if isinstance(candidates, list) and len(candidates) > 0:
# Check if first candidate has expected structure
first_candidate = candidates[0]
if isinstance(first_candidate, dict) and 'text' in first_candidate:
return "candidates"
# Also support simple string list as candidates
elif isinstance(first_candidate, str):
logger.info("Detected candidates format with string list - will convert to dict format")
return "candidates"
# Check for pairs format (reranker)
if 'passage' in item and 'label' in item:
# Additional check for query field
if 'query' in item:
return "pairs"
else:
logger.warning(f"Found 'passage' and 'label' but missing 'query'. Available keys: {available_keys}")
# Check for triplets format (embedding)
if 'pos' in item:
# Standard triplets format has both pos and neg
if 'neg' in item:
return "triplets"
# Might be triplets with only positives
elif 'query' in item and not ('label' in item or 'passage' in item):
logger.info("Detected triplets format with only positive samples")
return "triplets"
# Legacy nested format
elif self.legacy_support and 'query' in item:
logger.warning("Detected legacy nested format. Converting to flat pairs for reranker training.")
return "legacy_nested"
# Try to provide helpful error message
common_keys = available_keys & {'query', 'pos', 'neg', 'passage', 'label', 'candidates'}
error_msg = f"Unknown data format. Available keys: {available_keys}"
if common_keys:
error_msg += f", Common keys found: {common_keys}"
else:
error_msg += ". No common keys found. Expected 'query' with 'pos'/'neg' (triplets), 'passage'/'label' (pairs), or 'candidates' (candidates format)"
raise ValueError(error_msg)
except Exception as e:
logger.error(f"Format detection failed: {e}")
raise
def _validate_item(self, item: Dict, format_type: str) -> bool:
"""
Validate item based on detected format with comprehensive checks
Returns:
bool: True if item is valid, False otherwise
"""
try:
# Basic type check
if not isinstance(item, dict):
return False
# Format-specific validation
if format_type == "triplets":
# Check required fields
if not all(field in item for field in ['query', 'pos']):
return False
# Validate query
if not isinstance(item['query'], str) or not item['query'].strip():
return False
# Validate positives
if not isinstance(item.get('pos', []), list) or len(item['pos']) == 0:
return False
# Check each positive is a non-empty string
for pos in item['pos']:
if not isinstance(pos, str) or not pos.strip():
return False
# Validate negatives if present
if 'neg' in item:
if not isinstance(item['neg'], list):
return False
for neg in item['neg']:
if not isinstance(neg, str) or not neg.strip():
return False
return True
elif format_type == "pairs":
# Check required fields
if not all(field in item for field in ['query', 'passage', 'label']):
return False
# Validate query and passage
if not isinstance(item['query'], str) or not item['query'].strip():
return False
if not isinstance(item['passage'], str) or not item['passage'].strip():
return False
# Validate label (should be numeric 0 or 1, or float between 0 and 1)
label = item['label']
if isinstance(label, (int, float)):
if not (0 <= label <= 1):
return False
else:
return False
return True
elif format_type == "candidates":
# Check required fields
if not ('query' in item and 'candidates' in item):
return False
# Validate query
if not isinstance(item['query'], str) or not item['query'].strip():
return False
# Validate candidates
if not isinstance(item['candidates'], list) or len(item['candidates']) == 0:
return False
# Check each candidate
for i, candidate in enumerate(item['candidates']):
if not isinstance(candidate, dict):
return False
if 'text' not in candidate:
return False
if not isinstance(candidate['text'], str) or not candidate['text'].strip():
return False
# Validate score if present
if 'score' in candidate:
score = candidate['score']
if not isinstance(score, (int, float)):
return False
return True
elif format_type == "legacy_nested":
# Basic validation for legacy format
if not ('query' in item and 'pos' in item):
return False
if not isinstance(item['query'], str) or not item['query'].strip():
return False
return True
else:
return False
except Exception as e:
logger.debug(f"Validation error: {e}")
return False
def _process_item(self, item: Dict, format_type: str) -> Union[Dict, List[Dict]]:
"""Process item based on format type - may return multiple items for candidates"""
if format_type == "triplets":
return self._process_triplets_item(item)
elif format_type == "pairs":
return self._process_pairs_item(item)
elif format_type == "candidates":
return self._process_candidates_item(item)
elif format_type == "legacy_nested":
return self._process_legacy_nested_item(item)
else:
raise ValueError(f"Unknown format type: {format_type}")
def _process_triplets_item(self, item: Dict) -> Dict:
"""Process triplets format item for embedding training"""
processed = {
'query': item['query'],
'pos': item.get('pos', []),
'neg': item.get('neg', []),
'pos_scores': item.get('pos_scores', [1.0] * len(item.get('pos', []))),
'neg_scores': item.get('neg_scores', [0.0] * len(item.get('neg', []))),
'prompt': item.get('prompt', ''),
'type': item.get('type', 'normal'),
'format': 'triplets'
}
return processed
def _process_pairs_item(self, item: Dict) -> Dict:
"""Process pairs format item for reranker training"""
processed = {
'query': item['query'],
'passage': item['passage'],
'label': item['label'],
'score': item.get('score', 1.0 if item['label'] > 0 else 0.0),
'qid': item.get('qid', ''),
'pid': item.get('pid', ''),
'format': 'pairs'
}
return processed
def _process_candidates_item(self, item: Dict) -> List[Dict]:
"""
Process candidates format item - explode into multiple pairs
Each candidate becomes a separate (query, passage, label) pair
Label is assigned based on score threshold
"""
query = item['query']
candidates = item['candidates']
qid = item.get('qid', '')
pairs = []
for i, candidate in enumerate(candidates):
if isinstance(candidate, dict):
text = candidate.get('text', '')
score = candidate.get('score', 0.0)
pid = candidate.get('pid', f'cand_{i}')
else:
# Handle simple string candidates
text = str(candidate)
score = 1.0 # Default to positive if no score
pid = f'cand_{i}'
# Assign label based on threshold
label = 1 if score >= self.candidates_threshold else 0
pair = {
'query': query,
'passage': text,
'label': label,
'score': score,
'qid': qid,
'pid': pid,
'format': 'pairs', # Convert to pairs format
'source': 'candidates' # Track original source
}
pairs.append(pair)
return pairs
def _process_legacy_nested_item(self, item: Dict) -> List[Dict]:
"""
Process legacy nested format - convert to flat pairs for reranker training
Each pos[i] becomes (query, passage, label=1)
Each neg[j] becomes (query, passage, label=0)
"""
query = item['query']
positives = item.get('pos', [])
negatives = item.get('neg', [])
pos_scores = item.get('pos_scores', [1.0] * len(positives))
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
qid = item.get('qid', '')
pairs = []
# Convert positives
for i, passage in enumerate(positives):
score = pos_scores[i] if i < len(pos_scores) else 1.0
pair = {
'query': query,
'passage': passage,
'label': 1,
'score': score,
'qid': qid,
'pid': f'pos_{i}',
'format': 'pairs', # Convert to pairs format
'source': 'legacy_nested' # Track original source
}
pairs.append(pair)
# Convert negatives
for i, passage in enumerate(negatives):
score = neg_scores[i] if i < len(neg_scores) else 0.0
pair = {
'query': query,
'passage': passage,
'label': 0,
'score': score,
'qid': qid,
'pid': f'neg_{i}',
'format': 'pairs', # Convert to pairs format
'source': 'legacy_nested' # Track original source
}
pairs.append(pair)
return pairs
def _pretokenize_all(self):
"""Pre-tokenize all samples for in-memory caching"""
from tqdm import tqdm
for idx in tqdm(range(len(self.data)), desc="Pre-tokenizing"):
# Tokenize based on format
item = self.data[idx]
if item['format'] == 'triplets':
# Pre-tokenize query and passages for triplets
query = item['query']
query_encoding = self.tokenizer(
query,
max_length=self.query_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
self._tokenized_cache[f"{idx}_query"] = {
'input_ids': query_encoding.input_ids.squeeze(0),
'attention_mask': query_encoding.attention_mask.squeeze(0)
}
elif item['format'] == 'pairs':
# Pre-tokenize query-passage pair
query = item['query']
passage = item['passage']
pair_text = f"{query} [SEP] {passage}"
encoding = self.tokenizer(
pair_text,
max_length=self.query_max_length + self.passage_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
self._tokenized_cache[idx] = {
'input_ids': encoding.input_ids.squeeze(0),
'attention_mask': encoding.attention_mask.squeeze(0)
}
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict:
"""Get item based on detected format"""
item = self.data[idx]
if item['format'] == 'triplets':
return self._get_triplets_item(item, idx)
elif item['format'] == 'pairs':
return self._get_pairs_item(item, idx)
else:
raise ValueError(f"Unknown format: {item['format']}")
def _get_triplets_item(self, item: Dict, idx: int) -> Dict:
"""Get triplets training example for embedding model (contrastive learning)"""
query = item['query']
positives = item['pos']
negatives = item['neg']
pos_scores = item['pos_scores']
neg_scores = item['neg_scores']
prompt = item['prompt']
# Apply prompt if provided
if prompt and self.is_train:
query = f"{prompt}{query}"
# Check if we have cached tokenized query
if f"{idx}_query" in self._tokenized_cache:
query_encodings = self._tokenized_cache[f"{idx}_query"]
else:
# Tokenize query
query_encodings = self.tokenizer(
query,
max_length=self.query_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
query_encodings = {
'input_ids': query_encodings.input_ids.squeeze(0),
'attention_mask': query_encodings.attention_mask.squeeze(0)
}
# Enhanced sampling for contrastive learning
if self.is_train:
# Sample multiple positives for richer contrastive signals
num_positives = min(2, len(positives))
sampled_positives = random.sample(positives, num_positives) if positives else []
# Hard negative mining with score-based sampling
num_negatives = min(6, len(negatives))
if negatives and neg_scores:
# Weight by scores for hard negative selection
neg_weights = np.array(neg_scores)
neg_weights = neg_weights / neg_weights.sum() if neg_weights.sum() > 0 else np.ones_like(neg_weights)
sampled_neg_indices = np.random.choice(
len(negatives), size=min(num_negatives, len(negatives)),
replace=False, p=neg_weights
)
sampled_negatives = [negatives[i] for i in sampled_neg_indices]
sampled_neg_scores = [neg_scores[i] for i in sampled_neg_indices]
else:
sampled_negatives = random.sample(negatives, num_negatives) if negatives else []
sampled_neg_scores = [0.0] * len(sampled_negatives)
passages = sampled_positives + sampled_negatives
labels = [1] * len(sampled_positives) + [0] * len(sampled_negatives)
scores = pos_scores[:len(sampled_positives)] + sampled_neg_scores
else:
# Use all for evaluation
passages = positives + negatives
labels = [1] * len(positives) + [0] * len(negatives)
scores = pos_scores + neg_scores
# Tokenize passages
passage_encodings = []
for passage in passages:
encoding = self.tokenizer(
passage,
max_length=self.passage_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
passage_encodings.append({
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0)
})
# Pad to consistent size (max 8 passages for triplets)
max_passages = 8
while len(passage_encodings) < max_passages:
passage_encodings.append({
'input_ids': torch.zeros(self.passage_max_length, dtype=torch.long),
'attention_mask': torch.zeros(self.passage_max_length, dtype=torch.long)
})
labels.append(-1)
scores.append(0.0)
# Stack tensors
passage_input_ids = torch.stack([p['input_ids'] for p in passage_encodings])
passage_attention_mask = torch.stack([p['attention_mask'] for p in passage_encodings])
return {
'query_input_ids': query_encodings['input_ids'],
'query_attention_mask': query_encodings['attention_mask'],
'passage_input_ids': passage_input_ids,
'passage_attention_mask': passage_attention_mask,
'labels': torch.tensor(labels, dtype=torch.long),
'scores': torch.tensor(scores, dtype=torch.float),
'query_text': query,
'passage_texts': passages + [''] * (max_passages - len(passages)),
'item_idx': idx,
'format': 'triplets',
# Backward compatibility
'positives': positives,
'negatives': negatives
}
def _get_pairs_item(self, item: Dict, idx: int) -> Dict:
"""Get pairs training example for reranker (cross-encoder training)"""
query = item['query']
passage = item['passage']
label = item['label']
score = item['score']
# Check if we have cached tokenized data
if idx in self._tokenized_cache:
encoding = self._tokenized_cache[idx]
return {
'input_ids': encoding['input_ids'],
'attention_mask': encoding['attention_mask'],
'labels': torch.tensor(label, dtype=torch.long),
'scores': torch.tensor(score, dtype=torch.float),
'query_text': query,
'passage_text': passage,
'qid': item.get('qid', ''),
'pid': item.get('pid', ''),
'item_idx': idx,
'format': 'pairs',
'source': item.get('source', 'original')
}
# Tokenize query-passage pair if not cached
encoding = self.tokenizer(
query,
passage,
max_length=self.passage_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0),
'labels': torch.tensor(label, dtype=torch.long),
'scores': torch.tensor(score, dtype=torch.float),
'query_text': query,
'passage_text': passage,
'qid': item['qid'],
'pid': item['pid'],
'item_idx': idx,
'format': 'pairs',
'source': item.get('source', 'original')
}
# Specialized dataset classes for backward compatibility
class BGEM3Dataset(BGEDataset):
"""BGE-M3 specialized dataset (embedding model) with enhanced contrastive learning"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 512,
is_train: bool = True,
train_group_size: int = 8,
cache_dir: Optional[str] = None
):
# Force embedding format
super().__init__(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
cache_dir=cache_dir,
format_type="triplets"
)
self.train_group_size = train_group_size
# Warn if wrong format detected
if self.detected_format not in ['triplets']:
logger.warning(f"Expected triplets format for BGEM3Dataset, got {self.detected_format}")
class BGERerankerDataset(BGEDataset):
"""BGE-Reranker specialized dataset (cross-encoder model) with flat pairs support"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 448,
is_train: bool = True,
train_group_size: int = 16,
cache_dir: Optional[str] = None,
legacy_support: bool = True
):
# Force reranker format (or allow legacy)
format_type = "auto" if legacy_support else "pairs"
super().__init__(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
cache_dir=cache_dir,
legacy_support=legacy_support,
format_type=format_type
)
self.train_group_size = train_group_size
# Warn if wrong format detected
if self.detected_format not in ['pairs', 'legacy_nested']:
logger.warning(f"Expected pairs format for BGERerankerDataset, got {self.detected_format}")
class JointDataset(Dataset):
"""Dataset for joint training of retriever and reranker"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
retriever_config: Dict,
reranker_config: Dict,
is_train: bool = True,
cache_dir: Optional[str] = None
):
self.tokenizer = tokenizer
self.is_train = is_train
# Create both datasets
self.retriever_dataset = BGEM3Dataset(
data_path=data_path,
tokenizer=tokenizer,
is_train=is_train,
cache_dir=cache_dir,
**retriever_config
)
self.reranker_dataset = BGERerankerDataset(
data_path=data_path,
tokenizer=tokenizer,
is_train=is_train,
cache_dir=cache_dir,
**reranker_config
)
def __len__(self) -> int:
return len(self.retriever_dataset)
def __getitem__(self, idx: int) -> Dict:
"""Get example for both retriever and reranker"""
retriever_item = self.retriever_dataset[idx]
reranker_item = self.reranker_dataset[idx]
return {
'retriever': retriever_item,
'reranker': reranker_item,
'item_idx': idx
}
# Factory functions for easy dataset creation
def create_embedding_dataset(
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 512,
is_train: bool = True
) -> BGEDataset:
"""Create dataset specifically for embedding model training"""
dataset = BGEDataset(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
format_type="triplets"
)
if dataset.detected_format not in ['triplets']:
logger.warning(f"Expected triplets format, got {dataset.detected_format}")
return dataset
def create_reranker_dataset(
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 448,
is_train: bool = True,
legacy_support: bool = True
) -> BGEDataset:
"""Create dataset specifically for reranker model training"""
dataset = BGEDataset(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
legacy_support=legacy_support,
format_type="auto" if legacy_support else "pairs"
)
if dataset.detected_format not in ['pairs', 'legacy_nested']:
logger.warning(f"Expected pairs format, got {dataset.detected_format}")
return dataset
def migrate_nested_to_flat(input_file: str, output_file: str) -> str:
"""
Migrate legacy nested reranker JSONL to flat pairs format
Explodes each (pos[i], label=1) and (neg[j], label=0) into separate lines
"""
logger.info(f"Migrating nested reranker format: {input_file} -> {output_file}")
flat_pairs = []
with open(input_file, 'r', encoding='utf-8') as f:
for line_no, line in enumerate(tqdm(f, desc="Converting to flat pairs"), 1):
if line.strip():
try:
item = json.loads(line)
query = item['query']
positives = item.get('pos', [])
negatives = item.get('neg', [])
pos_scores = item.get('pos_scores', [1.0] * len(positives))
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
# Create flat pairs for positives
for i, passage in enumerate(positives):
flat_pairs.append({
'query': query,
'passage': passage,
'label': 1,
'score': pos_scores[i] if i < len(pos_scores) else 1.0,
'qid': f"Q{line_no}",
'pid': f"P{line_no}_{i}"
})
# Create flat pairs for negatives
for i, passage in enumerate(negatives):
flat_pairs.append({
'query': query,
'passage': passage,
'label': 0,
'score': neg_scores[i] if i < len(neg_scores) else 0.0,
'qid': f"Q{line_no}",
'pid': f"N{line_no}_{i}"
})
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON on line {line_no}: {e}")
continue
# Save flat pairs
with open(output_file, 'w', encoding='utf-8') as f:
for pair in flat_pairs:
f.write(json.dumps(pair, ensure_ascii=False) + '\n')
logger.info(f"✅ Migrated {len(flat_pairs)} flat pairs to {output_file}")
return output_file
# Collate functions for different formats
def collate_embedding_batch(batch: List[Dict]) -> Dict:
"""Collate function for embedding model batches"""
return {
'query_input_ids': torch.stack([item['query_input_ids'] for item in batch]),
'query_attention_mask': torch.stack([item['query_attention_mask'] for item in batch]),
'passage_input_ids': torch.stack([item['passage_input_ids'] for item in batch]),
'passage_attention_mask': torch.stack([item['passage_attention_mask'] for item in batch]),
'labels': torch.stack([item['labels'] for item in batch]),
'scores': torch.stack([item['scores'] for item in batch]),
'query_text': [item['query_text'] for item in batch],
'passage_texts': [item['passage_texts'] for item in batch],
'item_idx': [item['item_idx'] for item in batch]
}
def collate_reranker_batch(batch: List[Dict]) -> Dict:
"""Collate function for reranker model batches"""
if batch[0]['format'] == 'pairs':
# Flat pairs format
return {
'input_ids': torch.stack([item['input_ids'] for item in batch]),
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
'labels': torch.stack([item['labels'] for item in batch]),
'scores': torch.stack([item['scores'] for item in batch]),
'query_text': [item['query_text'] for item in batch],
'passage_text': [item['passage_text'] for item in batch],
'qid': [item['qid'] for item in batch],
'pid': [item['pid'] for item in batch],
'item_idx': [item['item_idx'] for item in batch]
}
else:
# Legacy nested format
return {
'input_ids': torch.stack([item['input_ids'] for item in batch]),
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
'labels': torch.stack([item['labels'] for item in batch]),
'query_text': [item['query_text'] for item in batch],
'passage_texts': [item['passage_texts'] for item in batch],
'item_idx': [item['item_idx'] for item in batch]
}