125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
#!/usr/bin/env python
|
|
"""Test script to validate full training pipeline with learning rate scheduling"""
|
|
|
|
import torch
|
|
import json
|
|
import tempfile
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from config.m3_config import BGEM3Config
|
|
from models.bge_m3 import BGEM3Model
|
|
from data.dataset import BGEM3Dataset, collate_embedding_batch
|
|
from training.m3_trainer import BGEM3Trainer
|
|
from transformers import AutoTokenizer
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def create_test_data():
|
|
"""Create minimal test data"""
|
|
data = [
|
|
{
|
|
"query": "What is machine learning?",
|
|
"pos": ["Machine learning is a subset of AI"],
|
|
"neg": ["The weather is sunny today"]
|
|
},
|
|
{
|
|
"query": "How does Python work?",
|
|
"pos": ["Python is a programming language"],
|
|
"neg": ["Cats are cute animals"]
|
|
},
|
|
{
|
|
"query": "What is deep learning?",
|
|
"pos": ["Deep learning uses neural networks"],
|
|
"neg": ["Pizza is delicious food"]
|
|
}
|
|
]
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
|
for item in data:
|
|
f.write(json.dumps(item) + '\n')
|
|
return f.name
|
|
|
|
|
|
def main():
|
|
print("🧪 Testing BGE-M3 training with learning rate scheduling...")
|
|
|
|
# Create test data
|
|
train_file = create_test_data()
|
|
print(f"✅ Created test data: {train_file}")
|
|
|
|
# Initialize config with small dataset settings
|
|
config = BGEM3Config(
|
|
model_name_or_path="./models/bge-m3",
|
|
train_data_path=train_file,
|
|
output_dir="output/test-lr-scheduler",
|
|
num_train_epochs=3,
|
|
per_device_train_batch_size=1, # Smaller to ensure more steps
|
|
gradient_accumulation_steps=2, # More accumulation steps for stability
|
|
learning_rate=1e-5, # Lower learning rate for stability
|
|
logging_steps=1,
|
|
save_steps=0, # Disable frequent checkpointing during test
|
|
eval_steps=0, # Disable evaluation during test
|
|
fp16=False, # Disable mixed precision for stability
|
|
bf16=False, # Disable bfloat16
|
|
use_amp=False,
|
|
dataloader_num_workers=0,
|
|
overwrite_output_dir=True,
|
|
use_self_distill=False,
|
|
use_hard_negatives=False,
|
|
temperature=0.1, # Fixed temperature
|
|
save_total_limit=1 # Limit checkpoint storage
|
|
)
|
|
|
|
try:
|
|
# Load model and tokenizer
|
|
print("📦 Loading model and tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
|
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
print(f"🖥️ Using device: {device}")
|
|
|
|
# Create dataset
|
|
dataset = BGEM3Dataset(
|
|
data_path=train_file,
|
|
tokenizer=tokenizer,
|
|
query_max_length=config.query_max_length,
|
|
passage_max_length=config.passage_max_length,
|
|
is_train=True
|
|
)
|
|
print(f"✅ Dataset size: {len(dataset)}")
|
|
|
|
# Initialize trainer
|
|
trainer = BGEM3Trainer(
|
|
config=config,
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
device=device
|
|
)
|
|
|
|
# Test training for a few steps
|
|
print("\n🏋️ Starting training test...")
|
|
trainer.train(dataset)
|
|
|
|
print("✅ Training completed successfully!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Training failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
finally:
|
|
# Clean up
|
|
Path(train_file).unlink()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = main()
|
|
sys.exit(0 if success else 1)
|