bge_finetune/tests/test_training_pipeline.py
2025-07-22 16:55:25 +08:00

300 lines
9.4 KiB
Python

#!/usr/bin/env python
"""Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines"""
import torch
import json
import tempfile
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch
from training.m3_trainer import BGEM3Trainer
from training.reranker_trainer import BGERerankerTrainer
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
def create_embedding_test_data():
"""Create test data for embedding model (triplets format)"""
data = [
{
"query": "What is machine learning?",
"pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."],
"neg": ["The weather today is sunny and warm."]
},
{
"query": "How does deep learning work?",
"pos": ["Deep learning uses neural networks with multiple layers to learn patterns."],
"neg": ["Pizza is a popular Italian dish."]
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def create_reranker_test_data():
"""Create test data for reranker model (pairs format)"""
data = [
{
"query": "What is machine learning?",
"passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
"label": 1,
"score": 0.95,
"qid": "q1",
"pid": "p1"
},
{
"query": "What is machine learning?",
"passage": "The weather today is sunny and warm.",
"label": 0,
"score": 0.15,
"qid": "q1",
"pid": "p2"
},
{
"query": "How does deep learning work?",
"passage": "Deep learning uses neural networks with multiple layers to learn patterns.",
"label": 1,
"score": 0.92,
"qid": "q2",
"pid": "p3"
},
{
"query": "How does deep learning work?",
"passage": "Pizza is a popular Italian dish.",
"label": 0,
"score": 0.12,
"qid": "q2",
"pid": "p4"
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def test_bge_m3():
"""Test BGE-M3 embedding model"""
print("\n🎯 Testing BGE-M3 Embedding Model...")
# Create test data
train_file = create_embedding_test_data()
print(f"✅ Created M3 test data: {train_file}")
# Initialize config
config = BGEM3Config(
model_name_or_path="./models/bge-m3",
train_data_path=train_file,
output_dir="output/test-m3",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=1e-5,
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
use_self_distill=False, # Disable for testing
use_hard_negatives=False # Disable for testing
)
try:
# Load model and tokenizer
print("📦 Loading M3 model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGEM3Model.from_pretrained(config.model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGEM3Dataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ M3 Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=config.per_device_train_batch_size,
shuffle=True,
collate_fn=collate_embedding_batch
)
# Initialize trainer
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test one batch
for batch in dataloader:
print(f"M3 Batch keys: {list(batch.keys())}")
print(f"Query shape: {batch['query_input_ids'].shape}")
print(f"Passage shape: {batch['passage_input_ids'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
# Move batch to device
batch = trainer._prepare_batch(batch)
# Compute loss
with torch.no_grad():
loss = trainer.compute_loss(batch)
print(f"✅ M3 Loss computed successfully: {loss.item():.4f}")
break
print("✅ BGE-M3 test passed!")
return True
except Exception as e:
print(f"❌ BGE-M3 test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
def test_bge_reranker():
"""Test BGE-Reranker model"""
print("\n🎯 Testing BGE-Reranker Model...")
# Create test data
train_file = create_reranker_test_data()
print(f"✅ Created Reranker test data: {train_file}")
# Initialize config
config = BGERerankerConfig(
model_name_or_path="./models/bge-reranker-base",
train_data_path=train_file,
output_dir="output/test-reranker",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=1e-5,
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
loss_type="binary_cross_entropy"
)
try:
# Load model and tokenizer
print("📦 Loading Reranker model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGERerankerModel(
model_name_or_path=config.model_name_or_path,
num_labels=1,
use_fp16=config.fp16,
gradient_checkpointing=config.gradient_checkpointing
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGERerankerDataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ Reranker Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=config.per_device_train_batch_size,
shuffle=True,
collate_fn=collate_reranker_batch
)
# Initialize trainer
trainer = BGERerankerTrainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test one batch
for batch in dataloader:
print(f"Reranker Batch keys: {list(batch.keys())}")
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
print(f"Labels values: {batch['labels']}")
# Move batch to device
batch = trainer._prepare_batch(batch)
# Compute loss
with torch.no_grad():
loss = trainer.compute_loss(batch)
print(f"✅ Reranker Loss computed successfully: {loss.item():.4f}")
break
print("✅ BGE-Reranker test passed!")
return True
except Exception as e:
print(f"❌ BGE-Reranker test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
def main():
"""Run all tests"""
print("🚀 Starting comprehensive BGE training pipeline tests...")
m3_success = test_bge_m3()
reranker_success = test_bge_reranker()
print("\n📊 Test Results Summary:")
print(f" BGE-M3: {'✅ PASS' if m3_success else '❌ FAIL'}")
print(f" BGE-Reranker: {'✅ PASS' if reranker_success else '❌ FAIL'}")
if m3_success and reranker_success:
print("\n🎉 All tests passed! Both models are working correctly.")
return 0
else:
print("\n⚠️ Some tests failed. Please check the error messages above.")
return 1
if __name__ == "__main__":
sys.exit(main())