300 lines
9.4 KiB
Python
300 lines
9.4 KiB
Python
#!/usr/bin/env python
|
|
"""Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines"""
|
|
|
|
import torch
|
|
import json
|
|
import tempfile
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from config.m3_config import BGEM3Config
|
|
from config.reranker_config import BGERerankerConfig
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch
|
|
from training.m3_trainer import BGEM3Trainer
|
|
from training.reranker_trainer import BGERerankerTrainer
|
|
from transformers import AutoTokenizer
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def create_embedding_test_data():
|
|
"""Create test data for embedding model (triplets format)"""
|
|
data = [
|
|
{
|
|
"query": "What is machine learning?",
|
|
"pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."],
|
|
"neg": ["The weather today is sunny and warm."]
|
|
},
|
|
{
|
|
"query": "How does deep learning work?",
|
|
"pos": ["Deep learning uses neural networks with multiple layers to learn patterns."],
|
|
"neg": ["Pizza is a popular Italian dish."]
|
|
}
|
|
]
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
|
for item in data:
|
|
f.write(json.dumps(item) + '\n')
|
|
return f.name
|
|
|
|
|
|
def create_reranker_test_data():
|
|
"""Create test data for reranker model (pairs format)"""
|
|
data = [
|
|
{
|
|
"query": "What is machine learning?",
|
|
"passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
|
|
"label": 1,
|
|
"score": 0.95,
|
|
"qid": "q1",
|
|
"pid": "p1"
|
|
},
|
|
{
|
|
"query": "What is machine learning?",
|
|
"passage": "The weather today is sunny and warm.",
|
|
"label": 0,
|
|
"score": 0.15,
|
|
"qid": "q1",
|
|
"pid": "p2"
|
|
},
|
|
{
|
|
"query": "How does deep learning work?",
|
|
"passage": "Deep learning uses neural networks with multiple layers to learn patterns.",
|
|
"label": 1,
|
|
"score": 0.92,
|
|
"qid": "q2",
|
|
"pid": "p3"
|
|
},
|
|
{
|
|
"query": "How does deep learning work?",
|
|
"passage": "Pizza is a popular Italian dish.",
|
|
"label": 0,
|
|
"score": 0.12,
|
|
"qid": "q2",
|
|
"pid": "p4"
|
|
}
|
|
]
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
|
for item in data:
|
|
f.write(json.dumps(item) + '\n')
|
|
return f.name
|
|
|
|
|
|
def test_bge_m3():
|
|
"""Test BGE-M3 embedding model"""
|
|
print("\n🎯 Testing BGE-M3 Embedding Model...")
|
|
|
|
# Create test data
|
|
train_file = create_embedding_test_data()
|
|
print(f"✅ Created M3 test data: {train_file}")
|
|
|
|
# Initialize config
|
|
config = BGEM3Config(
|
|
model_name_or_path="./models/bge-m3",
|
|
train_data_path=train_file,
|
|
output_dir="output/test-m3",
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=2,
|
|
gradient_accumulation_steps=1,
|
|
learning_rate=1e-5,
|
|
logging_steps=1,
|
|
save_steps=10,
|
|
fp16=False,
|
|
use_amp=False,
|
|
dataloader_num_workers=0,
|
|
overwrite_output_dir=True,
|
|
use_self_distill=False, # Disable for testing
|
|
use_hard_negatives=False # Disable for testing
|
|
)
|
|
|
|
try:
|
|
# Load model and tokenizer
|
|
print("📦 Loading M3 model and tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
|
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
print(f"🖥️ Using device: {device}")
|
|
|
|
# Create dataset
|
|
dataset = BGEM3Dataset(
|
|
data_path=train_file,
|
|
tokenizer=tokenizer,
|
|
query_max_length=config.query_max_length,
|
|
passage_max_length=config.passage_max_length,
|
|
is_train=True
|
|
)
|
|
print(f"✅ M3 Dataset size: {len(dataset)}")
|
|
|
|
# Create dataloader
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=config.per_device_train_batch_size,
|
|
shuffle=True,
|
|
collate_fn=collate_embedding_batch
|
|
)
|
|
|
|
# Initialize trainer
|
|
trainer = BGEM3Trainer(
|
|
config=config,
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
device=device
|
|
)
|
|
|
|
# Test one batch
|
|
for batch in dataloader:
|
|
print(f"M3 Batch keys: {list(batch.keys())}")
|
|
print(f"Query shape: {batch['query_input_ids'].shape}")
|
|
print(f"Passage shape: {batch['passage_input_ids'].shape}")
|
|
print(f"Labels shape: {batch['labels'].shape}")
|
|
|
|
# Move batch to device
|
|
batch = trainer._prepare_batch(batch)
|
|
|
|
# Compute loss
|
|
with torch.no_grad():
|
|
loss = trainer.compute_loss(batch)
|
|
|
|
print(f"✅ M3 Loss computed successfully: {loss.item():.4f}")
|
|
break
|
|
|
|
print("✅ BGE-M3 test passed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ BGE-M3 test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
finally:
|
|
# Clean up
|
|
Path(train_file).unlink()
|
|
|
|
|
|
def test_bge_reranker():
|
|
"""Test BGE-Reranker model"""
|
|
print("\n🎯 Testing BGE-Reranker Model...")
|
|
|
|
# Create test data
|
|
train_file = create_reranker_test_data()
|
|
print(f"✅ Created Reranker test data: {train_file}")
|
|
|
|
# Initialize config
|
|
config = BGERerankerConfig(
|
|
model_name_or_path="./models/bge-reranker-base",
|
|
train_data_path=train_file,
|
|
output_dir="output/test-reranker",
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=2,
|
|
gradient_accumulation_steps=1,
|
|
learning_rate=1e-5,
|
|
logging_steps=1,
|
|
save_steps=10,
|
|
fp16=False,
|
|
use_amp=False,
|
|
dataloader_num_workers=0,
|
|
overwrite_output_dir=True,
|
|
loss_type="binary_cross_entropy"
|
|
)
|
|
|
|
try:
|
|
# Load model and tokenizer
|
|
print("📦 Loading Reranker model and tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
|
model = BGERerankerModel(
|
|
model_name_or_path=config.model_name_or_path,
|
|
num_labels=1,
|
|
use_fp16=config.fp16,
|
|
gradient_checkpointing=config.gradient_checkpointing
|
|
)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
print(f"🖥️ Using device: {device}")
|
|
|
|
# Create dataset
|
|
dataset = BGERerankerDataset(
|
|
data_path=train_file,
|
|
tokenizer=tokenizer,
|
|
query_max_length=config.query_max_length,
|
|
passage_max_length=config.passage_max_length,
|
|
is_train=True
|
|
)
|
|
print(f"✅ Reranker Dataset size: {len(dataset)}")
|
|
|
|
# Create dataloader
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=config.per_device_train_batch_size,
|
|
shuffle=True,
|
|
collate_fn=collate_reranker_batch
|
|
)
|
|
|
|
# Initialize trainer
|
|
trainer = BGERerankerTrainer(
|
|
config=config,
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
device=device
|
|
)
|
|
|
|
# Test one batch
|
|
for batch in dataloader:
|
|
print(f"Reranker Batch keys: {list(batch.keys())}")
|
|
print(f"Input IDs shape: {batch['input_ids'].shape}")
|
|
print(f"Attention mask shape: {batch['attention_mask'].shape}")
|
|
print(f"Labels shape: {batch['labels'].shape}")
|
|
print(f"Labels values: {batch['labels']}")
|
|
|
|
# Move batch to device
|
|
batch = trainer._prepare_batch(batch)
|
|
|
|
# Compute loss
|
|
with torch.no_grad():
|
|
loss = trainer.compute_loss(batch)
|
|
|
|
print(f"✅ Reranker Loss computed successfully: {loss.item():.4f}")
|
|
break
|
|
|
|
print("✅ BGE-Reranker test passed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ BGE-Reranker test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
finally:
|
|
# Clean up
|
|
Path(train_file).unlink()
|
|
|
|
|
|
def main():
|
|
"""Run all tests"""
|
|
print("🚀 Starting comprehensive BGE training pipeline tests...")
|
|
|
|
m3_success = test_bge_m3()
|
|
reranker_success = test_bge_reranker()
|
|
|
|
print("\n📊 Test Results Summary:")
|
|
print(f" BGE-M3: {'✅ PASS' if m3_success else '❌ FAIL'}")
|
|
print(f" BGE-Reranker: {'✅ PASS' if reranker_success else '❌ FAIL'}")
|
|
|
|
if m3_success and reranker_success:
|
|
print("\n🎉 All tests passed! Both models are working correctly.")
|
|
return 0
|
|
else:
|
|
print("\n⚠️ Some tests failed. Please check the error messages above.")
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|