713 lines
25 KiB
Python
713 lines
25 KiB
Python
"""
|
|
Comprehensive testing script for BGE fine-tuning project
|
|
Tests all components, configurations, and integrations
|
|
"""
|
|
import os
|
|
import sys
|
|
import traceback
|
|
import tempfile
|
|
import json
|
|
from pathlib import Path
|
|
import torch
|
|
import numpy as np
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
class ComprehensiveTestSuite:
|
|
"""Comprehensive test suite for BGE fine-tuning project"""
|
|
|
|
def __init__(self):
|
|
self.test_results = []
|
|
self.failed_tests = []
|
|
|
|
def run_test(self, test_name, test_func, *args, **kwargs):
|
|
"""Run a single test and record results"""
|
|
print(f"\n{'='*60}")
|
|
print(f"Testing: {test_name}")
|
|
print('='*60)
|
|
|
|
try:
|
|
result = test_func(*args, **kwargs)
|
|
if result:
|
|
print(f"✅ {test_name} PASSED")
|
|
self.test_results.append((test_name, "PASSED", None))
|
|
return True
|
|
else:
|
|
print(f"❌ {test_name} FAILED")
|
|
self.test_results.append((test_name, "FAILED", "Test returned False"))
|
|
self.failed_tests.append(test_name)
|
|
return False
|
|
except Exception as e:
|
|
print(f"❌ {test_name} FAILED with exception: {e}")
|
|
traceback.print_exc()
|
|
self.test_results.append((test_name, "FAILED", str(e)))
|
|
self.failed_tests.append(test_name)
|
|
return False
|
|
|
|
def test_imports(self):
|
|
"""Test all critical imports"""
|
|
print("Testing critical imports...")
|
|
|
|
# Core dependencies
|
|
try:
|
|
import torch
|
|
print(f"✓ PyTorch {torch.__version__}")
|
|
print(f" CUDA available: {torch.cuda.is_available()}")
|
|
if torch.cuda.is_available():
|
|
print(f" CUDA version: {torch.version.cuda}")
|
|
print(f" GPU count: {torch.cuda.device_count()}")
|
|
except ImportError as e:
|
|
print(f"✗ PyTorch: {e}")
|
|
return False
|
|
|
|
# Test Ascend NPU support
|
|
try:
|
|
import torch_npu
|
|
if torch.npu.is_available(): # type: ignore[attr-defined]
|
|
print(f"✓ Ascend NPU available: {torch.npu.device_count()} devices") # type: ignore[attr-defined]
|
|
else:
|
|
print("! Ascend NPU library installed but no devices available")
|
|
except ImportError:
|
|
print("- Ascend NPU library not installed (optional)")
|
|
|
|
# Transformers and related
|
|
try:
|
|
import transformers
|
|
print(f"✓ Transformers {transformers.__version__}")
|
|
except ImportError as e:
|
|
print(f"✗ Transformers: {e}")
|
|
return False
|
|
|
|
# Scientific computing
|
|
try:
|
|
import numpy as np
|
|
import pandas as pd
|
|
import scipy
|
|
import sklearn
|
|
print("✓ Scientific computing libraries")
|
|
except ImportError as e:
|
|
print(f"✗ Scientific libraries: {e}")
|
|
return False
|
|
|
|
# BGE specific
|
|
try:
|
|
import faiss
|
|
print(f"✓ FAISS {faiss.__version__}")
|
|
except ImportError as e:
|
|
print(f"✗ FAISS: {e}")
|
|
return False
|
|
|
|
try:
|
|
from rank_bm25 import BM25Okapi
|
|
print("✓ BM25")
|
|
except ImportError as e:
|
|
print(f"✗ BM25: {e}")
|
|
return False
|
|
|
|
try:
|
|
import jieba
|
|
print("✓ Jieba (Chinese tokenization)")
|
|
except ImportError as e:
|
|
print(f"✗ Jieba: {e}")
|
|
return False
|
|
|
|
try:
|
|
import toml
|
|
print("✓ TOML")
|
|
except ImportError as e:
|
|
print(f"✗ TOML: {e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def test_project_imports(self):
|
|
"""Test project-specific imports"""
|
|
print("Testing project imports...")
|
|
|
|
# Configuration
|
|
try:
|
|
from utils.config_loader import get_config, load_config_for_training
|
|
from config.base_config import BaseConfig
|
|
from config.m3_config import BGEM3Config
|
|
from config.reranker_config import BGERerankerConfig
|
|
print("✓ Configuration classes")
|
|
except ImportError as e:
|
|
print(f"✗ Configuration imports: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Models
|
|
try:
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
|
|
print("✓ Model classes")
|
|
except ImportError as e:
|
|
print(f"✗ Model imports: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Data processing
|
|
try:
|
|
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
|
from data.preprocessing import DataPreprocessor
|
|
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
|
|
from data.hard_negative_mining import ANCEHardNegativeMiner
|
|
print("✓ Data processing classes")
|
|
except ImportError as e:
|
|
print(f"✗ Data imports: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Training
|
|
try:
|
|
from training.trainer import BaseTrainer
|
|
from training.m3_trainer import BGEM3Trainer
|
|
from training.reranker_trainer import BGERerankerTrainer
|
|
from training.joint_trainer import JointTrainer
|
|
print("✓ Training classes")
|
|
except ImportError as e:
|
|
print(f"✗ Training imports: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Evaluation
|
|
try:
|
|
from evaluation.evaluator import RetrievalEvaluator
|
|
from evaluation.metrics import compute_retrieval_metrics
|
|
print("✓ Evaluation classes")
|
|
except ImportError as e:
|
|
print(f"✗ Evaluation imports: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Utilities
|
|
try:
|
|
from utils.logging import setup_logging
|
|
from utils.checkpoint import CheckpointManager
|
|
from utils.distributed import setup_distributed, is_main_process
|
|
print("✓ Utility classes")
|
|
except ImportError as e:
|
|
print(f"✗ Utility imports: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
return True
|
|
|
|
def test_configuration_system(self):
|
|
"""Test the centralized configuration system"""
|
|
print("Testing configuration system...")
|
|
|
|
try:
|
|
from utils.config_loader import get_config, load_config_for_training
|
|
|
|
# Test singleton behavior
|
|
config1 = get_config()
|
|
config2 = get_config()
|
|
assert config1 is config2, "Config should be singleton"
|
|
print("✓ Singleton pattern working")
|
|
|
|
# Test configuration loading
|
|
test_value = config1.get('training.default_batch_size', None) # type: ignore[arg-type]
|
|
print(f"✓ Config loading: batch_size = {test_value}")
|
|
|
|
# Test section loading
|
|
training_section = config1.get_section('training')
|
|
assert isinstance(training_section, dict), "Should return dict"
|
|
print("✓ Section loading working")
|
|
|
|
# Test training configuration loading
|
|
training_config = load_config_for_training()
|
|
assert 'training' in training_config
|
|
assert 'model_paths' in training_config
|
|
print("✓ Training config loading working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Configuration system error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_configuration_integration(self):
|
|
"""Test configuration integration in classes"""
|
|
print("Testing configuration integration...")
|
|
|
|
try:
|
|
from config.base_config import BaseConfig
|
|
from config.m3_config import BGEM3Config
|
|
|
|
# Test BaseConfig with config.toml integration
|
|
base_config = BaseConfig()
|
|
print(f"✓ BaseConfig created with batch_size: {base_config.per_device_train_batch_size}")
|
|
|
|
# Test M3Config
|
|
m3_config = BGEM3Config()
|
|
print(f"✓ M3Config created with learning_rate: {m3_config.learning_rate}")
|
|
|
|
# Test device detection
|
|
print(f"✓ Device detected: {base_config.device}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Configuration integration error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_data_processing(self):
|
|
"""Test data processing functionality"""
|
|
print("Testing data processing...")
|
|
|
|
try:
|
|
from data.preprocessing import DataPreprocessor
|
|
|
|
# Create test data
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
|
test_data = [
|
|
{"query": "test query 1", "pos": ["positive passage 1"], "neg": ["negative passage 1"]},
|
|
{"query": "test query 2", "pos": ["positive passage 2"], "neg": ["negative passage 2"]}
|
|
]
|
|
for item in test_data:
|
|
f.write(json.dumps(item) + '\n')
|
|
temp_file = f.name
|
|
|
|
# Test preprocessing
|
|
preprocessor = DataPreprocessor()
|
|
data = preprocessor._load_jsonl(temp_file)
|
|
assert len(data) == 2
|
|
print("✓ JSONL loading working")
|
|
|
|
# Test validation
|
|
cleaned = preprocessor._validate_and_clean(data)
|
|
assert len(cleaned) == 2
|
|
print("✓ Data validation working")
|
|
|
|
# Cleanup
|
|
os.unlink(temp_file)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Data processing error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_data_augmentation(self):
|
|
"""Test data augmentation functionality"""
|
|
print("Testing data augmentation...")
|
|
|
|
try:
|
|
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
|
|
|
|
# Test QueryAugmenter
|
|
augmenter = QueryAugmenter(augmentation_methods=['paraphrase', 'synonym'])
|
|
test_data = [
|
|
{"query": "什么是机器学习?", "pos": ["机器学习是AI的分支"], "neg": ["深度学习不同"]}
|
|
]
|
|
|
|
augmented = augmenter.augment_queries(test_data, num_augmentations=1, keep_original=True)
|
|
assert len(augmented) >= len(test_data)
|
|
print("✓ Query augmentation working")
|
|
|
|
# Test PassageAugmenter
|
|
passage_aug = PassageAugmenter()
|
|
passage_augmented = passage_aug.augment_passages(test_data, augmentation_prob=1.0)
|
|
print("✓ Passage augmentation working")
|
|
|
|
# Test HardNegativeAugmenter
|
|
hn_aug = HardNegativeAugmenter()
|
|
hn_augmented = hn_aug.create_hard_negatives(test_data, num_hard_negatives=2)
|
|
print("✓ Hard negative augmentation working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Data augmentation error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_model_functionality(self):
|
|
"""Test basic model functionality"""
|
|
print("Testing model functionality...")
|
|
|
|
try:
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
|
|
|
|
# Test loss functions
|
|
infonce = InfoNCELoss()
|
|
query_emb = torch.randn(2, 128)
|
|
pos_emb = torch.randn(2, 128)
|
|
neg_emb = torch.randn(2, 5, 128)
|
|
|
|
loss = infonce(query_emb, pos_emb, neg_emb)
|
|
assert loss.item() > 0
|
|
print("✓ InfoNCE loss working")
|
|
|
|
# Test listwise loss
|
|
listwise = ListwiseCrossEntropyLoss()
|
|
scores = torch.randn(2, 5)
|
|
labels = torch.zeros(2, 5)
|
|
labels[:, 0] = 1 # First is positive
|
|
|
|
loss = listwise(scores, labels, [5, 5])
|
|
assert loss.item() > 0
|
|
print("✓ Listwise loss working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Model functionality error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_hard_negative_mining(self):
|
|
"""Test hard negative mining"""
|
|
print("Testing hard negative mining...")
|
|
|
|
try:
|
|
from data.hard_negative_mining import ANCEHardNegativeMiner
|
|
|
|
# Create miner
|
|
miner = ANCEHardNegativeMiner(
|
|
embedding_dim=128,
|
|
num_hard_negatives=5,
|
|
use_gpu=False # Use CPU for testing
|
|
)
|
|
|
|
# Add some dummy embeddings
|
|
embeddings = np.random.rand(10, 128).astype(np.float32)
|
|
passages = [f"passage_{i}" for i in range(10)]
|
|
|
|
miner._add_to_index(embeddings, passages)
|
|
print("✓ Index creation working")
|
|
|
|
# Test mining
|
|
queries = ["query1", "query2"]
|
|
query_embeddings = np.random.rand(2, 128).astype(np.float32)
|
|
positive_passages = [["passage_0"], ["passage_5"]]
|
|
|
|
hard_negs = miner.mine_hard_negatives(
|
|
queries, query_embeddings, positive_passages
|
|
)
|
|
|
|
assert len(hard_negs) == 2
|
|
print("✓ Hard negative mining working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Hard negative mining error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_evaluation_metrics(self):
|
|
"""Test evaluation metrics"""
|
|
print("Testing evaluation metrics...")
|
|
|
|
try:
|
|
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
|
|
|
|
# Create dummy data
|
|
query_embs = np.random.rand(5, 128)
|
|
passage_embs = np.random.rand(20, 128)
|
|
labels = np.zeros((5, 20))
|
|
labels[0, 0] = 1 # First query, first passage is relevant
|
|
labels[1, 5] = 1 # Second query, sixth passage is relevant
|
|
|
|
metrics = compute_retrieval_metrics(query_embs, passage_embs, labels)
|
|
assert 'recall@1' in metrics
|
|
assert 'precision@1' in metrics
|
|
print("✓ Retrieval metrics working")
|
|
|
|
# Test reranker metrics
|
|
scores = np.random.rand(10)
|
|
labels = np.random.randint(0, 2, 10)
|
|
reranker_metrics = compute_reranker_metrics(scores, labels)
|
|
assert 'accuracy' in reranker_metrics or 'mrr' in reranker_metrics
|
|
print("✓ Reranker metrics working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Evaluation metrics error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_data_formats(self):
|
|
"""Test multiple data format support"""
|
|
print("Testing data format support...")
|
|
|
|
try:
|
|
from data.preprocessing import DataPreprocessor
|
|
|
|
preprocessor = DataPreprocessor()
|
|
|
|
# Test JSONL format
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
|
f.write('{"query": "test", "pos": ["pos1"], "neg": ["neg1"]}\n')
|
|
f.write('{"query": "test2", "pos": ["pos2"], "neg": ["neg2"]}\n')
|
|
temp_file = f.name
|
|
|
|
data = preprocessor._load_jsonl(temp_file)
|
|
assert len(data) == 2
|
|
os.unlink(temp_file)
|
|
print("✓ JSONL format working")
|
|
|
|
# Test CSV format
|
|
import pandas as pd
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
|
|
df = pd.DataFrame({
|
|
'query': ['test1', 'test2'],
|
|
'positive': ['pos1', 'pos2'],
|
|
'negative': ['neg1', 'neg2']
|
|
})
|
|
df.to_csv(f.name, index=False)
|
|
temp_file = f.name
|
|
|
|
data = preprocessor._load_csv(temp_file)
|
|
assert len(data) >= 1 # At least one valid record
|
|
os.unlink(temp_file)
|
|
print("✓ CSV format working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Data format support error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_device_compatibility(self):
|
|
"""Test device compatibility (CUDA/Ascend/CPU)"""
|
|
print("Testing device compatibility...")
|
|
|
|
try:
|
|
from config.base_config import BaseConfig
|
|
|
|
# Test device detection
|
|
config = BaseConfig()
|
|
print(f"✓ Device detected: {config.device}")
|
|
print(f"✓ Number of devices: {config.n_gpu}")
|
|
|
|
# Test tensor operations
|
|
if config.device == "cuda":
|
|
x = torch.randn(10, 10).cuda()
|
|
y = torch.matmul(x, x.t())
|
|
assert y.device.type == "cuda"
|
|
print("✓ CUDA tensor operations working")
|
|
elif config.device == "npu":
|
|
print("✓ Ascend NPU detected")
|
|
# Note: Would need actual NPU to test operations
|
|
else:
|
|
x = torch.randn(10, 10)
|
|
y = torch.matmul(x, x.t())
|
|
print("✓ CPU tensor operations working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Device compatibility error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_checkpoint_functionality(self):
|
|
"""Test checkpointing functionality"""
|
|
print("Testing checkpoint functionality...")
|
|
|
|
try:
|
|
from utils.checkpoint import CheckpointManager
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
manager = CheckpointManager(temp_dir)
|
|
|
|
# Test saving
|
|
state_dict = {
|
|
'model_state_dict': torch.randn(10, 10),
|
|
'optimizer_state_dict': {'lr': 0.001},
|
|
'step': 100
|
|
}
|
|
|
|
checkpoint_path = manager.save(state_dict, step=100)
|
|
assert os.path.exists(checkpoint_path)
|
|
print("✓ Checkpoint saving working")
|
|
|
|
# Test loading
|
|
loaded_state = manager.load_checkpoint(checkpoint_path)
|
|
assert 'model_state_dict' in loaded_state
|
|
print("✓ Checkpoint loading working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Checkpoint functionality error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_logging_system(self):
|
|
"""Test logging system"""
|
|
print("Testing logging system...")
|
|
|
|
try:
|
|
from utils.logging import setup_logging, get_logger, MetricsLogger
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Test setup
|
|
setup_logging(temp_dir, log_to_file=True)
|
|
logger = get_logger("test")
|
|
logger.info("Test log message")
|
|
print("✓ Basic logging working")
|
|
|
|
# Test metrics logger
|
|
metrics_logger = MetricsLogger(temp_dir)
|
|
metrics_logger.log(step=1, metrics={'loss': 0.5}, phase='train')
|
|
print("✓ Metrics logging working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Logging system error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_training_script_config_integration(self):
|
|
"""Test training script configuration integration"""
|
|
print("Testing training script configuration integration...")
|
|
|
|
try:
|
|
# Import the main training script functions
|
|
from scripts.train_m3 import create_config_from_args
|
|
|
|
# Create mock arguments
|
|
class MockArgs:
|
|
def __init__(self):
|
|
self.config_path = None
|
|
self.model_name_or_path = None
|
|
self.cache_dir = None
|
|
self.train_data = "dummy"
|
|
self.eval_data = None
|
|
self.max_seq_length = None
|
|
self.query_max_length = None
|
|
self.passage_max_length = None
|
|
self.max_length = 8192
|
|
self.output_dir = "/tmp/test"
|
|
self.num_train_epochs = None
|
|
self.per_device_train_batch_size = None
|
|
self.per_device_eval_batch_size = None
|
|
self.gradient_accumulation_steps = None
|
|
self.learning_rate = None
|
|
self.warmup_ratio = None
|
|
self.weight_decay = None
|
|
self.train_group_size = None
|
|
self.temperature = None
|
|
self.use_hard_negatives = False
|
|
self.num_hard_negatives = None
|
|
self.use_self_distill = False
|
|
self.use_query_instruction = False
|
|
self.query_instruction = None
|
|
self.use_dense = True
|
|
self.use_sparse = False
|
|
self.use_colbert = False
|
|
self.unified_finetuning = False
|
|
self.augment_queries = False
|
|
self.augmentation_methods = None
|
|
self.create_hard_negatives = False
|
|
self.fp16 = False
|
|
self.bf16 = False
|
|
self.gradient_checkpointing = False
|
|
self.dataloader_num_workers = None
|
|
self.save_steps = None
|
|
self.eval_steps = None
|
|
self.logging_steps = None
|
|
self.save_total_limit = None
|
|
self.resume_from_checkpoint = None
|
|
self.local_rank = -1
|
|
self.seed = None
|
|
self.overwrite_output_dir = True
|
|
|
|
args = MockArgs()
|
|
config = create_config_from_args(args)
|
|
|
|
assert config.learning_rate > 0
|
|
assert config.per_device_train_batch_size > 0
|
|
print("✓ Training script config integration working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Training script config integration error: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def run_all_tests(self):
|
|
"""Run all tests and generate report"""
|
|
print("\n" + "="*80)
|
|
print("BGE FINE-TUNING COMPREHENSIVE TEST SUITE")
|
|
print("="*80)
|
|
|
|
tests = [
|
|
("Core Dependencies Import", self.test_imports),
|
|
("Project Components Import", self.test_project_imports),
|
|
("Configuration System", self.test_configuration_system),
|
|
("Configuration Integration", self.test_configuration_integration),
|
|
("Data Processing", self.test_data_processing),
|
|
("Data Augmentation", self.test_data_augmentation),
|
|
("Model Functionality", self.test_model_functionality),
|
|
("Hard Negative Mining", self.test_hard_negative_mining),
|
|
("Evaluation Metrics", self.test_evaluation_metrics),
|
|
("Data Format Support", self.test_data_formats),
|
|
("Device Compatibility", self.test_device_compatibility),
|
|
("Checkpoint Functionality", self.test_checkpoint_functionality),
|
|
("Logging System", self.test_logging_system),
|
|
("Training Script Integration", self.test_training_script_config_integration),
|
|
]
|
|
|
|
passed = 0
|
|
total = len(tests)
|
|
|
|
for test_name, test_func in tests:
|
|
if self.run_test(test_name, test_func):
|
|
passed += 1
|
|
|
|
# Generate report
|
|
print("\n" + "="*80)
|
|
print("TEST RESULTS SUMMARY")
|
|
print("="*80)
|
|
print(f"Tests passed: {passed}/{total}")
|
|
print(f"Success rate: {passed/total*100:.1f}%")
|
|
|
|
if passed == total:
|
|
print("\n🎉 ALL TESTS PASSED! Project is ready for use.")
|
|
print("\nNext steps:")
|
|
print("1. Configure your API keys in config.toml if needed")
|
|
print("2. Prepare your training data in JSONL format")
|
|
print("3. Run training with: python scripts/train_m3.py --train_data your_data.jsonl")
|
|
print("4. Evaluate with: python scripts/validate.py quick --retriever_model output/model")
|
|
else:
|
|
print(f"\n❌ {total - passed} tests failed. Please review the errors above.")
|
|
print("\nFailed tests:")
|
|
for test_name in self.failed_tests:
|
|
print(f" - {test_name}")
|
|
|
|
print("\nDetailed test results:")
|
|
for test_name, status, error in self.test_results:
|
|
status_emoji = "✅" if status == "PASSED" else "❌"
|
|
print(f" {status_emoji} {test_name}: {status}")
|
|
if error:
|
|
print(f" Error: {error}")
|
|
|
|
return passed == total
|
|
|
|
|
|
def main():
|
|
"""Run comprehensive test suite"""
|
|
test_suite = ComprehensiveTestSuite()
|
|
success = test_suite.run_all_tests()
|
|
return 0 if success else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|