bge_finetune/tests/test_full_training.py
2025-07-22 17:49:16 +08:00

125 lines
3.9 KiB
Python

#!/usr/bin/env python
"""Test script to validate full training pipeline with learning rate scheduling"""
import torch
import json
import tempfile
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.m3_config import BGEM3Config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset, collate_embedding_batch
from training.m3_trainer import BGEM3Trainer
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
def create_test_data():
"""Create minimal test data"""
data = [
{
"query": "What is machine learning?",
"pos": ["Machine learning is a subset of AI"],
"neg": ["The weather is sunny today"]
},
{
"query": "How does Python work?",
"pos": ["Python is a programming language"],
"neg": ["Cats are cute animals"]
},
{
"query": "What is deep learning?",
"pos": ["Deep learning uses neural networks"],
"neg": ["Pizza is delicious food"]
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def main():
print("🧪 Testing BGE-M3 training with learning rate scheduling...")
# Create test data
train_file = create_test_data()
print(f"✅ Created test data: {train_file}")
# Initialize config with small dataset settings
config = BGEM3Config(
model_name_or_path="./models/bge-m3",
train_data_path=train_file,
output_dir="output/test-lr-scheduler",
num_train_epochs=3,
per_device_train_batch_size=1, # Smaller to ensure more steps
gradient_accumulation_steps=2, # More accumulation steps for stability
learning_rate=1e-5, # Lower learning rate for stability
logging_steps=1,
save_steps=0, # Disable frequent checkpointing during test
eval_steps=0, # Disable evaluation during test
fp16=False, # Disable mixed precision for stability
bf16=False, # Disable bfloat16
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
use_self_distill=False,
use_hard_negatives=False,
temperature=0.1, # Fixed temperature
save_total_limit=1 # Limit checkpoint storage
)
try:
# Load model and tokenizer
print("📦 Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGEM3Model.from_pretrained(config.model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGEM3Dataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ Dataset size: {len(dataset)}")
# Initialize trainer
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test training for a few steps
print("\n🏋️ Starting training test...")
trainer.train(dataset)
print("✅ Training completed successfully!")
return True
except Exception as e:
print(f"❌ Training failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)