498 lines
20 KiB
Python
498 lines
20 KiB
Python
"""
|
|
Joint training of BGE-M3 and BGE-Reranker using RocketQAv2 approach
|
|
"""
|
|
import os
|
|
import sys
|
|
import argparse
|
|
|
|
# Set tokenizer parallelism to avoid warnings
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
import logging
|
|
import json
|
|
from datetime import datetime
|
|
import torch
|
|
from transformers import AutoTokenizer, set_seed
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from config.m3_config import BGEM3Config
|
|
from config.reranker_config import BGERerankerConfig
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from data.dataset import BGEM3Dataset, BGERerankerDataset, JointDataset
|
|
from data.preprocessing import DataPreprocessor
|
|
from training.joint_trainer import JointTrainer
|
|
from utils.logging import setup_logging, get_logger
|
|
from utils.distributed import setup_distributed, is_main_process, set_random_seed
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Joint training of BGE-M3 and BGE-Reranker")
|
|
|
|
# Model arguments
|
|
parser.add_argument("--retriever_model", type=str, default="BAAI/bge-m3",
|
|
help="Path to retriever model")
|
|
parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-base",
|
|
help="Path to reranker model")
|
|
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
|
help="Directory to cache downloaded models")
|
|
|
|
# Data arguments
|
|
parser.add_argument("--train_data", type=str, required=True,
|
|
help="Path to training data file")
|
|
parser.add_argument("--eval_data", type=str, default=None,
|
|
help="Path to evaluation data file")
|
|
parser.add_argument("--data_format", type=str, default="auto",
|
|
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
|
help="Format of input data")
|
|
|
|
# Training arguments
|
|
parser.add_argument("--output_dir", type=str, default="./output/joint-training",
|
|
help="Directory to save the models")
|
|
parser.add_argument("--num_train_epochs", type=int, default=3,
|
|
help="Number of training epochs")
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
|
|
help="Batch size per GPU/CPU for training")
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
|
|
help="Batch size per GPU/CPU for evaluation")
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
|
|
help="Number of updates steps to accumulate")
|
|
parser.add_argument("--learning_rate", type=float, default=1e-5,
|
|
help="Initial learning rate for retriever")
|
|
parser.add_argument("--reranker_learning_rate", type=float, default=6e-5,
|
|
help="Initial learning rate for reranker")
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
|
help="Ratio of warmup steps")
|
|
parser.add_argument("--weight_decay", type=float, default=0.01,
|
|
help="Weight decay")
|
|
|
|
# Joint training specific
|
|
parser.add_argument("--joint_loss_weight", type=float, default=0.3,
|
|
help="Weight for joint distillation loss")
|
|
parser.add_argument("--update_retriever_steps", type=int, default=100,
|
|
help="Update retriever every N steps")
|
|
parser.add_argument("--use_dynamic_distillation", action="store_true",
|
|
help="Use dynamic listwise distillation")
|
|
|
|
# Model configuration
|
|
parser.add_argument("--retriever_train_group_size", type=int, default=8,
|
|
help="Number of passages per query for retriever")
|
|
parser.add_argument("--reranker_train_group_size", type=int, default=16,
|
|
help="Number of passages per query for reranker")
|
|
parser.add_argument("--temperature", type=float, default=0.02,
|
|
help="Temperature for InfoNCE loss")
|
|
|
|
# BGE-M3 specific
|
|
parser.add_argument("--use_dense", action="store_true", default=True,
|
|
help="Use dense embeddings")
|
|
parser.add_argument("--use_sparse", action="store_true",
|
|
help="Use sparse embeddings")
|
|
parser.add_argument("--use_colbert", action="store_true",
|
|
help="Use ColBERT embeddings")
|
|
parser.add_argument("--query_max_length", type=int, default=512,
|
|
help="Maximum query length")
|
|
parser.add_argument("--passage_max_length", type=int, default=8192,
|
|
help="Maximum passage length")
|
|
|
|
# Reranker specific
|
|
parser.add_argument("--reranker_max_length", type=int, default=512,
|
|
help="Maximum sequence length for reranker")
|
|
|
|
# Performance
|
|
parser.add_argument("--fp16", action="store_true",
|
|
help="Use mixed precision training")
|
|
parser.add_argument("--gradient_checkpointing", action="store_true",
|
|
help="Use gradient checkpointing")
|
|
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
|
help="Number of workers for data loading")
|
|
|
|
# Checkpointing
|
|
parser.add_argument("--save_steps", type=int, default=1000,
|
|
help="Save checkpoint every N steps")
|
|
parser.add_argument("--eval_steps", type=int, default=1000,
|
|
help="Evaluate every N steps")
|
|
parser.add_argument("--logging_steps", type=int, default=100,
|
|
help="Log every N steps")
|
|
parser.add_argument("--save_total_limit", type=int, default=3,
|
|
help="Maximum number of checkpoints to keep")
|
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
|
help="Path to checkpoint to resume from")
|
|
|
|
# Distributed training
|
|
parser.add_argument("--local_rank", type=int, default=-1,
|
|
help="Local rank for distributed training")
|
|
|
|
# Other
|
|
parser.add_argument("--seed", type=int, default=42,
|
|
help="Random seed")
|
|
parser.add_argument("--overwrite_output_dir", action="store_true",
|
|
help="Overwrite the output directory")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate arguments
|
|
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
|
|
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
# Parse arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging
|
|
setup_logging(
|
|
output_dir=args.output_dir,
|
|
log_level=logging.INFO,
|
|
log_to_console=True,
|
|
log_to_file=True
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info("Starting joint training of BGE-M3 and BGE-Reranker")
|
|
logger.info(f"Arguments: {args}")
|
|
|
|
# Set random seed
|
|
set_random_seed(args.seed, deterministic=False)
|
|
|
|
# Setup distributed training if needed
|
|
device = None
|
|
if args.local_rank != -1:
|
|
setup_distributed()
|
|
from utils.config_loader import get_config
|
|
config = get_config()
|
|
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
|
|
if device_type == 'npu':
|
|
try:
|
|
import torch_npu
|
|
torch_npu.npu.set_device(args.local_rank)
|
|
device = torch.device("npu", args.local_rank)
|
|
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
|
except ImportError:
|
|
logger.warning("torch_npu not installed, falling back to CPU")
|
|
device = torch.device("cpu")
|
|
elif torch.cuda.is_available():
|
|
torch.cuda.set_device(args.local_rank)
|
|
device = torch.device("cuda", args.local_rank)
|
|
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info("Distributed training enabled: using CPU")
|
|
else:
|
|
from utils.config_loader import get_config
|
|
config = get_config()
|
|
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
|
|
if device_type == 'npu':
|
|
try:
|
|
import torch_npu
|
|
device = torch.device("npu")
|
|
logger.info("Using Ascend NPU")
|
|
except ImportError:
|
|
logger.warning("torch_npu not installed, falling back to CPU")
|
|
device = torch.device("cpu")
|
|
elif torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info("Using CPU")
|
|
|
|
# Create configurations
|
|
retriever_config = BGEM3Config(
|
|
model_name_or_path=args.retriever_model,
|
|
cache_dir=args.cache_dir,
|
|
train_data_path=args.train_data,
|
|
eval_data_path=args.eval_data,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
output_dir=os.path.join(args.output_dir, "retriever"),
|
|
num_train_epochs=args.num_train_epochs,
|
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
learning_rate=args.learning_rate,
|
|
warmup_ratio=args.warmup_ratio,
|
|
weight_decay=args.weight_decay,
|
|
train_group_size=args.retriever_train_group_size,
|
|
temperature=args.temperature,
|
|
use_dense=args.use_dense,
|
|
use_sparse=args.use_sparse,
|
|
use_colbert=args.use_colbert,
|
|
fp16=args.fp16,
|
|
gradient_checkpointing=args.gradient_checkpointing,
|
|
dataloader_num_workers=args.dataloader_num_workers,
|
|
save_steps=args.save_steps,
|
|
eval_steps=args.eval_steps,
|
|
logging_steps=args.logging_steps,
|
|
save_total_limit=args.save_total_limit,
|
|
seed=args.seed,
|
|
local_rank=args.local_rank,
|
|
device=str(device)
|
|
)
|
|
|
|
reranker_config = BGERerankerConfig(
|
|
model_name_or_path=args.reranker_model,
|
|
cache_dir=args.cache_dir,
|
|
train_data_path=args.train_data,
|
|
eval_data_path=args.eval_data,
|
|
max_seq_length=args.reranker_max_length,
|
|
output_dir=os.path.join(args.output_dir, "reranker"),
|
|
num_train_epochs=args.num_train_epochs,
|
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
learning_rate=args.reranker_learning_rate,
|
|
warmup_ratio=args.warmup_ratio,
|
|
weight_decay=args.weight_decay,
|
|
train_group_size=args.reranker_train_group_size,
|
|
joint_training=True, # type: ignore[call-arg]
|
|
joint_loss_weight=args.joint_loss_weight, # type: ignore[call-arg]
|
|
use_dynamic_listwise_distillation=args.use_dynamic_distillation, # type: ignore[call-arg]
|
|
update_retriever_steps=args.update_retriever_steps, # type: ignore[call-arg]
|
|
fp16=args.fp16,
|
|
gradient_checkpointing=args.gradient_checkpointing,
|
|
dataloader_num_workers=args.dataloader_num_workers,
|
|
save_steps=args.save_steps,
|
|
eval_steps=args.eval_steps,
|
|
logging_steps=args.logging_steps,
|
|
save_total_limit=args.save_total_limit,
|
|
seed=args.seed,
|
|
local_rank=args.local_rank,
|
|
device=str(device)
|
|
)
|
|
|
|
# Create joint config (use retriever config as base)
|
|
joint_config = retriever_config
|
|
joint_config.output_dir = args.output_dir
|
|
joint_config.joint_loss_weight = args.joint_loss_weight # type: ignore[attr-defined]
|
|
joint_config.use_dynamic_listwise_distillation = args.use_dynamic_distillation # type: ignore[attr-defined]
|
|
joint_config.update_retriever_steps = args.update_retriever_steps # type: ignore[attr-defined]
|
|
|
|
# Save configurations
|
|
if is_main_process():
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
joint_config.save(os.path.join(args.output_dir, "joint_config.json"))
|
|
retriever_config.save(os.path.join(args.output_dir, "retriever_config.json"))
|
|
reranker_config.save(os.path.join(args.output_dir, "reranker_config.json"))
|
|
|
|
# Load tokenizers
|
|
logger.info("Loading tokenizers...")
|
|
retriever_tokenizer = AutoTokenizer.from_pretrained(
|
|
args.retriever_model,
|
|
cache_dir=args.cache_dir
|
|
)
|
|
reranker_tokenizer = AutoTokenizer.from_pretrained(
|
|
args.reranker_model,
|
|
cache_dir=args.cache_dir
|
|
)
|
|
|
|
# Prepare data
|
|
logger.info("Preparing data...")
|
|
preprocessor = DataPreprocessor(
|
|
tokenizer=retriever_tokenizer, # Use retriever tokenizer for preprocessing
|
|
max_query_length=args.query_max_length,
|
|
max_passage_length=args.passage_max_length
|
|
)
|
|
|
|
# Preprocess training data
|
|
train_path, val_path = preprocessor.preprocess_file(
|
|
input_path=args.train_data,
|
|
output_path=os.path.join(args.output_dir, "processed_train.jsonl"),
|
|
file_format=args.data_format,
|
|
validation_split=0.1 if args.eval_data is None else 0.0
|
|
)
|
|
|
|
# Use provided eval data or split validation
|
|
eval_path = args.eval_data if args.eval_data else val_path
|
|
|
|
# Create datasets
|
|
logger.info("Creating datasets...")
|
|
|
|
# Retriever dataset
|
|
retriever_train_dataset = BGEM3Dataset(
|
|
data_path=train_path,
|
|
tokenizer=retriever_tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
train_group_size=args.retriever_train_group_size,
|
|
is_train=True
|
|
)
|
|
|
|
# Reranker dataset
|
|
reranker_train_dataset = BGERerankerDataset(
|
|
data_path=train_path,
|
|
tokenizer=reranker_tokenizer,
|
|
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
|
|
train_group_size=args.reranker_train_group_size,
|
|
is_train=True
|
|
)
|
|
|
|
# Joint dataset
|
|
joint_train_dataset = JointDataset( # type: ignore[call-arg]
|
|
retriever_dataset=retriever_train_dataset, # type: ignore[call-arg]
|
|
reranker_dataset=reranker_train_dataset # type: ignore[call-arg]
|
|
)
|
|
|
|
# Evaluation datasets
|
|
joint_eval_dataset = None
|
|
if eval_path:
|
|
retriever_eval_dataset = BGEM3Dataset(
|
|
data_path=eval_path,
|
|
tokenizer=retriever_tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
train_group_size=args.retriever_train_group_size,
|
|
is_train=False
|
|
)
|
|
|
|
reranker_eval_dataset = BGERerankerDataset(
|
|
data_path=eval_path,
|
|
tokenizer=reranker_tokenizer,
|
|
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
|
|
train_group_size=args.reranker_train_group_size,
|
|
is_train=False
|
|
)
|
|
|
|
joint_eval_dataset = JointDataset( # type: ignore[call-arg]
|
|
retriever_dataset=retriever_eval_dataset, # type: ignore[call-arg]
|
|
reranker_dataset=reranker_eval_dataset # type: ignore[call-arg]
|
|
)
|
|
|
|
# Initialize models
|
|
logger.info("Loading models...")
|
|
retriever = BGEM3Model(
|
|
model_name_or_path=args.retriever_model,
|
|
use_dense=args.use_dense,
|
|
use_sparse=args.use_sparse,
|
|
use_colbert=args.use_colbert,
|
|
cache_dir=args.cache_dir
|
|
)
|
|
reranker = BGERerankerModel(
|
|
model_name_or_path=args.reranker_model,
|
|
cache_dir=args.cache_dir,
|
|
use_fp16=args.fp16
|
|
)
|
|
from utils.config_loader import get_config
|
|
logger.info("Retriever model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
|
|
logger.info("Reranker model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
|
|
|
|
# Initialize optimizers
|
|
from torch.optim import AdamW
|
|
|
|
retriever_optimizer = AdamW(
|
|
retriever.parameters(),
|
|
lr=args.learning_rate,
|
|
weight_decay=args.weight_decay
|
|
)
|
|
|
|
reranker_optimizer = AdamW(
|
|
reranker.parameters(),
|
|
lr=args.reranker_learning_rate,
|
|
weight_decay=args.weight_decay
|
|
)
|
|
|
|
# Initialize joint trainer
|
|
logger.info("Initializing joint trainer...")
|
|
trainer = JointTrainer(
|
|
config=joint_config,
|
|
retriever=retriever,
|
|
reranker=reranker,
|
|
retriever_optimizer=retriever_optimizer,
|
|
reranker_optimizer=reranker_optimizer,
|
|
device=device
|
|
)
|
|
|
|
# Create data loaders
|
|
from torch.utils.data import DataLoader
|
|
from utils.distributed import create_distributed_dataloader
|
|
|
|
if args.local_rank != -1:
|
|
train_dataloader = create_distributed_dataloader(
|
|
joint_train_dataset,
|
|
batch_size=args.per_device_train_batch_size,
|
|
shuffle=True,
|
|
num_workers=args.dataloader_num_workers,
|
|
seed=args.seed
|
|
)
|
|
|
|
eval_dataloader = None
|
|
if joint_eval_dataset:
|
|
eval_dataloader = create_distributed_dataloader(
|
|
joint_eval_dataset,
|
|
batch_size=args.per_device_eval_batch_size,
|
|
shuffle=False,
|
|
num_workers=args.dataloader_num_workers,
|
|
seed=args.seed
|
|
)
|
|
else:
|
|
train_dataloader = DataLoader(
|
|
joint_train_dataset,
|
|
batch_size=args.per_device_train_batch_size,
|
|
shuffle=True,
|
|
num_workers=args.dataloader_num_workers
|
|
)
|
|
|
|
eval_dataloader = None
|
|
if joint_eval_dataset:
|
|
eval_dataloader = DataLoader(
|
|
joint_eval_dataset,
|
|
batch_size=args.per_device_eval_batch_size,
|
|
num_workers=args.dataloader_num_workers
|
|
)
|
|
|
|
# Resume from checkpoint if specified
|
|
if args.resume_from_checkpoint:
|
|
trainer.load_checkpoint(args.resume_from_checkpoint)
|
|
|
|
# Train
|
|
logger.info("Starting joint training...")
|
|
try:
|
|
trainer.train(train_dataloader, eval_dataloader, args.num_train_epochs)
|
|
|
|
# Save final models
|
|
if is_main_process():
|
|
logger.info("Saving final models...")
|
|
final_dir = os.path.join(args.output_dir, "final_models")
|
|
trainer.save_models(final_dir)
|
|
|
|
# Save individual models
|
|
retriever_path = os.path.join(args.output_dir, "final_retriever")
|
|
reranker_path = os.path.join(args.output_dir, "final_reranker")
|
|
|
|
os.makedirs(retriever_path, exist_ok=True)
|
|
os.makedirs(reranker_path, exist_ok=True)
|
|
|
|
retriever.save_pretrained(retriever_path)
|
|
retriever_tokenizer.save_pretrained(retriever_path)
|
|
|
|
reranker.model.save_pretrained(reranker_path)
|
|
reranker_tokenizer.save_pretrained(reranker_path)
|
|
|
|
# Save training summary
|
|
summary = {
|
|
"training_args": vars(args),
|
|
"retriever_config": retriever_config.to_dict(),
|
|
"reranker_config": reranker_config.to_dict(),
|
|
"joint_config": joint_config.to_dict(),
|
|
"final_metrics": getattr(trainer, 'best_metric', None),
|
|
"total_steps": getattr(trainer, 'global_step', 0),
|
|
"training_completed": datetime.now().isoformat()
|
|
}
|
|
|
|
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
logger.info(f"Joint training completed! Models saved to {args.output_dir}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Joint training failed with error: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|