615 lines
26 KiB
Python
615 lines
26 KiB
Python
"""
|
|
Enhanced BGE-M3 embedding model training script with integrated dataset pipeline
|
|
Supports both legacy nested format and new embedding schema with automatic detection
|
|
"""
|
|
import os
|
|
import sys
|
|
import argparse
|
|
|
|
# Set tokenizer parallelism to avoid warnings
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
import logging
|
|
import json
|
|
from datetime import datetime
|
|
import torch
|
|
from transformers import AutoTokenizer, set_seed
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
from transformers.optimization import get_linear_schedule_with_warmup
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
try:
|
|
from config.m3_config import BGEM3Config
|
|
from models.bge_m3 import BGEM3Model
|
|
from data.dataset import BGEM3Dataset, create_embedding_dataset, collate_embedding_batch, migrate_nested_to_flat
|
|
from data.preprocessing import DataPreprocessor
|
|
from data.augmentation import QueryAugmenter, HardNegativeAugmenter
|
|
from data.hard_negative_mining import ANCEHardNegativeMiner
|
|
from training.m3_trainer import BGEM3Trainer
|
|
from evaluation.evaluator import RetrievalEvaluator
|
|
from utils.logging import setup_logging, get_logger
|
|
from utils.distributed import setup_distributed, is_main_process, set_random_seed
|
|
from utils.config_loader import get_config, load_config_for_training, resolve_model_path
|
|
except ImportError as e:
|
|
print(f"Import error: {e}")
|
|
print("Please ensure you're running from the project root directory")
|
|
sys.exit(1)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Enhanced BGE-M3 Embedding Model Training")
|
|
|
|
# Model arguments
|
|
parser.add_argument("--model_name_or_path", type=str, default=None,
|
|
help="Path to pretrained model or model identifier")
|
|
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
|
help="Directory to cache downloaded models")
|
|
parser.add_argument("--config_path", type=str, default="./config.toml",
|
|
help="Path to configuration file")
|
|
|
|
# Data arguments - support both legacy and new schemas
|
|
parser.add_argument("--train_data", type=str, required=True,
|
|
help="Path to training data file (supports both embedding and legacy nested formats)")
|
|
parser.add_argument("--eval_data", type=str, default=None,
|
|
help="Path to evaluation data file")
|
|
parser.add_argument("--data_format", type=str, default="auto",
|
|
choices=["auto", "embedding", "legacy_nested", "jsonl", "json", "csv", "tsv"],
|
|
help="Format of input data (auto-detects by default)")
|
|
parser.add_argument("--legacy_data", type=str, default=None,
|
|
help="Path to additional legacy nested data (for backward compatibility)")
|
|
parser.add_argument("--migrate_legacy", action="store_true",
|
|
help="Migrate legacy nested format to flat format")
|
|
|
|
# Sequence length arguments
|
|
parser.add_argument("--max_seq_length", type=int, default=None,
|
|
help="Maximum sequence length")
|
|
parser.add_argument("--query_max_length", type=int, default=None,
|
|
help="Maximum query length")
|
|
parser.add_argument("--passage_max_length", type=int, default=None,
|
|
help="Maximum passage length")
|
|
parser.add_argument("--max_length", type=int, default=8192,
|
|
help="Maximum total input length (multi-granularity)")
|
|
|
|
# Training arguments
|
|
parser.add_argument("--output_dir", type=str, default="./output/bge-m3-enhanced",
|
|
help="Directory to save the model")
|
|
parser.add_argument("--num_train_epochs", type=int, default=None,
|
|
help="Number of training epochs")
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=None,
|
|
help="Batch size per GPU/CPU for training")
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=None,
|
|
help="Batch size per GPU/CPU for evaluation")
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
|
|
help="Number of updates steps to accumulate before performing a backward/update pass")
|
|
parser.add_argument("--learning_rate", type=float, default=None,
|
|
help="Initial learning rate")
|
|
parser.add_argument("--warmup_ratio", type=float, default=None,
|
|
help="Warmup steps as a ratio of total steps")
|
|
parser.add_argument("--weight_decay", type=float, default=None,
|
|
help="Weight decay")
|
|
|
|
# Model specific arguments - enhanced features
|
|
parser.add_argument("--train_group_size", type=int, default=8,
|
|
help="Number of passages per query in training")
|
|
parser.add_argument("--temperature", type=float, default=0.1,
|
|
help="Temperature for InfoNCE loss")
|
|
parser.add_argument("--use_hard_negatives", action="store_true",
|
|
help="Enable hard negative mining")
|
|
parser.add_argument("--use_self_distill", action="store_true",
|
|
help="Enable self-knowledge distillation")
|
|
parser.add_argument("--use_score_weighted_sampling", action="store_true", default=True,
|
|
help="Enable score-weighted positive sampling")
|
|
parser.add_argument("--query_instruction", type=str, default="",
|
|
help="Query instruction prefix")
|
|
|
|
# Enhanced sampling features
|
|
parser.add_argument("--multiple_positives", action="store_true", default=True,
|
|
help="Use multiple positives per batch for richer contrastive signals")
|
|
parser.add_argument("--hard_negative_ratio", type=float, default=0.7,
|
|
help="Ratio of hard to random negatives")
|
|
parser.add_argument("--difficulty_threshold", type=float, default=0.1,
|
|
help="Minimum difficulty for hard negatives")
|
|
|
|
# Optimization and performance
|
|
parser.add_argument("--fp16", action="store_true",
|
|
help="Use mixed precision training")
|
|
parser.add_argument("--gradient_checkpointing", action="store_true",
|
|
help="Use gradient checkpointing to save memory")
|
|
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
|
help="Number of workers for data loading")
|
|
|
|
# Checkpointing
|
|
parser.add_argument("--save_steps", type=int, default=500,
|
|
help="Save checkpoint every N steps")
|
|
parser.add_argument("--eval_steps", type=int, default=500,
|
|
help="Evaluate every N steps")
|
|
parser.add_argument("--logging_steps", type=int, default=100,
|
|
help="Log every N steps")
|
|
parser.add_argument("--save_total_limit", type=int, default=3,
|
|
help="Maximum number of checkpoints to keep")
|
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
|
help="Path to checkpoint to resume from")
|
|
|
|
# Distributed training
|
|
parser.add_argument("--local_rank", type=int, default=-1,
|
|
help="Local rank for distributed training")
|
|
|
|
# Other
|
|
parser.add_argument("--seed", type=int, default=42,
|
|
help="Random seed")
|
|
parser.add_argument("--overwrite_output_dir", action="store_true",
|
|
help="Overwrite the output directory")
|
|
parser.add_argument("--log_level", type=str, default="INFO",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
help="Logging level")
|
|
|
|
# Backward compatibility flags
|
|
parser.add_argument("--legacy_support", action="store_true", default=True,
|
|
help="Enable legacy nested format support")
|
|
parser.add_argument("--use_new_dataset_api", action="store_true", default=True,
|
|
help="Use new integrated dataset API with format detection")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate arguments
|
|
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
|
|
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
|
|
|
|
return args
|
|
|
|
|
|
def create_config_from_args(args):
|
|
"""Create configuration from command line arguments and config file
|
|
|
|
Parameter Precedence (highest to lowest):
|
|
1. Command-line arguments (CLI)
|
|
2. Model-specific config ([m3], [reranker])
|
|
3. [hardware] section in config.toml
|
|
4. [training] section in config.toml
|
|
5. Hardcoded defaults in code
|
|
"""
|
|
# Load base configuration from TOML file
|
|
config_dict = load_config_for_training(
|
|
config_path=args.config_path,
|
|
override_params={}
|
|
)
|
|
|
|
# Get the centralized config instance
|
|
central_config = get_config()
|
|
|
|
# Create override parameters from command line arguments
|
|
override_params = {}
|
|
|
|
# Model configuration - use new resolve_model_path function
|
|
resolved_model_path, resolved_cache_dir = resolve_model_path(
|
|
model_key='bge_m3',
|
|
config_instance=central_config,
|
|
model_name_or_path=args.model_name_or_path,
|
|
cache_dir=args.cache_dir
|
|
)
|
|
|
|
args.model_name_or_path = resolved_model_path
|
|
args.cache_dir = resolved_cache_dir
|
|
|
|
# Data configuration
|
|
data_config = config_dict.get('data', {})
|
|
if args.max_seq_length is None:
|
|
args.max_seq_length = data_config.get('max_seq_length', 512)
|
|
if args.query_max_length is None:
|
|
args.query_max_length = data_config.get('max_query_length', 64)
|
|
if args.passage_max_length is None:
|
|
args.passage_max_length = data_config.get('max_passage_length', 512)
|
|
|
|
# Get validation split from config
|
|
validation_split = data_config.get('validation_split', 0.1)
|
|
|
|
# Training configuration
|
|
training_config = config_dict.get('training', {})
|
|
if args.num_train_epochs is None:
|
|
args.num_train_epochs = training_config.get('default_num_epochs', 3)
|
|
if args.per_device_train_batch_size is None:
|
|
args.per_device_train_batch_size = training_config.get('default_batch_size', 4)
|
|
if args.per_device_eval_batch_size is None:
|
|
args.per_device_eval_batch_size = args.per_device_train_batch_size * 2
|
|
if args.learning_rate is None:
|
|
args.learning_rate = training_config.get('default_learning_rate', 1e-5)
|
|
if args.warmup_ratio is None:
|
|
args.warmup_ratio = training_config.get('default_warmup_ratio', 0.1)
|
|
if args.weight_decay is None:
|
|
args.weight_decay = training_config.get('default_weight_decay', 0.01)
|
|
|
|
# Model-specific (M3) configuration - highest precedence
|
|
m3_config = config_dict.get('m3', {})
|
|
train_group_size = m3_config.get('train_group_size', args.train_group_size)
|
|
temperature = m3_config.get('temperature', args.temperature)
|
|
use_hard_negatives = m3_config.get('use_hard_negatives', args.use_hard_negatives)
|
|
use_self_distill = m3_config.get('use_self_distill', args.use_self_distill)
|
|
query_instruction_for_retrieval = m3_config.get('query_instruction_for_retrieval', args.query_instruction)
|
|
|
|
# Hardware configuration
|
|
hardware_config = config_dict.get('hardware', {})
|
|
if args.per_device_train_batch_size is None:
|
|
args.per_device_train_batch_size = hardware_config.get('per_device_train_batch_size', args.per_device_train_batch_size)
|
|
if args.per_device_eval_batch_size is None:
|
|
args.per_device_eval_batch_size = hardware_config.get('per_device_eval_batch_size', args.per_device_eval_batch_size)
|
|
|
|
# Create BGE-M3 configuration
|
|
config = BGEM3Config(
|
|
model_name_or_path=args.model_name_or_path,
|
|
cache_dir=args.cache_dir,
|
|
train_data_path=args.train_data,
|
|
eval_data_path=args.eval_data,
|
|
max_seq_length=args.max_seq_length,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
max_length=args.max_length,
|
|
output_dir=args.output_dir,
|
|
num_train_epochs=args.num_train_epochs,
|
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
learning_rate=args.learning_rate,
|
|
warmup_ratio=args.warmup_ratio,
|
|
weight_decay=args.weight_decay,
|
|
train_group_size=train_group_size,
|
|
temperature=temperature,
|
|
use_hard_negatives=use_hard_negatives,
|
|
use_self_distill=use_self_distill,
|
|
query_instruction_for_retrieval=query_instruction_for_retrieval,
|
|
fp16=args.fp16,
|
|
gradient_checkpointing=args.gradient_checkpointing,
|
|
dataloader_num_workers=args.dataloader_num_workers,
|
|
save_steps=args.save_steps,
|
|
eval_steps=args.eval_steps,
|
|
logging_steps=args.logging_steps,
|
|
save_total_limit=args.save_total_limit,
|
|
seed=args.seed,
|
|
local_rank=args.local_rank,
|
|
overwrite_output_dir=args.overwrite_output_dir
|
|
)
|
|
|
|
return config
|
|
|
|
|
|
def setup_model_and_tokenizer(args):
|
|
"""Setup BGE-M3 model and tokenizer"""
|
|
logger.info(f"Loading BGE-M3 model from {args.model_name_or_path}")
|
|
|
|
# Detect offline mode if the path is local or HF_HUB_OFFLINE is set
|
|
local_files_only = os.path.isdir(args.model_name_or_path) or bool(os.environ.get("HF_HUB_OFFLINE"))
|
|
|
|
# Load tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
args.model_name_or_path,
|
|
cache_dir=args.cache_dir,
|
|
trust_remote_code=True,
|
|
local_files_only=local_files_only
|
|
)
|
|
|
|
# Load model
|
|
try:
|
|
model = BGEM3Model(
|
|
model_name_or_path=args.model_name_or_path,
|
|
cache_dir=args.cache_dir
|
|
)
|
|
logger.info("Model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
raise
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def setup_datasets(args, tokenizer):
|
|
"""Setup datasets with enhanced format detection and migration support"""
|
|
logger.info("Setting up datasets with enhanced format detection...")
|
|
|
|
# Handle legacy data migration if needed
|
|
if args.legacy_data and args.migrate_legacy:
|
|
logger.info(f"Migrating legacy data from {args.legacy_data}")
|
|
legacy_base = os.path.splitext(args.legacy_data)[0]
|
|
migrated_file = f"{legacy_base}_migrated_embedding.jsonl"
|
|
|
|
# For embedding, we don't need to migrate to flat pairs, but we can optimize the format
|
|
# Copy the legacy data to work with if needed
|
|
import shutil
|
|
shutil.copy2(args.legacy_data, migrated_file)
|
|
logger.info(f"Legacy data prepared for training: {migrated_file}")
|
|
|
|
# Setup training dataset
|
|
if args.use_new_dataset_api:
|
|
# Use new integrated dataset API with format detection
|
|
train_dataset = create_embedding_dataset(
|
|
data_path=args.train_data,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
is_train=True
|
|
)
|
|
|
|
logger.info(f"Loaded {len(train_dataset)} training examples from {args.train_data}")
|
|
logger.info(f"Detected format: {train_dataset.detected_format}")
|
|
|
|
# Add legacy data if provided
|
|
if args.legacy_data and args.legacy_support:
|
|
logger.info(f"Loading additional legacy data from {args.legacy_data}")
|
|
legacy_dataset = create_embedding_dataset(
|
|
data_path=args.legacy_data,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
is_train=True
|
|
)
|
|
logger.info(f"Loaded {len(legacy_dataset)} legacy examples")
|
|
|
|
# Combine datasets
|
|
train_dataset.data.extend(legacy_dataset.data)
|
|
logger.info(f"Total training examples: {len(train_dataset)}")
|
|
else:
|
|
# Use original dataset API for backward compatibility
|
|
train_dataset = BGEM3Dataset(
|
|
data_path=args.train_data,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
train_group_size=args.train_group_size,
|
|
is_train=True
|
|
)
|
|
logger.info(f"Loaded {len(train_dataset)} training examples (legacy API)")
|
|
|
|
# Setup evaluation dataset
|
|
eval_dataset = None
|
|
if args.eval_data:
|
|
if args.use_new_dataset_api:
|
|
eval_dataset = create_embedding_dataset(
|
|
data_path=args.eval_data,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
is_train=False
|
|
)
|
|
else:
|
|
eval_dataset = BGEM3Dataset(
|
|
data_path=args.eval_data,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
train_group_size=args.train_group_size,
|
|
is_train=False
|
|
)
|
|
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
|
|
|
|
return train_dataset, eval_dataset
|
|
|
|
|
|
def setup_data_loaders(args, train_dataset, eval_dataset):
|
|
"""Setup data loaders with appropriate collate functions"""
|
|
|
|
# 4. Set up data loaders with appropriate collate function
|
|
if hasattr(train_dataset, 'format') and train_dataset.format == 'triplets':
|
|
# This might be a generic BGE dataset in triplets format
|
|
collate_fn = collate_embedding_batch
|
|
logger.info("Using embedding collate function for triplets format")
|
|
else:
|
|
collate_fn = collate_embedding_batch
|
|
logger.info("Using standard embedding collate function")
|
|
|
|
# FastM3Dataset doesn't work well with multiprocessing due to pickle loading
|
|
workers = args.dataloader_num_workers
|
|
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=args.per_device_train_batch_size,
|
|
shuffle=True,
|
|
num_workers=workers,
|
|
collate_fn=collate_fn,
|
|
pin_memory=torch.cuda.is_available()
|
|
)
|
|
|
|
# Evaluation data loader
|
|
eval_dataloader = None
|
|
if eval_dataset:
|
|
eval_dataloader = DataLoader(
|
|
eval_dataset,
|
|
batch_size=args.per_device_eval_batch_size,
|
|
shuffle=False,
|
|
num_workers=args.dataloader_num_workers,
|
|
collate_fn=collate_fn,
|
|
pin_memory=torch.cuda.is_available()
|
|
)
|
|
|
|
return train_dataloader, eval_dataloader
|
|
|
|
|
|
def main():
|
|
# Parse arguments
|
|
args = parse_args()
|
|
|
|
# Create configuration
|
|
config = create_config_from_args(args)
|
|
|
|
# Setup logging
|
|
setup_logging(
|
|
output_dir=args.output_dir,
|
|
log_level=getattr(logging, args.log_level),
|
|
log_to_console=True,
|
|
log_to_file=True
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info("🚀 Starting Enhanced BGE-M3 Fine-tuning")
|
|
logger.info(f"Arguments: {args}")
|
|
logger.info(f"Configuration: {config.to_dict()}")
|
|
|
|
# Set random seed
|
|
set_random_seed(args.seed, deterministic=False)
|
|
|
|
# Setup distributed training if needed
|
|
device = None
|
|
if args.local_rank != -1:
|
|
setup_distributed()
|
|
# Device selection from config
|
|
from utils.config_loader import get_config
|
|
config_instance = get_config()
|
|
device_type = config_instance.get('ascend.device_type', 'cuda')
|
|
if device_type == 'npu':
|
|
try:
|
|
import torch_npu # type: ignore
|
|
torch_npu.npu.set_device(args.local_rank)
|
|
device = torch.device("npu", args.local_rank)
|
|
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
|
except ImportError:
|
|
logger.warning("torch_npu not installed, falling back to CPU")
|
|
device = torch.device("cpu")
|
|
elif torch.cuda.is_available():
|
|
torch.cuda.set_device(args.local_rank)
|
|
device = torch.device("cuda", args.local_rank)
|
|
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info("Distributed training enabled: using CPU")
|
|
else:
|
|
# Non-distributed device selection
|
|
from utils.config_loader import get_config
|
|
config_instance = get_config()
|
|
device_type = config_instance.get('ascend.device_type', 'cuda')
|
|
if device_type == 'npu':
|
|
try:
|
|
import torch_npu # type: ignore
|
|
device = torch.device("npu")
|
|
logger.info("Using Ascend NPU")
|
|
except ImportError:
|
|
logger.warning("torch_npu not installed, falling back to CPU")
|
|
device = torch.device("cpu")
|
|
elif torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info("Using CPU")
|
|
|
|
# Update config with device
|
|
config.device = str(device)
|
|
|
|
# Save configuration
|
|
if is_main_process():
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
config.save(os.path.join(args.output_dir, "config.json"))
|
|
|
|
# Setup model and tokenizer
|
|
model, tokenizer = setup_model_and_tokenizer(args)
|
|
|
|
# Setup datasets with enhanced format detection
|
|
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
|
|
|
|
# Setup data loaders
|
|
train_dataloader, eval_dataloader = setup_data_loaders(args, train_dataset, eval_dataset)
|
|
|
|
# Initialize trainer
|
|
logger.info("Initializing enhanced trainer...")
|
|
try:
|
|
trainer = BGEM3Trainer(
|
|
config=config,
|
|
model=model,
|
|
tokenizer=tokenizer
|
|
)
|
|
|
|
logger.info("Trainer initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize trainer: {e}")
|
|
raise
|
|
|
|
# Resume from checkpoint if specified
|
|
if args.resume_from_checkpoint:
|
|
logger.info(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
|
|
try:
|
|
trainer.load_checkpoint(args.resume_from_checkpoint)
|
|
logger.info("Checkpoint loaded successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load checkpoint: {e}")
|
|
raise
|
|
|
|
# Train
|
|
logger.info("🎯 Starting enhanced training...")
|
|
try:
|
|
# Handle type compatibility for trainer
|
|
from typing import cast
|
|
trainer_train_dataset = cast(BGEM3Dataset, train_dataset) if train_dataset else None
|
|
trainer_eval_dataset = cast(BGEM3Dataset, eval_dataset) if eval_dataset else None
|
|
|
|
if not trainer_train_dataset:
|
|
raise ValueError("Training dataset is required but not provided")
|
|
|
|
trainer.train(
|
|
train_dataset=trainer_train_dataset,
|
|
eval_dataset=trainer_eval_dataset,
|
|
resume_from_checkpoint=args.resume_from_checkpoint
|
|
)
|
|
|
|
# Save final model
|
|
if is_main_process():
|
|
logger.info("Saving final model...")
|
|
final_path = os.path.join(args.output_dir, "final_model")
|
|
os.makedirs(final_path, exist_ok=True)
|
|
|
|
# Save model
|
|
model.save_pretrained(final_path)
|
|
tokenizer.save_pretrained(final_path)
|
|
|
|
# Save training summary
|
|
summary = {
|
|
"training_args": vars(args),
|
|
"model_config": config.to_dict(),
|
|
"final_metrics": getattr(trainer, 'best_metric', None),
|
|
"total_steps": getattr(trainer, 'global_step', 0),
|
|
"training_completed": datetime.now().isoformat(),
|
|
"device_used": str(device),
|
|
"dataset_format": getattr(train_dataset, 'detected_format', 'unknown'),
|
|
"enhanced_features": {
|
|
"use_new_dataset_api": args.use_new_dataset_api,
|
|
"use_hard_negatives": args.use_hard_negatives,
|
|
"use_self_distill": args.use_self_distill,
|
|
"use_score_weighted_sampling": args.use_score_weighted_sampling,
|
|
"multiple_positives": args.multiple_positives,
|
|
"legacy_support": args.legacy_support
|
|
}
|
|
}
|
|
|
|
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
logger.info(f"✅ Enhanced training completed! Model saved to {final_path}")
|
|
|
|
# Show training summary
|
|
logger.info("📊 Training Summary:")
|
|
logger.info(f" Total examples: {len(train_dataset)}")
|
|
if hasattr(train_dataset, 'detected_format'):
|
|
logger.info(f" Format detected: {train_dataset.detected_format}") # type: ignore[attr-defined]
|
|
logger.info(f" Epochs: {args.num_train_epochs}")
|
|
logger.info(f" Learning rate: {args.learning_rate}")
|
|
logger.info(f" Batch size: {args.per_device_train_batch_size}")
|
|
logger.info(f" Temperature: {args.temperature}")
|
|
logger.info(f" Enhanced features enabled: {args.use_new_dataset_api}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training failed with error: {e}", exc_info=True)
|
|
raise
|
|
|
|
finally:
|
|
# Cleanup
|
|
if hasattr(trainer, 'hard_negative_miner') and trainer.hard_negative_miner:
|
|
trainer.hard_negative_miner.stop_async_updates()
|
|
|
|
if args.local_rank != -1:
|
|
from utils.distributed import cleanup_distributed
|
|
cleanup_distributed()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|