""" Enhanced BGE-M3 embedding model training script with integrated dataset pipeline Supports both legacy nested format and new embedding schema with automatic detection """ import os import sys import argparse # Set tokenizer parallelism to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import logging import json from datetime import datetime import torch from transformers import AutoTokenizer, set_seed from torch.utils.data import DataLoader from torch.optim import AdamW from transformers.optimization import get_linear_schedule_with_warmup # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) try: from config.m3_config import BGEM3Config from models.bge_m3 import BGEM3Model from data.dataset import BGEM3Dataset, create_embedding_dataset, collate_embedding_batch, migrate_nested_to_flat from data.preprocessing import DataPreprocessor from data.augmentation import QueryAugmenter, HardNegativeAugmenter from data.hard_negative_mining import ANCEHardNegativeMiner from training.m3_trainer import BGEM3Trainer from evaluation.evaluator import RetrievalEvaluator from utils.logging import setup_logging, get_logger from utils.distributed import setup_distributed, is_main_process, set_random_seed from utils.config_loader import get_config, load_config_for_training, resolve_model_path except ImportError as e: print(f"Import error: {e}") print("Please ensure you're running from the project root directory") sys.exit(1) logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Enhanced BGE-M3 Embedding Model Training") # Model arguments parser.add_argument("--model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier") parser.add_argument("--cache_dir", type=str, default="./cache/data", help="Directory to cache downloaded models") parser.add_argument("--config_path", type=str, default="./config.toml", help="Path to configuration file") # Data arguments - support both legacy and new schemas parser.add_argument("--train_data", type=str, required=True, help="Path to training data file (supports both embedding and legacy nested formats)") parser.add_argument("--eval_data", type=str, default=None, help="Path to evaluation data file") parser.add_argument("--data_format", type=str, default="auto", choices=["auto", "embedding", "legacy_nested", "jsonl", "json", "csv", "tsv"], help="Format of input data (auto-detects by default)") parser.add_argument("--legacy_data", type=str, default=None, help="Path to additional legacy nested data (for backward compatibility)") parser.add_argument("--migrate_legacy", action="store_true", help="Migrate legacy nested format to flat format") # Sequence length arguments parser.add_argument("--max_seq_length", type=int, default=None, help="Maximum sequence length") parser.add_argument("--query_max_length", type=int, default=None, help="Maximum query length") parser.add_argument("--passage_max_length", type=int, default=None, help="Maximum passage length") parser.add_argument("--max_length", type=int, default=8192, help="Maximum total input length (multi-granularity)") # Training arguments parser.add_argument("--output_dir", type=str, default="./output/bge-m3-enhanced", help="Directory to save the model") parser.add_argument("--num_train_epochs", type=int, default=None, help="Number of training epochs") parser.add_argument("--per_device_train_batch_size", type=int, default=None, help="Batch size per GPU/CPU for training") parser.add_argument("--per_device_eval_batch_size", type=int, default=None, help="Batch size per GPU/CPU for evaluation") parser.add_argument("--gradient_accumulation_steps", type=int, default=2, help="Number of updates steps to accumulate before performing a backward/update pass") parser.add_argument("--learning_rate", type=float, default=None, help="Initial learning rate") parser.add_argument("--warmup_ratio", type=float, default=None, help="Warmup steps as a ratio of total steps") parser.add_argument("--weight_decay", type=float, default=None, help="Weight decay") # Model specific arguments - enhanced features parser.add_argument("--train_group_size", type=int, default=8, help="Number of passages per query in training") parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for InfoNCE loss") parser.add_argument("--use_hard_negatives", action="store_true", help="Enable hard negative mining") parser.add_argument("--use_self_distill", action="store_true", help="Enable self-knowledge distillation") parser.add_argument("--use_score_weighted_sampling", action="store_true", default=True, help="Enable score-weighted positive sampling") parser.add_argument("--query_instruction", type=str, default="", help="Query instruction prefix") # Enhanced sampling features parser.add_argument("--multiple_positives", action="store_true", default=True, help="Use multiple positives per batch for richer contrastive signals") parser.add_argument("--hard_negative_ratio", type=float, default=0.7, help="Ratio of hard to random negatives") parser.add_argument("--difficulty_threshold", type=float, default=0.1, help="Minimum difficulty for hard negatives") # Optimization and performance parser.add_argument("--fp16", action="store_true", help="Use mixed precision training") parser.add_argument("--gradient_checkpointing", action="store_true", help="Use gradient checkpointing to save memory") parser.add_argument("--dataloader_num_workers", type=int, default=4, help="Number of workers for data loading") # Checkpointing parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every N steps") parser.add_argument("--eval_steps", type=int, default=500, help="Evaluate every N steps") parser.add_argument("--logging_steps", type=int, default=100, help="Log every N steps") parser.add_argument("--save_total_limit", type=int, default=3, help="Maximum number of checkpoints to keep") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint to resume from") # Distributed training parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training") # Other parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--overwrite_output_dir", action="store_true", help="Overwrite the output directory") parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging level") # Backward compatibility flags parser.add_argument("--legacy_support", action="store_true", default=True, help="Enable legacy nested format support") parser.add_argument("--use_new_dataset_api", action="store_true", default=True, help="Use new integrated dataset API with format detection") args = parser.parse_args() # Validate arguments if os.path.exists(args.output_dir) and not args.overwrite_output_dir: raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.") return args def create_config_from_args(args): """Create configuration from command line arguments and config file Parameter Precedence (highest to lowest): 1. Command-line arguments (CLI) 2. Model-specific config ([m3], [reranker]) 3. [hardware] section in config.toml 4. [training] section in config.toml 5. Hardcoded defaults in code """ # Load base configuration from TOML file config_dict = load_config_for_training( config_path=args.config_path, override_params={} ) # Get the centralized config instance central_config = get_config() # Create override parameters from command line arguments override_params = {} # Model configuration - use new resolve_model_path function resolved_model_path, resolved_cache_dir = resolve_model_path( model_key='bge_m3', config_instance=central_config, model_name_or_path=args.model_name_or_path, cache_dir=args.cache_dir ) args.model_name_or_path = resolved_model_path args.cache_dir = resolved_cache_dir # Data configuration data_config = config_dict.get('data', {}) if args.max_seq_length is None: args.max_seq_length = data_config.get('max_seq_length', 512) if args.query_max_length is None: args.query_max_length = data_config.get('max_query_length', 64) if args.passage_max_length is None: args.passage_max_length = data_config.get('max_passage_length', 512) # Get validation split from config validation_split = data_config.get('validation_split', 0.1) # Training configuration training_config = config_dict.get('training', {}) if args.num_train_epochs is None: args.num_train_epochs = training_config.get('default_num_epochs', 3) if args.per_device_train_batch_size is None: args.per_device_train_batch_size = training_config.get('default_batch_size', 4) if args.per_device_eval_batch_size is None: args.per_device_eval_batch_size = args.per_device_train_batch_size * 2 if args.learning_rate is None: args.learning_rate = training_config.get('default_learning_rate', 1e-5) if args.warmup_ratio is None: args.warmup_ratio = training_config.get('default_warmup_ratio', 0.1) if args.weight_decay is None: args.weight_decay = training_config.get('default_weight_decay', 0.01) # Model-specific (M3) configuration - highest precedence m3_config = config_dict.get('m3', {}) train_group_size = m3_config.get('train_group_size', args.train_group_size) temperature = m3_config.get('temperature', args.temperature) use_hard_negatives = m3_config.get('use_hard_negatives', args.use_hard_negatives) use_self_distill = m3_config.get('use_self_distill', args.use_self_distill) query_instruction_for_retrieval = m3_config.get('query_instruction_for_retrieval', args.query_instruction) # Hardware configuration hardware_config = config_dict.get('hardware', {}) if args.per_device_train_batch_size is None: args.per_device_train_batch_size = hardware_config.get('per_device_train_batch_size', args.per_device_train_batch_size) if args.per_device_eval_batch_size is None: args.per_device_eval_batch_size = hardware_config.get('per_device_eval_batch_size', args.per_device_eval_batch_size) # Create BGE-M3 configuration config = BGEM3Config( model_name_or_path=args.model_name_or_path, cache_dir=args.cache_dir, train_data_path=args.train_data, eval_data_path=args.eval_data, max_seq_length=args.max_seq_length, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, max_length=args.max_length, output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, train_group_size=train_group_size, temperature=temperature, use_hard_negatives=use_hard_negatives, use_self_distill=use_self_distill, query_instruction_for_retrieval=query_instruction_for_retrieval, fp16=args.fp16, gradient_checkpointing=args.gradient_checkpointing, dataloader_num_workers=args.dataloader_num_workers, save_steps=args.save_steps, eval_steps=args.eval_steps, logging_steps=args.logging_steps, save_total_limit=args.save_total_limit, seed=args.seed, local_rank=args.local_rank, overwrite_output_dir=args.overwrite_output_dir ) return config def setup_model_and_tokenizer(args): """Setup BGE-M3 model and tokenizer""" logger.info(f"Loading BGE-M3 model from {args.model_name_or_path}") # Detect offline mode if the path is local or HF_HUB_OFFLINE is set local_files_only = os.path.isdir(args.model_name_or_path) or bool(os.environ.get("HF_HUB_OFFLINE")) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, cache_dir=args.cache_dir, trust_remote_code=True, local_files_only=local_files_only ) # Load model try: model = BGEM3Model( model_name_or_path=args.model_name_or_path, cache_dir=args.cache_dir ) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise return model, tokenizer def setup_datasets(args, tokenizer): """Setup datasets with enhanced format detection and migration support""" logger.info("Setting up datasets with enhanced format detection...") # Handle legacy data migration if needed if args.legacy_data and args.migrate_legacy: logger.info(f"Migrating legacy data from {args.legacy_data}") legacy_base = os.path.splitext(args.legacy_data)[0] migrated_file = f"{legacy_base}_migrated_embedding.jsonl" # For embedding, we don't need to migrate to flat pairs, but we can optimize the format # Copy the legacy data to work with if needed import shutil shutil.copy2(args.legacy_data, migrated_file) logger.info(f"Legacy data prepared for training: {migrated_file}") # Setup training dataset if args.use_new_dataset_api: # Use new integrated dataset API with format detection train_dataset = create_embedding_dataset( data_path=args.train_data, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, is_train=True ) logger.info(f"Loaded {len(train_dataset)} training examples from {args.train_data}") logger.info(f"Detected format: {train_dataset.detected_format}") # Add legacy data if provided if args.legacy_data and args.legacy_support: logger.info(f"Loading additional legacy data from {args.legacy_data}") legacy_dataset = create_embedding_dataset( data_path=args.legacy_data, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, is_train=True ) logger.info(f"Loaded {len(legacy_dataset)} legacy examples") # Combine datasets train_dataset.data.extend(legacy_dataset.data) logger.info(f"Total training examples: {len(train_dataset)}") else: # Use original dataset API for backward compatibility train_dataset = BGEM3Dataset( data_path=args.train_data, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, train_group_size=args.train_group_size, is_train=True ) logger.info(f"Loaded {len(train_dataset)} training examples (legacy API)") # Setup evaluation dataset eval_dataset = None if args.eval_data: if args.use_new_dataset_api: eval_dataset = create_embedding_dataset( data_path=args.eval_data, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, is_train=False ) else: eval_dataset = BGEM3Dataset( data_path=args.eval_data, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, train_group_size=args.train_group_size, is_train=False ) logger.info(f"Loaded {len(eval_dataset)} evaluation examples") return train_dataset, eval_dataset def setup_data_loaders(args, train_dataset, eval_dataset): """Setup data loaders with appropriate collate functions""" # 4. Set up data loaders with appropriate collate function if hasattr(train_dataset, 'format') and train_dataset.format == 'triplets': # This might be a generic BGE dataset in triplets format collate_fn = collate_embedding_batch logger.info("Using embedding collate function for triplets format") else: collate_fn = collate_embedding_batch logger.info("Using standard embedding collate function") # FastM3Dataset doesn't work well with multiprocessing due to pickle loading workers = args.dataloader_num_workers train_dataloader = DataLoader( train_dataset, batch_size=args.per_device_train_batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available() ) # Evaluation data loader eval_dataloader = None if eval_dataset: eval_dataloader = DataLoader( eval_dataset, batch_size=args.per_device_eval_batch_size, shuffle=False, num_workers=args.dataloader_num_workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available() ) return train_dataloader, eval_dataloader def main(): # Parse arguments args = parse_args() # Create configuration config = create_config_from_args(args) # Setup logging setup_logging( output_dir=args.output_dir, log_level=getattr(logging, args.log_level), log_to_console=True, log_to_file=True ) logger = get_logger(__name__) logger.info("🚀 Starting Enhanced BGE-M3 Fine-tuning") logger.info(f"Arguments: {args}") logger.info(f"Configuration: {config.to_dict()}") # Set random seed set_random_seed(args.seed, deterministic=False) # Setup distributed training if needed device = None if args.local_rank != -1: setup_distributed() # Device selection from config from utils.config_loader import get_config config_instance = get_config() device_type = config_instance.get('ascend.device_type', 'cuda') if device_type == 'npu': try: import torch_npu # type: ignore torch_npu.npu.set_device(args.local_rank) device = torch.device("npu", args.local_rank) logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}") except ImportError: logger.warning("torch_npu not installed, falling back to CPU") device = torch.device("cpu") elif torch.cuda.is_available(): torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}") else: device = torch.device("cpu") logger.info("Distributed training enabled: using CPU") else: # Non-distributed device selection from utils.config_loader import get_config config_instance = get_config() device_type = config_instance.get('ascend.device_type', 'cuda') if device_type == 'npu': try: import torch_npu # type: ignore device = torch.device("npu") logger.info("Using Ascend NPU") except ImportError: logger.warning("torch_npu not installed, falling back to CPU") device = torch.device("cpu") elif torch.cuda.is_available(): device = torch.device("cuda") logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})") else: device = torch.device("cpu") logger.info("Using CPU") # Update config with device config.device = str(device) # Save configuration if is_main_process(): os.makedirs(args.output_dir, exist_ok=True) config.save(os.path.join(args.output_dir, "config.json")) # Setup model and tokenizer model, tokenizer = setup_model_and_tokenizer(args) # Setup datasets with enhanced format detection train_dataset, eval_dataset = setup_datasets(args, tokenizer) # Setup data loaders train_dataloader, eval_dataloader = setup_data_loaders(args, train_dataset, eval_dataset) # Initialize trainer logger.info("Initializing enhanced trainer...") try: trainer = BGEM3Trainer( config=config, model=model, tokenizer=tokenizer ) logger.info("Trainer initialized successfully") except Exception as e: logger.error(f"Failed to initialize trainer: {e}") raise # Resume from checkpoint if specified if args.resume_from_checkpoint: logger.info(f"Resuming from checkpoint: {args.resume_from_checkpoint}") try: trainer.load_checkpoint(args.resume_from_checkpoint) logger.info("Checkpoint loaded successfully.") except Exception as e: logger.error(f"Failed to load checkpoint: {e}") raise # Train logger.info("🎯 Starting enhanced training...") try: # Handle type compatibility for trainer from typing import cast trainer_train_dataset = cast(BGEM3Dataset, train_dataset) if train_dataset else None trainer_eval_dataset = cast(BGEM3Dataset, eval_dataset) if eval_dataset else None if not trainer_train_dataset: raise ValueError("Training dataset is required but not provided") trainer.train( train_dataset=trainer_train_dataset, eval_dataset=trainer_eval_dataset, resume_from_checkpoint=args.resume_from_checkpoint ) # Save final model if is_main_process(): logger.info("Saving final model...") final_path = os.path.join(args.output_dir, "final_model") os.makedirs(final_path, exist_ok=True) # Save model model.save_pretrained(final_path) tokenizer.save_pretrained(final_path) # Save training summary summary = { "training_args": vars(args), "model_config": config.to_dict(), "final_metrics": getattr(trainer, 'best_metric', None), "total_steps": getattr(trainer, 'global_step', 0), "training_completed": datetime.now().isoformat(), "device_used": str(device), "dataset_format": getattr(train_dataset, 'detected_format', 'unknown'), "enhanced_features": { "use_new_dataset_api": args.use_new_dataset_api, "use_hard_negatives": args.use_hard_negatives, "use_self_distill": args.use_self_distill, "use_score_weighted_sampling": args.use_score_weighted_sampling, "multiple_positives": args.multiple_positives, "legacy_support": args.legacy_support } } with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f: json.dump(summary, f, indent=2) logger.info(f"✅ Enhanced training completed! Model saved to {final_path}") # Show training summary logger.info("📊 Training Summary:") logger.info(f" Total examples: {len(train_dataset)}") if hasattr(train_dataset, 'detected_format'): logger.info(f" Format detected: {train_dataset.detected_format}") # type: ignore[attr-defined] logger.info(f" Epochs: {args.num_train_epochs}") logger.info(f" Learning rate: {args.learning_rate}") logger.info(f" Batch size: {args.per_device_train_batch_size}") logger.info(f" Temperature: {args.temperature}") logger.info(f" Enhanced features enabled: {args.use_new_dataset_api}") except Exception as e: logger.error(f"Training failed with error: {e}", exc_info=True) raise finally: # Cleanup if hasattr(trainer, 'hard_negative_miner') and trainer.hard_negative_miner: trainer.hard_negative_miner.stop_async_updates() if args.local_rank != -1: from utils.distributed import cleanup_distributed cleanup_distributed() if __name__ == "__main__": main()