init commit
This commit is contained in:
121
tests/test_full_training.py
Normal file
121
tests/test_full_training.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script to validate full training pipeline with learning rate scheduling"""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import tempfile
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset, collate_embedding_batch
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def create_test_data():
|
||||
"""Create minimal test data"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": ["Machine learning is a subset of AI"],
|
||||
"neg": ["The weather is sunny today"]
|
||||
},
|
||||
{
|
||||
"query": "How does Python work?",
|
||||
"pos": ["Python is a programming language"],
|
||||
"neg": ["Cats are cute animals"]
|
||||
},
|
||||
{
|
||||
"query": "What is deep learning?",
|
||||
"pos": ["Deep learning uses neural networks"],
|
||||
"neg": ["Pizza is delicious food"]
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def main():
|
||||
print("🧪 Testing BGE-M3 training with learning rate scheduling...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_test_data()
|
||||
print(f"✅ Created test data: {train_file}")
|
||||
|
||||
# Initialize config with small dataset settings
|
||||
config = BGEM3Config(
|
||||
model_name_or_path="./models/bge-m3",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-lr-scheduler",
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1, # Small for testing
|
||||
learning_rate=1e-4, # Higher for testing
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
use_self_distill=False,
|
||||
use_hard_negatives=False,
|
||||
temperature=0.1 # Fixed temperature
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ Dataset size: {len(dataset)}")
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test training for a few steps
|
||||
print("\n🏋️ Starting training test...")
|
||||
trainer.train(dataset)
|
||||
|
||||
print("✅ Training completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Training failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user