init commit
This commit is contained in:
299
tests/test_training_pipeline.py
Normal file
299
tests/test_training_pipeline.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python
|
||||
"""Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines"""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import tempfile
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def create_embedding_test_data():
|
||||
"""Create test data for embedding model (triplets format)"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."],
|
||||
"neg": ["The weather today is sunny and warm."]
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"pos": ["Deep learning uses neural networks with multiple layers to learn patterns."],
|
||||
"neg": ["Pizza is a popular Italian dish."]
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def create_reranker_test_data():
|
||||
"""Create test data for reranker model (pairs format)"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
},
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "The weather today is sunny and warm.",
|
||||
"label": 0,
|
||||
"score": 0.15,
|
||||
"qid": "q1",
|
||||
"pid": "p2"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Deep learning uses neural networks with multiple layers to learn patterns.",
|
||||
"label": 1,
|
||||
"score": 0.92,
|
||||
"qid": "q2",
|
||||
"pid": "p3"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Pizza is a popular Italian dish.",
|
||||
"label": 0,
|
||||
"score": 0.12,
|
||||
"qid": "q2",
|
||||
"pid": "p4"
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def test_bge_m3():
|
||||
"""Test BGE-M3 embedding model"""
|
||||
print("\n🎯 Testing BGE-M3 Embedding Model...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_embedding_test_data()
|
||||
print(f"✅ Created M3 test data: {train_file}")
|
||||
|
||||
# Initialize config
|
||||
config = BGEM3Config(
|
||||
model_name_or_path="./models/bge-m3",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-m3",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=1e-5,
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
use_self_distill=False, # Disable for testing
|
||||
use_hard_negatives=False # Disable for testing
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading M3 model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ M3 Dataset size: {len(dataset)}")
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_embedding_batch
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test one batch
|
||||
for batch in dataloader:
|
||||
print(f"M3 Batch keys: {list(batch.keys())}")
|
||||
print(f"Query shape: {batch['query_input_ids'].shape}")
|
||||
print(f"Passage shape: {batch['passage_input_ids'].shape}")
|
||||
print(f"Labels shape: {batch['labels'].shape}")
|
||||
|
||||
# Move batch to device
|
||||
batch = trainer._prepare_batch(batch)
|
||||
|
||||
# Compute loss
|
||||
with torch.no_grad():
|
||||
loss = trainer.compute_loss(batch)
|
||||
|
||||
print(f"✅ M3 Loss computed successfully: {loss.item():.4f}")
|
||||
break
|
||||
|
||||
print("✅ BGE-M3 test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BGE-M3 test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
def test_bge_reranker():
|
||||
"""Test BGE-Reranker model"""
|
||||
print("\n🎯 Testing BGE-Reranker Model...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_reranker_test_data()
|
||||
print(f"✅ Created Reranker test data: {train_file}")
|
||||
|
||||
# Initialize config
|
||||
config = BGERerankerConfig(
|
||||
model_name_or_path="./models/bge-reranker-base",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-reranker",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=1e-5,
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
loss_type="binary_cross_entropy"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading Reranker model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=config.model_name_or_path,
|
||||
num_labels=1,
|
||||
use_fp16=config.fp16,
|
||||
gradient_checkpointing=config.gradient_checkpointing
|
||||
)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGERerankerDataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ Reranker Dataset size: {len(dataset)}")
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_reranker_batch
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGERerankerTrainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test one batch
|
||||
for batch in dataloader:
|
||||
print(f"Reranker Batch keys: {list(batch.keys())}")
|
||||
print(f"Input IDs shape: {batch['input_ids'].shape}")
|
||||
print(f"Attention mask shape: {batch['attention_mask'].shape}")
|
||||
print(f"Labels shape: {batch['labels'].shape}")
|
||||
print(f"Labels values: {batch['labels']}")
|
||||
|
||||
# Move batch to device
|
||||
batch = trainer._prepare_batch(batch)
|
||||
|
||||
# Compute loss
|
||||
with torch.no_grad():
|
||||
loss = trainer.compute_loss(batch)
|
||||
|
||||
print(f"✅ Reranker Loss computed successfully: {loss.item():.4f}")
|
||||
break
|
||||
|
||||
print("✅ BGE-Reranker test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BGE-Reranker test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("🚀 Starting comprehensive BGE training pipeline tests...")
|
||||
|
||||
m3_success = test_bge_m3()
|
||||
reranker_success = test_bge_reranker()
|
||||
|
||||
print("\n📊 Test Results Summary:")
|
||||
print(f" BGE-M3: {'✅ PASS' if m3_success else '❌ FAIL'}")
|
||||
print(f" BGE-Reranker: {'✅ PASS' if reranker_success else '❌ FAIL'}")
|
||||
|
||||
if m3_success and reranker_success:
|
||||
print("\n🎉 All tests passed! Both models are working correctly.")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the error messages above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user