init commit
This commit is contained in:
857
data/dataset.py
Normal file
857
data/dataset.py
Normal file
@@ -0,0 +1,857 @@
|
||||
"""
|
||||
Unified BGE dataset pipeline with comprehensive format support
|
||||
- Triplets: Query-triplets format (query, pos, neg) for embedding training
|
||||
- Pairs: Flat pairs format (query, passage, label) for reranker training
|
||||
- Candidates: Unified format (query, candidates) with threshold-based label assignment
|
||||
- Legacy: Supports nested formats with automatic conversion to flat pairs
|
||||
- Enhanced: Score-weighted sampling, hard negative mining, multiple positives
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGEDataset(Dataset):
|
||||
"""
|
||||
Unified BGE dataset with automatic format detection and enhanced features
|
||||
Supports: triplets, pairs, candidates, and legacy nested formats
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 512,
|
||||
is_train: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
legacy_support: bool = True,
|
||||
format_type: str = "auto", # "triplets", "pairs", "candidates", "legacy_nested", "auto"
|
||||
candidates_threshold: float = 0.5, # Threshold for candidates format label assignment
|
||||
cache_in_memory: bool = False, # Cache tokenized data in memory for small datasets
|
||||
max_cache_size: int = 100000 # Maximum number of samples to cache in memory
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.query_max_length = query_max_length
|
||||
self.passage_max_length = passage_max_length
|
||||
self.is_train = is_train
|
||||
self.legacy_support = legacy_support
|
||||
self.format_type = format_type
|
||||
self.candidates_threshold = candidates_threshold
|
||||
self.cache_in_memory = cache_in_memory
|
||||
self.max_cache_size = max_cache_size
|
||||
self._tokenized_cache = {}
|
||||
|
||||
logger.info(f"Loading data from {data_path}")
|
||||
self.data, self.detected_format = self._load_and_detect_format(data_path)
|
||||
logger.info(f"Loaded {len(self.data)} examples in {self.detected_format} format")
|
||||
|
||||
# Pre-tokenize data if caching is enabled and dataset is small enough
|
||||
if self.cache_in_memory and len(self.data) <= self.max_cache_size:
|
||||
logger.info(f"Pre-tokenizing {len(self.data)} samples for in-memory cache...")
|
||||
self._pretokenize_all()
|
||||
logger.info("Pre-tokenization complete")
|
||||
|
||||
def _load_and_detect_format(self, data_path: str) -> Tuple[List[Dict], str]:
|
||||
"""Load data and automatically detect format with improved detection logic"""
|
||||
data = []
|
||||
format_type = "unknown"
|
||||
total_lines = 0
|
||||
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading and detecting format"), 1):
|
||||
if line.strip():
|
||||
total_lines += 1
|
||||
try:
|
||||
item = json.loads(line)
|
||||
|
||||
# Detect format on first valid item
|
||||
if format_type == "unknown":
|
||||
format_type = self._detect_schema_type(item)
|
||||
logger.info(f"Detected schema: {format_type}")
|
||||
|
||||
# Validate and process item
|
||||
if self._validate_item(item, format_type):
|
||||
processed_items = self._process_item(item, format_type)
|
||||
# Process_item may return multiple items (for candidates format)
|
||||
if isinstance(processed_items, list):
|
||||
data.extend(processed_items)
|
||||
else:
|
||||
data.append(processed_items)
|
||||
else:
|
||||
logger.warning(f"Invalid item at line {line_num}: {item}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON at line {line_num}: {e}")
|
||||
continue
|
||||
|
||||
if not data:
|
||||
raise ValueError(f"No valid data found in {data_path}")
|
||||
|
||||
# Validation check
|
||||
if total_lines > 10 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(f"Few valid data lines loaded from {data_path}: {len(data)} out of {total_lines}")
|
||||
|
||||
return data, format_type
|
||||
|
||||
def _detect_schema_type(self, item: Dict) -> str:
|
||||
"""
|
||||
Enhanced format detection based on key presence and structure:
|
||||
- triplets: has 'pos' and 'neg' keys (for embedding training)
|
||||
- pairs: has 'passage' and 'label' keys (for reranker training)
|
||||
- candidates: has 'candidates' key with list of candidate objects
|
||||
- legacy_nested: has 'pos' but no 'label' (backward compatibility)
|
||||
"""
|
||||
try:
|
||||
# Use explicitly set format if provided
|
||||
if self.format_type != "auto":
|
||||
return self.format_type
|
||||
|
||||
# Ensure item is a dictionary
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(f"Expected dict but got {type(item)}")
|
||||
|
||||
# Extract available keys for debugging
|
||||
available_keys = set(item.keys())
|
||||
|
||||
# Check for candidates format first (most specific)
|
||||
if 'candidates' in item:
|
||||
# Validate candidates structure
|
||||
candidates = item.get('candidates', [])
|
||||
if isinstance(candidates, list) and len(candidates) > 0:
|
||||
# Check if first candidate has expected structure
|
||||
first_candidate = candidates[0]
|
||||
if isinstance(first_candidate, dict) and 'text' in first_candidate:
|
||||
return "candidates"
|
||||
# Also support simple string list as candidates
|
||||
elif isinstance(first_candidate, str):
|
||||
logger.info("Detected candidates format with string list - will convert to dict format")
|
||||
return "candidates"
|
||||
|
||||
# Check for pairs format (reranker)
|
||||
if 'passage' in item and 'label' in item:
|
||||
# Additional check for query field
|
||||
if 'query' in item:
|
||||
return "pairs"
|
||||
else:
|
||||
logger.warning(f"Found 'passage' and 'label' but missing 'query'. Available keys: {available_keys}")
|
||||
|
||||
# Check for triplets format (embedding)
|
||||
if 'pos' in item:
|
||||
# Standard triplets format has both pos and neg
|
||||
if 'neg' in item:
|
||||
return "triplets"
|
||||
# Might be triplets with only positives
|
||||
elif 'query' in item and not ('label' in item or 'passage' in item):
|
||||
logger.info("Detected triplets format with only positive samples")
|
||||
return "triplets"
|
||||
# Legacy nested format
|
||||
elif self.legacy_support and 'query' in item:
|
||||
logger.warning("Detected legacy nested format. Converting to flat pairs for reranker training.")
|
||||
return "legacy_nested"
|
||||
|
||||
# Try to provide helpful error message
|
||||
common_keys = available_keys & {'query', 'pos', 'neg', 'passage', 'label', 'candidates'}
|
||||
error_msg = f"Unknown data format. Available keys: {available_keys}"
|
||||
if common_keys:
|
||||
error_msg += f", Common keys found: {common_keys}"
|
||||
else:
|
||||
error_msg += ". No common keys found. Expected 'query' with 'pos'/'neg' (triplets), 'passage'/'label' (pairs), or 'candidates' (candidates format)"
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Format detection failed: {e}")
|
||||
raise
|
||||
|
||||
def _validate_item(self, item: Dict, format_type: str) -> bool:
|
||||
"""
|
||||
Validate item based on detected format with comprehensive checks
|
||||
|
||||
Returns:
|
||||
bool: True if item is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Basic type check
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
|
||||
# Format-specific validation
|
||||
if format_type == "triplets":
|
||||
# Check required fields
|
||||
if not all(field in item for field in ['query', 'pos']):
|
||||
return False
|
||||
# Validate query
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
# Validate positives
|
||||
if not isinstance(item.get('pos', []), list) or len(item['pos']) == 0:
|
||||
return False
|
||||
# Check each positive is a non-empty string
|
||||
for pos in item['pos']:
|
||||
if not isinstance(pos, str) or not pos.strip():
|
||||
return False
|
||||
# Validate negatives if present
|
||||
if 'neg' in item:
|
||||
if not isinstance(item['neg'], list):
|
||||
return False
|
||||
for neg in item['neg']:
|
||||
if not isinstance(neg, str) or not neg.strip():
|
||||
return False
|
||||
return True
|
||||
|
||||
elif format_type == "pairs":
|
||||
# Check required fields
|
||||
if not all(field in item for field in ['query', 'passage', 'label']):
|
||||
return False
|
||||
# Validate query and passage
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
if not isinstance(item['passage'], str) or not item['passage'].strip():
|
||||
return False
|
||||
# Validate label (should be numeric 0 or 1, or float between 0 and 1)
|
||||
label = item['label']
|
||||
if isinstance(label, (int, float)):
|
||||
if not (0 <= label <= 1):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
elif format_type == "candidates":
|
||||
# Check required fields
|
||||
if not ('query' in item and 'candidates' in item):
|
||||
return False
|
||||
# Validate query
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
# Validate candidates
|
||||
if not isinstance(item['candidates'], list) or len(item['candidates']) == 0:
|
||||
return False
|
||||
# Check each candidate
|
||||
for i, candidate in enumerate(item['candidates']):
|
||||
if not isinstance(candidate, dict):
|
||||
return False
|
||||
if 'text' not in candidate:
|
||||
return False
|
||||
if not isinstance(candidate['text'], str) or not candidate['text'].strip():
|
||||
return False
|
||||
# Validate score if present
|
||||
if 'score' in candidate:
|
||||
score = candidate['score']
|
||||
if not isinstance(score, (int, float)):
|
||||
return False
|
||||
return True
|
||||
|
||||
elif format_type == "legacy_nested":
|
||||
# Basic validation for legacy format
|
||||
if not ('query' in item and 'pos' in item):
|
||||
return False
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Validation error: {e}")
|
||||
return False
|
||||
|
||||
def _process_item(self, item: Dict, format_type: str) -> Union[Dict, List[Dict]]:
|
||||
"""Process item based on format type - may return multiple items for candidates"""
|
||||
if format_type == "triplets":
|
||||
return self._process_triplets_item(item)
|
||||
elif format_type == "pairs":
|
||||
return self._process_pairs_item(item)
|
||||
elif format_type == "candidates":
|
||||
return self._process_candidates_item(item)
|
||||
elif format_type == "legacy_nested":
|
||||
return self._process_legacy_nested_item(item)
|
||||
else:
|
||||
raise ValueError(f"Unknown format type: {format_type}")
|
||||
|
||||
def _process_triplets_item(self, item: Dict) -> Dict:
|
||||
"""Process triplets format item for embedding training"""
|
||||
processed = {
|
||||
'query': item['query'],
|
||||
'pos': item.get('pos', []),
|
||||
'neg': item.get('neg', []),
|
||||
'pos_scores': item.get('pos_scores', [1.0] * len(item.get('pos', []))),
|
||||
'neg_scores': item.get('neg_scores', [0.0] * len(item.get('neg', []))),
|
||||
'prompt': item.get('prompt', ''),
|
||||
'type': item.get('type', 'normal'),
|
||||
'format': 'triplets'
|
||||
}
|
||||
return processed
|
||||
|
||||
def _process_pairs_item(self, item: Dict) -> Dict:
|
||||
"""Process pairs format item for reranker training"""
|
||||
processed = {
|
||||
'query': item['query'],
|
||||
'passage': item['passage'],
|
||||
'label': item['label'],
|
||||
'score': item.get('score', 1.0 if item['label'] > 0 else 0.0),
|
||||
'qid': item.get('qid', ''),
|
||||
'pid': item.get('pid', ''),
|
||||
'format': 'pairs'
|
||||
}
|
||||
return processed
|
||||
|
||||
def _process_candidates_item(self, item: Dict) -> List[Dict]:
|
||||
"""
|
||||
Process candidates format item - explode into multiple pairs
|
||||
Each candidate becomes a separate (query, passage, label) pair
|
||||
Label is assigned based on score threshold
|
||||
"""
|
||||
query = item['query']
|
||||
candidates = item['candidates']
|
||||
qid = item.get('qid', '')
|
||||
|
||||
pairs = []
|
||||
for i, candidate in enumerate(candidates):
|
||||
if isinstance(candidate, dict):
|
||||
text = candidate.get('text', '')
|
||||
score = candidate.get('score', 0.0)
|
||||
pid = candidate.get('pid', f'cand_{i}')
|
||||
else:
|
||||
# Handle simple string candidates
|
||||
text = str(candidate)
|
||||
score = 1.0 # Default to positive if no score
|
||||
pid = f'cand_{i}'
|
||||
|
||||
# Assign label based on threshold
|
||||
label = 1 if score >= self.candidates_threshold else 0
|
||||
|
||||
pair = {
|
||||
'query': query,
|
||||
'passage': text,
|
||||
'label': label,
|
||||
'score': score,
|
||||
'qid': qid,
|
||||
'pid': pid,
|
||||
'format': 'pairs', # Convert to pairs format
|
||||
'source': 'candidates' # Track original source
|
||||
}
|
||||
pairs.append(pair)
|
||||
|
||||
return pairs
|
||||
|
||||
def _process_legacy_nested_item(self, item: Dict) -> List[Dict]:
|
||||
"""
|
||||
Process legacy nested format - convert to flat pairs for reranker training
|
||||
Each pos[i] becomes (query, passage, label=1)
|
||||
Each neg[j] becomes (query, passage, label=0)
|
||||
"""
|
||||
query = item['query']
|
||||
positives = item.get('pos', [])
|
||||
negatives = item.get('neg', [])
|
||||
pos_scores = item.get('pos_scores', [1.0] * len(positives))
|
||||
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
|
||||
qid = item.get('qid', '')
|
||||
|
||||
pairs = []
|
||||
|
||||
# Convert positives
|
||||
for i, passage in enumerate(positives):
|
||||
score = pos_scores[i] if i < len(pos_scores) else 1.0
|
||||
pair = {
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 1,
|
||||
'score': score,
|
||||
'qid': qid,
|
||||
'pid': f'pos_{i}',
|
||||
'format': 'pairs', # Convert to pairs format
|
||||
'source': 'legacy_nested' # Track original source
|
||||
}
|
||||
pairs.append(pair)
|
||||
|
||||
# Convert negatives
|
||||
for i, passage in enumerate(negatives):
|
||||
score = neg_scores[i] if i < len(neg_scores) else 0.0
|
||||
pair = {
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 0,
|
||||
'score': score,
|
||||
'qid': qid,
|
||||
'pid': f'neg_{i}',
|
||||
'format': 'pairs', # Convert to pairs format
|
||||
'source': 'legacy_nested' # Track original source
|
||||
}
|
||||
pairs.append(pair)
|
||||
|
||||
return pairs
|
||||
|
||||
def _pretokenize_all(self):
|
||||
"""Pre-tokenize all samples for in-memory caching"""
|
||||
from tqdm import tqdm
|
||||
for idx in tqdm(range(len(self.data)), desc="Pre-tokenizing"):
|
||||
# Tokenize based on format
|
||||
item = self.data[idx]
|
||||
if item['format'] == 'triplets':
|
||||
# Pre-tokenize query and passages for triplets
|
||||
query = item['query']
|
||||
query_encoding = self.tokenizer(
|
||||
query,
|
||||
max_length=self.query_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
self._tokenized_cache[f"{idx}_query"] = {
|
||||
'input_ids': query_encoding.input_ids.squeeze(0),
|
||||
'attention_mask': query_encoding.attention_mask.squeeze(0)
|
||||
}
|
||||
elif item['format'] == 'pairs':
|
||||
# Pre-tokenize query-passage pair
|
||||
query = item['query']
|
||||
passage = item['passage']
|
||||
pair_text = f"{query} [SEP] {passage}"
|
||||
encoding = self.tokenizer(
|
||||
pair_text,
|
||||
max_length=self.query_max_length + self.passage_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
self._tokenized_cache[idx] = {
|
||||
'input_ids': encoding.input_ids.squeeze(0),
|
||||
'attention_mask': encoding.attention_mask.squeeze(0)
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""Get item based on detected format"""
|
||||
item = self.data[idx]
|
||||
|
||||
if item['format'] == 'triplets':
|
||||
return self._get_triplets_item(item, idx)
|
||||
elif item['format'] == 'pairs':
|
||||
return self._get_pairs_item(item, idx)
|
||||
else:
|
||||
raise ValueError(f"Unknown format: {item['format']}")
|
||||
|
||||
def _get_triplets_item(self, item: Dict, idx: int) -> Dict:
|
||||
"""Get triplets training example for embedding model (contrastive learning)"""
|
||||
query = item['query']
|
||||
positives = item['pos']
|
||||
negatives = item['neg']
|
||||
pos_scores = item['pos_scores']
|
||||
neg_scores = item['neg_scores']
|
||||
prompt = item['prompt']
|
||||
|
||||
# Apply prompt if provided
|
||||
if prompt and self.is_train:
|
||||
query = f"{prompt}{query}"
|
||||
|
||||
# Check if we have cached tokenized query
|
||||
if f"{idx}_query" in self._tokenized_cache:
|
||||
query_encodings = self._tokenized_cache[f"{idx}_query"]
|
||||
else:
|
||||
# Tokenize query
|
||||
query_encodings = self.tokenizer(
|
||||
query,
|
||||
max_length=self.query_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
query_encodings = {
|
||||
'input_ids': query_encodings.input_ids.squeeze(0),
|
||||
'attention_mask': query_encodings.attention_mask.squeeze(0)
|
||||
}
|
||||
|
||||
# Enhanced sampling for contrastive learning
|
||||
if self.is_train:
|
||||
# Sample multiple positives for richer contrastive signals
|
||||
num_positives = min(2, len(positives))
|
||||
sampled_positives = random.sample(positives, num_positives) if positives else []
|
||||
|
||||
# Hard negative mining with score-based sampling
|
||||
num_negatives = min(6, len(negatives))
|
||||
if negatives and neg_scores:
|
||||
# Weight by scores for hard negative selection
|
||||
neg_weights = np.array(neg_scores)
|
||||
neg_weights = neg_weights / neg_weights.sum() if neg_weights.sum() > 0 else np.ones_like(neg_weights)
|
||||
sampled_neg_indices = np.random.choice(
|
||||
len(negatives), size=min(num_negatives, len(negatives)),
|
||||
replace=False, p=neg_weights
|
||||
)
|
||||
sampled_negatives = [negatives[i] for i in sampled_neg_indices]
|
||||
sampled_neg_scores = [neg_scores[i] for i in sampled_neg_indices]
|
||||
else:
|
||||
sampled_negatives = random.sample(negatives, num_negatives) if negatives else []
|
||||
sampled_neg_scores = [0.0] * len(sampled_negatives)
|
||||
|
||||
passages = sampled_positives + sampled_negatives
|
||||
labels = [1] * len(sampled_positives) + [0] * len(sampled_negatives)
|
||||
scores = pos_scores[:len(sampled_positives)] + sampled_neg_scores
|
||||
else:
|
||||
# Use all for evaluation
|
||||
passages = positives + negatives
|
||||
labels = [1] * len(positives) + [0] * len(negatives)
|
||||
scores = pos_scores + neg_scores
|
||||
|
||||
# Tokenize passages
|
||||
passage_encodings = []
|
||||
for passage in passages:
|
||||
encoding = self.tokenizer(
|
||||
passage,
|
||||
max_length=self.passage_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
passage_encodings.append({
|
||||
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
|
||||
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0)
|
||||
})
|
||||
|
||||
# Pad to consistent size (max 8 passages for triplets)
|
||||
max_passages = 8
|
||||
while len(passage_encodings) < max_passages:
|
||||
passage_encodings.append({
|
||||
'input_ids': torch.zeros(self.passage_max_length, dtype=torch.long),
|
||||
'attention_mask': torch.zeros(self.passage_max_length, dtype=torch.long)
|
||||
})
|
||||
labels.append(-1)
|
||||
scores.append(0.0)
|
||||
|
||||
# Stack tensors
|
||||
passage_input_ids = torch.stack([p['input_ids'] for p in passage_encodings])
|
||||
passage_attention_mask = torch.stack([p['attention_mask'] for p in passage_encodings])
|
||||
|
||||
return {
|
||||
'query_input_ids': query_encodings['input_ids'],
|
||||
'query_attention_mask': query_encodings['attention_mask'],
|
||||
'passage_input_ids': passage_input_ids,
|
||||
'passage_attention_mask': passage_attention_mask,
|
||||
'labels': torch.tensor(labels, dtype=torch.long),
|
||||
'scores': torch.tensor(scores, dtype=torch.float),
|
||||
'query_text': query,
|
||||
'passage_texts': passages + [''] * (max_passages - len(passages)),
|
||||
'item_idx': idx,
|
||||
'format': 'triplets',
|
||||
# Backward compatibility
|
||||
'positives': positives,
|
||||
'negatives': negatives
|
||||
}
|
||||
|
||||
def _get_pairs_item(self, item: Dict, idx: int) -> Dict:
|
||||
"""Get pairs training example for reranker (cross-encoder training)"""
|
||||
query = item['query']
|
||||
passage = item['passage']
|
||||
label = item['label']
|
||||
score = item['score']
|
||||
|
||||
# Check if we have cached tokenized data
|
||||
if idx in self._tokenized_cache:
|
||||
encoding = self._tokenized_cache[idx]
|
||||
return {
|
||||
'input_ids': encoding['input_ids'],
|
||||
'attention_mask': encoding['attention_mask'],
|
||||
'labels': torch.tensor(label, dtype=torch.long),
|
||||
'scores': torch.tensor(score, dtype=torch.float),
|
||||
'query_text': query,
|
||||
'passage_text': passage,
|
||||
'qid': item.get('qid', ''),
|
||||
'pid': item.get('pid', ''),
|
||||
'item_idx': idx,
|
||||
'format': 'pairs',
|
||||
'source': item.get('source', 'original')
|
||||
}
|
||||
|
||||
# Tokenize query-passage pair if not cached
|
||||
encoding = self.tokenizer(
|
||||
query,
|
||||
passage,
|
||||
max_length=self.passage_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
return {
|
||||
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
|
||||
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0),
|
||||
'labels': torch.tensor(label, dtype=torch.long),
|
||||
'scores': torch.tensor(score, dtype=torch.float),
|
||||
'query_text': query,
|
||||
'passage_text': passage,
|
||||
'qid': item['qid'],
|
||||
'pid': item['pid'],
|
||||
'item_idx': idx,
|
||||
'format': 'pairs',
|
||||
'source': item.get('source', 'original')
|
||||
}
|
||||
|
||||
|
||||
# Specialized dataset classes for backward compatibility
|
||||
class BGEM3Dataset(BGEDataset):
|
||||
"""BGE-M3 specialized dataset (embedding model) with enhanced contrastive learning"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 512,
|
||||
is_train: bool = True,
|
||||
train_group_size: int = 8,
|
||||
cache_dir: Optional[str] = None
|
||||
):
|
||||
# Force embedding format
|
||||
super().__init__(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
format_type="triplets"
|
||||
)
|
||||
self.train_group_size = train_group_size
|
||||
|
||||
# Warn if wrong format detected
|
||||
if self.detected_format not in ['triplets']:
|
||||
logger.warning(f"Expected triplets format for BGEM3Dataset, got {self.detected_format}")
|
||||
|
||||
|
||||
class BGERerankerDataset(BGEDataset):
|
||||
"""BGE-Reranker specialized dataset (cross-encoder model) with flat pairs support"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 448,
|
||||
is_train: bool = True,
|
||||
train_group_size: int = 16,
|
||||
cache_dir: Optional[str] = None,
|
||||
legacy_support: bool = True
|
||||
):
|
||||
# Force reranker format (or allow legacy)
|
||||
format_type = "auto" if legacy_support else "pairs"
|
||||
super().__init__(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
legacy_support=legacy_support,
|
||||
format_type=format_type
|
||||
)
|
||||
self.train_group_size = train_group_size
|
||||
|
||||
# Warn if wrong format detected
|
||||
if self.detected_format not in ['pairs', 'legacy_nested']:
|
||||
logger.warning(f"Expected pairs format for BGERerankerDataset, got {self.detected_format}")
|
||||
|
||||
|
||||
class JointDataset(Dataset):
|
||||
"""Dataset for joint training of retriever and reranker"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
retriever_config: Dict,
|
||||
reranker_config: Dict,
|
||||
is_train: bool = True,
|
||||
cache_dir: Optional[str] = None
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.is_train = is_train
|
||||
|
||||
# Create both datasets
|
||||
self.retriever_dataset = BGEM3Dataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
**retriever_config
|
||||
)
|
||||
|
||||
self.reranker_dataset = BGERerankerDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
**reranker_config
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.retriever_dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""Get example for both retriever and reranker"""
|
||||
retriever_item = self.retriever_dataset[idx]
|
||||
reranker_item = self.reranker_dataset[idx]
|
||||
|
||||
return {
|
||||
'retriever': retriever_item,
|
||||
'reranker': reranker_item,
|
||||
'item_idx': idx
|
||||
}
|
||||
|
||||
|
||||
# Factory functions for easy dataset creation
|
||||
def create_embedding_dataset(
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 512,
|
||||
is_train: bool = True
|
||||
) -> BGEDataset:
|
||||
"""Create dataset specifically for embedding model training"""
|
||||
dataset = BGEDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
format_type="triplets"
|
||||
)
|
||||
|
||||
if dataset.detected_format not in ['triplets']:
|
||||
logger.warning(f"Expected triplets format, got {dataset.detected_format}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def create_reranker_dataset(
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 448,
|
||||
is_train: bool = True,
|
||||
legacy_support: bool = True
|
||||
) -> BGEDataset:
|
||||
"""Create dataset specifically for reranker model training"""
|
||||
dataset = BGEDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
legacy_support=legacy_support,
|
||||
format_type="auto" if legacy_support else "pairs"
|
||||
)
|
||||
|
||||
if dataset.detected_format not in ['pairs', 'legacy_nested']:
|
||||
logger.warning(f"Expected pairs format, got {dataset.detected_format}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def migrate_nested_to_flat(input_file: str, output_file: str) -> str:
|
||||
"""
|
||||
Migrate legacy nested reranker JSONL to flat pairs format
|
||||
Explodes each (pos[i], label=1) and (neg[j], label=0) into separate lines
|
||||
"""
|
||||
logger.info(f"Migrating nested reranker format: {input_file} -> {output_file}")
|
||||
|
||||
flat_pairs = []
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line_no, line in enumerate(tqdm(f, desc="Converting to flat pairs"), 1):
|
||||
if line.strip():
|
||||
try:
|
||||
item = json.loads(line)
|
||||
query = item['query']
|
||||
positives = item.get('pos', [])
|
||||
negatives = item.get('neg', [])
|
||||
pos_scores = item.get('pos_scores', [1.0] * len(positives))
|
||||
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
|
||||
|
||||
# Create flat pairs for positives
|
||||
for i, passage in enumerate(positives):
|
||||
flat_pairs.append({
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 1,
|
||||
'score': pos_scores[i] if i < len(pos_scores) else 1.0,
|
||||
'qid': f"Q{line_no}",
|
||||
'pid': f"P{line_no}_{i}"
|
||||
})
|
||||
|
||||
# Create flat pairs for negatives
|
||||
for i, passage in enumerate(negatives):
|
||||
flat_pairs.append({
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 0,
|
||||
'score': neg_scores[i] if i < len(neg_scores) else 0.0,
|
||||
'qid': f"Q{line_no}",
|
||||
'pid': f"N{line_no}_{i}"
|
||||
})
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON on line {line_no}: {e}")
|
||||
continue
|
||||
|
||||
# Save flat pairs
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for pair in flat_pairs:
|
||||
f.write(json.dumps(pair, ensure_ascii=False) + '\n')
|
||||
|
||||
logger.info(f"✅ Migrated {len(flat_pairs)} flat pairs to {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
# Collate functions for different formats
|
||||
def collate_embedding_batch(batch: List[Dict]) -> Dict:
|
||||
"""Collate function for embedding model batches"""
|
||||
return {
|
||||
'query_input_ids': torch.stack([item['query_input_ids'] for item in batch]),
|
||||
'query_attention_mask': torch.stack([item['query_attention_mask'] for item in batch]),
|
||||
'passage_input_ids': torch.stack([item['passage_input_ids'] for item in batch]),
|
||||
'passage_attention_mask': torch.stack([item['passage_attention_mask'] for item in batch]),
|
||||
'labels': torch.stack([item['labels'] for item in batch]),
|
||||
'scores': torch.stack([item['scores'] for item in batch]),
|
||||
'query_text': [item['query_text'] for item in batch],
|
||||
'passage_texts': [item['passage_texts'] for item in batch],
|
||||
'item_idx': [item['item_idx'] for item in batch]
|
||||
}
|
||||
|
||||
|
||||
def collate_reranker_batch(batch: List[Dict]) -> Dict:
|
||||
"""Collate function for reranker model batches"""
|
||||
if batch[0]['format'] == 'pairs':
|
||||
# Flat pairs format
|
||||
return {
|
||||
'input_ids': torch.stack([item['input_ids'] for item in batch]),
|
||||
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
|
||||
'labels': torch.stack([item['labels'] for item in batch]),
|
||||
'scores': torch.stack([item['scores'] for item in batch]),
|
||||
'query_text': [item['query_text'] for item in batch],
|
||||
'passage_text': [item['passage_text'] for item in batch],
|
||||
'qid': [item['qid'] for item in batch],
|
||||
'pid': [item['pid'] for item in batch],
|
||||
'item_idx': [item['item_idx'] for item in batch]
|
||||
}
|
||||
else:
|
||||
# Legacy nested format
|
||||
return {
|
||||
'input_ids': torch.stack([item['input_ids'] for item in batch]),
|
||||
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
|
||||
'labels': torch.stack([item['labels'] for item in batch]),
|
||||
'query_text': [item['query_text'] for item in batch],
|
||||
'passage_texts': [item['passage_texts'] for item in batch],
|
||||
'item_idx': [item['item_idx'] for item in batch]
|
||||
}
|
||||
Reference in New Issue
Block a user