init commit

This commit is contained in:
ldy
2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
import os
import sys
import tempfile
import pytest
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter
@pytest.fixture
def tmp_jsonl_file():
# Create a malformed JSONL file
fd, path = tempfile.mkstemp(suffix='.jsonl')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('{"query": "valid", "pos": ["a"], "neg": ["b"]}\n')
f.write('{invalid json}\n')
f.write('')
yield path
os.remove(path)
@pytest.fixture
def tmp_csv_file():
fd, path = tempfile.mkstemp(suffix='.csv')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('query,pos,neg\n')
f.write('valid,"a","b"\n')
f.write('malformed line\n')
yield path
os.remove(path)
@pytest.fixture
def tmp_tsv_file():
fd, path = tempfile.mkstemp(suffix='.tsv')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('query\tpos\tneg\n')
f.write('valid\ta\tb\n')
f.write('malformed line\n')
yield path
os.remove(path)
def test_jsonl_malformed_lines(tmp_jsonl_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_jsonl(tmp_jsonl_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_csv_malformed_lines(tmp_csv_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_csv(tmp_csv_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_tsv_malformed_lines(tmp_tsv_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_tsv(tmp_tsv_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_augmentation_methods_config():
# Should use config-driven augmentation methods
augmenter = QueryAugmenter(augmentation_methods=None, config_path='config.toml')
assert isinstance(augmenter.augmentation_methods, list)
assert 'paraphrase' in augmenter.augmentation_methods
def test_config_list_parsing():
"""Test that config loader parses comma/semicolon-separated lists correctly."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {'list_param': 'a, b; c'}}
value = loader.get('section', 'list_param', default=[])
assert isinstance(value, list)
assert set(value) == {'a', 'b', 'c'}
def test_config_param_source_logging(caplog):
"""Test that config loader logs parameter source correctly."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {'param': 42}}
with caplog.at_level('INFO'):
value = loader.get('section', 'param', default=0)
assert 'source: config.toml' in caplog.text
value = loader.get('section', 'missing', default=99)
assert 'source: default' in caplog.text
def test_config_fallback_default():
"""Test that config loader falls back to default if key is missing."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {}}
value = loader.get('section', 'not_found', default='fallback')
assert value == 'fallback'

121
tests/test_full_training.py Normal file
View File

@@ -0,0 +1,121 @@
#!/usr/bin/env python
"""Test script to validate full training pipeline with learning rate scheduling"""
import torch
import json
import tempfile
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.m3_config import BGEM3Config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset, collate_embedding_batch
from training.m3_trainer import BGEM3Trainer
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
def create_test_data():
"""Create minimal test data"""
data = [
{
"query": "What is machine learning?",
"pos": ["Machine learning is a subset of AI"],
"neg": ["The weather is sunny today"]
},
{
"query": "How does Python work?",
"pos": ["Python is a programming language"],
"neg": ["Cats are cute animals"]
},
{
"query": "What is deep learning?",
"pos": ["Deep learning uses neural networks"],
"neg": ["Pizza is delicious food"]
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def main():
print("🧪 Testing BGE-M3 training with learning rate scheduling...")
# Create test data
train_file = create_test_data()
print(f"✅ Created test data: {train_file}")
# Initialize config with small dataset settings
config = BGEM3Config(
model_name_or_path="./models/bge-m3",
train_data_path=train_file,
output_dir="output/test-lr-scheduler",
num_train_epochs=2,
per_device_train_batch_size=2,
gradient_accumulation_steps=1, # Small for testing
learning_rate=1e-4, # Higher for testing
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
use_self_distill=False,
use_hard_negatives=False,
temperature=0.1 # Fixed temperature
)
try:
# Load model and tokenizer
print("📦 Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGEM3Model.from_pretrained(config.model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGEM3Dataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ Dataset size: {len(dataset)}")
# Initialize trainer
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test training for a few steps
print("\n🏋️ Starting training test...")
trainer.train(dataset)
print("✅ Training completed successfully!")
return True
except Exception as e:
print(f"❌ Training failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

712
tests/test_installation.py Normal file
View File

@@ -0,0 +1,712 @@
"""
Comprehensive testing script for BGE fine-tuning project
Tests all components, configurations, and integrations
"""
import os
import sys
import traceback
import tempfile
import json
from pathlib import Path
import torch
import numpy as np
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class ComprehensiveTestSuite:
"""Comprehensive test suite for BGE fine-tuning project"""
def __init__(self):
self.test_results = []
self.failed_tests = []
def run_test(self, test_name, test_func, *args, **kwargs):
"""Run a single test and record results"""
print(f"\n{'='*60}")
print(f"Testing: {test_name}")
print('='*60)
try:
result = test_func(*args, **kwargs)
if result:
print(f"{test_name} PASSED")
self.test_results.append((test_name, "PASSED", None))
return True
else:
print(f"{test_name} FAILED")
self.test_results.append((test_name, "FAILED", "Test returned False"))
self.failed_tests.append(test_name)
return False
except Exception as e:
print(f"{test_name} FAILED with exception: {e}")
traceback.print_exc()
self.test_results.append((test_name, "FAILED", str(e)))
self.failed_tests.append(test_name)
return False
def test_imports(self):
"""Test all critical imports"""
print("Testing critical imports...")
# Core dependencies
try:
import torch
print(f"✓ PyTorch {torch.__version__}")
print(f" CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" CUDA version: {torch.version.cuda}")
print(f" GPU count: {torch.cuda.device_count()}")
except ImportError as e:
print(f"✗ PyTorch: {e}")
return False
# Test Ascend NPU support
try:
import torch_npu
if torch.npu.is_available(): # type: ignore[attr-defined]
print(f"✓ Ascend NPU available: {torch.npu.device_count()} devices") # type: ignore[attr-defined]
else:
print("! Ascend NPU library installed but no devices available")
except ImportError:
print("- Ascend NPU library not installed (optional)")
# Transformers and related
try:
import transformers
print(f"✓ Transformers {transformers.__version__}")
except ImportError as e:
print(f"✗ Transformers: {e}")
return False
# Scientific computing
try:
import numpy as np
import pandas as pd
import scipy
import sklearn
print("✓ Scientific computing libraries")
except ImportError as e:
print(f"✗ Scientific libraries: {e}")
return False
# BGE specific
try:
import faiss
print(f"✓ FAISS {faiss.__version__}")
except ImportError as e:
print(f"✗ FAISS: {e}")
return False
try:
from rank_bm25 import BM25Okapi
print("✓ BM25")
except ImportError as e:
print(f"✗ BM25: {e}")
return False
try:
import jieba
print("✓ Jieba (Chinese tokenization)")
except ImportError as e:
print(f"✗ Jieba: {e}")
return False
try:
import toml
print("✓ TOML")
except ImportError as e:
print(f"✗ TOML: {e}")
return False
return True
def test_project_imports(self):
"""Test project-specific imports"""
print("Testing project imports...")
# Configuration
try:
from utils.config_loader import get_config, load_config_for_training
from config.base_config import BaseConfig
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
print("✓ Configuration classes")
except ImportError as e:
print(f"✗ Configuration imports: {e}")
traceback.print_exc()
return False
# Models
try:
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
print("✓ Model classes")
except ImportError as e:
print(f"✗ Model imports: {e}")
traceback.print_exc()
return False
# Data processing
try:
from data.dataset import BGEM3Dataset, BGERerankerDataset
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
from data.hard_negative_mining import ANCEHardNegativeMiner
print("✓ Data processing classes")
except ImportError as e:
print(f"✗ Data imports: {e}")
traceback.print_exc()
return False
# Training
try:
from training.trainer import BaseTrainer
from training.m3_trainer import BGEM3Trainer
from training.reranker_trainer import BGERerankerTrainer
from training.joint_trainer import JointTrainer
print("✓ Training classes")
except ImportError as e:
print(f"✗ Training imports: {e}")
traceback.print_exc()
return False
# Evaluation
try:
from evaluation.evaluator import RetrievalEvaluator
from evaluation.metrics import compute_retrieval_metrics
print("✓ Evaluation classes")
except ImportError as e:
print(f"✗ Evaluation imports: {e}")
traceback.print_exc()
return False
# Utilities
try:
from utils.logging import setup_logging
from utils.checkpoint import CheckpointManager
from utils.distributed import setup_distributed, is_main_process
print("✓ Utility classes")
except ImportError as e:
print(f"✗ Utility imports: {e}")
traceback.print_exc()
return False
return True
def test_configuration_system(self):
"""Test the centralized configuration system"""
print("Testing configuration system...")
try:
from utils.config_loader import get_config, load_config_for_training
# Test singleton behavior
config1 = get_config()
config2 = get_config()
assert config1 is config2, "Config should be singleton"
print("✓ Singleton pattern working")
# Test configuration loading
test_value = config1.get('training.default_batch_size', None) # type: ignore[arg-type]
print(f"✓ Config loading: batch_size = {test_value}")
# Test section loading
training_section = config1.get_section('training')
assert isinstance(training_section, dict), "Should return dict"
print("✓ Section loading working")
# Test training configuration loading
training_config = load_config_for_training()
assert 'training' in training_config
assert 'model_paths' in training_config
print("✓ Training config loading working")
return True
except Exception as e:
print(f"✗ Configuration system error: {e}")
traceback.print_exc()
return False
def test_configuration_integration(self):
"""Test configuration integration in classes"""
print("Testing configuration integration...")
try:
from config.base_config import BaseConfig
from config.m3_config import BGEM3Config
# Test BaseConfig with config.toml integration
base_config = BaseConfig()
print(f"✓ BaseConfig created with batch_size: {base_config.per_device_train_batch_size}")
# Test M3Config
m3_config = BGEM3Config()
print(f"✓ M3Config created with learning_rate: {m3_config.learning_rate}")
# Test device detection
print(f"✓ Device detected: {base_config.device}")
return True
except Exception as e:
print(f"✗ Configuration integration error: {e}")
traceback.print_exc()
return False
def test_data_processing(self):
"""Test data processing functionality"""
print("Testing data processing...")
try:
from data.preprocessing import DataPreprocessor
# Create test data
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
test_data = [
{"query": "test query 1", "pos": ["positive passage 1"], "neg": ["negative passage 1"]},
{"query": "test query 2", "pos": ["positive passage 2"], "neg": ["negative passage 2"]}
]
for item in test_data:
f.write(json.dumps(item) + '\n')
temp_file = f.name
# Test preprocessing
preprocessor = DataPreprocessor()
data = preprocessor._load_jsonl(temp_file)
assert len(data) == 2
print("✓ JSONL loading working")
# Test validation
cleaned = preprocessor._validate_and_clean(data)
assert len(cleaned) == 2
print("✓ Data validation working")
# Cleanup
os.unlink(temp_file)
return True
except Exception as e:
print(f"✗ Data processing error: {e}")
traceback.print_exc()
return False
def test_data_augmentation(self):
"""Test data augmentation functionality"""
print("Testing data augmentation...")
try:
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
# Test QueryAugmenter
augmenter = QueryAugmenter(augmentation_methods=['paraphrase', 'synonym'])
test_data = [
{"query": "什么是机器学习?", "pos": ["机器学习是AI的分支"], "neg": ["深度学习不同"]}
]
augmented = augmenter.augment_queries(test_data, num_augmentations=1, keep_original=True)
assert len(augmented) >= len(test_data)
print("✓ Query augmentation working")
# Test PassageAugmenter
passage_aug = PassageAugmenter()
passage_augmented = passage_aug.augment_passages(test_data, augmentation_prob=1.0)
print("✓ Passage augmentation working")
# Test HardNegativeAugmenter
hn_aug = HardNegativeAugmenter()
hn_augmented = hn_aug.create_hard_negatives(test_data, num_hard_negatives=2)
print("✓ Hard negative augmentation working")
return True
except Exception as e:
print(f"✗ Data augmentation error: {e}")
traceback.print_exc()
return False
def test_model_functionality(self):
"""Test basic model functionality"""
print("Testing model functionality...")
try:
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
# Test loss functions
infonce = InfoNCELoss()
query_emb = torch.randn(2, 128)
pos_emb = torch.randn(2, 128)
neg_emb = torch.randn(2, 5, 128)
loss = infonce(query_emb, pos_emb, neg_emb)
assert loss.item() > 0
print("✓ InfoNCE loss working")
# Test listwise loss
listwise = ListwiseCrossEntropyLoss()
scores = torch.randn(2, 5)
labels = torch.zeros(2, 5)
labels[:, 0] = 1 # First is positive
loss = listwise(scores, labels, [5, 5])
assert loss.item() > 0
print("✓ Listwise loss working")
return True
except Exception as e:
print(f"✗ Model functionality error: {e}")
traceback.print_exc()
return False
def test_hard_negative_mining(self):
"""Test hard negative mining"""
print("Testing hard negative mining...")
try:
from data.hard_negative_mining import ANCEHardNegativeMiner
# Create miner
miner = ANCEHardNegativeMiner(
embedding_dim=128,
num_hard_negatives=5,
use_gpu=False # Use CPU for testing
)
# Add some dummy embeddings
embeddings = np.random.rand(10, 128).astype(np.float32)
passages = [f"passage_{i}" for i in range(10)]
miner._add_to_index(embeddings, passages)
print("✓ Index creation working")
# Test mining
queries = ["query1", "query2"]
query_embeddings = np.random.rand(2, 128).astype(np.float32)
positive_passages = [["passage_0"], ["passage_5"]]
hard_negs = miner.mine_hard_negatives(
queries, query_embeddings, positive_passages
)
assert len(hard_negs) == 2
print("✓ Hard negative mining working")
return True
except Exception as e:
print(f"✗ Hard negative mining error: {e}")
traceback.print_exc()
return False
def test_evaluation_metrics(self):
"""Test evaluation metrics"""
print("Testing evaluation metrics...")
try:
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
# Create dummy data
query_embs = np.random.rand(5, 128)
passage_embs = np.random.rand(20, 128)
labels = np.zeros((5, 20))
labels[0, 0] = 1 # First query, first passage is relevant
labels[1, 5] = 1 # Second query, sixth passage is relevant
metrics = compute_retrieval_metrics(query_embs, passage_embs, labels)
assert 'recall@1' in metrics
assert 'precision@1' in metrics
print("✓ Retrieval metrics working")
# Test reranker metrics
scores = np.random.rand(10)
labels = np.random.randint(0, 2, 10)
reranker_metrics = compute_reranker_metrics(scores, labels)
assert 'accuracy' in reranker_metrics or 'mrr' in reranker_metrics
print("✓ Reranker metrics working")
return True
except Exception as e:
print(f"✗ Evaluation metrics error: {e}")
traceback.print_exc()
return False
def test_data_formats(self):
"""Test multiple data format support"""
print("Testing data format support...")
try:
from data.preprocessing import DataPreprocessor
preprocessor = DataPreprocessor()
# Test JSONL format
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write('{"query": "test", "pos": ["pos1"], "neg": ["neg1"]}\n')
f.write('{"query": "test2", "pos": ["pos2"], "neg": ["neg2"]}\n')
temp_file = f.name
data = preprocessor._load_jsonl(temp_file)
assert len(data) == 2
os.unlink(temp_file)
print("✓ JSONL format working")
# Test CSV format
import pandas as pd
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
df = pd.DataFrame({
'query': ['test1', 'test2'],
'positive': ['pos1', 'pos2'],
'negative': ['neg1', 'neg2']
})
df.to_csv(f.name, index=False)
temp_file = f.name
data = preprocessor._load_csv(temp_file)
assert len(data) >= 1 # At least one valid record
os.unlink(temp_file)
print("✓ CSV format working")
return True
except Exception as e:
print(f"✗ Data format support error: {e}")
traceback.print_exc()
return False
def test_device_compatibility(self):
"""Test device compatibility (CUDA/Ascend/CPU)"""
print("Testing device compatibility...")
try:
from config.base_config import BaseConfig
# Test device detection
config = BaseConfig()
print(f"✓ Device detected: {config.device}")
print(f"✓ Number of devices: {config.n_gpu}")
# Test tensor operations
if config.device == "cuda":
x = torch.randn(10, 10).cuda()
y = torch.matmul(x, x.t())
assert y.device.type == "cuda"
print("✓ CUDA tensor operations working")
elif config.device == "npu":
print("✓ Ascend NPU detected")
# Note: Would need actual NPU to test operations
else:
x = torch.randn(10, 10)
y = torch.matmul(x, x.t())
print("✓ CPU tensor operations working")
return True
except Exception as e:
print(f"✗ Device compatibility error: {e}")
traceback.print_exc()
return False
def test_checkpoint_functionality(self):
"""Test checkpointing functionality"""
print("Testing checkpoint functionality...")
try:
from utils.checkpoint import CheckpointManager
with tempfile.TemporaryDirectory() as temp_dir:
manager = CheckpointManager(temp_dir)
# Test saving
state_dict = {
'model_state_dict': torch.randn(10, 10),
'optimizer_state_dict': {'lr': 0.001},
'step': 100
}
checkpoint_path = manager.save(state_dict, step=100)
assert os.path.exists(checkpoint_path)
print("✓ Checkpoint saving working")
# Test loading
loaded_state = manager.load_checkpoint(checkpoint_path)
assert 'model_state_dict' in loaded_state
print("✓ Checkpoint loading working")
return True
except Exception as e:
print(f"✗ Checkpoint functionality error: {e}")
traceback.print_exc()
return False
def test_logging_system(self):
"""Test logging system"""
print("Testing logging system...")
try:
from utils.logging import setup_logging, get_logger, MetricsLogger
with tempfile.TemporaryDirectory() as temp_dir:
# Test setup
setup_logging(temp_dir, log_to_file=True)
logger = get_logger("test")
logger.info("Test log message")
print("✓ Basic logging working")
# Test metrics logger
metrics_logger = MetricsLogger(temp_dir)
metrics_logger.log(step=1, metrics={'loss': 0.5}, phase='train')
print("✓ Metrics logging working")
return True
except Exception as e:
print(f"✗ Logging system error: {e}")
traceback.print_exc()
return False
def test_training_script_config_integration(self):
"""Test training script configuration integration"""
print("Testing training script configuration integration...")
try:
# Import the main training script functions
from scripts.train_m3 import create_config_from_args
# Create mock arguments
class MockArgs:
def __init__(self):
self.config_path = None
self.model_name_or_path = None
self.cache_dir = None
self.train_data = "dummy"
self.eval_data = None
self.max_seq_length = None
self.query_max_length = None
self.passage_max_length = None
self.max_length = 8192
self.output_dir = "/tmp/test"
self.num_train_epochs = None
self.per_device_train_batch_size = None
self.per_device_eval_batch_size = None
self.gradient_accumulation_steps = None
self.learning_rate = None
self.warmup_ratio = None
self.weight_decay = None
self.train_group_size = None
self.temperature = None
self.use_hard_negatives = False
self.num_hard_negatives = None
self.use_self_distill = False
self.use_query_instruction = False
self.query_instruction = None
self.use_dense = True
self.use_sparse = False
self.use_colbert = False
self.unified_finetuning = False
self.augment_queries = False
self.augmentation_methods = None
self.create_hard_negatives = False
self.fp16 = False
self.bf16 = False
self.gradient_checkpointing = False
self.dataloader_num_workers = None
self.save_steps = None
self.eval_steps = None
self.logging_steps = None
self.save_total_limit = None
self.resume_from_checkpoint = None
self.local_rank = -1
self.seed = None
self.overwrite_output_dir = True
args = MockArgs()
config = create_config_from_args(args)
assert config.learning_rate > 0
assert config.per_device_train_batch_size > 0
print("✓ Training script config integration working")
return True
except Exception as e:
print(f"✗ Training script config integration error: {e}")
traceback.print_exc()
return False
def run_all_tests(self):
"""Run all tests and generate report"""
print("\n" + "="*80)
print("BGE FINE-TUNING COMPREHENSIVE TEST SUITE")
print("="*80)
tests = [
("Core Dependencies Import", self.test_imports),
("Project Components Import", self.test_project_imports),
("Configuration System", self.test_configuration_system),
("Configuration Integration", self.test_configuration_integration),
("Data Processing", self.test_data_processing),
("Data Augmentation", self.test_data_augmentation),
("Model Functionality", self.test_model_functionality),
("Hard Negative Mining", self.test_hard_negative_mining),
("Evaluation Metrics", self.test_evaluation_metrics),
("Data Format Support", self.test_data_formats),
("Device Compatibility", self.test_device_compatibility),
("Checkpoint Functionality", self.test_checkpoint_functionality),
("Logging System", self.test_logging_system),
("Training Script Integration", self.test_training_script_config_integration),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
if self.run_test(test_name, test_func):
passed += 1
# Generate report
print("\n" + "="*80)
print("TEST RESULTS SUMMARY")
print("="*80)
print(f"Tests passed: {passed}/{total}")
print(f"Success rate: {passed/total*100:.1f}%")
if passed == total:
print("\n🎉 ALL TESTS PASSED! Project is ready for use.")
print("\nNext steps:")
print("1. Configure your API keys in config.toml if needed")
print("2. Prepare your training data in JSONL format")
print("3. Run training with: python scripts/train_m3.py --train_data your_data.jsonl")
print("4. Evaluate with: python scripts/evaluate.py --retriever_model output/model")
else:
print(f"\n{total - passed} tests failed. Please review the errors above.")
print("\nFailed tests:")
for test_name in self.failed_tests:
print(f" - {test_name}")
print("\nDetailed test results:")
for test_name, status, error in self.test_results:
status_emoji = "" if status == "PASSED" else ""
print(f" {status_emoji} {test_name}: {status}")
if error:
print(f" Error: {error}")
return passed == total
def main():
"""Run comprehensive test suite"""
test_suite = ComprehensiveTestSuite()
success = test_suite.run_all_tests()
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python
"""Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines"""
import torch
import json
import tempfile
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch
from training.m3_trainer import BGEM3Trainer
from training.reranker_trainer import BGERerankerTrainer
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
def create_embedding_test_data():
"""Create test data for embedding model (triplets format)"""
data = [
{
"query": "What is machine learning?",
"pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."],
"neg": ["The weather today is sunny and warm."]
},
{
"query": "How does deep learning work?",
"pos": ["Deep learning uses neural networks with multiple layers to learn patterns."],
"neg": ["Pizza is a popular Italian dish."]
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def create_reranker_test_data():
"""Create test data for reranker model (pairs format)"""
data = [
{
"query": "What is machine learning?",
"passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
"label": 1,
"score": 0.95,
"qid": "q1",
"pid": "p1"
},
{
"query": "What is machine learning?",
"passage": "The weather today is sunny and warm.",
"label": 0,
"score": 0.15,
"qid": "q1",
"pid": "p2"
},
{
"query": "How does deep learning work?",
"passage": "Deep learning uses neural networks with multiple layers to learn patterns.",
"label": 1,
"score": 0.92,
"qid": "q2",
"pid": "p3"
},
{
"query": "How does deep learning work?",
"passage": "Pizza is a popular Italian dish.",
"label": 0,
"score": 0.12,
"qid": "q2",
"pid": "p4"
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def test_bge_m3():
"""Test BGE-M3 embedding model"""
print("\n🎯 Testing BGE-M3 Embedding Model...")
# Create test data
train_file = create_embedding_test_data()
print(f"✅ Created M3 test data: {train_file}")
# Initialize config
config = BGEM3Config(
model_name_or_path="./models/bge-m3",
train_data_path=train_file,
output_dir="output/test-m3",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=1e-5,
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
use_self_distill=False, # Disable for testing
use_hard_negatives=False # Disable for testing
)
try:
# Load model and tokenizer
print("📦 Loading M3 model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGEM3Model.from_pretrained(config.model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGEM3Dataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ M3 Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=config.per_device_train_batch_size,
shuffle=True,
collate_fn=collate_embedding_batch
)
# Initialize trainer
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test one batch
for batch in dataloader:
print(f"M3 Batch keys: {list(batch.keys())}")
print(f"Query shape: {batch['query_input_ids'].shape}")
print(f"Passage shape: {batch['passage_input_ids'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
# Move batch to device
batch = trainer._prepare_batch(batch)
# Compute loss
with torch.no_grad():
loss = trainer.compute_loss(batch)
print(f"✅ M3 Loss computed successfully: {loss.item():.4f}")
break
print("✅ BGE-M3 test passed!")
return True
except Exception as e:
print(f"❌ BGE-M3 test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
def test_bge_reranker():
"""Test BGE-Reranker model"""
print("\n🎯 Testing BGE-Reranker Model...")
# Create test data
train_file = create_reranker_test_data()
print(f"✅ Created Reranker test data: {train_file}")
# Initialize config
config = BGERerankerConfig(
model_name_or_path="./models/bge-reranker-base",
train_data_path=train_file,
output_dir="output/test-reranker",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=1e-5,
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
loss_type="binary_cross_entropy"
)
try:
# Load model and tokenizer
print("📦 Loading Reranker model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGERerankerModel(
model_name_or_path=config.model_name_or_path,
num_labels=1,
use_fp16=config.fp16,
gradient_checkpointing=config.gradient_checkpointing
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGERerankerDataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ Reranker Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=config.per_device_train_batch_size,
shuffle=True,
collate_fn=collate_reranker_batch
)
# Initialize trainer
trainer = BGERerankerTrainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test one batch
for batch in dataloader:
print(f"Reranker Batch keys: {list(batch.keys())}")
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
print(f"Labels values: {batch['labels']}")
# Move batch to device
batch = trainer._prepare_batch(batch)
# Compute loss
with torch.no_grad():
loss = trainer.compute_loss(batch)
print(f"✅ Reranker Loss computed successfully: {loss.item():.4f}")
break
print("✅ BGE-Reranker test passed!")
return True
except Exception as e:
print(f"❌ BGE-Reranker test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
def main():
"""Run all tests"""
print("🚀 Starting comprehensive BGE training pipeline tests...")
m3_success = test_bge_m3()
reranker_success = test_bge_reranker()
print("\n📊 Test Results Summary:")
print(f" BGE-M3: {'✅ PASS' if m3_success else '❌ FAIL'}")
print(f" BGE-Reranker: {'✅ PASS' if reranker_success else '❌ FAIL'}")
if m3_success and reranker_success:
print("\n🎉 All tests passed! Both models are working correctly.")
return 0
else:
print("\n⚠️ Some tests failed. Please check the error messages above.")
return 1
if __name__ == "__main__":
sys.exit(main())