bge_finetune/tests/test_installation.py
2025-07-22 16:55:25 +08:00

713 lines
25 KiB
Python

"""
Comprehensive testing script for BGE fine-tuning project
Tests all components, configurations, and integrations
"""
import os
import sys
import traceback
import tempfile
import json
from pathlib import Path
import torch
import numpy as np
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class ComprehensiveTestSuite:
"""Comprehensive test suite for BGE fine-tuning project"""
def __init__(self):
self.test_results = []
self.failed_tests = []
def run_test(self, test_name, test_func, *args, **kwargs):
"""Run a single test and record results"""
print(f"\n{'='*60}")
print(f"Testing: {test_name}")
print('='*60)
try:
result = test_func(*args, **kwargs)
if result:
print(f"{test_name} PASSED")
self.test_results.append((test_name, "PASSED", None))
return True
else:
print(f"{test_name} FAILED")
self.test_results.append((test_name, "FAILED", "Test returned False"))
self.failed_tests.append(test_name)
return False
except Exception as e:
print(f"{test_name} FAILED with exception: {e}")
traceback.print_exc()
self.test_results.append((test_name, "FAILED", str(e)))
self.failed_tests.append(test_name)
return False
def test_imports(self):
"""Test all critical imports"""
print("Testing critical imports...")
# Core dependencies
try:
import torch
print(f"✓ PyTorch {torch.__version__}")
print(f" CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" CUDA version: {torch.version.cuda}")
print(f" GPU count: {torch.cuda.device_count()}")
except ImportError as e:
print(f"✗ PyTorch: {e}")
return False
# Test Ascend NPU support
try:
import torch_npu
if torch.npu.is_available(): # type: ignore[attr-defined]
print(f"✓ Ascend NPU available: {torch.npu.device_count()} devices") # type: ignore[attr-defined]
else:
print("! Ascend NPU library installed but no devices available")
except ImportError:
print("- Ascend NPU library not installed (optional)")
# Transformers and related
try:
import transformers
print(f"✓ Transformers {transformers.__version__}")
except ImportError as e:
print(f"✗ Transformers: {e}")
return False
# Scientific computing
try:
import numpy as np
import pandas as pd
import scipy
import sklearn
print("✓ Scientific computing libraries")
except ImportError as e:
print(f"✗ Scientific libraries: {e}")
return False
# BGE specific
try:
import faiss
print(f"✓ FAISS {faiss.__version__}")
except ImportError as e:
print(f"✗ FAISS: {e}")
return False
try:
from rank_bm25 import BM25Okapi
print("✓ BM25")
except ImportError as e:
print(f"✗ BM25: {e}")
return False
try:
import jieba
print("✓ Jieba (Chinese tokenization)")
except ImportError as e:
print(f"✗ Jieba: {e}")
return False
try:
import toml
print("✓ TOML")
except ImportError as e:
print(f"✗ TOML: {e}")
return False
return True
def test_project_imports(self):
"""Test project-specific imports"""
print("Testing project imports...")
# Configuration
try:
from utils.config_loader import get_config, load_config_for_training
from config.base_config import BaseConfig
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
print("✓ Configuration classes")
except ImportError as e:
print(f"✗ Configuration imports: {e}")
traceback.print_exc()
return False
# Models
try:
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
print("✓ Model classes")
except ImportError as e:
print(f"✗ Model imports: {e}")
traceback.print_exc()
return False
# Data processing
try:
from data.dataset import BGEM3Dataset, BGERerankerDataset
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
from data.hard_negative_mining import ANCEHardNegativeMiner
print("✓ Data processing classes")
except ImportError as e:
print(f"✗ Data imports: {e}")
traceback.print_exc()
return False
# Training
try:
from training.trainer import BaseTrainer
from training.m3_trainer import BGEM3Trainer
from training.reranker_trainer import BGERerankerTrainer
from training.joint_trainer import JointTrainer
print("✓ Training classes")
except ImportError as e:
print(f"✗ Training imports: {e}")
traceback.print_exc()
return False
# Evaluation
try:
from evaluation.evaluator import RetrievalEvaluator
from evaluation.metrics import compute_retrieval_metrics
print("✓ Evaluation classes")
except ImportError as e:
print(f"✗ Evaluation imports: {e}")
traceback.print_exc()
return False
# Utilities
try:
from utils.logging import setup_logging
from utils.checkpoint import CheckpointManager
from utils.distributed import setup_distributed, is_main_process
print("✓ Utility classes")
except ImportError as e:
print(f"✗ Utility imports: {e}")
traceback.print_exc()
return False
return True
def test_configuration_system(self):
"""Test the centralized configuration system"""
print("Testing configuration system...")
try:
from utils.config_loader import get_config, load_config_for_training
# Test singleton behavior
config1 = get_config()
config2 = get_config()
assert config1 is config2, "Config should be singleton"
print("✓ Singleton pattern working")
# Test configuration loading
test_value = config1.get('training.default_batch_size', None) # type: ignore[arg-type]
print(f"✓ Config loading: batch_size = {test_value}")
# Test section loading
training_section = config1.get_section('training')
assert isinstance(training_section, dict), "Should return dict"
print("✓ Section loading working")
# Test training configuration loading
training_config = load_config_for_training()
assert 'training' in training_config
assert 'model_paths' in training_config
print("✓ Training config loading working")
return True
except Exception as e:
print(f"✗ Configuration system error: {e}")
traceback.print_exc()
return False
def test_configuration_integration(self):
"""Test configuration integration in classes"""
print("Testing configuration integration...")
try:
from config.base_config import BaseConfig
from config.m3_config import BGEM3Config
# Test BaseConfig with config.toml integration
base_config = BaseConfig()
print(f"✓ BaseConfig created with batch_size: {base_config.per_device_train_batch_size}")
# Test M3Config
m3_config = BGEM3Config()
print(f"✓ M3Config created with learning_rate: {m3_config.learning_rate}")
# Test device detection
print(f"✓ Device detected: {base_config.device}")
return True
except Exception as e:
print(f"✗ Configuration integration error: {e}")
traceback.print_exc()
return False
def test_data_processing(self):
"""Test data processing functionality"""
print("Testing data processing...")
try:
from data.preprocessing import DataPreprocessor
# Create test data
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
test_data = [
{"query": "test query 1", "pos": ["positive passage 1"], "neg": ["negative passage 1"]},
{"query": "test query 2", "pos": ["positive passage 2"], "neg": ["negative passage 2"]}
]
for item in test_data:
f.write(json.dumps(item) + '\n')
temp_file = f.name
# Test preprocessing
preprocessor = DataPreprocessor()
data = preprocessor._load_jsonl(temp_file)
assert len(data) == 2
print("✓ JSONL loading working")
# Test validation
cleaned = preprocessor._validate_and_clean(data)
assert len(cleaned) == 2
print("✓ Data validation working")
# Cleanup
os.unlink(temp_file)
return True
except Exception as e:
print(f"✗ Data processing error: {e}")
traceback.print_exc()
return False
def test_data_augmentation(self):
"""Test data augmentation functionality"""
print("Testing data augmentation...")
try:
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
# Test QueryAugmenter
augmenter = QueryAugmenter(augmentation_methods=['paraphrase', 'synonym'])
test_data = [
{"query": "什么是机器学习?", "pos": ["机器学习是AI的分支"], "neg": ["深度学习不同"]}
]
augmented = augmenter.augment_queries(test_data, num_augmentations=1, keep_original=True)
assert len(augmented) >= len(test_data)
print("✓ Query augmentation working")
# Test PassageAugmenter
passage_aug = PassageAugmenter()
passage_augmented = passage_aug.augment_passages(test_data, augmentation_prob=1.0)
print("✓ Passage augmentation working")
# Test HardNegativeAugmenter
hn_aug = HardNegativeAugmenter()
hn_augmented = hn_aug.create_hard_negatives(test_data, num_hard_negatives=2)
print("✓ Hard negative augmentation working")
return True
except Exception as e:
print(f"✗ Data augmentation error: {e}")
traceback.print_exc()
return False
def test_model_functionality(self):
"""Test basic model functionality"""
print("Testing model functionality...")
try:
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
# Test loss functions
infonce = InfoNCELoss()
query_emb = torch.randn(2, 128)
pos_emb = torch.randn(2, 128)
neg_emb = torch.randn(2, 5, 128)
loss = infonce(query_emb, pos_emb, neg_emb)
assert loss.item() > 0
print("✓ InfoNCE loss working")
# Test listwise loss
listwise = ListwiseCrossEntropyLoss()
scores = torch.randn(2, 5)
labels = torch.zeros(2, 5)
labels[:, 0] = 1 # First is positive
loss = listwise(scores, labels, [5, 5])
assert loss.item() > 0
print("✓ Listwise loss working")
return True
except Exception as e:
print(f"✗ Model functionality error: {e}")
traceback.print_exc()
return False
def test_hard_negative_mining(self):
"""Test hard negative mining"""
print("Testing hard negative mining...")
try:
from data.hard_negative_mining import ANCEHardNegativeMiner
# Create miner
miner = ANCEHardNegativeMiner(
embedding_dim=128,
num_hard_negatives=5,
use_gpu=False # Use CPU for testing
)
# Add some dummy embeddings
embeddings = np.random.rand(10, 128).astype(np.float32)
passages = [f"passage_{i}" for i in range(10)]
miner._add_to_index(embeddings, passages)
print("✓ Index creation working")
# Test mining
queries = ["query1", "query2"]
query_embeddings = np.random.rand(2, 128).astype(np.float32)
positive_passages = [["passage_0"], ["passage_5"]]
hard_negs = miner.mine_hard_negatives(
queries, query_embeddings, positive_passages
)
assert len(hard_negs) == 2
print("✓ Hard negative mining working")
return True
except Exception as e:
print(f"✗ Hard negative mining error: {e}")
traceback.print_exc()
return False
def test_evaluation_metrics(self):
"""Test evaluation metrics"""
print("Testing evaluation metrics...")
try:
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
# Create dummy data
query_embs = np.random.rand(5, 128)
passage_embs = np.random.rand(20, 128)
labels = np.zeros((5, 20))
labels[0, 0] = 1 # First query, first passage is relevant
labels[1, 5] = 1 # Second query, sixth passage is relevant
metrics = compute_retrieval_metrics(query_embs, passage_embs, labels)
assert 'recall@1' in metrics
assert 'precision@1' in metrics
print("✓ Retrieval metrics working")
# Test reranker metrics
scores = np.random.rand(10)
labels = np.random.randint(0, 2, 10)
reranker_metrics = compute_reranker_metrics(scores, labels)
assert 'accuracy' in reranker_metrics or 'mrr' in reranker_metrics
print("✓ Reranker metrics working")
return True
except Exception as e:
print(f"✗ Evaluation metrics error: {e}")
traceback.print_exc()
return False
def test_data_formats(self):
"""Test multiple data format support"""
print("Testing data format support...")
try:
from data.preprocessing import DataPreprocessor
preprocessor = DataPreprocessor()
# Test JSONL format
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write('{"query": "test", "pos": ["pos1"], "neg": ["neg1"]}\n')
f.write('{"query": "test2", "pos": ["pos2"], "neg": ["neg2"]}\n')
temp_file = f.name
data = preprocessor._load_jsonl(temp_file)
assert len(data) == 2
os.unlink(temp_file)
print("✓ JSONL format working")
# Test CSV format
import pandas as pd
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
df = pd.DataFrame({
'query': ['test1', 'test2'],
'positive': ['pos1', 'pos2'],
'negative': ['neg1', 'neg2']
})
df.to_csv(f.name, index=False)
temp_file = f.name
data = preprocessor._load_csv(temp_file)
assert len(data) >= 1 # At least one valid record
os.unlink(temp_file)
print("✓ CSV format working")
return True
except Exception as e:
print(f"✗ Data format support error: {e}")
traceback.print_exc()
return False
def test_device_compatibility(self):
"""Test device compatibility (CUDA/Ascend/CPU)"""
print("Testing device compatibility...")
try:
from config.base_config import BaseConfig
# Test device detection
config = BaseConfig()
print(f"✓ Device detected: {config.device}")
print(f"✓ Number of devices: {config.n_gpu}")
# Test tensor operations
if config.device == "cuda":
x = torch.randn(10, 10).cuda()
y = torch.matmul(x, x.t())
assert y.device.type == "cuda"
print("✓ CUDA tensor operations working")
elif config.device == "npu":
print("✓ Ascend NPU detected")
# Note: Would need actual NPU to test operations
else:
x = torch.randn(10, 10)
y = torch.matmul(x, x.t())
print("✓ CPU tensor operations working")
return True
except Exception as e:
print(f"✗ Device compatibility error: {e}")
traceback.print_exc()
return False
def test_checkpoint_functionality(self):
"""Test checkpointing functionality"""
print("Testing checkpoint functionality...")
try:
from utils.checkpoint import CheckpointManager
with tempfile.TemporaryDirectory() as temp_dir:
manager = CheckpointManager(temp_dir)
# Test saving
state_dict = {
'model_state_dict': torch.randn(10, 10),
'optimizer_state_dict': {'lr': 0.001},
'step': 100
}
checkpoint_path = manager.save(state_dict, step=100)
assert os.path.exists(checkpoint_path)
print("✓ Checkpoint saving working")
# Test loading
loaded_state = manager.load_checkpoint(checkpoint_path)
assert 'model_state_dict' in loaded_state
print("✓ Checkpoint loading working")
return True
except Exception as e:
print(f"✗ Checkpoint functionality error: {e}")
traceback.print_exc()
return False
def test_logging_system(self):
"""Test logging system"""
print("Testing logging system...")
try:
from utils.logging import setup_logging, get_logger, MetricsLogger
with tempfile.TemporaryDirectory() as temp_dir:
# Test setup
setup_logging(temp_dir, log_to_file=True)
logger = get_logger("test")
logger.info("Test log message")
print("✓ Basic logging working")
# Test metrics logger
metrics_logger = MetricsLogger(temp_dir)
metrics_logger.log(step=1, metrics={'loss': 0.5}, phase='train')
print("✓ Metrics logging working")
return True
except Exception as e:
print(f"✗ Logging system error: {e}")
traceback.print_exc()
return False
def test_training_script_config_integration(self):
"""Test training script configuration integration"""
print("Testing training script configuration integration...")
try:
# Import the main training script functions
from scripts.train_m3 import create_config_from_args
# Create mock arguments
class MockArgs:
def __init__(self):
self.config_path = None
self.model_name_or_path = None
self.cache_dir = None
self.train_data = "dummy"
self.eval_data = None
self.max_seq_length = None
self.query_max_length = None
self.passage_max_length = None
self.max_length = 8192
self.output_dir = "/tmp/test"
self.num_train_epochs = None
self.per_device_train_batch_size = None
self.per_device_eval_batch_size = None
self.gradient_accumulation_steps = None
self.learning_rate = None
self.warmup_ratio = None
self.weight_decay = None
self.train_group_size = None
self.temperature = None
self.use_hard_negatives = False
self.num_hard_negatives = None
self.use_self_distill = False
self.use_query_instruction = False
self.query_instruction = None
self.use_dense = True
self.use_sparse = False
self.use_colbert = False
self.unified_finetuning = False
self.augment_queries = False
self.augmentation_methods = None
self.create_hard_negatives = False
self.fp16 = False
self.bf16 = False
self.gradient_checkpointing = False
self.dataloader_num_workers = None
self.save_steps = None
self.eval_steps = None
self.logging_steps = None
self.save_total_limit = None
self.resume_from_checkpoint = None
self.local_rank = -1
self.seed = None
self.overwrite_output_dir = True
args = MockArgs()
config = create_config_from_args(args)
assert config.learning_rate > 0
assert config.per_device_train_batch_size > 0
print("✓ Training script config integration working")
return True
except Exception as e:
print(f"✗ Training script config integration error: {e}")
traceback.print_exc()
return False
def run_all_tests(self):
"""Run all tests and generate report"""
print("\n" + "="*80)
print("BGE FINE-TUNING COMPREHENSIVE TEST SUITE")
print("="*80)
tests = [
("Core Dependencies Import", self.test_imports),
("Project Components Import", self.test_project_imports),
("Configuration System", self.test_configuration_system),
("Configuration Integration", self.test_configuration_integration),
("Data Processing", self.test_data_processing),
("Data Augmentation", self.test_data_augmentation),
("Model Functionality", self.test_model_functionality),
("Hard Negative Mining", self.test_hard_negative_mining),
("Evaluation Metrics", self.test_evaluation_metrics),
("Data Format Support", self.test_data_formats),
("Device Compatibility", self.test_device_compatibility),
("Checkpoint Functionality", self.test_checkpoint_functionality),
("Logging System", self.test_logging_system),
("Training Script Integration", self.test_training_script_config_integration),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
if self.run_test(test_name, test_func):
passed += 1
# Generate report
print("\n" + "="*80)
print("TEST RESULTS SUMMARY")
print("="*80)
print(f"Tests passed: {passed}/{total}")
print(f"Success rate: {passed/total*100:.1f}%")
if passed == total:
print("\n🎉 ALL TESTS PASSED! Project is ready for use.")
print("\nNext steps:")
print("1. Configure your API keys in config.toml if needed")
print("2. Prepare your training data in JSONL format")
print("3. Run training with: python scripts/train_m3.py --train_data your_data.jsonl")
print("4. Evaluate with: python scripts/evaluate.py --retriever_model output/model")
else:
print(f"\n{total - passed} tests failed. Please review the errors above.")
print("\nFailed tests:")
for test_name in self.failed_tests:
print(f" - {test_name}")
print("\nDetailed test results:")
for test_name, status, error in self.test_results:
status_emoji = "" if status == "PASSED" else ""
print(f" {status_emoji} {test_name}: {status}")
if error:
print(f" Error: {error}")
return passed == total
def main():
"""Run comprehensive test suite"""
test_suite = ComprehensiveTestSuite()
success = test_suite.run_all_tests()
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())