#!/usr/bin/env python """Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines""" import torch import json import tempfile import sys from pathlib import Path # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) from config.m3_config import BGEM3Config from config.reranker_config import BGERerankerConfig from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch from training.m3_trainer import BGEM3Trainer from training.reranker_trainer import BGERerankerTrainer from transformers import AutoTokenizer from torch.utils.data import DataLoader def create_embedding_test_data(): """Create test data for embedding model (triplets format)""" data = [ { "query": "What is machine learning?", "pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."], "neg": ["The weather today is sunny and warm."] }, { "query": "How does deep learning work?", "pos": ["Deep learning uses neural networks with multiple layers to learn patterns."], "neg": ["Pizza is a popular Italian dish."] } ] with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: for item in data: f.write(json.dumps(item) + '\n') return f.name def create_reranker_test_data(): """Create test data for reranker model (pairs format)""" data = [ { "query": "What is machine learning?", "passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.", "label": 1, "score": 0.95, "qid": "q1", "pid": "p1" }, { "query": "What is machine learning?", "passage": "The weather today is sunny and warm.", "label": 0, "score": 0.15, "qid": "q1", "pid": "p2" }, { "query": "How does deep learning work?", "passage": "Deep learning uses neural networks with multiple layers to learn patterns.", "label": 1, "score": 0.92, "qid": "q2", "pid": "p3" }, { "query": "How does deep learning work?", "passage": "Pizza is a popular Italian dish.", "label": 0, "score": 0.12, "qid": "q2", "pid": "p4" } ] with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: for item in data: f.write(json.dumps(item) + '\n') return f.name def test_bge_m3(): """Test BGE-M3 embedding model""" print("\nšŸŽÆ Testing BGE-M3 Embedding Model...") # Create test data train_file = create_embedding_test_data() print(f"āœ… Created M3 test data: {train_file}") # Initialize config config = BGEM3Config( model_name_or_path="./models/bge-m3", train_data_path=train_file, output_dir="output/test-m3", num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=1, learning_rate=1e-5, logging_steps=1, save_steps=10, fp16=False, use_amp=False, dataloader_num_workers=0, overwrite_output_dir=True, use_self_distill=False, # Disable for testing use_hard_negatives=False # Disable for testing ) try: # Load model and tokenizer print("šŸ“¦ Loading M3 model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path) model = BGEM3Model.from_pretrained(config.model_name_or_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(f"šŸ–„ļø Using device: {device}") # Create dataset dataset = BGEM3Dataset( data_path=train_file, tokenizer=tokenizer, query_max_length=config.query_max_length, passage_max_length=config.passage_max_length, is_train=True ) print(f"āœ… M3 Dataset size: {len(dataset)}") # Create dataloader dataloader = DataLoader( dataset, batch_size=config.per_device_train_batch_size, shuffle=True, collate_fn=collate_embedding_batch ) # Initialize trainer trainer = BGEM3Trainer( config=config, model=model, tokenizer=tokenizer, device=device ) # Test one batch for batch in dataloader: print(f"M3 Batch keys: {list(batch.keys())}") print(f"Query shape: {batch['query_input_ids'].shape}") print(f"Passage shape: {batch['passage_input_ids'].shape}") print(f"Labels shape: {batch['labels'].shape}") # Move batch to device batch = trainer._prepare_batch(batch) # Compute loss with torch.no_grad(): loss = trainer.compute_loss(batch) print(f"āœ… M3 Loss computed successfully: {loss.item():.4f}") break print("āœ… BGE-M3 test passed!") return True except Exception as e: print(f"āŒ BGE-M3 test failed: {e}") import traceback traceback.print_exc() return False finally: # Clean up Path(train_file).unlink() def test_bge_reranker(): """Test BGE-Reranker model""" print("\nšŸŽÆ Testing BGE-Reranker Model...") # Create test data train_file = create_reranker_test_data() print(f"āœ… Created Reranker test data: {train_file}") # Initialize config config = BGERerankerConfig( model_name_or_path="./models/bge-reranker-base", train_data_path=train_file, output_dir="output/test-reranker", num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=1, learning_rate=1e-5, logging_steps=1, save_steps=10, fp16=False, use_amp=False, dataloader_num_workers=0, overwrite_output_dir=True, loss_type="binary_cross_entropy" ) try: # Load model and tokenizer print("šŸ“¦ Loading Reranker model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path) model = BGERerankerModel( model_name_or_path=config.model_name_or_path, num_labels=1, use_fp16=config.fp16, gradient_checkpointing=config.gradient_checkpointing ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(f"šŸ–„ļø Using device: {device}") # Create dataset dataset = BGERerankerDataset( data_path=train_file, tokenizer=tokenizer, query_max_length=config.query_max_length, passage_max_length=config.passage_max_length, is_train=True ) print(f"āœ… Reranker Dataset size: {len(dataset)}") # Create dataloader dataloader = DataLoader( dataset, batch_size=config.per_device_train_batch_size, shuffle=True, collate_fn=collate_reranker_batch ) # Initialize trainer trainer = BGERerankerTrainer( config=config, model=model, tokenizer=tokenizer, device=device ) # Test one batch for batch in dataloader: print(f"Reranker Batch keys: {list(batch.keys())}") print(f"Input IDs shape: {batch['input_ids'].shape}") print(f"Attention mask shape: {batch['attention_mask'].shape}") print(f"Labels shape: {batch['labels'].shape}") print(f"Labels values: {batch['labels']}") # Move batch to device batch = trainer._prepare_batch(batch) # Compute loss with torch.no_grad(): loss = trainer.compute_loss(batch) print(f"āœ… Reranker Loss computed successfully: {loss.item():.4f}") break print("āœ… BGE-Reranker test passed!") return True except Exception as e: print(f"āŒ BGE-Reranker test failed: {e}") import traceback traceback.print_exc() return False finally: # Clean up Path(train_file).unlink() def main(): """Run all tests""" print("šŸš€ Starting comprehensive BGE training pipeline tests...") m3_success = test_bge_m3() reranker_success = test_bge_reranker() print("\nšŸ“Š Test Results Summary:") print(f" BGE-M3: {'āœ… PASS' if m3_success else 'āŒ FAIL'}") print(f" BGE-Reranker: {'āœ… PASS' if reranker_success else 'āŒ FAIL'}") if m3_success and reranker_success: print("\nšŸŽ‰ All tests passed! Both models are working correctly.") return 0 else: print("\nāš ļø Some tests failed. Please check the error messages above.") return 1 if __name__ == "__main__": sys.exit(main())