bge_finetune/scripts/train_m3.py
2025-07-22 16:55:25 +08:00

615 lines
26 KiB
Python

"""
Enhanced BGE-M3 embedding model training script with integrated dataset pipeline
Supports both legacy nested format and new embedding schema with automatic detection
"""
import os
import sys
import argparse
# Set tokenizer parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
import json
from datetime import datetime
import torch
from transformers import AutoTokenizer, set_seed
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from config.m3_config import BGEM3Config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset, create_embedding_dataset, collate_embedding_batch, migrate_nested_to_flat
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter, HardNegativeAugmenter
from data.hard_negative_mining import ANCEHardNegativeMiner
from training.m3_trainer import BGEM3Trainer
from evaluation.evaluator import RetrievalEvaluator
from utils.logging import setup_logging, get_logger
from utils.distributed import setup_distributed, is_main_process, set_random_seed
from utils.config_loader import get_config, load_config_for_training, resolve_model_path
except ImportError as e:
print(f"Import error: {e}")
print("Please ensure you're running from the project root directory")
sys.exit(1)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Enhanced BGE-M3 Embedding Model Training")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default=None,
help="Path to pretrained model or model identifier")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
parser.add_argument("--config_path", type=str, default="./config.toml",
help="Path to configuration file")
# Data arguments - support both legacy and new schemas
parser.add_argument("--train_data", type=str, required=True,
help="Path to training data file (supports both embedding and legacy nested formats)")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data file")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "embedding", "legacy_nested", "jsonl", "json", "csv", "tsv"],
help="Format of input data (auto-detects by default)")
parser.add_argument("--legacy_data", type=str, default=None,
help="Path to additional legacy nested data (for backward compatibility)")
parser.add_argument("--migrate_legacy", action="store_true",
help="Migrate legacy nested format to flat format")
# Sequence length arguments
parser.add_argument("--max_seq_length", type=int, default=None,
help="Maximum sequence length")
parser.add_argument("--query_max_length", type=int, default=None,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=None,
help="Maximum passage length")
parser.add_argument("--max_length", type=int, default=8192,
help="Maximum total input length (multi-granularity)")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/bge-m3-enhanced",
help="Directory to save the model")
parser.add_argument("--num_train_epochs", type=int, default=None,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=None,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=None,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
help="Number of updates steps to accumulate before performing a backward/update pass")
parser.add_argument("--learning_rate", type=float, default=None,
help="Initial learning rate")
parser.add_argument("--warmup_ratio", type=float, default=None,
help="Warmup steps as a ratio of total steps")
parser.add_argument("--weight_decay", type=float, default=None,
help="Weight decay")
# Model specific arguments - enhanced features
parser.add_argument("--train_group_size", type=int, default=8,
help="Number of passages per query in training")
parser.add_argument("--temperature", type=float, default=0.1,
help="Temperature for InfoNCE loss")
parser.add_argument("--use_hard_negatives", action="store_true",
help="Enable hard negative mining")
parser.add_argument("--use_self_distill", action="store_true",
help="Enable self-knowledge distillation")
parser.add_argument("--use_score_weighted_sampling", action="store_true", default=True,
help="Enable score-weighted positive sampling")
parser.add_argument("--query_instruction", type=str, default="",
help="Query instruction prefix")
# Enhanced sampling features
parser.add_argument("--multiple_positives", action="store_true", default=True,
help="Use multiple positives per batch for richer contrastive signals")
parser.add_argument("--hard_negative_ratio", type=float, default=0.7,
help="Ratio of hard to random negatives")
parser.add_argument("--difficulty_threshold", type=float, default=0.1,
help="Minimum difficulty for hard negatives")
# Optimization and performance
parser.add_argument("--fp16", action="store_true",
help="Use mixed precision training")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="Use gradient checkpointing to save memory")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
# Checkpointing
parser.add_argument("--save_steps", type=int, default=500,
help="Save checkpoint every N steps")
parser.add_argument("--eval_steps", type=int, default=500,
help="Evaluate every N steps")
parser.add_argument("--logging_steps", type=int, default=100,
help="Log every N steps")
parser.add_argument("--save_total_limit", type=int, default=3,
help="Maximum number of checkpoints to keep")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Path to checkpoint to resume from")
# Distributed training
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank for distributed training")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--overwrite_output_dir", action="store_true",
help="Overwrite the output directory")
parser.add_argument("--log_level", type=str, default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level")
# Backward compatibility flags
parser.add_argument("--legacy_support", action="store_true", default=True,
help="Enable legacy nested format support")
parser.add_argument("--use_new_dataset_api", action="store_true", default=True,
help="Use new integrated dataset API with format detection")
args = parser.parse_args()
# Validate arguments
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
return args
def create_config_from_args(args):
"""Create configuration from command line arguments and config file
Parameter Precedence (highest to lowest):
1. Command-line arguments (CLI)
2. Model-specific config ([m3], [reranker])
3. [hardware] section in config.toml
4. [training] section in config.toml
5. Hardcoded defaults in code
"""
# Load base configuration from TOML file
config_dict = load_config_for_training(
config_path=args.config_path,
override_params={}
)
# Get the centralized config instance
central_config = get_config()
# Create override parameters from command line arguments
override_params = {}
# Model configuration - use new resolve_model_path function
resolved_model_path, resolved_cache_dir = resolve_model_path(
model_key='bge_m3',
config_instance=central_config,
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
args.model_name_or_path = resolved_model_path
args.cache_dir = resolved_cache_dir
# Data configuration
data_config = config_dict.get('data', {})
if args.max_seq_length is None:
args.max_seq_length = data_config.get('max_seq_length', 512)
if args.query_max_length is None:
args.query_max_length = data_config.get('max_query_length', 64)
if args.passage_max_length is None:
args.passage_max_length = data_config.get('max_passage_length', 512)
# Get validation split from config
validation_split = data_config.get('validation_split', 0.1)
# Training configuration
training_config = config_dict.get('training', {})
if args.num_train_epochs is None:
args.num_train_epochs = training_config.get('default_num_epochs', 3)
if args.per_device_train_batch_size is None:
args.per_device_train_batch_size = training_config.get('default_batch_size', 4)
if args.per_device_eval_batch_size is None:
args.per_device_eval_batch_size = args.per_device_train_batch_size * 2
if args.learning_rate is None:
args.learning_rate = training_config.get('default_learning_rate', 1e-5)
if args.warmup_ratio is None:
args.warmup_ratio = training_config.get('default_warmup_ratio', 0.1)
if args.weight_decay is None:
args.weight_decay = training_config.get('default_weight_decay', 0.01)
# Model-specific (M3) configuration - highest precedence
m3_config = config_dict.get('m3', {})
train_group_size = m3_config.get('train_group_size', args.train_group_size)
temperature = m3_config.get('temperature', args.temperature)
use_hard_negatives = m3_config.get('use_hard_negatives', args.use_hard_negatives)
use_self_distill = m3_config.get('use_self_distill', args.use_self_distill)
query_instruction_for_retrieval = m3_config.get('query_instruction_for_retrieval', args.query_instruction)
# Hardware configuration
hardware_config = config_dict.get('hardware', {})
if args.per_device_train_batch_size is None:
args.per_device_train_batch_size = hardware_config.get('per_device_train_batch_size', args.per_device_train_batch_size)
if args.per_device_eval_batch_size is None:
args.per_device_eval_batch_size = hardware_config.get('per_device_eval_batch_size', args.per_device_eval_batch_size)
# Create BGE-M3 configuration
config = BGEM3Config(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
max_seq_length=args.max_seq_length,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
max_length=args.max_length,
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=train_group_size,
temperature=temperature,
use_hard_negatives=use_hard_negatives,
use_self_distill=use_self_distill,
query_instruction_for_retrieval=query_instruction_for_retrieval,
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
overwrite_output_dir=args.overwrite_output_dir
)
return config
def setup_model_and_tokenizer(args):
"""Setup BGE-M3 model and tokenizer"""
logger.info(f"Loading BGE-M3 model from {args.model_name_or_path}")
# Detect offline mode if the path is local or HF_HUB_OFFLINE is set
local_files_only = os.path.isdir(args.model_name_or_path) or bool(os.environ.get("HF_HUB_OFFLINE"))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
trust_remote_code=True,
local_files_only=local_files_only
)
# Load model
try:
model = BGEM3Model(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
return model, tokenizer
def setup_datasets(args, tokenizer):
"""Setup datasets with enhanced format detection and migration support"""
logger.info("Setting up datasets with enhanced format detection...")
# Handle legacy data migration if needed
if args.legacy_data and args.migrate_legacy:
logger.info(f"Migrating legacy data from {args.legacy_data}")
legacy_base = os.path.splitext(args.legacy_data)[0]
migrated_file = f"{legacy_base}_migrated_embedding.jsonl"
# For embedding, we don't need to migrate to flat pairs, but we can optimize the format
# Copy the legacy data to work with if needed
import shutil
shutil.copy2(args.legacy_data, migrated_file)
logger.info(f"Legacy data prepared for training: {migrated_file}")
# Setup training dataset
if args.use_new_dataset_api:
# Use new integrated dataset API with format detection
train_dataset = create_embedding_dataset(
data_path=args.train_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True
)
logger.info(f"Loaded {len(train_dataset)} training examples from {args.train_data}")
logger.info(f"Detected format: {train_dataset.detected_format}")
# Add legacy data if provided
if args.legacy_data and args.legacy_support:
logger.info(f"Loading additional legacy data from {args.legacy_data}")
legacy_dataset = create_embedding_dataset(
data_path=args.legacy_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True
)
logger.info(f"Loaded {len(legacy_dataset)} legacy examples")
# Combine datasets
train_dataset.data.extend(legacy_dataset.data)
logger.info(f"Total training examples: {len(train_dataset)}")
else:
# Use original dataset API for backward compatibility
train_dataset = BGEM3Dataset(
data_path=args.train_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.train_group_size,
is_train=True
)
logger.info(f"Loaded {len(train_dataset)} training examples (legacy API)")
# Setup evaluation dataset
eval_dataset = None
if args.eval_data:
if args.use_new_dataset_api:
eval_dataset = create_embedding_dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=False
)
else:
eval_dataset = BGEM3Dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.train_group_size,
is_train=False
)
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
return train_dataset, eval_dataset
def setup_data_loaders(args, train_dataset, eval_dataset):
"""Setup data loaders with appropriate collate functions"""
# 4. Set up data loaders with appropriate collate function
if hasattr(train_dataset, 'format') and train_dataset.format == 'triplets':
# This might be a generic BGE dataset in triplets format
collate_fn = collate_embedding_batch
logger.info("Using embedding collate function for triplets format")
else:
collate_fn = collate_embedding_batch
logger.info("Using standard embedding collate function")
# FastM3Dataset doesn't work well with multiprocessing due to pickle loading
workers = args.dataloader_num_workers
train_dataloader = DataLoader(
train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available()
)
# Evaluation data loader
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available()
)
return train_dataloader, eval_dataloader
def main():
# Parse arguments
args = parse_args()
# Create configuration
config = create_config_from_args(args)
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=getattr(logging, args.log_level),
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("🚀 Starting Enhanced BGE-M3 Fine-tuning")
logger.info(f"Arguments: {args}")
logger.info(f"Configuration: {config.to_dict()}")
# Set random seed
set_random_seed(args.seed, deterministic=False)
# Setup distributed training if needed
device = None
if args.local_rank != -1:
setup_distributed()
# Device selection from config
from utils.config_loader import get_config
config_instance = get_config()
device_type = config_instance.get('ascend.device_type', 'cuda')
if device_type == 'npu':
try:
import torch_npu # type: ignore
torch_npu.npu.set_device(args.local_rank)
device = torch.device("npu", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
else:
device = torch.device("cpu")
logger.info("Distributed training enabled: using CPU")
else:
# Non-distributed device selection
from utils.config_loader import get_config
config_instance = get_config()
device_type = config_instance.get('ascend.device_type', 'cuda')
if device_type == 'npu':
try:
import torch_npu # type: ignore
device = torch.device("npu")
logger.info("Using Ascend NPU")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
else:
device = torch.device("cpu")
logger.info("Using CPU")
# Update config with device
config.device = str(device)
# Save configuration
if is_main_process():
os.makedirs(args.output_dir, exist_ok=True)
config.save(os.path.join(args.output_dir, "config.json"))
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args)
# Setup datasets with enhanced format detection
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
# Setup data loaders
train_dataloader, eval_dataloader = setup_data_loaders(args, train_dataset, eval_dataset)
# Initialize trainer
logger.info("Initializing enhanced trainer...")
try:
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer
)
logger.info("Trainer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize trainer: {e}")
raise
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
logger.info(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
try:
trainer.load_checkpoint(args.resume_from_checkpoint)
logger.info("Checkpoint loaded successfully.")
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise
# Train
logger.info("🎯 Starting enhanced training...")
try:
# Handle type compatibility for trainer
from typing import cast
trainer_train_dataset = cast(BGEM3Dataset, train_dataset) if train_dataset else None
trainer_eval_dataset = cast(BGEM3Dataset, eval_dataset) if eval_dataset else None
if not trainer_train_dataset:
raise ValueError("Training dataset is required but not provided")
trainer.train(
train_dataset=trainer_train_dataset,
eval_dataset=trainer_eval_dataset,
resume_from_checkpoint=args.resume_from_checkpoint
)
# Save final model
if is_main_process():
logger.info("Saving final model...")
final_path = os.path.join(args.output_dir, "final_model")
os.makedirs(final_path, exist_ok=True)
# Save model
model.save_pretrained(final_path)
tokenizer.save_pretrained(final_path)
# Save training summary
summary = {
"training_args": vars(args),
"model_config": config.to_dict(),
"final_metrics": getattr(trainer, 'best_metric', None),
"total_steps": getattr(trainer, 'global_step', 0),
"training_completed": datetime.now().isoformat(),
"device_used": str(device),
"dataset_format": getattr(train_dataset, 'detected_format', 'unknown'),
"enhanced_features": {
"use_new_dataset_api": args.use_new_dataset_api,
"use_hard_negatives": args.use_hard_negatives,
"use_self_distill": args.use_self_distill,
"use_score_weighted_sampling": args.use_score_weighted_sampling,
"multiple_positives": args.multiple_positives,
"legacy_support": args.legacy_support
}
}
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
json.dump(summary, f, indent=2)
logger.info(f"✅ Enhanced training completed! Model saved to {final_path}")
# Show training summary
logger.info("📊 Training Summary:")
logger.info(f" Total examples: {len(train_dataset)}")
if hasattr(train_dataset, 'detected_format'):
logger.info(f" Format detected: {train_dataset.detected_format}") # type: ignore[attr-defined]
logger.info(f" Epochs: {args.num_train_epochs}")
logger.info(f" Learning rate: {args.learning_rate}")
logger.info(f" Batch size: {args.per_device_train_batch_size}")
logger.info(f" Temperature: {args.temperature}")
logger.info(f" Enhanced features enabled: {args.use_new_dataset_api}")
except Exception as e:
logger.error(f"Training failed with error: {e}", exc_info=True)
raise
finally:
# Cleanup
if hasattr(trainer, 'hard_negative_miner') and trainer.hard_negative_miner:
trainer.hard_negative_miner.stop_async_updates()
if args.local_rank != -1:
from utils.distributed import cleanup_distributed
cleanup_distributed()
if __name__ == "__main__":
main()