init commit
This commit is contained in:
99
tests/test_data_edge_cases.py
Normal file
99
tests/test_data_edge_cases.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from data.augmentation import QueryAugmenter
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_jsonl_file():
|
||||
# Create a malformed JSONL file
|
||||
fd, path = tempfile.mkstemp(suffix='.jsonl')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
f.write('{"query": "valid", "pos": ["a"], "neg": ["b"]}\n')
|
||||
f.write('{invalid json}\n')
|
||||
f.write('')
|
||||
yield path
|
||||
os.remove(path)
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_csv_file():
|
||||
fd, path = tempfile.mkstemp(suffix='.csv')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
f.write('query,pos,neg\n')
|
||||
f.write('valid,"a","b"\n')
|
||||
f.write('malformed line\n')
|
||||
yield path
|
||||
os.remove(path)
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_tsv_file():
|
||||
fd, path = tempfile.mkstemp(suffix='.tsv')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
f.write('query\tpos\tneg\n')
|
||||
f.write('valid\ta\tb\n')
|
||||
f.write('malformed line\n')
|
||||
yield path
|
||||
os.remove(path)
|
||||
|
||||
def test_jsonl_malformed_lines(tmp_jsonl_file):
|
||||
preprocessor = DataPreprocessor(tokenizer=None)
|
||||
# Should skip malformed line and not raise
|
||||
data = preprocessor._load_jsonl(tmp_jsonl_file)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]['query'] == 'valid'
|
||||
|
||||
def test_csv_malformed_lines(tmp_csv_file):
|
||||
preprocessor = DataPreprocessor(tokenizer=None)
|
||||
# Should skip malformed line and not raise
|
||||
data = preprocessor._load_csv(tmp_csv_file)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]['query'] == 'valid'
|
||||
|
||||
def test_tsv_malformed_lines(tmp_tsv_file):
|
||||
preprocessor = DataPreprocessor(tokenizer=None)
|
||||
# Should skip malformed line and not raise
|
||||
data = preprocessor._load_tsv(tmp_tsv_file)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]['query'] == 'valid'
|
||||
|
||||
def test_augmentation_methods_config():
|
||||
# Should use config-driven augmentation methods
|
||||
augmenter = QueryAugmenter(augmentation_methods=None, config_path='config.toml')
|
||||
assert isinstance(augmenter.augmentation_methods, list)
|
||||
assert 'paraphrase' in augmenter.augmentation_methods
|
||||
|
||||
def test_config_list_parsing():
|
||||
"""Test that config loader parses comma/semicolon-separated lists correctly."""
|
||||
from utils.config_loader import ConfigLoader
|
||||
loader = ConfigLoader()
|
||||
loader._config = {'section': {'list_param': 'a, b; c'}}
|
||||
value = loader.get('section', 'list_param', default=[])
|
||||
assert isinstance(value, list)
|
||||
assert set(value) == {'a', 'b', 'c'}
|
||||
|
||||
def test_config_param_source_logging(caplog):
|
||||
"""Test that config loader logs parameter source correctly."""
|
||||
from utils.config_loader import ConfigLoader
|
||||
loader = ConfigLoader()
|
||||
loader._config = {'section': {'param': 42}}
|
||||
with caplog.at_level('INFO'):
|
||||
value = loader.get('section', 'param', default=0)
|
||||
assert 'source: config.toml' in caplog.text
|
||||
value = loader.get('section', 'missing', default=99)
|
||||
assert 'source: default' in caplog.text
|
||||
|
||||
def test_config_fallback_default():
|
||||
"""Test that config loader falls back to default if key is missing."""
|
||||
from utils.config_loader import ConfigLoader
|
||||
loader = ConfigLoader()
|
||||
loader._config = {'section': {}}
|
||||
value = loader.get('section', 'not_found', default='fallback')
|
||||
assert value == 'fallback'
|
||||
121
tests/test_full_training.py
Normal file
121
tests/test_full_training.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script to validate full training pipeline with learning rate scheduling"""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import tempfile
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset, collate_embedding_batch
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def create_test_data():
|
||||
"""Create minimal test data"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": ["Machine learning is a subset of AI"],
|
||||
"neg": ["The weather is sunny today"]
|
||||
},
|
||||
{
|
||||
"query": "How does Python work?",
|
||||
"pos": ["Python is a programming language"],
|
||||
"neg": ["Cats are cute animals"]
|
||||
},
|
||||
{
|
||||
"query": "What is deep learning?",
|
||||
"pos": ["Deep learning uses neural networks"],
|
||||
"neg": ["Pizza is delicious food"]
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def main():
|
||||
print("🧪 Testing BGE-M3 training with learning rate scheduling...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_test_data()
|
||||
print(f"✅ Created test data: {train_file}")
|
||||
|
||||
# Initialize config with small dataset settings
|
||||
config = BGEM3Config(
|
||||
model_name_or_path="./models/bge-m3",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-lr-scheduler",
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1, # Small for testing
|
||||
learning_rate=1e-4, # Higher for testing
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
use_self_distill=False,
|
||||
use_hard_negatives=False,
|
||||
temperature=0.1 # Fixed temperature
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ Dataset size: {len(dataset)}")
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test training for a few steps
|
||||
print("\n🏋️ Starting training test...")
|
||||
trainer.train(dataset)
|
||||
|
||||
print("✅ Training completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Training failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
712
tests/test_installation.py
Normal file
712
tests/test_installation.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""
|
||||
Comprehensive testing script for BGE fine-tuning project
|
||||
Tests all components, configurations, and integrations
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import tempfile
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
class ComprehensiveTestSuite:
|
||||
"""Comprehensive test suite for BGE fine-tuning project"""
|
||||
|
||||
def __init__(self):
|
||||
self.test_results = []
|
||||
self.failed_tests = []
|
||||
|
||||
def run_test(self, test_name, test_func, *args, **kwargs):
|
||||
"""Run a single test and record results"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing: {test_name}")
|
||||
print('='*60)
|
||||
|
||||
try:
|
||||
result = test_func(*args, **kwargs)
|
||||
if result:
|
||||
print(f"✅ {test_name} PASSED")
|
||||
self.test_results.append((test_name, "PASSED", None))
|
||||
return True
|
||||
else:
|
||||
print(f"❌ {test_name} FAILED")
|
||||
self.test_results.append((test_name, "FAILED", "Test returned False"))
|
||||
self.failed_tests.append(test_name)
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ {test_name} FAILED with exception: {e}")
|
||||
traceback.print_exc()
|
||||
self.test_results.append((test_name, "FAILED", str(e)))
|
||||
self.failed_tests.append(test_name)
|
||||
return False
|
||||
|
||||
def test_imports(self):
|
||||
"""Test all critical imports"""
|
||||
print("Testing critical imports...")
|
||||
|
||||
# Core dependencies
|
||||
try:
|
||||
import torch
|
||||
print(f"✓ PyTorch {torch.__version__}")
|
||||
print(f" CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f" CUDA version: {torch.version.cuda}")
|
||||
print(f" GPU count: {torch.cuda.device_count()}")
|
||||
except ImportError as e:
|
||||
print(f"✗ PyTorch: {e}")
|
||||
return False
|
||||
|
||||
# Test Ascend NPU support
|
||||
try:
|
||||
import torch_npu
|
||||
if torch.npu.is_available(): # type: ignore[attr-defined]
|
||||
print(f"✓ Ascend NPU available: {torch.npu.device_count()} devices") # type: ignore[attr-defined]
|
||||
else:
|
||||
print("! Ascend NPU library installed but no devices available")
|
||||
except ImportError:
|
||||
print("- Ascend NPU library not installed (optional)")
|
||||
|
||||
# Transformers and related
|
||||
try:
|
||||
import transformers
|
||||
print(f"✓ Transformers {transformers.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"✗ Transformers: {e}")
|
||||
return False
|
||||
|
||||
# Scientific computing
|
||||
try:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy
|
||||
import sklearn
|
||||
print("✓ Scientific computing libraries")
|
||||
except ImportError as e:
|
||||
print(f"✗ Scientific libraries: {e}")
|
||||
return False
|
||||
|
||||
# BGE specific
|
||||
try:
|
||||
import faiss
|
||||
print(f"✓ FAISS {faiss.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"✗ FAISS: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
print("✓ BM25")
|
||||
except ImportError as e:
|
||||
print(f"✗ BM25: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
import jieba
|
||||
print("✓ Jieba (Chinese tokenization)")
|
||||
except ImportError as e:
|
||||
print(f"✗ Jieba: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
import toml
|
||||
print("✓ TOML")
|
||||
except ImportError as e:
|
||||
print(f"✗ TOML: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_project_imports(self):
|
||||
"""Test project-specific imports"""
|
||||
print("Testing project imports...")
|
||||
|
||||
# Configuration
|
||||
try:
|
||||
from utils.config_loader import get_config, load_config_for_training
|
||||
from config.base_config import BaseConfig
|
||||
from config.m3_config import BGEM3Config
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
print("✓ Configuration classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Configuration imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Models
|
||||
try:
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
|
||||
print("✓ Model classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Model imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Data processing
|
||||
try:
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
print("✓ Data processing classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Data imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Training
|
||||
try:
|
||||
from training.trainer import BaseTrainer
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from training.joint_trainer import JointTrainer
|
||||
print("✓ Training classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Training imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Evaluation
|
||||
try:
|
||||
from evaluation.evaluator import RetrievalEvaluator
|
||||
from evaluation.metrics import compute_retrieval_metrics
|
||||
print("✓ Evaluation classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Evaluation imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Utilities
|
||||
try:
|
||||
from utils.logging import setup_logging
|
||||
from utils.checkpoint import CheckpointManager
|
||||
from utils.distributed import setup_distributed, is_main_process
|
||||
print("✓ Utility classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Utility imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_configuration_system(self):
|
||||
"""Test the centralized configuration system"""
|
||||
print("Testing configuration system...")
|
||||
|
||||
try:
|
||||
from utils.config_loader import get_config, load_config_for_training
|
||||
|
||||
# Test singleton behavior
|
||||
config1 = get_config()
|
||||
config2 = get_config()
|
||||
assert config1 is config2, "Config should be singleton"
|
||||
print("✓ Singleton pattern working")
|
||||
|
||||
# Test configuration loading
|
||||
test_value = config1.get('training.default_batch_size', None) # type: ignore[arg-type]
|
||||
print(f"✓ Config loading: batch_size = {test_value}")
|
||||
|
||||
# Test section loading
|
||||
training_section = config1.get_section('training')
|
||||
assert isinstance(training_section, dict), "Should return dict"
|
||||
print("✓ Section loading working")
|
||||
|
||||
# Test training configuration loading
|
||||
training_config = load_config_for_training()
|
||||
assert 'training' in training_config
|
||||
assert 'model_paths' in training_config
|
||||
print("✓ Training config loading working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Configuration system error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_configuration_integration(self):
|
||||
"""Test configuration integration in classes"""
|
||||
print("Testing configuration integration...")
|
||||
|
||||
try:
|
||||
from config.base_config import BaseConfig
|
||||
from config.m3_config import BGEM3Config
|
||||
|
||||
# Test BaseConfig with config.toml integration
|
||||
base_config = BaseConfig()
|
||||
print(f"✓ BaseConfig created with batch_size: {base_config.per_device_train_batch_size}")
|
||||
|
||||
# Test M3Config
|
||||
m3_config = BGEM3Config()
|
||||
print(f"✓ M3Config created with learning_rate: {m3_config.learning_rate}")
|
||||
|
||||
# Test device detection
|
||||
print(f"✓ Device detected: {base_config.device}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Configuration integration error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_data_processing(self):
|
||||
"""Test data processing functionality"""
|
||||
print("Testing data processing...")
|
||||
|
||||
try:
|
||||
from data.preprocessing import DataPreprocessor
|
||||
|
||||
# Create test data
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
test_data = [
|
||||
{"query": "test query 1", "pos": ["positive passage 1"], "neg": ["negative passage 1"]},
|
||||
{"query": "test query 2", "pos": ["positive passage 2"], "neg": ["negative passage 2"]}
|
||||
]
|
||||
for item in test_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
temp_file = f.name
|
||||
|
||||
# Test preprocessing
|
||||
preprocessor = DataPreprocessor()
|
||||
data = preprocessor._load_jsonl(temp_file)
|
||||
assert len(data) == 2
|
||||
print("✓ JSONL loading working")
|
||||
|
||||
# Test validation
|
||||
cleaned = preprocessor._validate_and_clean(data)
|
||||
assert len(cleaned) == 2
|
||||
print("✓ Data validation working")
|
||||
|
||||
# Cleanup
|
||||
os.unlink(temp_file)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Data processing error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_data_augmentation(self):
|
||||
"""Test data augmentation functionality"""
|
||||
print("Testing data augmentation...")
|
||||
|
||||
try:
|
||||
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
|
||||
|
||||
# Test QueryAugmenter
|
||||
augmenter = QueryAugmenter(augmentation_methods=['paraphrase', 'synonym'])
|
||||
test_data = [
|
||||
{"query": "什么是机器学习?", "pos": ["机器学习是AI的分支"], "neg": ["深度学习不同"]}
|
||||
]
|
||||
|
||||
augmented = augmenter.augment_queries(test_data, num_augmentations=1, keep_original=True)
|
||||
assert len(augmented) >= len(test_data)
|
||||
print("✓ Query augmentation working")
|
||||
|
||||
# Test PassageAugmenter
|
||||
passage_aug = PassageAugmenter()
|
||||
passage_augmented = passage_aug.augment_passages(test_data, augmentation_prob=1.0)
|
||||
print("✓ Passage augmentation working")
|
||||
|
||||
# Test HardNegativeAugmenter
|
||||
hn_aug = HardNegativeAugmenter()
|
||||
hn_augmented = hn_aug.create_hard_negatives(test_data, num_hard_negatives=2)
|
||||
print("✓ Hard negative augmentation working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Data augmentation error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_model_functionality(self):
|
||||
"""Test basic model functionality"""
|
||||
print("Testing model functionality...")
|
||||
|
||||
try:
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
|
||||
|
||||
# Test loss functions
|
||||
infonce = InfoNCELoss()
|
||||
query_emb = torch.randn(2, 128)
|
||||
pos_emb = torch.randn(2, 128)
|
||||
neg_emb = torch.randn(2, 5, 128)
|
||||
|
||||
loss = infonce(query_emb, pos_emb, neg_emb)
|
||||
assert loss.item() > 0
|
||||
print("✓ InfoNCE loss working")
|
||||
|
||||
# Test listwise loss
|
||||
listwise = ListwiseCrossEntropyLoss()
|
||||
scores = torch.randn(2, 5)
|
||||
labels = torch.zeros(2, 5)
|
||||
labels[:, 0] = 1 # First is positive
|
||||
|
||||
loss = listwise(scores, labels, [5, 5])
|
||||
assert loss.item() > 0
|
||||
print("✓ Listwise loss working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Model functionality error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_hard_negative_mining(self):
|
||||
"""Test hard negative mining"""
|
||||
print("Testing hard negative mining...")
|
||||
|
||||
try:
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
|
||||
# Create miner
|
||||
miner = ANCEHardNegativeMiner(
|
||||
embedding_dim=128,
|
||||
num_hard_negatives=5,
|
||||
use_gpu=False # Use CPU for testing
|
||||
)
|
||||
|
||||
# Add some dummy embeddings
|
||||
embeddings = np.random.rand(10, 128).astype(np.float32)
|
||||
passages = [f"passage_{i}" for i in range(10)]
|
||||
|
||||
miner._add_to_index(embeddings, passages)
|
||||
print("✓ Index creation working")
|
||||
|
||||
# Test mining
|
||||
queries = ["query1", "query2"]
|
||||
query_embeddings = np.random.rand(2, 128).astype(np.float32)
|
||||
positive_passages = [["passage_0"], ["passage_5"]]
|
||||
|
||||
hard_negs = miner.mine_hard_negatives(
|
||||
queries, query_embeddings, positive_passages
|
||||
)
|
||||
|
||||
assert len(hard_negs) == 2
|
||||
print("✓ Hard negative mining working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Hard negative mining error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_evaluation_metrics(self):
|
||||
"""Test evaluation metrics"""
|
||||
print("Testing evaluation metrics...")
|
||||
|
||||
try:
|
||||
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
|
||||
|
||||
# Create dummy data
|
||||
query_embs = np.random.rand(5, 128)
|
||||
passage_embs = np.random.rand(20, 128)
|
||||
labels = np.zeros((5, 20))
|
||||
labels[0, 0] = 1 # First query, first passage is relevant
|
||||
labels[1, 5] = 1 # Second query, sixth passage is relevant
|
||||
|
||||
metrics = compute_retrieval_metrics(query_embs, passage_embs, labels)
|
||||
assert 'recall@1' in metrics
|
||||
assert 'precision@1' in metrics
|
||||
print("✓ Retrieval metrics working")
|
||||
|
||||
# Test reranker metrics
|
||||
scores = np.random.rand(10)
|
||||
labels = np.random.randint(0, 2, 10)
|
||||
reranker_metrics = compute_reranker_metrics(scores, labels)
|
||||
assert 'accuracy' in reranker_metrics or 'mrr' in reranker_metrics
|
||||
print("✓ Reranker metrics working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Evaluation metrics error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_data_formats(self):
|
||||
"""Test multiple data format support"""
|
||||
print("Testing data format support...")
|
||||
|
||||
try:
|
||||
from data.preprocessing import DataPreprocessor
|
||||
|
||||
preprocessor = DataPreprocessor()
|
||||
|
||||
# Test JSONL format
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write('{"query": "test", "pos": ["pos1"], "neg": ["neg1"]}\n')
|
||||
f.write('{"query": "test2", "pos": ["pos2"], "neg": ["neg2"]}\n')
|
||||
temp_file = f.name
|
||||
|
||||
data = preprocessor._load_jsonl(temp_file)
|
||||
assert len(data) == 2
|
||||
os.unlink(temp_file)
|
||||
print("✓ JSONL format working")
|
||||
|
||||
# Test CSV format
|
||||
import pandas as pd
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
|
||||
df = pd.DataFrame({
|
||||
'query': ['test1', 'test2'],
|
||||
'positive': ['pos1', 'pos2'],
|
||||
'negative': ['neg1', 'neg2']
|
||||
})
|
||||
df.to_csv(f.name, index=False)
|
||||
temp_file = f.name
|
||||
|
||||
data = preprocessor._load_csv(temp_file)
|
||||
assert len(data) >= 1 # At least one valid record
|
||||
os.unlink(temp_file)
|
||||
print("✓ CSV format working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Data format support error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_device_compatibility(self):
|
||||
"""Test device compatibility (CUDA/Ascend/CPU)"""
|
||||
print("Testing device compatibility...")
|
||||
|
||||
try:
|
||||
from config.base_config import BaseConfig
|
||||
|
||||
# Test device detection
|
||||
config = BaseConfig()
|
||||
print(f"✓ Device detected: {config.device}")
|
||||
print(f"✓ Number of devices: {config.n_gpu}")
|
||||
|
||||
# Test tensor operations
|
||||
if config.device == "cuda":
|
||||
x = torch.randn(10, 10).cuda()
|
||||
y = torch.matmul(x, x.t())
|
||||
assert y.device.type == "cuda"
|
||||
print("✓ CUDA tensor operations working")
|
||||
elif config.device == "npu":
|
||||
print("✓ Ascend NPU detected")
|
||||
# Note: Would need actual NPU to test operations
|
||||
else:
|
||||
x = torch.randn(10, 10)
|
||||
y = torch.matmul(x, x.t())
|
||||
print("✓ CPU tensor operations working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Device compatibility error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_checkpoint_functionality(self):
|
||||
"""Test checkpointing functionality"""
|
||||
print("Testing checkpoint functionality...")
|
||||
|
||||
try:
|
||||
from utils.checkpoint import CheckpointManager
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
manager = CheckpointManager(temp_dir)
|
||||
|
||||
# Test saving
|
||||
state_dict = {
|
||||
'model_state_dict': torch.randn(10, 10),
|
||||
'optimizer_state_dict': {'lr': 0.001},
|
||||
'step': 100
|
||||
}
|
||||
|
||||
checkpoint_path = manager.save(state_dict, step=100)
|
||||
assert os.path.exists(checkpoint_path)
|
||||
print("✓ Checkpoint saving working")
|
||||
|
||||
# Test loading
|
||||
loaded_state = manager.load_checkpoint(checkpoint_path)
|
||||
assert 'model_state_dict' in loaded_state
|
||||
print("✓ Checkpoint loading working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Checkpoint functionality error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_logging_system(self):
|
||||
"""Test logging system"""
|
||||
print("Testing logging system...")
|
||||
|
||||
try:
|
||||
from utils.logging import setup_logging, get_logger, MetricsLogger
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Test setup
|
||||
setup_logging(temp_dir, log_to_file=True)
|
||||
logger = get_logger("test")
|
||||
logger.info("Test log message")
|
||||
print("✓ Basic logging working")
|
||||
|
||||
# Test metrics logger
|
||||
metrics_logger = MetricsLogger(temp_dir)
|
||||
metrics_logger.log(step=1, metrics={'loss': 0.5}, phase='train')
|
||||
print("✓ Metrics logging working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Logging system error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_training_script_config_integration(self):
|
||||
"""Test training script configuration integration"""
|
||||
print("Testing training script configuration integration...")
|
||||
|
||||
try:
|
||||
# Import the main training script functions
|
||||
from scripts.train_m3 import create_config_from_args
|
||||
|
||||
# Create mock arguments
|
||||
class MockArgs:
|
||||
def __init__(self):
|
||||
self.config_path = None
|
||||
self.model_name_or_path = None
|
||||
self.cache_dir = None
|
||||
self.train_data = "dummy"
|
||||
self.eval_data = None
|
||||
self.max_seq_length = None
|
||||
self.query_max_length = None
|
||||
self.passage_max_length = None
|
||||
self.max_length = 8192
|
||||
self.output_dir = "/tmp/test"
|
||||
self.num_train_epochs = None
|
||||
self.per_device_train_batch_size = None
|
||||
self.per_device_eval_batch_size = None
|
||||
self.gradient_accumulation_steps = None
|
||||
self.learning_rate = None
|
||||
self.warmup_ratio = None
|
||||
self.weight_decay = None
|
||||
self.train_group_size = None
|
||||
self.temperature = None
|
||||
self.use_hard_negatives = False
|
||||
self.num_hard_negatives = None
|
||||
self.use_self_distill = False
|
||||
self.use_query_instruction = False
|
||||
self.query_instruction = None
|
||||
self.use_dense = True
|
||||
self.use_sparse = False
|
||||
self.use_colbert = False
|
||||
self.unified_finetuning = False
|
||||
self.augment_queries = False
|
||||
self.augmentation_methods = None
|
||||
self.create_hard_negatives = False
|
||||
self.fp16 = False
|
||||
self.bf16 = False
|
||||
self.gradient_checkpointing = False
|
||||
self.dataloader_num_workers = None
|
||||
self.save_steps = None
|
||||
self.eval_steps = None
|
||||
self.logging_steps = None
|
||||
self.save_total_limit = None
|
||||
self.resume_from_checkpoint = None
|
||||
self.local_rank = -1
|
||||
self.seed = None
|
||||
self.overwrite_output_dir = True
|
||||
|
||||
args = MockArgs()
|
||||
config = create_config_from_args(args)
|
||||
|
||||
assert config.learning_rate > 0
|
||||
assert config.per_device_train_batch_size > 0
|
||||
print("✓ Training script config integration working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Training script config integration error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all tests and generate report"""
|
||||
print("\n" + "="*80)
|
||||
print("BGE FINE-TUNING COMPREHENSIVE TEST SUITE")
|
||||
print("="*80)
|
||||
|
||||
tests = [
|
||||
("Core Dependencies Import", self.test_imports),
|
||||
("Project Components Import", self.test_project_imports),
|
||||
("Configuration System", self.test_configuration_system),
|
||||
("Configuration Integration", self.test_configuration_integration),
|
||||
("Data Processing", self.test_data_processing),
|
||||
("Data Augmentation", self.test_data_augmentation),
|
||||
("Model Functionality", self.test_model_functionality),
|
||||
("Hard Negative Mining", self.test_hard_negative_mining),
|
||||
("Evaluation Metrics", self.test_evaluation_metrics),
|
||||
("Data Format Support", self.test_data_formats),
|
||||
("Device Compatibility", self.test_device_compatibility),
|
||||
("Checkpoint Functionality", self.test_checkpoint_functionality),
|
||||
("Logging System", self.test_logging_system),
|
||||
("Training Script Integration", self.test_training_script_config_integration),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_name, test_func in tests:
|
||||
if self.run_test(test_name, test_func):
|
||||
passed += 1
|
||||
|
||||
# Generate report
|
||||
print("\n" + "="*80)
|
||||
print("TEST RESULTS SUMMARY")
|
||||
print("="*80)
|
||||
print(f"Tests passed: {passed}/{total}")
|
||||
print(f"Success rate: {passed/total*100:.1f}%")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 ALL TESTS PASSED! Project is ready for use.")
|
||||
print("\nNext steps:")
|
||||
print("1. Configure your API keys in config.toml if needed")
|
||||
print("2. Prepare your training data in JSONL format")
|
||||
print("3. Run training with: python scripts/train_m3.py --train_data your_data.jsonl")
|
||||
print("4. Evaluate with: python scripts/evaluate.py --retriever_model output/model")
|
||||
else:
|
||||
print(f"\n❌ {total - passed} tests failed. Please review the errors above.")
|
||||
print("\nFailed tests:")
|
||||
for test_name in self.failed_tests:
|
||||
print(f" - {test_name}")
|
||||
|
||||
print("\nDetailed test results:")
|
||||
for test_name, status, error in self.test_results:
|
||||
status_emoji = "✅" if status == "PASSED" else "❌"
|
||||
print(f" {status_emoji} {test_name}: {status}")
|
||||
if error:
|
||||
print(f" Error: {error}")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
def main():
|
||||
"""Run comprehensive test suite"""
|
||||
test_suite = ComprehensiveTestSuite()
|
||||
success = test_suite.run_all_tests()
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
299
tests/test_training_pipeline.py
Normal file
299
tests/test_training_pipeline.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python
|
||||
"""Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines"""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import tempfile
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def create_embedding_test_data():
|
||||
"""Create test data for embedding model (triplets format)"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."],
|
||||
"neg": ["The weather today is sunny and warm."]
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"pos": ["Deep learning uses neural networks with multiple layers to learn patterns."],
|
||||
"neg": ["Pizza is a popular Italian dish."]
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def create_reranker_test_data():
|
||||
"""Create test data for reranker model (pairs format)"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
},
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "The weather today is sunny and warm.",
|
||||
"label": 0,
|
||||
"score": 0.15,
|
||||
"qid": "q1",
|
||||
"pid": "p2"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Deep learning uses neural networks with multiple layers to learn patterns.",
|
||||
"label": 1,
|
||||
"score": 0.92,
|
||||
"qid": "q2",
|
||||
"pid": "p3"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Pizza is a popular Italian dish.",
|
||||
"label": 0,
|
||||
"score": 0.12,
|
||||
"qid": "q2",
|
||||
"pid": "p4"
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def test_bge_m3():
|
||||
"""Test BGE-M3 embedding model"""
|
||||
print("\n🎯 Testing BGE-M3 Embedding Model...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_embedding_test_data()
|
||||
print(f"✅ Created M3 test data: {train_file}")
|
||||
|
||||
# Initialize config
|
||||
config = BGEM3Config(
|
||||
model_name_or_path="./models/bge-m3",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-m3",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=1e-5,
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
use_self_distill=False, # Disable for testing
|
||||
use_hard_negatives=False # Disable for testing
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading M3 model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ M3 Dataset size: {len(dataset)}")
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_embedding_batch
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test one batch
|
||||
for batch in dataloader:
|
||||
print(f"M3 Batch keys: {list(batch.keys())}")
|
||||
print(f"Query shape: {batch['query_input_ids'].shape}")
|
||||
print(f"Passage shape: {batch['passage_input_ids'].shape}")
|
||||
print(f"Labels shape: {batch['labels'].shape}")
|
||||
|
||||
# Move batch to device
|
||||
batch = trainer._prepare_batch(batch)
|
||||
|
||||
# Compute loss
|
||||
with torch.no_grad():
|
||||
loss = trainer.compute_loss(batch)
|
||||
|
||||
print(f"✅ M3 Loss computed successfully: {loss.item():.4f}")
|
||||
break
|
||||
|
||||
print("✅ BGE-M3 test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BGE-M3 test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
def test_bge_reranker():
|
||||
"""Test BGE-Reranker model"""
|
||||
print("\n🎯 Testing BGE-Reranker Model...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_reranker_test_data()
|
||||
print(f"✅ Created Reranker test data: {train_file}")
|
||||
|
||||
# Initialize config
|
||||
config = BGERerankerConfig(
|
||||
model_name_or_path="./models/bge-reranker-base",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-reranker",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=1e-5,
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
loss_type="binary_cross_entropy"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading Reranker model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=config.model_name_or_path,
|
||||
num_labels=1,
|
||||
use_fp16=config.fp16,
|
||||
gradient_checkpointing=config.gradient_checkpointing
|
||||
)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGERerankerDataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ Reranker Dataset size: {len(dataset)}")
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_reranker_batch
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGERerankerTrainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test one batch
|
||||
for batch in dataloader:
|
||||
print(f"Reranker Batch keys: {list(batch.keys())}")
|
||||
print(f"Input IDs shape: {batch['input_ids'].shape}")
|
||||
print(f"Attention mask shape: {batch['attention_mask'].shape}")
|
||||
print(f"Labels shape: {batch['labels'].shape}")
|
||||
print(f"Labels values: {batch['labels']}")
|
||||
|
||||
# Move batch to device
|
||||
batch = trainer._prepare_batch(batch)
|
||||
|
||||
# Compute loss
|
||||
with torch.no_grad():
|
||||
loss = trainer.compute_loss(batch)
|
||||
|
||||
print(f"✅ Reranker Loss computed successfully: {loss.item():.4f}")
|
||||
break
|
||||
|
||||
print("✅ BGE-Reranker test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BGE-Reranker test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("🚀 Starting comprehensive BGE training pipeline tests...")
|
||||
|
||||
m3_success = test_bge_m3()
|
||||
reranker_success = test_bge_reranker()
|
||||
|
||||
print("\n📊 Test Results Summary:")
|
||||
print(f" BGE-M3: {'✅ PASS' if m3_success else '❌ FAIL'}")
|
||||
print(f" BGE-Reranker: {'✅ PASS' if reranker_success else '❌ FAIL'}")
|
||||
|
||||
if m3_success and reranker_success:
|
||||
print("\n🎉 All tests passed! Both models are working correctly.")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the error messages above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user