""" Comprehensive testing script for BGE fine-tuning project Tests all components, configurations, and integrations """ import os import sys import traceback import tempfile import json from pathlib import Path import torch import numpy as np # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class ComprehensiveTestSuite: """Comprehensive test suite for BGE fine-tuning project""" def __init__(self): self.test_results = [] self.failed_tests = [] def run_test(self, test_name, test_func, *args, **kwargs): """Run a single test and record results""" print(f"\n{'='*60}") print(f"Testing: {test_name}") print('='*60) try: result = test_func(*args, **kwargs) if result: print(f"✅ {test_name} PASSED") self.test_results.append((test_name, "PASSED", None)) return True else: print(f"❌ {test_name} FAILED") self.test_results.append((test_name, "FAILED", "Test returned False")) self.failed_tests.append(test_name) return False except Exception as e: print(f"❌ {test_name} FAILED with exception: {e}") traceback.print_exc() self.test_results.append((test_name, "FAILED", str(e))) self.failed_tests.append(test_name) return False def test_imports(self): """Test all critical imports""" print("Testing critical imports...") # Core dependencies try: import torch print(f"✓ PyTorch {torch.__version__}") print(f" CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f" CUDA version: {torch.version.cuda}") print(f" GPU count: {torch.cuda.device_count()}") except ImportError as e: print(f"✗ PyTorch: {e}") return False # Test Ascend NPU support try: import torch_npu if torch.npu.is_available(): # type: ignore[attr-defined] print(f"✓ Ascend NPU available: {torch.npu.device_count()} devices") # type: ignore[attr-defined] else: print("! Ascend NPU library installed but no devices available") except ImportError: print("- Ascend NPU library not installed (optional)") # Transformers and related try: import transformers print(f"✓ Transformers {transformers.__version__}") except ImportError as e: print(f"✗ Transformers: {e}") return False # Scientific computing try: import numpy as np import pandas as pd import scipy import sklearn print("✓ Scientific computing libraries") except ImportError as e: print(f"✗ Scientific libraries: {e}") return False # BGE specific try: import faiss print(f"✓ FAISS {faiss.__version__}") except ImportError as e: print(f"✗ FAISS: {e}") return False try: from rank_bm25 import BM25Okapi print("✓ BM25") except ImportError as e: print(f"✗ BM25: {e}") return False try: import jieba print("✓ Jieba (Chinese tokenization)") except ImportError as e: print(f"✗ Jieba: {e}") return False try: import toml print("✓ TOML") except ImportError as e: print(f"✗ TOML: {e}") return False return True def test_project_imports(self): """Test project-specific imports""" print("Testing project imports...") # Configuration try: from utils.config_loader import get_config, load_config_for_training from config.base_config import BaseConfig from config.m3_config import BGEM3Config from config.reranker_config import BGERerankerConfig print("✓ Configuration classes") except ImportError as e: print(f"✗ Configuration imports: {e}") traceback.print_exc() return False # Models try: from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss print("✓ Model classes") except ImportError as e: print(f"✗ Model imports: {e}") traceback.print_exc() return False # Data processing try: from data.dataset import BGEM3Dataset, BGERerankerDataset from data.preprocessing import DataPreprocessor from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter from data.hard_negative_mining import ANCEHardNegativeMiner print("✓ Data processing classes") except ImportError as e: print(f"✗ Data imports: {e}") traceback.print_exc() return False # Training try: from training.trainer import BaseTrainer from training.m3_trainer import BGEM3Trainer from training.reranker_trainer import BGERerankerTrainer from training.joint_trainer import JointTrainer print("✓ Training classes") except ImportError as e: print(f"✗ Training imports: {e}") traceback.print_exc() return False # Evaluation try: from evaluation.evaluator import RetrievalEvaluator from evaluation.metrics import compute_retrieval_metrics print("✓ Evaluation classes") except ImportError as e: print(f"✗ Evaluation imports: {e}") traceback.print_exc() return False # Utilities try: from utils.logging import setup_logging from utils.checkpoint import CheckpointManager from utils.distributed import setup_distributed, is_main_process print("✓ Utility classes") except ImportError as e: print(f"✗ Utility imports: {e}") traceback.print_exc() return False return True def test_configuration_system(self): """Test the centralized configuration system""" print("Testing configuration system...") try: from utils.config_loader import get_config, load_config_for_training # Test singleton behavior config1 = get_config() config2 = get_config() assert config1 is config2, "Config should be singleton" print("✓ Singleton pattern working") # Test configuration loading test_value = config1.get('training.default_batch_size', None) # type: ignore[arg-type] print(f"✓ Config loading: batch_size = {test_value}") # Test section loading training_section = config1.get_section('training') assert isinstance(training_section, dict), "Should return dict" print("✓ Section loading working") # Test training configuration loading training_config = load_config_for_training() assert 'training' in training_config assert 'model_paths' in training_config print("✓ Training config loading working") return True except Exception as e: print(f"✗ Configuration system error: {e}") traceback.print_exc() return False def test_configuration_integration(self): """Test configuration integration in classes""" print("Testing configuration integration...") try: from config.base_config import BaseConfig from config.m3_config import BGEM3Config # Test BaseConfig with config.toml integration base_config = BaseConfig() print(f"✓ BaseConfig created with batch_size: {base_config.per_device_train_batch_size}") # Test M3Config m3_config = BGEM3Config() print(f"✓ M3Config created with learning_rate: {m3_config.learning_rate}") # Test device detection print(f"✓ Device detected: {base_config.device}") return True except Exception as e: print(f"✗ Configuration integration error: {e}") traceback.print_exc() return False def test_data_processing(self): """Test data processing functionality""" print("Testing data processing...") try: from data.preprocessing import DataPreprocessor # Create test data with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: test_data = [ {"query": "test query 1", "pos": ["positive passage 1"], "neg": ["negative passage 1"]}, {"query": "test query 2", "pos": ["positive passage 2"], "neg": ["negative passage 2"]} ] for item in test_data: f.write(json.dumps(item) + '\n') temp_file = f.name # Test preprocessing preprocessor = DataPreprocessor() data = preprocessor._load_jsonl(temp_file) assert len(data) == 2 print("✓ JSONL loading working") # Test validation cleaned = preprocessor._validate_and_clean(data) assert len(cleaned) == 2 print("✓ Data validation working") # Cleanup os.unlink(temp_file) return True except Exception as e: print(f"✗ Data processing error: {e}") traceback.print_exc() return False def test_data_augmentation(self): """Test data augmentation functionality""" print("Testing data augmentation...") try: from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter # Test QueryAugmenter augmenter = QueryAugmenter(augmentation_methods=['paraphrase', 'synonym']) test_data = [ {"query": "什么是机器学习?", "pos": ["机器学习是AI的分支"], "neg": ["深度学习不同"]} ] augmented = augmenter.augment_queries(test_data, num_augmentations=1, keep_original=True) assert len(augmented) >= len(test_data) print("✓ Query augmentation working") # Test PassageAugmenter passage_aug = PassageAugmenter() passage_augmented = passage_aug.augment_passages(test_data, augmentation_prob=1.0) print("✓ Passage augmentation working") # Test HardNegativeAugmenter hn_aug = HardNegativeAugmenter() hn_augmented = hn_aug.create_hard_negatives(test_data, num_hard_negatives=2) print("✓ Hard negative augmentation working") return True except Exception as e: print(f"✗ Data augmentation error: {e}") traceback.print_exc() return False def test_model_functionality(self): """Test basic model functionality""" print("Testing model functionality...") try: from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss # Test loss functions infonce = InfoNCELoss() query_emb = torch.randn(2, 128) pos_emb = torch.randn(2, 128) neg_emb = torch.randn(2, 5, 128) loss = infonce(query_emb, pos_emb, neg_emb) assert loss.item() > 0 print("✓ InfoNCE loss working") # Test listwise loss listwise = ListwiseCrossEntropyLoss() scores = torch.randn(2, 5) labels = torch.zeros(2, 5) labels[:, 0] = 1 # First is positive loss = listwise(scores, labels, [5, 5]) assert loss.item() > 0 print("✓ Listwise loss working") return True except Exception as e: print(f"✗ Model functionality error: {e}") traceback.print_exc() return False def test_hard_negative_mining(self): """Test hard negative mining""" print("Testing hard negative mining...") try: from data.hard_negative_mining import ANCEHardNegativeMiner # Create miner miner = ANCEHardNegativeMiner( embedding_dim=128, num_hard_negatives=5, use_gpu=False # Use CPU for testing ) # Add some dummy embeddings embeddings = np.random.rand(10, 128).astype(np.float32) passages = [f"passage_{i}" for i in range(10)] miner._add_to_index(embeddings, passages) print("✓ Index creation working") # Test mining queries = ["query1", "query2"] query_embeddings = np.random.rand(2, 128).astype(np.float32) positive_passages = [["passage_0"], ["passage_5"]] hard_negs = miner.mine_hard_negatives( queries, query_embeddings, positive_passages ) assert len(hard_negs) == 2 print("✓ Hard negative mining working") return True except Exception as e: print(f"✗ Hard negative mining error: {e}") traceback.print_exc() return False def test_evaluation_metrics(self): """Test evaluation metrics""" print("Testing evaluation metrics...") try: from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics # Create dummy data query_embs = np.random.rand(5, 128) passage_embs = np.random.rand(20, 128) labels = np.zeros((5, 20)) labels[0, 0] = 1 # First query, first passage is relevant labels[1, 5] = 1 # Second query, sixth passage is relevant metrics = compute_retrieval_metrics(query_embs, passage_embs, labels) assert 'recall@1' in metrics assert 'precision@1' in metrics print("✓ Retrieval metrics working") # Test reranker metrics scores = np.random.rand(10) labels = np.random.randint(0, 2, 10) reranker_metrics = compute_reranker_metrics(scores, labels) assert 'accuracy' in reranker_metrics or 'mrr' in reranker_metrics print("✓ Reranker metrics working") return True except Exception as e: print(f"✗ Evaluation metrics error: {e}") traceback.print_exc() return False def test_data_formats(self): """Test multiple data format support""" print("Testing data format support...") try: from data.preprocessing import DataPreprocessor preprocessor = DataPreprocessor() # Test JSONL format with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: f.write('{"query": "test", "pos": ["pos1"], "neg": ["neg1"]}\n') f.write('{"query": "test2", "pos": ["pos2"], "neg": ["neg2"]}\n') temp_file = f.name data = preprocessor._load_jsonl(temp_file) assert len(data) == 2 os.unlink(temp_file) print("✓ JSONL format working") # Test CSV format import pandas as pd with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: df = pd.DataFrame({ 'query': ['test1', 'test2'], 'positive': ['pos1', 'pos2'], 'negative': ['neg1', 'neg2'] }) df.to_csv(f.name, index=False) temp_file = f.name data = preprocessor._load_csv(temp_file) assert len(data) >= 1 # At least one valid record os.unlink(temp_file) print("✓ CSV format working") return True except Exception as e: print(f"✗ Data format support error: {e}") traceback.print_exc() return False def test_device_compatibility(self): """Test device compatibility (CUDA/Ascend/CPU)""" print("Testing device compatibility...") try: from config.base_config import BaseConfig # Test device detection config = BaseConfig() print(f"✓ Device detected: {config.device}") print(f"✓ Number of devices: {config.n_gpu}") # Test tensor operations if config.device == "cuda": x = torch.randn(10, 10).cuda() y = torch.matmul(x, x.t()) assert y.device.type == "cuda" print("✓ CUDA tensor operations working") elif config.device == "npu": print("✓ Ascend NPU detected") # Note: Would need actual NPU to test operations else: x = torch.randn(10, 10) y = torch.matmul(x, x.t()) print("✓ CPU tensor operations working") return True except Exception as e: print(f"✗ Device compatibility error: {e}") traceback.print_exc() return False def test_checkpoint_functionality(self): """Test checkpointing functionality""" print("Testing checkpoint functionality...") try: from utils.checkpoint import CheckpointManager with tempfile.TemporaryDirectory() as temp_dir: manager = CheckpointManager(temp_dir) # Test saving state_dict = { 'model_state_dict': { 'layer1.weight': torch.randn(10, 5), 'layer1.bias': torch.randn(10), 'layer2.weight': torch.randn(3, 10), 'layer2.bias': torch.randn(3) }, 'optimizer_state_dict': {'lr': 0.001}, 'step': 100 } checkpoint_path = manager.save(state_dict, step=100) assert os.path.exists(checkpoint_path) print("✓ Checkpoint saving working") # Test loading loaded_state = manager.load_checkpoint(checkpoint_path) assert 'model_state_dict' in loaded_state print("✓ Checkpoint loading working") return True except Exception as e: print(f"✗ Checkpoint functionality error: {e}") traceback.print_exc() return False def test_logging_system(self): """Test logging system""" print("Testing logging system...") try: from utils.logging import setup_logging, get_logger, MetricsLogger with tempfile.TemporaryDirectory() as temp_dir: # Test setup setup_logging(temp_dir, log_to_file=True) logger = get_logger("test") logger.info("Test log message") print("✓ Basic logging working") # Test metrics logger metrics_logger = MetricsLogger(temp_dir) metrics_logger.log(step=1, metrics={'loss': 0.5}, phase='train') print("✓ Metrics logging working") return True except Exception as e: print(f"✗ Logging system error: {e}") traceback.print_exc() return False def test_training_script_config_integration(self): """Test training script configuration integration""" print("Testing training script configuration integration...") try: # Import the main training script functions from scripts.train_m3 import create_config_from_args # Create mock arguments class MockArgs: def __init__(self): self.config_path = None self.model_name_or_path = None self.cache_dir = None self.train_data = "dummy" self.eval_data = None self.max_seq_length = None self.query_max_length = None self.passage_max_length = None self.max_length = 8192 self.output_dir = "/tmp/test" self.num_train_epochs = None self.per_device_train_batch_size = None self.per_device_eval_batch_size = None self.gradient_accumulation_steps = None self.learning_rate = None self.warmup_ratio = None self.weight_decay = None self.train_group_size = None self.temperature = None self.use_hard_negatives = False self.num_hard_negatives = None self.use_self_distill = False self.use_query_instruction = False self.query_instruction = None self.use_dense = True self.use_sparse = False self.use_colbert = False self.unified_finetuning = False self.augment_queries = False self.augmentation_methods = None self.create_hard_negatives = False self.fp16 = False self.bf16 = False self.gradient_checkpointing = False self.dataloader_num_workers = None self.save_steps = None self.eval_steps = None self.logging_steps = None self.save_total_limit = None self.resume_from_checkpoint = None self.local_rank = -1 self.seed = None self.overwrite_output_dir = True args = MockArgs() config = create_config_from_args(args) assert config.learning_rate > 0 assert config.per_device_train_batch_size > 0 print("✓ Training script config integration working") return True except Exception as e: print(f"✗ Training script config integration error: {e}") traceback.print_exc() return False def run_all_tests(self): """Run all tests and generate report""" print("\n" + "="*80) print("BGE FINE-TUNING COMPREHENSIVE TEST SUITE") print("="*80) tests = [ ("Core Dependencies Import", self.test_imports), ("Project Components Import", self.test_project_imports), ("Configuration System", self.test_configuration_system), ("Configuration Integration", self.test_configuration_integration), ("Data Processing", self.test_data_processing), ("Data Augmentation", self.test_data_augmentation), ("Model Functionality", self.test_model_functionality), ("Hard Negative Mining", self.test_hard_negative_mining), ("Evaluation Metrics", self.test_evaluation_metrics), ("Data Format Support", self.test_data_formats), ("Device Compatibility", self.test_device_compatibility), ("Checkpoint Functionality", self.test_checkpoint_functionality), ("Logging System", self.test_logging_system), ("Training Script Integration", self.test_training_script_config_integration), ] passed = 0 total = len(tests) for test_name, test_func in tests: if self.run_test(test_name, test_func): passed += 1 # Generate report print("\n" + "="*80) print("TEST RESULTS SUMMARY") print("="*80) print(f"Tests passed: {passed}/{total}") print(f"Success rate: {passed/total*100:.1f}%") if passed == total: print("\n🎉 ALL TESTS PASSED! Project is ready for use.") print("\nNext steps:") print("1. Configure your API keys in config.toml if needed") print("2. Prepare your training data in JSONL format") print("3. Run training with: python scripts/train_m3.py --train_data your_data.jsonl") print("4. Evaluate with: python scripts/validate.py quick --retriever_model output/model") else: print(f"\n❌ {total - passed} tests failed. Please review the errors above.") print("\nFailed tests:") for test_name in self.failed_tests: print(f" - {test_name}") print("\nDetailed test results:") for test_name, status, error in self.test_results: status_emoji = "✅" if status == "PASSED" else "❌" print(f" {status_emoji} {test_name}: {status}") if error: print(f" Error: {error}") return passed == total def main(): """Run comprehensive test suite""" test_suite = ComprehensiveTestSuite() success = test_suite.run_all_tests() return 0 if success else 1 if __name__ == "__main__": sys.exit(main())