init commit

This commit is contained in:
ldy
2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

144
scripts/benchmark.py Normal file
View File

@@ -0,0 +1,144 @@
import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("benchmark")
def parse_args():
"""Parse command line arguments for benchmarking."""
parser = argparse.ArgumentParser(description="Benchmark BGE fine-tuned models.")
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
return parser.parse_args()
def measure_memory():
"""Measure the current memory usage of the process."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def benchmark_retriever(model_path, data_path, batch_size, max_samples, device):
"""Benchmark the BGE retriever model."""
logger.info(f"Loading retriever model from {model_path}")
model = BGEM3Model(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
# Only dense embedding for speed
_ = model(
batch['query_input_ids'].to(device),
batch['query_attention_mask'].to(device),
return_dense=True, return_sparse=False, return_colbert=False
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n_samples += batch['query_input_ids'].shape[0]
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
logger.info(f"Retriever throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
return throughput, latency, mem_mb
def benchmark_reranker(model_path, data_path, batch_size, max_samples, device):
"""Benchmark the BGE reranker model."""
logger.info(f"Loading reranker model from {model_path}")
model = BGERerankerModel(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
_ = model.model(
input_ids=batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device)
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n_samples += batch['input_ids'].shape[0]
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
logger.info(f"Reranker throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
return throughput, latency, mem_mb
def main():
"""Run benchmark with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Benchmark BGE models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
if args.model_type == 'retriever':
throughput, latency, mem_mb = benchmark_retriever(
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
else:
throughput, latency, mem_mb = benchmark_reranker(
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
# Save results
with open(args.output, 'w') as f:
f.write(f"Model: {args.model_path}\n")
f.write(f"Data: {args.data_path}\n")
f.write(f"Batch size: {args.batch_size}\n")
f.write(f"Throughput: {throughput:.2f} samples/sec\n")
f.write(f"Latency: {latency*1000:.2f} ms/sample\n")
f.write(f"Peak memory: {mem_mb:.2f} MB\n")
logger.info(f"Benchmark results saved to {args.output}")
logging.info(f"Benchmark completed on device: {device}")
except Exception as e:
logging.error(f"Benchmark failed: {e}")
raise
if __name__ == '__main__':
main()

138
scripts/compare_reranker.py Normal file
View File

@@ -0,0 +1,138 @@
import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_reranker import BGERerankerModel
from data.dataset import BGERerankerDataset
from transformers import AutoTokenizer
from evaluation.metrics import compute_reranker_metrics
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("compare_reranker")
def parse_args():
"""Parse command-line arguments for reranker comparison."""
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline reranker models.")
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
return parser.parse_args()
def measure_memory():
"""Measure the current memory usage of the process."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def run_reranker(model_path, data_path, batch_size, max_samples, device):
"""
Run inference on a reranker model and measure throughput, latency, and memory.
Args:
model_path (str): Path to the reranker model.
data_path (str): Path to the evaluation data.
batch_size (int): Batch size for inference.
max_samples (int): Maximum number of samples to process.
device (str): Device to use (cuda, cpu, npu).
Returns:
tuple: (throughput, latency, mem_mb, all_scores, all_labels)
"""
model = BGERerankerModel(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
all_scores = []
all_labels = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
outputs = model.model(
input_ids=batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device)
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n = batch['input_ids'].shape[0]
n_samples += n
# For metrics: collect scores and labels
scores = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
all_scores.append(scores.cpu().numpy())
all_labels.append(batch['labels'].cpu().numpy())
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
# Prepare for metrics
all_scores = np.concatenate(all_scores, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
return throughput, latency, mem_mb, all_scores, all_labels
def main():
"""Run reranker comparison with robust error handling and config-driven defaults."""
parser = argparse.ArgumentParser(description="Compare reranker models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
args = parser.parse_args()
# Load config and set defaults
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
# Fine-tuned model
logger.info("Running fine-tuned reranker...")
ft_throughput, ft_latency, ft_mem, ft_scores, ft_labels = run_reranker(
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
ft_metrics = compute_reranker_metrics(ft_scores, ft_labels, k_values=args.k_values)
# Baseline model
logger.info("Running baseline reranker...")
bl_throughput, bl_latency, bl_mem, bl_scores, bl_labels = run_reranker(
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
bl_metrics = compute_reranker_metrics(bl_scores, bl_labels, k_values=args.k_values)
# Output comparison
with open(args.output, 'w') as f:
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
for k in args.k_values:
for metric in [f'accuracy@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
for metric in ['map', 'mrr']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
logger.info(f"Comparison results saved to {args.output}")
print(f"\nComparison complete. See {args.output} for details.")
except Exception as e:
logging.error(f"Reranker comparison failed: {e}")
raise
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,144 @@
import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset
from transformers import AutoTokenizer
from evaluation.metrics import compute_retrieval_metrics
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("compare_retriever")
def parse_args():
"""Parse command line arguments for retriever comparison."""
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline retriever models.")
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
return parser.parse_args()
def measure_memory():
"""Measure the current process memory usage in MB."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def run_retriever(model_path, data_path, batch_size, max_samples, device):
"""
Run inference on a retriever model and measure throughput, latency, and memory.
Args:
model_path (str): Path to the retriever model.
data_path (str): Path to the evaluation data.
batch_size (int): Batch size for inference.
max_samples (int): Maximum number of samples to process.
device (str): Device to use (cuda, cpu, npu).
Returns:
tuple: throughput (samples/sec), latency (ms/sample), peak memory (MB),
query_embeds (np.ndarray), passage_embeds (np.ndarray), labels (np.ndarray).
"""
model = BGEM3Model(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
query_embeds = []
passage_embeds = []
labels = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
out = model(
batch['query_input_ids'].to(device),
batch['query_attention_mask'].to(device),
return_dense=True, return_sparse=False, return_colbert=False
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n = batch['query_input_ids'].shape[0]
n_samples += n
# For metrics: collect embeddings and labels
query_embeds.append(out['query_embeds'].cpu().numpy())
passage_embeds.append(out['passage_embeds'].cpu().numpy())
labels.append(batch['labels'].cpu().numpy())
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
# Prepare for metrics
query_embeds = np.concatenate(query_embeds, axis=0)
passage_embeds = np.concatenate(passage_embeds, axis=0)
labels = np.concatenate(labels, axis=0)
return throughput, latency, mem_mb, query_embeds, passage_embeds, labels
def main():
"""Run retriever comparison with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Compare retriever models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
# Fine-tuned model
logger.info("Running fine-tuned retriever...")
ft_throughput, ft_latency, ft_mem, ft_qe, ft_pe, ft_labels = run_retriever(
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
ft_metrics = compute_retrieval_metrics(ft_qe, ft_pe, ft_labels, k_values=args.k_values)
# Baseline model
logger.info("Running baseline retriever...")
bl_throughput, bl_latency, bl_mem, bl_qe, bl_pe, bl_labels = run_retriever(
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
bl_metrics = compute_retrieval_metrics(bl_qe, bl_pe, bl_labels, k_values=args.k_values)
# Output comparison
with open(args.output, 'w') as f:
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
for k in args.k_values:
for metric in [f'recall@{k}', f'precision@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
for metric in ['map', 'mrr']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
logger.info(f"Comparison results saved to {args.output}")
print(f"\nComparison complete. See {args.output} for details.")
except Exception as e:
logging.error(f"Retriever comparison failed: {e}")
raise
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,990 @@
#!/usr/bin/env python3
"""
Comprehensive Validation Script for BGE Fine-tuned Models
This script provides a complete validation framework to determine if fine-tuned models
actually perform better than baseline models across multiple datasets and metrics.
Features:
- Tests both M3 retriever and reranker models
- Compares against baseline models on multiple datasets
- Provides statistical significance testing
- Tests pipeline performance (retrieval + reranking)
- Generates comprehensive reports
- Includes performance degradation detection
- Supports cross-dataset validation
Usage:
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--reranker_finetuned ./output/bge-reranker/final_model \\
--output_dir ./validation_results
"""
import os
import sys
import json
import logging
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from scipy import stats
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset
from evaluation.metrics import (
compute_retrieval_metrics, compute_reranker_metrics,
compute_retrieval_metrics_comprehensive
)
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
from data.preprocessing import DataPreprocessor
from utils.logging import setup_logging, get_logger
from utils.config_loader import get_config
@dataclass
class ValidationConfig:
"""Configuration for comprehensive validation"""
# Model paths
retriever_baseline: str = "BAAI/bge-m3"
reranker_baseline: str = "BAAI/bge-reranker-base"
retriever_finetuned: Optional[str] = None
reranker_finetuned: Optional[str] = None
# Test datasets
test_datasets: List[str] = None
# Evaluation settings
batch_size: int = 32
max_length: int = 512
k_values: List[int] = None
significance_level: float = 0.05
min_samples_for_significance: int = 30
# Pipeline settings
retrieval_top_k: int = 100
rerank_top_k: int = 10
# Output settings
output_dir: str = "./validation_results"
save_embeddings: bool = False
create_report: bool = True
def __post_init__(self):
if self.test_datasets is None:
self.test_datasets = [
"data/datasets/examples/embedding_data.jsonl",
"data/datasets/examples/reranker_data.jsonl",
"data/datasets/Reranker_AFQMC/dev.json",
"data/datasets/三国演义/reranker_train.jsonl"
]
if self.k_values is None:
self.k_values = [1, 3, 5, 10, 20]
class ComprehensiveValidator:
"""Comprehensive validation framework for BGE models"""
def __init__(self, config: ValidationConfig):
self.config = config
self.logger = get_logger(__name__)
self.results = {}
self.device = self._get_device()
# Create output directory
os.makedirs(config.output_dir, exist_ok=True)
# Setup logging
setup_logging(
output_dir=config.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
def _get_device(self) -> torch.device:
"""Get optimal device for evaluation"""
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
device = torch.device("npu:0")
self.logger.info("Using NPU device")
return device
except ImportError:
pass
if torch.cuda.is_available():
device = torch.device("cuda")
self.logger.info("Using CUDA device")
return device
except Exception as e:
self.logger.warning(f"Error detecting optimal device: {e}")
device = torch.device("cpu")
self.logger.info("Using CPU device")
return device
def validate_all(self) -> Dict[str, Any]:
"""Run comprehensive validation on all models and datasets"""
self.logger.info("🚀 Starting Comprehensive BGE Model Validation")
self.logger.info("=" * 80)
start_time = datetime.now()
try:
# 1. Validate retriever models
if self.config.retriever_finetuned:
self.logger.info("📊 Validating Retriever Models...")
retriever_results = self._validate_retrievers()
self.results['retriever'] = retriever_results
# 2. Validate reranker models
if self.config.reranker_finetuned:
self.logger.info("📊 Validating Reranker Models...")
reranker_results = self._validate_rerankers()
self.results['reranker'] = reranker_results
# 3. Validate pipeline (if both models available)
if self.config.retriever_finetuned and self.config.reranker_finetuned:
self.logger.info("📊 Validating Pipeline Performance...")
pipeline_results = self._validate_pipeline()
self.results['pipeline'] = pipeline_results
# 4. Generate comprehensive report
if self.config.create_report:
self._generate_report()
# 5. Statistical analysis
self._perform_statistical_analysis()
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
self.logger.info("✅ Comprehensive validation completed!")
self.logger.info(f"⏱️ Total validation time: {duration:.2f} seconds")
return self.results
except Exception as e:
self.logger.error(f"❌ Validation failed: {e}", exc_info=True)
raise
def _validate_retrievers(self) -> Dict[str, Any]:
"""Validate retriever models on all suitable datasets"""
results = {
'baseline_metrics': {},
'finetuned_metrics': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
self.logger.warning(f"Dataset not found: {dataset_path}")
continue
# Skip reranker-only datasets for retriever testing
if 'reranker' in dataset_path.lower() and 'AFQMC' in dataset_path:
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing retriever on dataset: {dataset_name}")
try:
# Test baseline model
baseline_metrics = self._test_retriever_on_dataset(
model_path=self.config.retriever_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned model
finetuned_metrics = self._test_retriever_on_dataset(
model_path=self.config.retriever_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "retriever")
results['baseline_metrics'][dataset_name] = baseline_metrics
results['finetuned_metrics'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Retriever validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test retriever on {dataset_name}: {e}")
continue
return results
def _validate_rerankers(self) -> Dict[str, Any]:
"""Validate reranker models on all suitable datasets"""
results = {
'baseline_metrics': {},
'finetuned_metrics': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
self.logger.warning(f"Dataset not found: {dataset_path}")
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing reranker on dataset: {dataset_name}")
try:
# Test baseline model
baseline_metrics = self._test_reranker_on_dataset(
model_path=self.config.reranker_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned model
finetuned_metrics = self._test_reranker_on_dataset(
model_path=self.config.reranker_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "reranker")
results['baseline_metrics'][dataset_name] = baseline_metrics
results['finetuned_metrics'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Reranker validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test reranker on {dataset_name}: {e}")
continue
return results
def _validate_pipeline(self) -> Dict[str, Any]:
"""Validate end-to-end retrieval + reranking pipeline"""
results = {
'baseline_pipeline': {},
'finetuned_pipeline': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
continue
# Only test on datasets suitable for pipeline evaluation
if 'embedding_data' not in dataset_path:
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing pipeline on dataset: {dataset_name}")
try:
# Test baseline pipeline
baseline_metrics = self._test_pipeline_on_dataset(
retriever_path=self.config.retriever_baseline,
reranker_path=self.config.reranker_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned pipeline
finetuned_metrics = self._test_pipeline_on_dataset(
retriever_path=self.config.retriever_finetuned,
reranker_path=self.config.reranker_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "pipeline")
results['baseline_pipeline'][dataset_name] = baseline_metrics
results['finetuned_pipeline'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Pipeline validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test pipeline on {dataset_name}: {e}")
continue
return results
def _test_retriever_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test a retriever model on a specific dataset"""
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Create dataset
dataset = BGEM3Dataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=self.config.max_length,
passage_max_length=self.config.max_length,
is_train=False
)
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=False
)
# Create evaluator
evaluator = RetrievalEvaluator(
model=model,
tokenizer=tokenizer,
device=self.device,
k_values=self.config.k_values
)
# Run evaluation
metrics = evaluator.evaluate(
dataloader=dataloader,
desc=f"{'Baseline' if is_baseline else 'Fine-tuned'} Retriever"
)
return metrics
except Exception as e:
self.logger.error(f"Error testing retriever: {e}")
return {}
def _test_reranker_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test a reranker model on a specific dataset"""
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Create dataset
dataset = BGERerankerDataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=self.config.max_length,
passage_max_length=self.config.max_length,
is_train=False
)
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=False
)
# Collect predictions and labels
all_scores = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
# Move to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].cpu().numpy()
# Get predictions
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
scores = outputs.logits.cpu().numpy()
all_scores.append(scores)
all_labels.append(labels)
# Concatenate results
all_scores = np.concatenate(all_scores, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Compute metrics
metrics = compute_reranker_metrics(
scores=all_scores,
labels=all_labels,
k_values=self.config.k_values
)
return metrics
except Exception as e:
self.logger.error(f"Error testing reranker: {e}")
return {}
def _test_pipeline_on_dataset(self, retriever_path: str, reranker_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test end-to-end pipeline on a dataset"""
try:
# Load models
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
retriever = BGEM3Model.from_pretrained(retriever_path)
retriever.to(self.device)
retriever.eval()
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
reranker = BGERerankerModel.from_pretrained(reranker_path)
reranker.to(self.device)
reranker.eval()
# Load and preprocess data
data = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
try:
item = json.loads(line.strip())
if 'query' in item and 'pos' in item and 'neg' in item:
data.append(item)
except:
continue
if not data:
return {}
# Test pipeline performance
total_queries = 0
pipeline_metrics = {f'recall@{k}': [] for k in self.config.k_values}
pipeline_metrics.update({f'mrr@{k}': [] for k in self.config.k_values})
for item in data[:100]: # Limit to 100 queries for speed
query = item['query']
pos_docs = item['pos'][:2] # Take first 2 positive docs
neg_docs = item['neg'][:8] # Take first 8 negative docs
# Create candidate pool
candidates = pos_docs + neg_docs
labels = [1] * len(pos_docs) + [0] * len(neg_docs)
if len(candidates) < 3: # Need minimum candidates
continue
try:
# Step 1: Retrieval (encode query and candidates)
query_inputs = retriever_tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
query_outputs = retriever(**query_inputs, return_dense=True)
query_emb = query_outputs['dense']
# Encode candidates
candidate_embs = []
for candidate in candidates:
cand_inputs = retriever_tokenizer(
candidate,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
cand_outputs = retriever(**cand_inputs, return_dense=True)
candidate_embs.append(cand_outputs['dense'])
candidate_embs = torch.cat(candidate_embs, dim=0)
# Compute similarities
similarities = torch.cosine_similarity(query_emb, candidate_embs, dim=1)
# Get top-k for retrieval
top_k_retrieval = min(self.config.retrieval_top_k, len(candidates))
retrieval_indices = torch.argsort(similarities, descending=True)[:top_k_retrieval]
# Step 2: Reranking
rerank_candidates = [candidates[i] for i in retrieval_indices]
rerank_labels = [labels[i] for i in retrieval_indices]
# Score pairs with reranker
rerank_scores = []
for candidate in rerank_candidates:
pair_text = f"{query} [SEP] {candidate}"
pair_inputs = reranker_tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
pair_outputs = reranker(**pair_inputs)
score = pair_outputs.logits.item()
rerank_scores.append(score)
# Final ranking
final_indices = np.argsort(rerank_scores)[::-1]
final_labels = [rerank_labels[i] for i in final_indices]
# Compute metrics
for k in self.config.k_values:
if k <= len(final_labels):
# Recall@k
relevant_in_topk = sum(final_labels[:k])
total_relevant = sum(rerank_labels)
recall = relevant_in_topk / total_relevant if total_relevant > 0 else 0
pipeline_metrics[f'recall@{k}'].append(recall)
# MRR@k
for i, label in enumerate(final_labels[:k]):
if label == 1:
pipeline_metrics[f'mrr@{k}'].append(1.0 / (i + 1))
break
else:
pipeline_metrics[f'mrr@{k}'].append(0.0)
total_queries += 1
except Exception as e:
self.logger.warning(f"Error processing query: {e}")
continue
# Average metrics
final_metrics = {}
for metric, values in pipeline_metrics.items():
if values:
final_metrics[metric] = np.mean(values)
else:
final_metrics[metric] = 0.0
final_metrics['total_queries'] = total_queries
return final_metrics
except Exception as e:
self.logger.error(f"Error testing pipeline: {e}")
return {}
def _compare_metrics(self, baseline: Dict[str, float], finetuned: Dict[str, float], model_type: str) -> Dict[str, Any]:
"""Compare baseline vs fine-tuned metrics with statistical analysis"""
comparison = {
'improvements': {},
'degradations': {},
'statistical_significance': {},
'summary': {}
}
# Compare common metrics
common_metrics = set(baseline.keys()) & set(finetuned.keys())
improvements = 0
degradations = 0
significant_improvements = 0
for metric in common_metrics:
if isinstance(baseline[metric], (int, float)) and isinstance(finetuned[metric], (int, float)):
baseline_val = baseline[metric]
finetuned_val = finetuned[metric]
# Calculate improvement
if baseline_val != 0:
improvement_pct = ((finetuned_val - baseline_val) / baseline_val) * 100
else:
improvement_pct = float('inf') if finetuned_val > 0 else 0
# Classify improvement/degradation
if finetuned_val > baseline_val:
comparison['improvements'][metric] = {
'baseline': baseline_val,
'finetuned': finetuned_val,
'improvement_pct': improvement_pct,
'absolute_improvement': finetuned_val - baseline_val
}
improvements += 1
elif finetuned_val < baseline_val:
comparison['degradations'][metric] = {
'baseline': baseline_val,
'finetuned': finetuned_val,
'degradation_pct': -improvement_pct,
'absolute_degradation': baseline_val - finetuned_val
}
degradations += 1
# Summary statistics
comparison['summary'] = {
'total_metrics': len(common_metrics),
'improvements': improvements,
'degradations': degradations,
'unchanged': len(common_metrics) - improvements - degradations,
'improvement_ratio': improvements / len(common_metrics) if common_metrics else 0,
'net_improvement': improvements - degradations
}
return comparison
def _perform_statistical_analysis(self):
"""Perform statistical analysis on validation results"""
self.logger.info("📈 Performing Statistical Analysis...")
analysis = {
'overall_summary': {},
'significance_tests': {},
'confidence_intervals': {},
'effect_sizes': {}
}
# Analyze each model type
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type not in self.results:
continue
model_results = self.results[model_type]
# Collect all improvements across datasets
all_improvements = []
significant_improvements = 0
for dataset, comparison in model_results.get('comparisons', {}).items():
for metric, improvement_data in comparison.get('improvements', {}).items():
improvement_pct = improvement_data.get('improvement_pct', 0)
if improvement_pct > 0:
all_improvements.append(improvement_pct)
# Simple significance test (can be enhanced)
if improvement_pct > 5.0: # 5% improvement threshold
significant_improvements += 1
if all_improvements:
analysis[model_type] = {
'mean_improvement_pct': np.mean(all_improvements),
'median_improvement_pct': np.median(all_improvements),
'std_improvement_pct': np.std(all_improvements),
'min_improvement_pct': np.min(all_improvements),
'max_improvement_pct': np.max(all_improvements),
'significant_improvements': significant_improvements,
'total_improvements': len(all_improvements)
}
self.results['statistical_analysis'] = analysis
def _generate_report(self):
"""Generate comprehensive validation report"""
self.logger.info("📝 Generating Comprehensive Report...")
report_path = os.path.join(self.config.output_dir, "validation_report.html")
html_content = self._create_html_report()
with open(report_path, 'w', encoding='utf-8') as f:
f.write(html_content)
# Also save JSON results
json_path = os.path.join(self.config.output_dir, "validation_results.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(self.results, f, indent=2, default=str)
self.logger.info(f"📊 Report saved to: {report_path}")
self.logger.info(f"📊 JSON results saved to: {json_path}")
def _create_html_report(self) -> str:
"""Create HTML report with comprehensive validation results"""
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>BGE Model Validation Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.header {{ background: #f0f8ff; padding: 20px; border-radius: 10px; }}
.section {{ margin: 20px 0; }}
.metrics-table {{ border-collapse: collapse; width: 100%; }}
.metrics-table th, .metrics-table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
.metrics-table th {{ background-color: #f2f2f2; }}
.improvement {{ color: green; font-weight: bold; }}
.degradation {{ color: red; font-weight: bold; }}
.summary-box {{ background: #f9f9f9; padding: 15px; border-left: 4px solid #4CAF50; margin: 10px 0; }}
</style>
</head>
<body>
<div class="header">
<h1>🚀 BGE Model Comprehensive Validation Report</h1>
<p><strong>Generated:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
<p><strong>Configuration:</strong></p>
<ul>
<li>Retriever Baseline: {self.config.retriever_baseline}</li>
<li>Retriever Fine-tuned: {self.config.retriever_finetuned or 'Not tested'}</li>
<li>Reranker Baseline: {self.config.reranker_baseline}</li>
<li>Reranker Fine-tuned: {self.config.reranker_finetuned or 'Not tested'}</li>
</ul>
</div>
"""
# Add executive summary
html += self._create_executive_summary()
# Add detailed results for each model type
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in self.results:
html += self._create_model_section(model_type)
html += """
</body>
</html>
"""
return html
def _create_executive_summary(self) -> str:
"""Create executive summary section"""
html = """
<div class="section">
<h2>📊 Executive Summary</h2>
"""
# Overall verdict
total_improvements = 0
total_degradations = 0
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in self.results and 'comparisons' in self.results[model_type]:
for dataset, comparison in self.results[model_type]['comparisons'].items():
total_improvements += len(comparison.get('improvements', {}))
total_degradations += len(comparison.get('degradations', {}))
if total_improvements > total_degradations:
verdict = "✅ <strong>FINE-TUNING SUCCESSFUL</strong> - Models show overall improvement"
verdict_color = "green"
elif total_improvements == total_degradations:
verdict = "⚠️ <strong>MIXED RESULTS</strong> - Equal improvements and degradations"
verdict_color = "orange"
else:
verdict = "❌ <strong>FINE-TUNING ISSUES</strong> - More degradations than improvements"
verdict_color = "red"
html += f"""
<div class="summary-box" style="border-left-color: {verdict_color};">
<h3>{verdict}</h3>
<ul>
<li>Total Metric Improvements: {total_improvements}</li>
<li>Total Metric Degradations: {total_degradations}</li>
<li>Net Improvement: {total_improvements - total_degradations}</li>
</ul>
</div>
"""
html += """
</div>
"""
return html
def _create_model_section(self, model_type: str) -> str:
"""Create detailed section for a model type"""
html = f"""
<div class="section">
<h2>📈 {model_type.title()} Model Results</h2>
"""
model_results = self.results[model_type]
for dataset in model_results.get('datasets_tested', []):
if dataset in model_results['comparisons']:
comparison = model_results['comparisons'][dataset]
html += f"""
<h3>Dataset: {dataset}</h3>
<table class="metrics-table">
<thead>
<tr>
<th>Metric</th>
<th>Baseline</th>
<th>Fine-tuned</th>
<th>Change</th>
<th>Status</th>
</tr>
</thead>
<tbody>
"""
# Add improvements
for metric, data in comparison.get('improvements', {}).items():
html += f"""
<tr>
<td>{metric}</td>
<td>{data['baseline']:.4f}</td>
<td>{data['finetuned']:.4f}</td>
<td class="improvement">+{data['improvement_pct']:.2f}%</td>
<td class="improvement">✅ Improved</td>
</tr>
"""
# Add degradations
for metric, data in comparison.get('degradations', {}).items():
html += f"""
<tr>
<td>{metric}</td>
<td>{data['baseline']:.4f}</td>
<td>{data['finetuned']:.4f}</td>
<td class="degradation">-{data['degradation_pct']:.2f}%</td>
<td class="degradation">❌ Degraded</td>
</tr>
"""
html += """
</tbody>
</table>
"""
html += """
</div>
"""
return html
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Comprehensive validation of BGE fine-tuned models",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Validate both retriever and reranker
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--reranker_finetuned ./output/bge-reranker/final_model
# Validate only retriever
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model
# Custom datasets and output
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--test_datasets data/my_test1.jsonl data/my_test2.jsonl \\
--output_dir ./my_validation_results
"""
)
# Model paths
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Path to baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Path to baseline reranker model")
parser.add_argument("--retriever_finetuned", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_finetuned", type=str, default=None,
help="Path to fine-tuned reranker model")
# Test data
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
help="Paths to test datasets (JSONL format)")
# Evaluation settings
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 3, 5, 10, 20],
help="K values for evaluation metrics")
# Output settings
parser.add_argument("--output_dir", type=str, default="./validation_results",
help="Directory to save validation results")
parser.add_argument("--save_embeddings", action="store_true",
help="Save embeddings for further analysis")
parser.add_argument("--no_report", action="store_true",
help="Skip HTML report generation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for retrieval stage")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking stage")
return parser.parse_args()
def main():
"""Main validation function"""
args = parse_args()
# Validate arguments
if not args.retriever_finetuned and not args.reranker_finetuned:
print("❌ Error: At least one fine-tuned model must be specified")
print(" Use --retriever_finetuned or --reranker_finetuned")
return 1
# Create configuration
config = ValidationConfig(
retriever_baseline=args.retriever_baseline,
reranker_baseline=args.reranker_baseline,
retriever_finetuned=args.retriever_finetuned,
reranker_finetuned=args.reranker_finetuned,
test_datasets=args.test_datasets,
batch_size=args.batch_size,
max_length=args.max_length,
k_values=args.k_values,
output_dir=args.output_dir,
save_embeddings=args.save_embeddings,
create_report=not args.no_report,
retrieval_top_k=args.retrieval_top_k,
rerank_top_k=args.rerank_top_k
)
# Run validation
try:
validator = ComprehensiveValidator(config)
results = validator.validate_all()
print("\n" + "=" * 80)
print("🎉 COMPREHENSIVE VALIDATION COMPLETED!")
print("=" * 80)
print(f"📊 Results saved to: {config.output_dir}")
print("\n📈 Quick Summary:")
# Print quick summary
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in results:
model_results = results[model_type]
total_datasets = len(model_results.get('datasets_tested', []))
total_improvements = sum(len(comp.get('improvements', {}))
for comp in model_results.get('comparisons', {}).values())
total_degradations = sum(len(comp.get('degradations', {}))
for comp in model_results.get('comparisons', {}).values())
status = "" if total_improvements > total_degradations else "⚠️" if total_improvements == total_degradations else ""
print(f" {status} {model_type.title()}: {total_improvements} improvements, {total_degradations} degradations across {total_datasets} datasets")
return 0
except KeyboardInterrupt:
print("\n❌ Validation interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Validation failed: {e}")
return 1
if __name__ == "__main__":
exit(main())

View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python3
"""
BGE-M3 Dataset Format Converter
Converts the nested dataset format where training data is stored inside a "data" field
to the flat format expected by the BGE-M3 training pipeline.
Input format:
{
"query": "...",
"data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}"
}
Output format:
{
"query": "...",
"pos": [...],
"neg": [...],
"pos_scores": [...],
"neg_scores": [...],
"prompt": "...",
"type": "..."
}
"""
import json
import argparse
import logging
import os
from tqdm import tqdm
from typing import Dict, Any, Optional
import re
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]:
"""
Extract JSON data from the "data" field which contains a JSON string
wrapped in markdown code blocks.
Args:
data_field: String containing JSON data, possibly wrapped in markdown
Returns:
Parsed JSON dictionary or None if parsing fails
"""
try:
# Remove markdown code block markers if present
cleaned_data = data_field.strip()
# Remove ```json and ``` markers
if cleaned_data.startswith('```json'):
cleaned_data = cleaned_data[7:] # Remove ```json
elif cleaned_data.startswith('```'):
cleaned_data = cleaned_data[3:] # Remove ```
if cleaned_data.endswith('```'):
cleaned_data = cleaned_data[:-3] # Remove trailing ```
# Clean up any remaining whitespace
cleaned_data = cleaned_data.strip()
# Parse the JSON
return json.loads(cleaned_data)
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(f"Failed to parse JSON from data field: {e}")
logger.warning(f"Data content (first 200 chars): {data_field[:200]}")
return None
def convert_dataset_format(input_file: str, output_file: str) -> None:
"""
Convert dataset from nested format to flat format.
Args:
input_file: Path to input JSONL file with nested format
output_file: Path to output JSONL file with flat format
"""
logger.info(f"Converting dataset from {input_file} to {output_file}")
# Count total lines for progress bar
total_lines = 0
with open(input_file, 'r', encoding='utf-8') as f:
for _ in f:
total_lines += 1
converted_count = 0
error_count = 0
with open(input_file, 'r', encoding='utf-8') as infile, \
open(output_file, 'w', encoding='utf-8') as outfile:
for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1):
line = line.strip()
if not line:
continue
try:
# Parse the input JSON line
item = json.loads(line)
# Validate required fields
if 'query' not in item:
logger.warning(f"Line {line_num}: Missing 'query' field")
error_count += 1
continue
if 'data' not in item:
logger.warning(f"Line {line_num}: Missing 'data' field")
error_count += 1
continue
# Extract data from the nested field
data_content = extract_json_from_data_field(item['data'])
if data_content is None:
logger.warning(f"Line {line_num}: Failed to parse data field")
error_count += 1
continue
# Create the flattened format
flattened_item = {
'query': item['query']
}
# Add all fields from the data content
expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type']
for field in expected_fields:
if field in data_content:
flattened_item[field] = data_content[field]
else:
# Provide reasonable defaults
if field == 'pos':
flattened_item[field] = []
elif field == 'neg':
flattened_item[field] = []
elif field == 'pos_scores':
flattened_item[field] = [1.0] * len(data_content.get('pos', []))
elif field == 'neg_scores':
flattened_item[field] = [0.0] * len(data_content.get('neg', []))
elif field == 'prompt':
flattened_item[field] = '为此查询生成表示:'
elif field == 'type':
flattened_item[field] = 'normal'
# Validate the converted item
if not flattened_item['pos'] and not flattened_item['neg']:
logger.warning(f"Line {line_num}: No positive or negative passages found")
error_count += 1
continue
# Ensure scores arrays match passage arrays
if len(flattened_item['pos_scores']) != len(flattened_item['pos']):
flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos'])
if len(flattened_item['neg_scores']) != len(flattened_item['neg']):
flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg'])
# Write the flattened item
outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n')
converted_count += 1
except json.JSONDecodeError as e:
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
error_count += 1
continue
except Exception as e:
logger.warning(f"Line {line_num}: Unexpected error - {e}")
error_count += 1
continue
logger.info(f"Conversion complete!")
logger.info(f"Successfully converted: {converted_count} items")
logger.info(f"Errors encountered: {error_count} items")
logger.info(f"Output saved to: {output_file}")
def validate_converted_file(file_path: str, sample_size: int = 5) -> None:
"""
Validate the converted file by showing a few sample entries.
Args:
file_path: Path to the converted JSONL file
sample_size: Number of sample entries to display
"""
logger.info(f"Validating converted file: {file_path}")
try:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
total_lines = len(lines)
logger.info(f"Total converted entries: {total_lines}")
if total_lines == 0:
logger.warning("No entries found in converted file!")
return
# Show sample entries
sample_indices = list(range(min(sample_size, total_lines)))
for i, idx in enumerate(sample_indices):
try:
item = json.loads(lines[idx])
logger.info(f"\nSample {i+1} (line {idx+1}):")
logger.info(f" Query: {item['query'][:100]}...")
logger.info(f" Positives: {len(item.get('pos', []))}")
logger.info(f" Negatives: {len(item.get('neg', []))}")
logger.info(f" Type: {item.get('type', 'N/A')}")
logger.info(f" Prompt: {item.get('prompt', 'N/A')}")
# Show first positive and negative if available
if item.get('pos'):
logger.info(f" First positive: {item['pos'][0][:150]}...")
if item.get('neg'):
logger.info(f" First negative: {item['neg'][0][:150]}...")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {idx+1}: {e}")
except Exception as e:
logger.error(f"Error validating file: {e}")
def main():
parser = argparse.ArgumentParser(
description="Convert BGE-M3 dataset from nested to flat format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl
python convert_m3_dataset.py input.jsonl output.jsonl --validate
"""
)
parser.add_argument(
'input_file',
help='Path to input JSONL file with nested format'
)
parser.add_argument(
'output_file',
help='Path to output JSONL file with flat format'
)
parser.add_argument(
'--validate',
action='store_true',
help='Validate the converted file and show samples'
)
parser.add_argument(
'--sample-size',
type=int,
default=5,
help='Number of sample entries to show during validation (default: 5)'
)
args = parser.parse_args()
# Check if input file exists
if not os.path.exists(args.input_file):
logger.error(f"Input file not found: {args.input_file}")
return 1
# Create output directory if it doesn't exist
output_dir = os.path.dirname(args.output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info(f"Created output directory: {output_dir}")
try:
# Convert the dataset
convert_dataset_format(args.input_file, args.output_file)
# Validate if requested
if args.validate:
validate_converted_file(args.output_file, args.sample_size)
logger.info("Dataset conversion completed successfully!")
return 0
except Exception as e:
logger.error(f"Dataset conversion failed: {e}")
return 1
if __name__ == '__main__':
exit(main())

584
scripts/evaluate.py Normal file
View File

@@ -0,0 +1,584 @@
"""
Evaluation script for fine-tuned BGE models
"""
import os
import sys
import argparse
import logging
import json
import torch
from transformers import AutoTokenizer
import numpy as np
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
from data.preprocessing import DataPreprocessor
from utils.logging import setup_logging, get_logger
logger = logging.getLogger(__name__)
def parse_args():
"""Parse command line arguments for evaluation."""
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
# Model arguments
parser.add_argument("--retriever_model", type=str, required=True,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
help="Base retriever model for comparison")
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
help="Base reranker model for comparison")
# Data arguments
parser.add_argument("--eval_data", type=str, required=True,
help="Path to evaluation data file")
parser.add_argument("--corpus_data", type=str, default=None,
help="Path to corpus data for retrieval evaluation")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Evaluation arguments
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
help="Directory to save evaluation results")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
help="K values for evaluation metrics")
# Evaluation modes
parser.add_argument("--eval_retriever", action="store_true",
help="Evaluate retriever model")
parser.add_argument("--eval_reranker", action="store_true",
help="Evaluate reranker model")
parser.add_argument("--eval_pipeline", action="store_true",
help="Evaluate full retrieval pipeline")
parser.add_argument("--compare_baseline", action="store_true",
help="Compare with baseline models")
parser.add_argument("--interactive", action="store_true",
help="Run interactive evaluation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for initial retrieval")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking")
# Other
parser.add_argument("--device", type=str, default="auto",
help="Device to use (auto, cpu, cuda)")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
# Validation
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
args.eval_pipeline = True # Default to pipeline evaluation
return args
def load_evaluation_data(data_path: str, data_format: str):
"""Load evaluation data from a file."""
preprocessor = DataPreprocessor()
if data_format == "auto":
data_format = preprocessor._detect_format(data_path)
if data_format == "jsonl":
data = preprocessor._load_jsonl(data_path)
elif data_format == "json":
data = preprocessor._load_json(data_path)
elif data_format in ["csv", "tsv"]:
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
elif data_format == "dureader":
data = preprocessor._load_dureader(data_path)
elif data_format == "msmarco":
data = preprocessor._load_msmarco(data_path)
else:
raise ValueError(f"Unsupported format: {data_format}")
return data
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
"""Evaluate the retriever model."""
logger.info("Evaluating retriever model...")
evaluator = RetrievalEvaluator(
model=retriever_model,
tokenizer=tokenizer,
device=args.device,
k_values=args.k_values
)
# Prepare data for evaluation
queries = [item['query'] for item in eval_data]
# If we have passage data, create passage corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
# Create relevance mapping
relevant_docs = {}
for i, item in enumerate(eval_data):
query = item['query']
if 'pos' in item:
relevant_indices = []
for pos in item['pos']:
if pos in passages:
relevant_indices.append(passages.index(pos))
if relevant_indices:
relevant_docs[query] = relevant_indices
# Evaluate
if passages and relevant_docs:
metrics = evaluator.evaluate_retrieval_pipeline(
queries=queries,
corpus=passages,
relevant_docs=relevant_docs,
batch_size=args.batch_size
)
else:
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
# Create dummy evaluation
metrics = {"note": "Limited evaluation due to data format"}
return metrics
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
"""Evaluate the reranker model."""
logger.info("Evaluating reranker model...")
evaluator = RetrievalEvaluator(
model=reranker_model,
tokenizer=tokenizer,
device=args.device,
k_values=args.k_values
)
# Prepare data for reranker evaluation
total_queries = 0
total_correct = 0
total_mrr = 0
for item in eval_data:
if 'pos' not in item or 'neg' not in item:
continue
query = item['query']
positives = item['pos']
negatives = item['neg']
if not positives or not negatives:
continue
# Create candidates (1 positive + negatives)
candidates = positives[:1] + negatives
# Score all candidates
pairs = [(query, candidate) for candidate in candidates]
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
# Check if positive is ranked first
best_idx = np.argmax(scores)
if best_idx == 0: # First is positive
total_correct += 1
total_mrr += 1.0
else:
# Find rank of positive (it's always at index 0)
sorted_indices = np.argsort(scores)[::-1]
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
total_mrr += 1.0 / positive_rank
total_queries += 1
if total_queries > 0:
metrics = {
'accuracy@1': total_correct / total_queries,
'mrr': total_mrr / total_queries,
'total_queries': total_queries
}
else:
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
return metrics
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
"""Evaluate the full retrieval + reranking pipeline."""
logger.info("Evaluating full pipeline...")
# Prepare corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
if not passages:
return {"note": "No passages found for pipeline evaluation"}
# Pre-encode corpus with retriever
logger.info("Encoding corpus...")
corpus_embeddings = retriever_model.encode(
passages,
batch_size=args.batch_size,
return_numpy=True,
device=args.device
)
total_queries = 0
pipeline_metrics = {
'retrieval_recall@10': [],
'rerank_accuracy@1': [],
'rerank_mrr': []
}
for item in eval_data:
if 'pos' not in item or not item['pos']:
continue
query = item['query']
relevant_passages = item['pos']
# Step 1: Retrieval
query_embedding = retriever_model.encode(
[query],
return_numpy=True,
device=args.device
)
# Compute similarities
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
# Get top-k for retrieval
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
retrieved_passages = [passages[i] for i in top_k_indices]
# Check retrieval recall
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
# Step 2: Reranking (if reranker available)
if reranker_model is not None:
# Rerank top candidates
rerank_candidates = retrieved_passages[:args.rerank_top_k]
pairs = [(query, candidate) for candidate in rerank_candidates]
if pairs:
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
reranked_indices = np.argsort(rerank_scores)[::-1]
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
# Check reranking performance
if reranked_passages[0] in relevant_passages:
pipeline_metrics['rerank_accuracy@1'].append(1.0)
pipeline_metrics['rerank_mrr'].append(1.0)
else:
pipeline_metrics['rerank_accuracy@1'].append(0.0)
# Find MRR
for rank, passage in enumerate(reranked_passages, 1):
if passage in relevant_passages:
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
break
else:
pipeline_metrics['rerank_mrr'].append(0.0)
total_queries += 1
# Average metrics
final_metrics = {}
for metric, values in pipeline_metrics.items():
if values:
final_metrics[metric] = np.mean(values)
final_metrics['total_queries'] = total_queries
return final_metrics
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
"""Run interactive evaluation."""
logger.info("Starting interactive evaluation...")
# Prepare corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
if not passages:
print("No passages found for interactive evaluation")
return
# Create interactive evaluator
evaluator = InteractiveEvaluator(
retriever=retriever_model,
reranker=reranker_model,
tokenizer=tokenizer,
corpus=passages,
device=args.device
)
print("\n=== Interactive Evaluation ===")
print("Enter queries to test the retrieval system.")
print("Type 'quit' to exit, 'sample' to test with sample queries.")
# Load some sample queries
sample_queries = [item['query'] for item in eval_data[:5]]
while True:
try:
user_input = input("\nEnter query: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
break
elif user_input.lower() == 'sample':
print("\nSample queries:")
for i, query in enumerate(sample_queries):
print(f"{i + 1}. {query}")
continue
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
query = sample_queries[int(user_input) - 1]
else:
query = user_input
if not query:
continue
# Search
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
print(f"\nResults for: {query}")
print("-" * 50)
for i, (idx, score, text) in enumerate(results):
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
except KeyboardInterrupt:
print("\nExiting interactive evaluation...")
break
except Exception as e:
print(f"Error: {e}")
def main():
"""Run evaluation with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
# Model arguments
parser.add_argument("--retriever_model", type=str, required=True,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
help="Base retriever model for comparison")
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
help="Base reranker model for comparison")
# Data arguments
parser.add_argument("--eval_data", type=str, required=True,
help="Path to evaluation data file")
parser.add_argument("--corpus_data", type=str, default=None,
help="Path to corpus data for retrieval evaluation")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Evaluation arguments
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
help="Directory to save evaluation results")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
help="K values for evaluation metrics")
# Evaluation modes
parser.add_argument("--eval_retriever", action="store_true",
help="Evaluate retriever model")
parser.add_argument("--eval_reranker", action="store_true",
help="Evaluate reranker model")
parser.add_argument("--eval_pipeline", action="store_true",
help="Evaluate full retrieval pipeline")
parser.add_argument("--compare_baseline", action="store_true",
help="Compare with baseline models")
parser.add_argument("--interactive", action="store_true",
help="Run interactive evaluation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for initial retrieval")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("Starting model evaluation")
logger.info(f"Arguments: {args}")
logger.info(f"Using device: {device}")
# Set device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "npu":
npu = getattr(torch, "npu", None)
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
device = "npu"
else:
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
device = "cpu"
else:
device = device
logger.info(f"Using device: {device}")
# Load models
logger.info("Loading models...")
from utils.config_loader import get_config
model_source = get_config().get('model_paths.source', 'huggingface')
# Load retriever
try:
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
retriever_model.to(device)
retriever_model.eval()
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
except Exception as e:
logger.error(f"Failed to load retriever: {e}")
return
# Load reranker (optional)
reranker_model = None
reranker_tokenizer = None
if args.reranker_model:
try:
reranker_model = BGERerankerModel(args.reranker_model)
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
reranker_model.to(device)
reranker_model.eval()
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
except Exception as e:
logger.warning(f"Failed to load reranker: {e}")
# Load evaluation data
logger.info("Loading evaluation data...")
eval_data = load_evaluation_data(args.eval_data, args.data_format)
logger.info(f"Loaded {len(eval_data)} evaluation examples")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Run evaluations
all_results = {}
if args.eval_retriever:
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
all_results['retriever'] = retriever_metrics
logger.info(f"Retriever metrics: {retriever_metrics}")
if args.eval_reranker and reranker_model:
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
all_results['reranker'] = reranker_metrics
logger.info(f"Reranker metrics: {reranker_metrics}")
if args.eval_pipeline:
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
all_results['pipeline'] = pipeline_metrics
logger.info(f"Pipeline metrics: {pipeline_metrics}")
# Compare with baseline if requested
if args.compare_baseline:
logger.info("Loading baseline models for comparison...")
try:
base_retriever = BGEM3Model(args.base_retriever)
base_retriever.to(device)
base_retriever.eval()
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
all_results['baseline_retriever'] = base_retriever_metrics
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
if reranker_model:
base_reranker = BGERerankerModel(args.base_reranker)
base_reranker.to(device)
base_reranker.eval()
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
all_results['baseline_reranker'] = base_reranker_metrics
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
except Exception as e:
logger.warning(f"Failed to load baseline models: {e}")
# Save results
results_file = os.path.join(args.output_dir, "evaluation_results.json")
with open(results_file, 'w') as f:
json.dump(all_results, f, indent=2)
logger.info(f"Evaluation results saved to {results_file}")
# Print summary
print("\n" + "=" * 50)
print("EVALUATION SUMMARY")
print("=" * 50)
for model_type, metrics in all_results.items():
print(f"\n{model_type.upper()}:")
for metric, value in metrics.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.4f}")
else:
print(f" {metric}: {value}")
# Interactive evaluation
if args.interactive:
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
if __name__ == "__main__":
main()

624
scripts/quick_validation.py Normal file
View File

@@ -0,0 +1,624 @@
#!/usr/bin/env python3
"""
Quick Validation Script for BGE Fine-tuned Models
This script provides fast validation to quickly check if fine-tuned models
are working correctly and performing better than baselines.
Usage:
# Quick test both models
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Test only retriever
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model
"""
import os
import sys
import json
import argparse
import logging
import torch
import numpy as np
from transformers import AutoTokenizer
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class QuickValidator:
"""Quick validation for BGE models"""
def __init__(self):
self.device = self._get_device()
self.test_queries = self._get_test_queries()
def _get_device(self):
"""Get optimal device"""
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
return torch.device("npu:0")
except ImportError:
pass
if torch.cuda.is_available():
return torch.device("cuda")
except:
pass
return torch.device("cpu")
def _get_test_queries(self):
"""Get standard test queries for validation"""
return [
{
"query": "什么是深度学习?",
"relevant": [
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
],
"irrelevant": [
"今天天气很好,阳光明媚",
"我喜欢吃苹果和香蕉"
]
},
{
"query": "如何制作红烧肉?",
"relevant": [
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
"制作红烧肉的关键是掌握火候和糖色的调制"
],
"irrelevant": [
"Python是一种编程语言",
"机器学习需要大量的数据"
]
},
{
"query": "Python编程入门",
"relevant": [
"Python是一种易学易用的编程语言适合初学者入门",
"学习Python需要掌握基本语法、数据类型和控制结构"
],
"irrelevant": [
"红烧肉是一道经典的中式菜肴",
"深度学习使用神经网络进行训练"
]
}
]
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
"""Quick validation of retriever model"""
print(f"\n🔍 Validating Retriever Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned model...")
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline model...")
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Retriever Validation Summary:")
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
if comparison['relevance_improvement'] > 0:
print(" ✅ Fine-tuned model shows improvement!")
else:
print(" ❌ Fine-tuned model shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['relevance_improvement']
}
return results
except Exception as e:
logger.error(f"Retriever validation failed: {e}")
return None
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
"""Quick validation of reranker model"""
print(f"\n🏆 Validating Reranker Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned reranker...")
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline reranker...")
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Reranker Validation Summary:")
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
if comparison['accuracy_improvement'] > 0:
print(" ✅ Fine-tuned reranker shows improvement!")
else:
print(" ❌ Fine-tuned reranker shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['accuracy_improvement']
}
return results
except Exception as e:
logger.error(f"Reranker validation failed: {e}")
return None
def _test_retriever_model(self, model_path, model_type):
"""Test retriever model with predefined queries"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} model loaded successfully")
relevant_scores = []
irrelevant_scores = []
ranking_accuracy = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Encode query
query_inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
query_outputs = model(**query_inputs, return_dense=True)
query_emb = query_outputs['dense']
# Test relevant documents
for doc in relevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
# Compute similarity
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
relevant_scores.append(similarity)
# Test irrelevant documents
for doc in irrelevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
irrelevant_scores.append(similarity)
# Check ranking accuracy (relevant should score higher than irrelevant)
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in relevant_docs])
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in irrelevant_docs])
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
return {
'avg_relevant_score': np.mean(relevant_scores),
'avg_irrelevant_score': np.mean(irrelevant_scores),
'ranking_accuracy': np.mean(ranking_accuracy),
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
}
except Exception as e:
logger.error(f"Error testing {model_type} retriever: {e}")
return None
def _test_reranker_model(self, model_path, model_type):
"""Test reranker model with predefined query-document pairs"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} reranker loaded successfully")
correct_rankings = 0
total_pairs = 0
all_scores = []
all_labels = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Test relevant documents (positive examples)
for doc in relevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(1) # Relevant
# Check if score indicates relevance (positive score)
if score > 0:
correct_rankings += 1
total_pairs += 1
# Test irrelevant documents (negative examples)
for doc in irrelevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(0) # Irrelevant
# Check if score indicates irrelevance (negative score)
if score <= 0:
correct_rankings += 1
total_pairs += 1
return {
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
'total_pairs': total_pairs
}
except Exception as e:
logger.error(f"Error testing {model_type} reranker: {e}")
return None
def _compare_retriever_results(self, baseline, finetuned):
"""Compare retriever results"""
if not baseline or not finetuned:
return {}
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
return {
'relevance_improvement': relevance_improvement,
'separation_improvement': separation_improvement,
'ranking_improvement': ranking_improvement
}
def _compare_reranker_results(self, baseline, finetuned):
"""Compare reranker results"""
if not baseline or not finetuned:
return {}
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
return {
'accuracy_improvement': accuracy_improvement
}
def validate_pipeline(self, retriever_path, reranker_path):
"""Quick validation of full pipeline"""
print(f"\n🚀 Validating Full Pipeline")
print("=" * 50)
try:
# Load models
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
retriever = BGEM3Model.from_pretrained(retriever_path)
retriever.to(self.device)
retriever.eval()
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
reranker = BGERerankerModel.from_pretrained(reranker_path)
reranker.to(self.device)
reranker.eval()
print(" ✅ Pipeline models loaded successfully")
total_queries = 0
pipeline_accuracy = 0
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Create candidate pool
all_docs = relevant_docs + irrelevant_docs
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
# Step 1: Retrieval
query_inputs = retriever_tokenizer(
query, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
query_emb = retriever(**query_inputs, return_dense=True)['dense']
# Score all documents with retriever
doc_scores = []
for doc in all_docs:
doc_inputs = retriever_tokenizer(
doc, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
doc_scores.append(score)
# Get top-k for reranking (all docs in this case)
retrieval_indices = sorted(range(len(doc_scores)),
key=lambda i: doc_scores[i], reverse=True)
# Step 2: Reranking
rerank_scores = []
rerank_labels = []
for idx in retrieval_indices:
doc = all_docs[idx]
label = doc_labels[idx]
pair_text = f"{query} [SEP] {doc}"
pair_inputs = reranker_tokenizer(
pair_text, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
rerank_output = reranker(**pair_inputs)
rerank_score = rerank_output.logits.item()
rerank_scores.append(rerank_score)
rerank_labels.append(label)
# Check if top-ranked document is relevant
best_idx = np.argmax(rerank_scores)
if rerank_labels[best_idx] == 1:
pipeline_accuracy += 1
total_queries += 1
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
print(f"\n📊 Pipeline Validation Summary:")
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
print(f" Queries Tested: {total_queries}")
if pipeline_acc > 0.5:
print(" ✅ Pipeline performing well!")
else:
print(" ❌ Pipeline needs improvement!")
return {
'accuracy': pipeline_acc,
'total_queries': total_queries,
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
}
except Exception as e:
logger.error(f"Pipeline validation failed: {e}")
return None
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Quick validation of BGE fine-tuned models",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--retriever_model", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Path to baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Path to baseline reranker model")
parser.add_argument("--test_pipeline", action="store_true",
help="Test full retrieval + reranking pipeline")
parser.add_argument("--output_file", type=str, default=None,
help="Save results to JSON file")
return parser.parse_args()
def main():
"""Main validation function"""
args = parse_args()
if not args.retriever_model and not args.reranker_model:
print("❌ Error: At least one model must be specified")
print(" Use --retriever_model or --reranker_model")
return 1
print("🚀 BGE Quick Validation")
print("=" * 60)
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
validator = QuickValidator()
results = {}
# Validate retriever
if args.retriever_model:
if os.path.exists(args.retriever_model):
retriever_results = validator.validate_retriever(
args.retriever_model, args.retriever_baseline
)
if retriever_results:
results['retriever'] = retriever_results
else:
print(f"❌ Retriever model not found: {args.retriever_model}")
# Validate reranker
if args.reranker_model:
if os.path.exists(args.reranker_model):
reranker_results = validator.validate_reranker(
args.reranker_model, args.reranker_baseline
)
if reranker_results:
results['reranker'] = reranker_results
else:
print(f"❌ Reranker model not found: {args.reranker_model}")
# Validate pipeline
if (args.test_pipeline and args.retriever_model and args.reranker_model and
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
pipeline_results = validator.validate_pipeline(
args.retriever_model, args.reranker_model
)
if pipeline_results:
results['pipeline'] = pipeline_results
# Print final summary
print(f"\n{'='*60}")
print("🎯 FINAL VALIDATION SUMMARY")
print(f"{'='*60}")
overall_status = "✅ SUCCESS"
issues = []
for model_type, model_results in results.items():
if 'summary' in model_results:
status = model_results['summary'].get('status', 'unknown')
improvement = model_results['summary'].get('improvement_pct', 0)
if status == 'improved':
print(f"{model_type.title()}: Improved by {improvement:.2f}%")
elif status == 'degraded':
print(f"{model_type.title()}: Degraded by {abs(improvement):.2f}%")
issues.append(model_type)
overall_status = "⚠️ ISSUES DETECTED"
elif model_type == 'pipeline' and 'status' in model_results:
if model_results['status'] == 'good':
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
else:
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
issues.append('pipeline')
overall_status = "⚠️ ISSUES DETECTED"
print(f"\n{overall_status}")
if issues:
print(f"⚠️ Issues detected in: {', '.join(issues)}")
print("💡 Consider running comprehensive validation for detailed analysis")
# Save results if requested
if args.output_file:
results['timestamp'] = datetime.now().isoformat()
results['overall_status'] = overall_status
results['issues'] = issues
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"📊 Results saved to: {args.output_file}")
return 0 if overall_status == "✅ SUCCESS" else 1
if __name__ == "__main__":
exit(main())

View File

@@ -0,0 +1,696 @@
#!/usr/bin/env python3
"""
BGE Validation Suite Runner
This script runs a complete validation suite for BGE fine-tuned models,
including multiple validation approaches and generating comprehensive reports.
Usage:
# Auto-discover models and run full suite
python scripts/run_validation_suite.py --auto-discover
# Run with specific models
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Run only specific validation types
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--validation_types quick comprehensive comparison
"""
import os
import sys
import json
import argparse
import subprocess
import logging
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional, Any
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from scripts.validation_utils import ModelDiscovery, TestDataManager, ValidationReportAnalyzer
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class ValidationSuiteRunner:
"""Orchestrates comprehensive validation of BGE models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.results_dir = Path(config.get('output_dir', './validation_suite_results'))
self.results_dir.mkdir(parents=True, exist_ok=True)
# Initialize components
self.model_discovery = ModelDiscovery(config.get('workspace_root', '.'))
self.data_manager = TestDataManager(config.get('data_dir', 'data/datasets'))
# Suite results
self.suite_results = {
'start_time': datetime.now().isoformat(),
'config': config,
'validation_results': {},
'summary': {},
'issues': [],
'recommendations': []
}
def run_suite(self) -> Dict[str, Any]:
"""Run the complete validation suite"""
logger.info("🚀 Starting BGE Validation Suite")
logger.info("=" * 60)
try:
# Phase 1: Discovery and Setup
self._phase_discovery()
# Phase 2: Quick Validation
if 'quick' in self.config.get('validation_types', ['quick']):
self._phase_quick_validation()
# Phase 3: Comprehensive Validation
if 'comprehensive' in self.config.get('validation_types', ['comprehensive']):
self._phase_comprehensive_validation()
# Phase 4: Comparison Benchmarks
if 'comparison' in self.config.get('validation_types', ['comparison']):
self._phase_comparison_benchmarks()
# Phase 5: Analysis and Reporting
self._phase_analysis()
# Phase 6: Generate Final Report
self._generate_final_report()
self.suite_results['end_time'] = datetime.now().isoformat()
self.suite_results['status'] = 'completed'
logger.info("✅ Validation Suite Completed Successfully!")
except Exception as e:
logger.error(f"❌ Validation Suite Failed: {e}")
self.suite_results['status'] = 'failed'
self.suite_results['error'] = str(e)
raise
return self.suite_results
def _phase_discovery(self):
"""Phase 1: Discover models and datasets"""
logger.info("📡 Phase 1: Discovery and Setup")
# Auto-discover models if requested
if self.config.get('auto_discover', False):
discovered_models = self.model_discovery.find_trained_models()
# Set models from discovery if not provided
if not self.config.get('retriever_model'):
if discovered_models['retrievers']:
model_name = list(discovered_models['retrievers'].keys())[0]
self.config['retriever_model'] = discovered_models['retrievers'][model_name]['path']
logger.info(f" 📊 Auto-discovered retriever: {model_name}")
if not self.config.get('reranker_model'):
if discovered_models['rerankers']:
model_name = list(discovered_models['rerankers'].keys())[0]
self.config['reranker_model'] = discovered_models['rerankers'][model_name]['path']
logger.info(f" 🏆 Auto-discovered reranker: {model_name}")
self.suite_results['discovered_models'] = discovered_models
# Discover test datasets
discovered_datasets = self.data_manager.find_test_datasets()
self.suite_results['available_datasets'] = discovered_datasets
# Validate models exist
issues = []
if self.config.get('retriever_model'):
if not os.path.exists(self.config['retriever_model']):
issues.append(f"Retriever model not found: {self.config['retriever_model']}")
if self.config.get('reranker_model'):
if not os.path.exists(self.config['reranker_model']):
issues.append(f"Reranker model not found: {self.config['reranker_model']}")
if issues:
for issue in issues:
logger.error(f"{issue}")
self.suite_results['issues'].extend(issues)
return
logger.info(" ✅ Discovery phase completed")
def _phase_quick_validation(self):
"""Phase 2: Quick validation tests"""
logger.info("⚡ Phase 2: Quick Validation")
try:
# Run quick validation script
cmd = ['python', 'scripts/quick_validation.py']
if self.config.get('retriever_model'):
cmd.extend(['--retriever_model', self.config['retriever_model']])
if self.config.get('reranker_model'):
cmd.extend(['--reranker_model', self.config['reranker_model']])
if self.config.get('retriever_model') and self.config.get('reranker_model'):
cmd.append('--test_pipeline')
# Save results to file
quick_results_file = self.results_dir / 'quick_validation_results.json'
cmd.extend(['--output_file', str(quick_results_file)])
logger.info(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(" ✅ Quick validation completed")
# Load results
if quick_results_file.exists():
with open(quick_results_file, 'r') as f:
quick_results = json.load(f)
self.suite_results['validation_results']['quick'] = quick_results
else:
logger.error(f" ❌ Quick validation failed: {result.stderr}")
self.suite_results['issues'].append(f"Quick validation failed: {result.stderr}")
except Exception as e:
logger.error(f" ❌ Error in quick validation: {e}")
self.suite_results['issues'].append(f"Quick validation error: {e}")
def _phase_comprehensive_validation(self):
"""Phase 3: Comprehensive validation"""
logger.info("🔬 Phase 3: Comprehensive Validation")
try:
# Run comprehensive validation script
cmd = ['python', 'scripts/comprehensive_validation.py']
if self.config.get('retriever_model'):
cmd.extend(['--retriever_finetuned', self.config['retriever_model']])
if self.config.get('reranker_model'):
cmd.extend(['--reranker_finetuned', self.config['reranker_model']])
# Use specific datasets if provided
if self.config.get('test_datasets'):
cmd.extend(['--test_datasets'] + self.config['test_datasets'])
# Set output directory
comprehensive_dir = self.results_dir / 'comprehensive'
cmd.extend(['--output_dir', str(comprehensive_dir)])
logger.info(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(" ✅ Comprehensive validation completed")
# Load results
results_file = comprehensive_dir / 'validation_results.json'
if results_file.exists():
with open(results_file, 'r') as f:
comp_results = json.load(f)
self.suite_results['validation_results']['comprehensive'] = comp_results
else:
logger.error(f" ❌ Comprehensive validation failed: {result.stderr}")
self.suite_results['issues'].append(f"Comprehensive validation failed: {result.stderr}")
except Exception as e:
logger.error(f" ❌ Error in comprehensive validation: {e}")
self.suite_results['issues'].append(f"Comprehensive validation error: {e}")
def _phase_comparison_benchmarks(self):
"""Phase 4: Run comparison benchmarks"""
logger.info("📊 Phase 4: Comparison Benchmarks")
# Find suitable datasets for comparison
datasets = self.data_manager.find_test_datasets()
try:
# Run retriever comparison if model exists
if self.config.get('retriever_model'):
self._run_retriever_comparison(datasets)
# Run reranker comparison if model exists
if self.config.get('reranker_model'):
self._run_reranker_comparison(datasets)
except Exception as e:
logger.error(f" ❌ Error in comparison benchmarks: {e}")
self.suite_results['issues'].append(f"Comparison benchmarks error: {e}")
def _run_retriever_comparison(self, datasets: Dict[str, List[str]]):
"""Run retriever comparison benchmarks"""
logger.info(" 🔍 Running retriever comparison...")
# Find suitable datasets for retriever testing
test_datasets = datasets.get('retriever_datasets', []) + datasets.get('general_datasets', [])
if not test_datasets:
logger.warning(" ⚠️ No suitable datasets found for retriever comparison")
return
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
if not os.path.exists(dataset_path):
continue
try:
cmd = [
'python', 'scripts/compare_retriever.py',
'--finetuned_model_path', self.config['retriever_model'],
'--baseline_model_path', self.config.get('retriever_baseline', 'BAAI/bge-m3'),
'--data_path', dataset_path,
'--batch_size', str(self.config.get('batch_size', 16)),
'--max_samples', str(self.config.get('max_samples', 500)),
'--output', str(self.results_dir / f'retriever_comparison_{Path(dataset_path).stem}.txt')
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(f" ✅ Retriever comparison on {Path(dataset_path).name}")
else:
logger.warning(f" ⚠️ Retriever comparison failed on {Path(dataset_path).name}")
except Exception as e:
logger.warning(f" ⚠️ Error comparing retriever on {dataset_path}: {e}")
def _run_reranker_comparison(self, datasets: Dict[str, List[str]]):
"""Run reranker comparison benchmarks"""
logger.info(" 🏆 Running reranker comparison...")
# Find suitable datasets for reranker testing
test_datasets = datasets.get('reranker_datasets', []) + datasets.get('general_datasets', [])
if not test_datasets:
logger.warning(" ⚠️ No suitable datasets found for reranker comparison")
return
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
if not os.path.exists(dataset_path):
continue
try:
cmd = [
'python', 'scripts/compare_reranker.py',
'--finetuned_model_path', self.config['reranker_model'],
'--baseline_model_path', self.config.get('reranker_baseline', 'BAAI/bge-reranker-base'),
'--data_path', dataset_path,
'--batch_size', str(self.config.get('batch_size', 16)),
'--max_samples', str(self.config.get('max_samples', 500)),
'--output', str(self.results_dir / f'reranker_comparison_{Path(dataset_path).stem}.txt')
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(f" ✅ Reranker comparison on {Path(dataset_path).name}")
else:
logger.warning(f" ⚠️ Reranker comparison failed on {Path(dataset_path).name}")
except Exception as e:
logger.warning(f" ⚠️ Error comparing reranker on {dataset_path}: {e}")
def _phase_analysis(self):
"""Phase 5: Analyze all results"""
logger.info("📈 Phase 5: Results Analysis")
try:
# Analyze comprehensive results if available
comprehensive_dir = self.results_dir / 'comprehensive'
if comprehensive_dir.exists():
analyzer = ValidationReportAnalyzer(str(comprehensive_dir))
analysis = analyzer.analyze_results()
self.suite_results['analysis'] = analysis
logger.info(" ✅ Results analysis completed")
else:
logger.warning(" ⚠️ No comprehensive results found for analysis")
except Exception as e:
logger.error(f" ❌ Error in results analysis: {e}")
self.suite_results['issues'].append(f"Analysis error: {e}")
def _generate_final_report(self):
"""Generate final comprehensive report"""
logger.info("📝 Phase 6: Generating Final Report")
try:
# Create summary
self._create_suite_summary()
# Save suite results
suite_results_file = self.results_dir / 'validation_suite_results.json'
with open(suite_results_file, 'w', encoding='utf-8') as f:
json.dump(self.suite_results, f, indent=2, ensure_ascii=False)
# Generate HTML report
self._generate_html_report()
logger.info(f" 📊 Reports saved to: {self.results_dir}")
except Exception as e:
logger.error(f" ❌ Error generating final report: {e}")
self.suite_results['issues'].append(f"Report generation error: {e}")
def _create_suite_summary(self):
"""Create suite summary"""
summary = {
'total_validations': 0,
'successful_validations': 0,
'failed_validations': 0,
'overall_verdict': 'unknown',
'key_findings': [],
'critical_issues': []
}
# Count validations
for validation_type, results in self.suite_results.get('validation_results', {}).items():
summary['total_validations'] += 1
if results:
summary['successful_validations'] += 1
else:
summary['failed_validations'] += 1
# Determine overall verdict
if summary['successful_validations'] > 0 and summary['failed_validations'] == 0:
if self.suite_results.get('analysis', {}).get('overall_status') == 'excellent':
summary['overall_verdict'] = 'excellent'
elif self.suite_results.get('analysis', {}).get('overall_status') in ['good', 'mixed']:
summary['overall_verdict'] = 'good'
else:
summary['overall_verdict'] = 'fair'
elif summary['successful_validations'] > summary['failed_validations']:
summary['overall_verdict'] = 'partial'
else:
summary['overall_verdict'] = 'poor'
# Extract key findings from analysis
if 'analysis' in self.suite_results:
analysis = self.suite_results['analysis']
if 'recommendations' in analysis:
summary['key_findings'] = analysis['recommendations'][:5] # Top 5 recommendations
# Extract critical issues
summary['critical_issues'] = self.suite_results.get('issues', [])
self.suite_results['summary'] = summary
def _generate_html_report(self):
"""Generate HTML summary report"""
html_file = self.results_dir / 'validation_suite_report.html'
summary = self.suite_results.get('summary', {})
verdict_colors = {
'excellent': '#28a745',
'good': '#17a2b8',
'fair': '#ffc107',
'partial': '#fd7e14',
'poor': '#dc3545',
'unknown': '#6c757d'
}
verdict_color = verdict_colors.get(summary.get('overall_verdict', 'unknown'), '#6c757d')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>BGE Validation Suite Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; }}
.header {{ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white; padding: 30px; border-radius: 10px; text-align: center; }}
.verdict {{ font-size: 24px; font-weight: bold; color: {verdict_color};
margin: 20px 0; text-align: center; padding: 15px;
border: 3px solid {verdict_color}; border-radius: 10px; }}
.section {{ margin: 30px 0; }}
.metrics {{ display: flex; justify-content: space-around; margin: 20px 0; }}
.metric {{ text-align: center; padding: 15px; background: #f8f9fa;
border-radius: 8px; min-width: 120px; }}
.metric-value {{ font-size: 2em; font-weight: bold; color: #495057; }}
.metric-label {{ color: #6c757d; }}
.findings {{ background: #e3f2fd; padding: 20px; border-radius: 8px;
border-left: 4px solid #2196f3; }}
.issues {{ background: #ffebee; padding: 20px; border-radius: 8px;
border-left: 4px solid #f44336; }}
.timeline {{ margin: 20px 0; }}
.timeline-item {{ margin: 10px 0; padding: 10px; background: #f8f9fa;
border-left: 3px solid #007bff; border-radius: 0 5px 5px 0; }}
ul {{ padding-left: 20px; }}
li {{ margin: 8px 0; }}
</style>
</head>
<body>
<div class="header">
<h1>🚀 BGE Validation Suite Report</h1>
<p>Comprehensive Model Performance Analysis</p>
<p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
<div class="verdict">
Overall Verdict: {summary.get('overall_verdict', 'UNKNOWN').upper()}
</div>
<div class="metrics">
<div class="metric">
<div class="metric-value">{summary.get('total_validations', 0)}</div>
<div class="metric-label">Total Validations</div>
</div>
<div class="metric">
<div class="metric-value">{summary.get('successful_validations', 0)}</div>
<div class="metric-label">Successful</div>
</div>
<div class="metric">
<div class="metric-value">{summary.get('failed_validations', 0)}</div>
<div class="metric-label">Failed</div>
</div>
</div>
"""
# Add key findings
if summary.get('key_findings'):
html_content += """
<div class="section">
<h2>🔍 Key Findings</h2>
<div class="findings">
<ul>
"""
for finding in summary['key_findings']:
html_content += f" <li>{finding}</li>\n"
html_content += """
</ul>
</div>
</div>
"""
# Add critical issues if any
if summary.get('critical_issues'):
html_content += """
<div class="section">
<h2>⚠️ Critical Issues</h2>
<div class="issues">
<ul>
"""
for issue in summary['critical_issues']:
html_content += f" <li>{issue}</li>\n"
html_content += """
</ul>
</div>
</div>
"""
# Add validation timeline
html_content += f"""
<div class="section">
<h2>⏱️ Validation Timeline</h2>
<div class="timeline">
<div class="timeline-item">
<strong>Suite Started:</strong> {self.suite_results.get('start_time', 'Unknown')}
</div>
"""
for validation_type in self.suite_results.get('validation_results', {}).keys():
html_content += f"""
<div class="timeline-item">
<strong>{validation_type.title()} Validation:</strong> Completed
</div>
"""
html_content += f"""
<div class="timeline-item">
<strong>Suite Completed:</strong> {self.suite_results.get('end_time', 'In Progress')}
</div>
</div>
</div>
<div class="section">
<h2>📁 Detailed Results</h2>
<p>For detailed validation results, check the following files in <code>{self.results_dir}</code>:</p>
<ul>
<li><code>validation_suite_results.json</code> - Complete suite results</li>
<li><code>quick_validation_results.json</code> - Quick validation results</li>
<li><code>comprehensive/</code> - Comprehensive validation reports</li>
<li><code>*_comparison_*.txt</code> - Model comparison results</li>
</ul>
</div>
</body>
</html>
"""
with open(html_file, 'w', encoding='utf-8') as f:
f.write(html_content)
logger.info(f" 📊 HTML report: {html_file}")
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Run comprehensive BGE validation suite",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Auto-discover models and run full suite
python scripts/run_validation_suite.py --auto-discover
# Run with specific models
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Run only specific validation types
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--validation_types quick comprehensive
"""
)
# Model paths
parser.add_argument("--retriever_model", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--auto_discover", action="store_true",
help="Auto-discover trained models in workspace")
# Baseline models
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Baseline reranker model")
# Validation configuration
parser.add_argument("--validation_types", type=str, nargs="+",
default=["quick", "comprehensive", "comparison"],
choices=["quick", "comprehensive", "comparison"],
help="Types of validation to run")
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
help="Specific test datasets to use")
# Performance settings
parser.add_argument("--batch_size", type=int, default=16,
help="Batch size for evaluation")
parser.add_argument("--max_samples", type=int, default=1000,
help="Maximum samples per dataset for speed")
# Directories
parser.add_argument("--workspace_root", type=str, default=".",
help="Workspace root directory")
parser.add_argument("--data_dir", type=str, default="data/datasets",
help="Test datasets directory")
parser.add_argument("--output_dir", type=str, default="./validation_suite_results",
help="Output directory for all results")
return parser.parse_args()
def main():
"""Main function"""
args = parse_args()
# Validate input
if not args.auto_discover and not args.retriever_model and not args.reranker_model:
print("❌ Error: Must specify models or use --auto_discover")
print(" Use --retriever_model, --reranker_model, or --auto_discover")
return 1
# Create configuration
config = {
'retriever_model': args.retriever_model,
'reranker_model': args.reranker_model,
'auto_discover': args.auto_discover,
'retriever_baseline': args.retriever_baseline,
'reranker_baseline': args.reranker_baseline,
'validation_types': args.validation_types,
'test_datasets': args.test_datasets,
'batch_size': args.batch_size,
'max_samples': args.max_samples,
'workspace_root': args.workspace_root,
'data_dir': args.data_dir,
'output_dir': args.output_dir
}
# Run validation suite
try:
runner = ValidationSuiteRunner(config)
results = runner.run_suite()
# Print final summary
summary = results.get('summary', {})
verdict = summary.get('overall_verdict', 'unknown')
print("\n" + "=" * 80)
print("🎯 VALIDATION SUITE SUMMARY")
print("=" * 80)
verdict_emojis = {
'excellent': '🌟',
'good': '',
'fair': '👌',
'partial': '⚠️',
'poor': '',
'unknown': ''
}
print(f"{verdict_emojis.get(verdict, '')} Overall Verdict: {verdict.upper()}")
print(f"📊 Validations: {summary.get('successful_validations', 0)}/{summary.get('total_validations', 0)} successful")
if summary.get('key_findings'):
print(f"\n🔍 Key Findings:")
for i, finding in enumerate(summary['key_findings'][:3], 1):
print(f" {i}. {finding}")
if summary.get('critical_issues'):
print(f"\n⚠️ Critical Issues:")
for issue in summary['critical_issues'][:3]:
print(f"{issue}")
print(f"\n📁 Detailed results: {config['output_dir']}")
print("🌐 Open validation_suite_report.html in your browser for full report")
return 0 if verdict in ['excellent', 'good'] else 1
except KeyboardInterrupt:
print("\n❌ Validation suite interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Validation suite failed: {e}")
return 1
if __name__ == "__main__":
exit(main())

401
scripts/setup.py Normal file
View File

@@ -0,0 +1,401 @@
"""
Setup script for BGE fine-tuning environment
"""
import os
import sys
import subprocess
import argparse
import shutil
import urllib.request
import json
from pathlib import Path
import logging
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def run_command(command, check=True):
"""Run a shell command"""
logger.info(f"Running: {command}")
result = subprocess.run(command, shell=True, capture_output=True, text=True)
if check and result.returncode != 0:
logger.error(f"Command failed: {command}")
logger.error(f"Error: {result.stderr}")
sys.exit(1)
return result
def check_python_version():
"""Check Python version"""
if sys.version_info < (3, 8):
logger.error("Python 3.8 or higher is required")
sys.exit(1)
logger.info(f"Python version: {sys.version}")
def check_gpu():
"""Check GPU availability"""
try:
result = run_command("nvidia-smi", check=False)
if result.returncode == 0:
logger.info("NVIDIA GPU detected")
return "cuda"
else:
logger.info("No NVIDIA GPU detected")
return "cpu"
except:
logger.info("nvidia-smi not available")
return "cpu"
def install_pytorch(device_type="cuda"):
"""Install PyTorch"""
logger.info("Installing PyTorch...")
if device_type == "cuda":
# Install CUDA version for RTX 5090
pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
else:
# Install CPU version
pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
run_command(pip_command)
# Verify installation
try:
import torch
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
except ImportError:
logger.error("Failed to import PyTorch")
sys.exit(1)
def install_requirements():
"""Install requirements"""
logger.info("Installing requirements...")
requirements_file = Path(__file__).parent.parent / "requirements.txt"
if requirements_file.exists():
run_command(f"pip install -r {requirements_file}")
else:
logger.warning("requirements.txt not found, installing core packages...")
# Core packages
packages = [
"transformers>=4.36.0",
"datasets>=2.14.0",
"tokenizers>=0.15.0",
"numpy>=1.24.0",
"scipy>=1.10.0",
"scikit-learn>=1.3.0",
"pandas>=2.0.0",
"FlagEmbedding>=1.2.0",
"faiss-gpu>=1.7.4",
"rank-bm25>=0.2.2",
"accelerate>=0.25.0",
"tensorboard>=2.14.0",
"tqdm>=4.66.0",
"pyyaml>=6.0",
"jsonlines>=3.1.0",
"huggingface-hub>=0.19.0",
"jieba>=0.42.1",
"toml>=0.10.2",
"requests>=2.28.0"
]
for package in packages:
run_command(f"pip install {package}")
def setup_directories():
"""Setup project directories"""
logger.info("Setting up directories...")
base_dir = Path(__file__).parent.parent
directories = [
"models",
"models/bge-m3",
"models/bge-reranker-base",
"data/raw",
"data/processed",
"output",
"logs",
"cache/data"
]
for directory in directories:
dir_path = base_dir / directory
dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created directory: {dir_path}")
def download_models(download_models_flag=False):
"""Download base models from HuggingFace or ModelScope based on config.toml"""
if not download_models_flag:
logger.info("Skipping model download. Use --download-models to download base models.")
return
logger.info("Downloading base models...")
base_dir = Path(__file__).parent.parent
config = get_config()
model_source = config.get('model_paths.source', 'huggingface')
if model_source == 'huggingface':
try:
from huggingface_hub import snapshot_download
# Download BGE-M3
logger.info("Downloading BGE-M3 from HuggingFace...")
snapshot_download(
repo_id="BAAI/bge-m3",
local_dir=base_dir / "models/bge-m3",
local_dir_use_symlinks=False
)
# Download BGE-Reranker
logger.info("Downloading BGE-Reranker from HuggingFace...")
snapshot_download(
repo_id="BAAI/bge-reranker-base",
local_dir=base_dir / "models/bge-reranker-base",
local_dir_use_symlinks=False
)
logger.info("Models downloaded successfully from HuggingFace!")
except ImportError:
logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub")
except Exception as e:
logger.error(f"Failed to download models from HuggingFace: {e}")
logger.info("You can download models manually or use the Hugging Face CLI:")
logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3")
logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base")
elif model_source == 'modelscope':
try:
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
# Download BGE-M3
logger.info("Downloading BGE-M3 from ModelScope...")
ms_snapshot_download(
model_id="damo/nlp_corom_sentence-embedding_chinese-base",
cache_dir=str(base_dir / "models/bge-m3")
)
# Download BGE-Reranker
logger.info("Downloading BGE-Reranker from ModelScope...")
ms_snapshot_download(
model_id="damo/nlp_corom_cross-encoder_chinese-base",
cache_dir=str(base_dir / "models/bge-reranker-base")
)
logger.info("Models downloaded successfully from ModelScope!")
except ImportError:
logger.error("modelscope not available. Install it first: pip install modelscope")
except Exception as e:
logger.error(f"Failed to download models from ModelScope: {e}")
logger.info("You can download models manually from ModelScope website or CLI.")
else:
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
def create_sample_config():
"""Create sample configuration files"""
logger.info("Creating sample configuration...")
base_dir = Path(__file__).parent.parent
config_file = base_dir / "config.toml"
if not config_file.exists():
# Create default config (this would match our config.toml artifact)
sample_config = """# Configuration file for BGE fine-tuning project
[api]
# OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.)
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here" # Replace with your actual API key
model = "deepseek-chat"
timeout = 30
[translation]
# Translation settings
enable_api = true
target_languages = ["en", "zh", "ja", "fr"]
max_retries = 3
fallback_to_simulation = true
[training]
# Training defaults
default_batch_size = 4
default_learning_rate = 1e-5
default_epochs = 3
default_warmup_ratio = 0.1
[model_paths]
# Default model paths
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
[data]
# Data processing settings
max_query_length = 512
max_passage_length = 8192
min_passage_length = 10
validation_split = 0.1
"""
with open(config_file, 'w') as f:
f.write(sample_config)
logger.info(f"Created sample config: {config_file}")
logger.info("Please edit config.toml to add your API keys if needed.")
def create_sample_data():
"""Create sample data files"""
logger.info("Creating sample data...")
base_dir = Path(__file__).parent.parent
sample_file = base_dir / "data/raw/sample_data.jsonl"
if not sample_file.exists():
sample_data = [
{
"query": "什么是机器学习?",
"pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"],
"neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"]
},
{
"query": "How does neural network work?",
"pos": [
"Neural networks are computing systems inspired by biological neural networks that constitute the brain."],
"neg": ["Machine learning algorithms can be supervised or unsupervised.",
"Data preprocessing is an important step in ML."]
},
{
"query": "Python编程语言的特点",
"pos": ["Python是一种高级编程语言具有简洁的语法和强大的功能广泛用于数据科学和AI开发。"],
"neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"]
}
]
with open(sample_file, 'w', encoding='utf-8') as f:
for item in sample_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logger.info(f"Created sample data: {sample_file}")
def verify_installation():
"""Verify installation"""
logger.info("Verifying installation...")
try:
# Test imports
import torch
import transformers
import numpy as np
import pandas as pd
logger.info("✓ Core packages imported successfully")
# Test GPU
if torch.cuda.is_available():
logger.info("✓ CUDA is available")
device = torch.device("cuda")
x = torch.randn(10, 10).to(device)
logger.info("✓ GPU tensor operations work")
else:
logger.info("✓ CPU-only setup")
# Test model loading (if models exist)
base_dir = Path(__file__).parent.parent
if (base_dir / "models/bge-m3").exists():
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3"))
logger.info("✓ BGE-M3 model files are accessible")
except Exception as e:
logger.warning(f"BGE-M3 model check failed: {e}")
logger.info("Installation verification completed!")
except ImportError as e:
logger.error(f"Import error: {e}")
logger.error("Installation may be incomplete")
sys.exit(1)
def print_usage_info():
"""Print usage information"""
print("\n" + "=" * 60)
print("BGE FINE-TUNING SETUP COMPLETED!")
print("=" * 60)
print("\nQuick Start:")
print("1. Edit config.toml to add your API keys (optional)")
print("2. Prepare your training data in JSONL format")
print("3. Run training:")
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
print("4. Evaluate models:")
print(" python scripts/evaluate.py --retriever_model output/model --eval_data data/raw/test.jsonl")
print("\nSample data created in: data/raw/sample_data.jsonl")
print("Configuration file: config.toml")
print("\nFor more information, see the README.md file.")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment")
parser.add_argument("--download-models", action="store_true",
help="Download base models from Hugging Face")
parser.add_argument("--cpu-only", action="store_true",
help="Install CPU-only version of PyTorch")
parser.add_argument("--skip-pytorch", action="store_true",
help="Skip PyTorch installation")
parser.add_argument("--skip-requirements", action="store_true",
help="Skip requirements installation")
args = parser.parse_args()
logger.info("Starting BGE fine-tuning environment setup...")
# Check Python version
check_python_version()
# Detect device type
if args.cpu_only:
device_type = "cpu"
else:
device_type = check_gpu()
# Install PyTorch
if not args.skip_pytorch:
install_pytorch(device_type)
# Install requirements
if not args.skip_requirements:
install_requirements()
# Setup directories
setup_directories()
# Download models
download_models(args.download_models)
# Create sample files
create_sample_config()
create_sample_data()
# Verify installation
verify_installation()
# Print usage info
print_usage_info()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""
Test script to verify the converted dataset can be loaded by BGE training system
"""
from data.dataset import BGEDataset
from transformers import AutoTokenizer
import logging
import sys
import os
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_dataset_loading(dataset_path: str):
"""Test that the dataset can be loaded properly"""
# Check if dataset file exists
if not os.path.exists(dataset_path):
logger.error(f"Dataset file not found: {dataset_path}")
return False
try:
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('models/bge-m3')
# Test loading the converted dataset
logger.info("Loading dataset...")
dataset = BGEDataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=64,
passage_max_length=512,
format_type='triplets'
)
logger.info(f"✅ Dataset loaded successfully!")
logger.info(f"📊 Total samples: {len(dataset)}")
logger.info(f"🔍 Detected format: {dataset.detected_format}")
# Test getting a sample
logger.info("Testing sample retrieval...")
sample = dataset[0]
logger.info(f"🎯 Sample keys: {list(sample.keys())}")
logger.info(f"🔤 Query shape: {sample['query_input_ids'].shape}")
logger.info(f"📄 Passages shape: {sample['passage_input_ids'].shape}")
logger.info(f"🏷️ Labels shape: {sample['labels'].shape}")
logger.info(f"📊 Scores shape: {sample['scores'].shape}")
# Show sample content
logger.info(f"🔍 Query text: {sample['query_text'][:100]}...")
logger.info(f"📝 First passage: {sample['passage_texts'][0][:100]}...")
logger.info(f"🏷️ Labels: {sample['labels']}")
logger.info(f"📊 Scores: {sample['scores']}")
# Test a few more samples
logger.info("Testing multiple samples...")
for i in [1, 10, 100]:
if i < len(dataset):
sample_i = dataset[i]
logger.info(f"Sample {i}: Query length={len(sample_i['query_text'])}, Passages={len([p for p in sample_i['passage_texts'] if p])}")
logger.info("✅ All tests passed! Dataset is ready for BGE-M3 training.")
return True
except Exception as e:
logger.error(f"❌ Dataset loading failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
dataset_path = "data/datasets/三国演义/m3_train_converted.jsonl"
logger.info("🧪 Testing converted BGE-M3 dataset...")
logger.info(f"📁 Dataset path: {dataset_path}")
success = test_dataset_loading(dataset_path)
if success:
logger.info("🎉 Dataset is ready for training!")
logger.info("💡 You can now use this dataset with:")
logger.info(" python scripts/train_m3.py --data_path data/datasets/三国演义/m3_train_converted.jsonl")
else:
logger.error("💥 Dataset testing failed!")
sys.exit(1)
if __name__ == '__main__':
main()

497
scripts/train_joint.py Normal file
View File

@@ -0,0 +1,497 @@
"""
Joint training of BGE-M3 and BGE-Reranker using RocketQAv2 approach
"""
import os
import sys
import argparse
# Set tokenizer parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
import json
from datetime import datetime
import torch
from transformers import AutoTokenizer, set_seed
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset, JointDataset
from data.preprocessing import DataPreprocessor
from training.joint_trainer import JointTrainer
from utils.logging import setup_logging, get_logger
from utils.distributed import setup_distributed, is_main_process, set_random_seed
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Joint training of BGE-M3 and BGE-Reranker")
# Model arguments
parser.add_argument("--retriever_model", type=str, default="BAAI/bge-m3",
help="Path to retriever model")
parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-base",
help="Path to reranker model")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
# Data arguments
parser.add_argument("--train_data", type=str, required=True,
help="Path to training data file")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data file")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/joint-training",
help="Directory to save the models")
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
help="Number of updates steps to accumulate")
parser.add_argument("--learning_rate", type=float, default=1e-5,
help="Initial learning rate for retriever")
parser.add_argument("--reranker_learning_rate", type=float, default=6e-5,
help="Initial learning rate for reranker")
parser.add_argument("--warmup_ratio", type=float, default=0.1,
help="Ratio of warmup steps")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
# Joint training specific
parser.add_argument("--joint_loss_weight", type=float, default=0.3,
help="Weight for joint distillation loss")
parser.add_argument("--update_retriever_steps", type=int, default=100,
help="Update retriever every N steps")
parser.add_argument("--use_dynamic_distillation", action="store_true",
help="Use dynamic listwise distillation")
# Model configuration
parser.add_argument("--retriever_train_group_size", type=int, default=8,
help="Number of passages per query for retriever")
parser.add_argument("--reranker_train_group_size", type=int, default=16,
help="Number of passages per query for reranker")
parser.add_argument("--temperature", type=float, default=0.02,
help="Temperature for InfoNCE loss")
# BGE-M3 specific
parser.add_argument("--use_dense", action="store_true", default=True,
help="Use dense embeddings")
parser.add_argument("--use_sparse", action="store_true",
help="Use sparse embeddings")
parser.add_argument("--use_colbert", action="store_true",
help="Use ColBERT embeddings")
parser.add_argument("--query_max_length", type=int, default=512,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=8192,
help="Maximum passage length")
# Reranker specific
parser.add_argument("--reranker_max_length", type=int, default=512,
help="Maximum sequence length for reranker")
# Performance
parser.add_argument("--fp16", action="store_true",
help="Use mixed precision training")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="Use gradient checkpointing")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
# Checkpointing
parser.add_argument("--save_steps", type=int, default=1000,
help="Save checkpoint every N steps")
parser.add_argument("--eval_steps", type=int, default=1000,
help="Evaluate every N steps")
parser.add_argument("--logging_steps", type=int, default=100,
help="Log every N steps")
parser.add_argument("--save_total_limit", type=int, default=3,
help="Maximum number of checkpoints to keep")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Path to checkpoint to resume from")
# Distributed training
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank for distributed training")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--overwrite_output_dir", action="store_true",
help="Overwrite the output directory")
args = parser.parse_args()
# Validate arguments
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
return args
def main():
# Parse arguments
args = parse_args()
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("Starting joint training of BGE-M3 and BGE-Reranker")
logger.info(f"Arguments: {args}")
# Set random seed
set_random_seed(args.seed, deterministic=False)
# Setup distributed training if needed
device = None
if args.local_rank != -1:
setup_distributed()
from utils.config_loader import get_config
config = get_config()
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
if device_type == 'npu':
try:
import torch_npu
torch_npu.npu.set_device(args.local_rank)
device = torch.device("npu", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
else:
device = torch.device("cpu")
logger.info("Distributed training enabled: using CPU")
else:
from utils.config_loader import get_config
config = get_config()
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
if device_type == 'npu':
try:
import torch_npu
device = torch.device("npu")
logger.info("Using Ascend NPU")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
else:
device = torch.device("cpu")
logger.info("Using CPU")
# Create configurations
retriever_config = BGEM3Config(
model_name_or_path=args.retriever_model,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
output_dir=os.path.join(args.output_dir, "retriever"),
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=args.retriever_train_group_size,
temperature=args.temperature,
use_dense=args.use_dense,
use_sparse=args.use_sparse,
use_colbert=args.use_colbert,
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
device=str(device)
)
reranker_config = BGERerankerConfig(
model_name_or_path=args.reranker_model,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
max_seq_length=args.reranker_max_length,
output_dir=os.path.join(args.output_dir, "reranker"),
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.reranker_learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=args.reranker_train_group_size,
joint_training=True, # type: ignore[call-arg]
joint_loss_weight=args.joint_loss_weight, # type: ignore[call-arg]
use_dynamic_listwise_distillation=args.use_dynamic_distillation, # type: ignore[call-arg]
update_retriever_steps=args.update_retriever_steps, # type: ignore[call-arg]
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
device=str(device)
)
# Create joint config (use retriever config as base)
joint_config = retriever_config
joint_config.output_dir = args.output_dir
joint_config.joint_loss_weight = args.joint_loss_weight # type: ignore[attr-defined]
joint_config.use_dynamic_listwise_distillation = args.use_dynamic_distillation # type: ignore[attr-defined]
joint_config.update_retriever_steps = args.update_retriever_steps # type: ignore[attr-defined]
# Save configurations
if is_main_process():
os.makedirs(args.output_dir, exist_ok=True)
joint_config.save(os.path.join(args.output_dir, "joint_config.json"))
retriever_config.save(os.path.join(args.output_dir, "retriever_config.json"))
reranker_config.save(os.path.join(args.output_dir, "reranker_config.json"))
# Load tokenizers
logger.info("Loading tokenizers...")
retriever_tokenizer = AutoTokenizer.from_pretrained(
args.retriever_model,
cache_dir=args.cache_dir
)
reranker_tokenizer = AutoTokenizer.from_pretrained(
args.reranker_model,
cache_dir=args.cache_dir
)
# Prepare data
logger.info("Preparing data...")
preprocessor = DataPreprocessor(
tokenizer=retriever_tokenizer, # Use retriever tokenizer for preprocessing
max_query_length=args.query_max_length,
max_passage_length=args.passage_max_length
)
# Preprocess training data
train_path, val_path = preprocessor.preprocess_file(
input_path=args.train_data,
output_path=os.path.join(args.output_dir, "processed_train.jsonl"),
file_format=args.data_format,
validation_split=0.1 if args.eval_data is None else 0.0
)
# Use provided eval data or split validation
eval_path = args.eval_data if args.eval_data else val_path
# Create datasets
logger.info("Creating datasets...")
# Retriever dataset
retriever_train_dataset = BGEM3Dataset(
data_path=train_path,
tokenizer=retriever_tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.retriever_train_group_size,
is_train=True
)
# Reranker dataset
reranker_train_dataset = BGERerankerDataset(
data_path=train_path,
tokenizer=reranker_tokenizer,
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
train_group_size=args.reranker_train_group_size,
is_train=True
)
# Joint dataset
joint_train_dataset = JointDataset( # type: ignore[call-arg]
retriever_dataset=retriever_train_dataset, # type: ignore[call-arg]
reranker_dataset=reranker_train_dataset # type: ignore[call-arg]
)
# Evaluation datasets
joint_eval_dataset = None
if eval_path:
retriever_eval_dataset = BGEM3Dataset(
data_path=eval_path,
tokenizer=retriever_tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.retriever_train_group_size,
is_train=False
)
reranker_eval_dataset = BGERerankerDataset(
data_path=eval_path,
tokenizer=reranker_tokenizer,
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
train_group_size=args.reranker_train_group_size,
is_train=False
)
joint_eval_dataset = JointDataset( # type: ignore[call-arg]
retriever_dataset=retriever_eval_dataset, # type: ignore[call-arg]
reranker_dataset=reranker_eval_dataset # type: ignore[call-arg]
)
# Initialize models
logger.info("Loading models...")
retriever = BGEM3Model(
model_name_or_path=args.retriever_model,
use_dense=args.use_dense,
use_sparse=args.use_sparse,
use_colbert=args.use_colbert,
cache_dir=args.cache_dir
)
reranker = BGERerankerModel(
model_name_or_path=args.reranker_model,
cache_dir=args.cache_dir,
use_fp16=args.fp16
)
from utils.config_loader import get_config
logger.info("Retriever model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
logger.info("Reranker model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
# Initialize optimizers
from torch.optim import AdamW
retriever_optimizer = AdamW(
retriever.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
reranker_optimizer = AdamW(
reranker.parameters(),
lr=args.reranker_learning_rate,
weight_decay=args.weight_decay
)
# Initialize joint trainer
logger.info("Initializing joint trainer...")
trainer = JointTrainer(
config=joint_config,
retriever=retriever,
reranker=reranker,
retriever_optimizer=retriever_optimizer,
reranker_optimizer=reranker_optimizer,
device=device
)
# Create data loaders
from torch.utils.data import DataLoader
from utils.distributed import create_distributed_dataloader
if args.local_rank != -1:
train_dataloader = create_distributed_dataloader(
joint_train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
seed=args.seed
)
eval_dataloader = None
if joint_eval_dataset:
eval_dataloader = create_distributed_dataloader(
joint_eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
seed=args.seed
)
else:
train_dataloader = DataLoader(
joint_train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers
)
eval_dataloader = None
if joint_eval_dataset:
eval_dataloader = DataLoader(
joint_eval_dataset,
batch_size=args.per_device_eval_batch_size,
num_workers=args.dataloader_num_workers
)
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
trainer.load_checkpoint(args.resume_from_checkpoint)
# Train
logger.info("Starting joint training...")
try:
trainer.train(train_dataloader, eval_dataloader, args.num_train_epochs)
# Save final models
if is_main_process():
logger.info("Saving final models...")
final_dir = os.path.join(args.output_dir, "final_models")
trainer.save_models(final_dir)
# Save individual models
retriever_path = os.path.join(args.output_dir, "final_retriever")
reranker_path = os.path.join(args.output_dir, "final_reranker")
os.makedirs(retriever_path, exist_ok=True)
os.makedirs(reranker_path, exist_ok=True)
retriever.save_pretrained(retriever_path)
retriever_tokenizer.save_pretrained(retriever_path)
reranker.model.save_pretrained(reranker_path)
reranker_tokenizer.save_pretrained(reranker_path)
# Save training summary
summary = {
"training_args": vars(args),
"retriever_config": retriever_config.to_dict(),
"reranker_config": reranker_config.to_dict(),
"joint_config": joint_config.to_dict(),
"final_metrics": getattr(trainer, 'best_metric', None),
"total_steps": getattr(trainer, 'global_step', 0),
"training_completed": datetime.now().isoformat()
}
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
json.dump(summary, f, indent=2)
logger.info(f"Joint training completed! Models saved to {args.output_dir}")
except Exception as e:
logger.error(f"Joint training failed with error: {e}", exc_info=True)
raise
if __name__ == "__main__":
main()

614
scripts/train_m3.py Normal file
View File

@@ -0,0 +1,614 @@
"""
Enhanced BGE-M3 embedding model training script with integrated dataset pipeline
Supports both legacy nested format and new embedding schema with automatic detection
"""
import os
import sys
import argparse
# Set tokenizer parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
import json
from datetime import datetime
import torch
from transformers import AutoTokenizer, set_seed
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from config.m3_config import BGEM3Config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset, create_embedding_dataset, collate_embedding_batch, migrate_nested_to_flat
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter, HardNegativeAugmenter
from data.hard_negative_mining import ANCEHardNegativeMiner
from training.m3_trainer import BGEM3Trainer
from evaluation.evaluator import RetrievalEvaluator
from utils.logging import setup_logging, get_logger
from utils.distributed import setup_distributed, is_main_process, set_random_seed
from utils.config_loader import get_config, load_config_for_training, resolve_model_path
except ImportError as e:
print(f"Import error: {e}")
print("Please ensure you're running from the project root directory")
sys.exit(1)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Enhanced BGE-M3 Embedding Model Training")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default=None,
help="Path to pretrained model or model identifier")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
parser.add_argument("--config_path", type=str, default="./config.toml",
help="Path to configuration file")
# Data arguments - support both legacy and new schemas
parser.add_argument("--train_data", type=str, required=True,
help="Path to training data file (supports both embedding and legacy nested formats)")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data file")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "embedding", "legacy_nested", "jsonl", "json", "csv", "tsv"],
help="Format of input data (auto-detects by default)")
parser.add_argument("--legacy_data", type=str, default=None,
help="Path to additional legacy nested data (for backward compatibility)")
parser.add_argument("--migrate_legacy", action="store_true",
help="Migrate legacy nested format to flat format")
# Sequence length arguments
parser.add_argument("--max_seq_length", type=int, default=None,
help="Maximum sequence length")
parser.add_argument("--query_max_length", type=int, default=None,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=None,
help="Maximum passage length")
parser.add_argument("--max_length", type=int, default=8192,
help="Maximum total input length (multi-granularity)")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/bge-m3-enhanced",
help="Directory to save the model")
parser.add_argument("--num_train_epochs", type=int, default=None,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=None,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=None,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
help="Number of updates steps to accumulate before performing a backward/update pass")
parser.add_argument("--learning_rate", type=float, default=None,
help="Initial learning rate")
parser.add_argument("--warmup_ratio", type=float, default=None,
help="Warmup steps as a ratio of total steps")
parser.add_argument("--weight_decay", type=float, default=None,
help="Weight decay")
# Model specific arguments - enhanced features
parser.add_argument("--train_group_size", type=int, default=8,
help="Number of passages per query in training")
parser.add_argument("--temperature", type=float, default=0.1,
help="Temperature for InfoNCE loss")
parser.add_argument("--use_hard_negatives", action="store_true",
help="Enable hard negative mining")
parser.add_argument("--use_self_distill", action="store_true",
help="Enable self-knowledge distillation")
parser.add_argument("--use_score_weighted_sampling", action="store_true", default=True,
help="Enable score-weighted positive sampling")
parser.add_argument("--query_instruction", type=str, default="",
help="Query instruction prefix")
# Enhanced sampling features
parser.add_argument("--multiple_positives", action="store_true", default=True,
help="Use multiple positives per batch for richer contrastive signals")
parser.add_argument("--hard_negative_ratio", type=float, default=0.7,
help="Ratio of hard to random negatives")
parser.add_argument("--difficulty_threshold", type=float, default=0.1,
help="Minimum difficulty for hard negatives")
# Optimization and performance
parser.add_argument("--fp16", action="store_true",
help="Use mixed precision training")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="Use gradient checkpointing to save memory")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
# Checkpointing
parser.add_argument("--save_steps", type=int, default=500,
help="Save checkpoint every N steps")
parser.add_argument("--eval_steps", type=int, default=500,
help="Evaluate every N steps")
parser.add_argument("--logging_steps", type=int, default=100,
help="Log every N steps")
parser.add_argument("--save_total_limit", type=int, default=3,
help="Maximum number of checkpoints to keep")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Path to checkpoint to resume from")
# Distributed training
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank for distributed training")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--overwrite_output_dir", action="store_true",
help="Overwrite the output directory")
parser.add_argument("--log_level", type=str, default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level")
# Backward compatibility flags
parser.add_argument("--legacy_support", action="store_true", default=True,
help="Enable legacy nested format support")
parser.add_argument("--use_new_dataset_api", action="store_true", default=True,
help="Use new integrated dataset API with format detection")
args = parser.parse_args()
# Validate arguments
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
return args
def create_config_from_args(args):
"""Create configuration from command line arguments and config file
Parameter Precedence (highest to lowest):
1. Command-line arguments (CLI)
2. Model-specific config ([m3], [reranker])
3. [hardware] section in config.toml
4. [training] section in config.toml
5. Hardcoded defaults in code
"""
# Load base configuration from TOML file
config_dict = load_config_for_training(
config_path=args.config_path,
override_params={}
)
# Get the centralized config instance
central_config = get_config()
# Create override parameters from command line arguments
override_params = {}
# Model configuration - use new resolve_model_path function
resolved_model_path, resolved_cache_dir = resolve_model_path(
model_key='bge_m3',
config_instance=central_config,
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
args.model_name_or_path = resolved_model_path
args.cache_dir = resolved_cache_dir
# Data configuration
data_config = config_dict.get('data', {})
if args.max_seq_length is None:
args.max_seq_length = data_config.get('max_seq_length', 512)
if args.query_max_length is None:
args.query_max_length = data_config.get('max_query_length', 64)
if args.passage_max_length is None:
args.passage_max_length = data_config.get('max_passage_length', 512)
# Get validation split from config
validation_split = data_config.get('validation_split', 0.1)
# Training configuration
training_config = config_dict.get('training', {})
if args.num_train_epochs is None:
args.num_train_epochs = training_config.get('default_num_epochs', 3)
if args.per_device_train_batch_size is None:
args.per_device_train_batch_size = training_config.get('default_batch_size', 4)
if args.per_device_eval_batch_size is None:
args.per_device_eval_batch_size = args.per_device_train_batch_size * 2
if args.learning_rate is None:
args.learning_rate = training_config.get('default_learning_rate', 1e-5)
if args.warmup_ratio is None:
args.warmup_ratio = training_config.get('default_warmup_ratio', 0.1)
if args.weight_decay is None:
args.weight_decay = training_config.get('default_weight_decay', 0.01)
# Model-specific (M3) configuration - highest precedence
m3_config = config_dict.get('m3', {})
train_group_size = m3_config.get('train_group_size', args.train_group_size)
temperature = m3_config.get('temperature', args.temperature)
use_hard_negatives = m3_config.get('use_hard_negatives', args.use_hard_negatives)
use_self_distill = m3_config.get('use_self_distill', args.use_self_distill)
query_instruction_for_retrieval = m3_config.get('query_instruction_for_retrieval', args.query_instruction)
# Hardware configuration
hardware_config = config_dict.get('hardware', {})
if args.per_device_train_batch_size is None:
args.per_device_train_batch_size = hardware_config.get('per_device_train_batch_size', args.per_device_train_batch_size)
if args.per_device_eval_batch_size is None:
args.per_device_eval_batch_size = hardware_config.get('per_device_eval_batch_size', args.per_device_eval_batch_size)
# Create BGE-M3 configuration
config = BGEM3Config(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
max_seq_length=args.max_seq_length,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
max_length=args.max_length,
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=train_group_size,
temperature=temperature,
use_hard_negatives=use_hard_negatives,
use_self_distill=use_self_distill,
query_instruction_for_retrieval=query_instruction_for_retrieval,
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
overwrite_output_dir=args.overwrite_output_dir
)
return config
def setup_model_and_tokenizer(args):
"""Setup BGE-M3 model and tokenizer"""
logger.info(f"Loading BGE-M3 model from {args.model_name_or_path}")
# Detect offline mode if the path is local or HF_HUB_OFFLINE is set
local_files_only = os.path.isdir(args.model_name_or_path) or bool(os.environ.get("HF_HUB_OFFLINE"))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
trust_remote_code=True,
local_files_only=local_files_only
)
# Load model
try:
model = BGEM3Model(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
return model, tokenizer
def setup_datasets(args, tokenizer):
"""Setup datasets with enhanced format detection and migration support"""
logger.info("Setting up datasets with enhanced format detection...")
# Handle legacy data migration if needed
if args.legacy_data and args.migrate_legacy:
logger.info(f"Migrating legacy data from {args.legacy_data}")
legacy_base = os.path.splitext(args.legacy_data)[0]
migrated_file = f"{legacy_base}_migrated_embedding.jsonl"
# For embedding, we don't need to migrate to flat pairs, but we can optimize the format
# Copy the legacy data to work with if needed
import shutil
shutil.copy2(args.legacy_data, migrated_file)
logger.info(f"Legacy data prepared for training: {migrated_file}")
# Setup training dataset
if args.use_new_dataset_api:
# Use new integrated dataset API with format detection
train_dataset = create_embedding_dataset(
data_path=args.train_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True
)
logger.info(f"Loaded {len(train_dataset)} training examples from {args.train_data}")
logger.info(f"Detected format: {train_dataset.detected_format}")
# Add legacy data if provided
if args.legacy_data and args.legacy_support:
logger.info(f"Loading additional legacy data from {args.legacy_data}")
legacy_dataset = create_embedding_dataset(
data_path=args.legacy_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True
)
logger.info(f"Loaded {len(legacy_dataset)} legacy examples")
# Combine datasets
train_dataset.data.extend(legacy_dataset.data)
logger.info(f"Total training examples: {len(train_dataset)}")
else:
# Use original dataset API for backward compatibility
train_dataset = BGEM3Dataset(
data_path=args.train_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.train_group_size,
is_train=True
)
logger.info(f"Loaded {len(train_dataset)} training examples (legacy API)")
# Setup evaluation dataset
eval_dataset = None
if args.eval_data:
if args.use_new_dataset_api:
eval_dataset = create_embedding_dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=False
)
else:
eval_dataset = BGEM3Dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.train_group_size,
is_train=False
)
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
return train_dataset, eval_dataset
def setup_data_loaders(args, train_dataset, eval_dataset):
"""Setup data loaders with appropriate collate functions"""
# 4. Set up data loaders with appropriate collate function
if hasattr(train_dataset, 'format') and train_dataset.format == 'triplets':
# This might be a generic BGE dataset in triplets format
collate_fn = collate_embedding_batch
logger.info("Using embedding collate function for triplets format")
else:
collate_fn = collate_embedding_batch
logger.info("Using standard embedding collate function")
# FastM3Dataset doesn't work well with multiprocessing due to pickle loading
workers = args.dataloader_num_workers
train_dataloader = DataLoader(
train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available()
)
# Evaluation data loader
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available()
)
return train_dataloader, eval_dataloader
def main():
# Parse arguments
args = parse_args()
# Create configuration
config = create_config_from_args(args)
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=getattr(logging, args.log_level),
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("🚀 Starting Enhanced BGE-M3 Fine-tuning")
logger.info(f"Arguments: {args}")
logger.info(f"Configuration: {config.to_dict()}")
# Set random seed
set_random_seed(args.seed, deterministic=False)
# Setup distributed training if needed
device = None
if args.local_rank != -1:
setup_distributed()
# Device selection from config
from utils.config_loader import get_config
config_instance = get_config()
device_type = config_instance.get('ascend.device_type', 'cuda')
if device_type == 'npu':
try:
import torch_npu # type: ignore
torch_npu.npu.set_device(args.local_rank)
device = torch.device("npu", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
else:
device = torch.device("cpu")
logger.info("Distributed training enabled: using CPU")
else:
# Non-distributed device selection
from utils.config_loader import get_config
config_instance = get_config()
device_type = config_instance.get('ascend.device_type', 'cuda')
if device_type == 'npu':
try:
import torch_npu # type: ignore
device = torch.device("npu")
logger.info("Using Ascend NPU")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
else:
device = torch.device("cpu")
logger.info("Using CPU")
# Update config with device
config.device = str(device)
# Save configuration
if is_main_process():
os.makedirs(args.output_dir, exist_ok=True)
config.save(os.path.join(args.output_dir, "config.json"))
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args)
# Setup datasets with enhanced format detection
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
# Setup data loaders
train_dataloader, eval_dataloader = setup_data_loaders(args, train_dataset, eval_dataset)
# Initialize trainer
logger.info("Initializing enhanced trainer...")
try:
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer
)
logger.info("Trainer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize trainer: {e}")
raise
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
logger.info(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
try:
trainer.load_checkpoint(args.resume_from_checkpoint)
logger.info("Checkpoint loaded successfully.")
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise
# Train
logger.info("🎯 Starting enhanced training...")
try:
# Handle type compatibility for trainer
from typing import cast
trainer_train_dataset = cast(BGEM3Dataset, train_dataset) if train_dataset else None
trainer_eval_dataset = cast(BGEM3Dataset, eval_dataset) if eval_dataset else None
if not trainer_train_dataset:
raise ValueError("Training dataset is required but not provided")
trainer.train(
train_dataset=trainer_train_dataset,
eval_dataset=trainer_eval_dataset,
resume_from_checkpoint=args.resume_from_checkpoint
)
# Save final model
if is_main_process():
logger.info("Saving final model...")
final_path = os.path.join(args.output_dir, "final_model")
os.makedirs(final_path, exist_ok=True)
# Save model
model.save_pretrained(final_path)
tokenizer.save_pretrained(final_path)
# Save training summary
summary = {
"training_args": vars(args),
"model_config": config.to_dict(),
"final_metrics": getattr(trainer, 'best_metric', None),
"total_steps": getattr(trainer, 'global_step', 0),
"training_completed": datetime.now().isoformat(),
"device_used": str(device),
"dataset_format": getattr(train_dataset, 'detected_format', 'unknown'),
"enhanced_features": {
"use_new_dataset_api": args.use_new_dataset_api,
"use_hard_negatives": args.use_hard_negatives,
"use_self_distill": args.use_self_distill,
"use_score_weighted_sampling": args.use_score_weighted_sampling,
"multiple_positives": args.multiple_positives,
"legacy_support": args.legacy_support
}
}
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
json.dump(summary, f, indent=2)
logger.info(f"✅ Enhanced training completed! Model saved to {final_path}")
# Show training summary
logger.info("📊 Training Summary:")
logger.info(f" Total examples: {len(train_dataset)}")
if hasattr(train_dataset, 'detected_format'):
logger.info(f" Format detected: {train_dataset.detected_format}") # type: ignore[attr-defined]
logger.info(f" Epochs: {args.num_train_epochs}")
logger.info(f" Learning rate: {args.learning_rate}")
logger.info(f" Batch size: {args.per_device_train_batch_size}")
logger.info(f" Temperature: {args.temperature}")
logger.info(f" Enhanced features enabled: {args.use_new_dataset_api}")
except Exception as e:
logger.error(f"Training failed with error: {e}", exc_info=True)
raise
finally:
# Cleanup
if hasattr(trainer, 'hard_negative_miner') and trainer.hard_negative_miner:
trainer.hard_negative_miner.stop_async_updates()
if args.local_rank != -1:
from utils.distributed import cleanup_distributed
cleanup_distributed()
if __name__ == "__main__":
main()

398
scripts/train_reranker.py Normal file
View File

@@ -0,0 +1,398 @@
#!/usr/bin/env python3
"""
BGE-Reranker Cross-Encoder Model Training Script
Uses the refactored dataset pipeline with flat pairs schema
"""
import os
import sys
import argparse
import logging
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from transformers.optimization import AdamW
import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat
from models.bge_reranker import BGERerankerModel
from training.reranker_trainer import BGERerankerTrainer
from utils.logging import setup_logging, get_logger
logger = get_logger(__name__)
def parse_args():
"""Parse command line arguments for reranker training"""
parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base",
help="Path to pretrained BGE-Reranker model")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
# Data arguments
parser.add_argument("--reranker_data", type=str, required=True,
help="Path to reranker training data (reranker_data.jsonl)")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data")
parser.add_argument("--legacy_data", type=str, default=None,
help="Path to legacy nested data (will be migrated)")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/bge-reranker",
help="Directory to save the trained model")
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=16,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--learning_rate", type=float, default=6e-5,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
parser.add_argument("--warmup_ratio", type=float, default=0.1,
help="Warmup ratio")
# Model specific arguments
parser.add_argument("--query_max_length", type=int, default=64,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=448,
help="Maximum passage length")
parser.add_argument("--loss_type", type=str, default="cross_entropy",
choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"],
help="Loss function type")
# Enhanced features
parser.add_argument("--use_score_weighting", action="store_true",
help="Use score weighting in loss calculation")
parser.add_argument("--label_smoothing", type=float, default=0.0,
help="Label smoothing for cross entropy loss")
# System arguments
parser.add_argument("--fp16", action="store_true",
help="Use 16-bit floating point precision")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--log_level", type=str, default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level")
# Backward compatibility
parser.add_argument("--legacy_support", action="store_true", default=True,
help="Enable legacy nested format support")
parser.add_argument("--migrate_legacy", action="store_true",
help="Migrate legacy nested format to flat format")
return parser.parse_args()
def setup_model_and_tokenizer(args):
"""Setup BGE-Reranker model and tokenizer"""
logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
trust_remote_code=True
)
# Load model
model = BGERerankerModel(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
return model, tokenizer
def migrate_legacy_data(args):
"""Migrate legacy nested data to flat format if needed"""
if args.legacy_data and args.migrate_legacy:
logger.info("Migrating legacy nested data to flat format...")
# Create output file name
legacy_base = os.path.splitext(args.legacy_data)[0]
migrated_file = f"{legacy_base}_migrated_flat.jsonl"
# Migrate the data
migrate_nested_to_flat(args.legacy_data, migrated_file)
logger.info(f"✅ Migrated legacy data to {migrated_file}")
return migrated_file
return None
def setup_datasets(args, tokenizer):
"""Setup reranker datasets"""
logger.info("Setting up reranker datasets...")
# Migrate legacy data if needed
migrated_file = migrate_legacy_data(args)
# Primary reranker data
train_dataset = create_reranker_dataset(
data_path=args.reranker_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True,
legacy_support=args.legacy_support
)
logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}")
logger.info(f"Detected format: {train_dataset.format_type}")
# Add migrated legacy data if available
if migrated_file:
logger.info(f"Loading migrated legacy data from {migrated_file}")
migrated_dataset = create_reranker_dataset(
data_path=migrated_file,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True,
legacy_support=False # Migrated data is in flat format
)
logger.info(f"Loaded {len(migrated_dataset)} migrated examples")
# Combine datasets
train_dataset.data.extend(migrated_dataset.data)
logger.info(f"Total training examples: {len(train_dataset)}")
# Evaluation dataset
eval_dataset = None
if args.eval_data:
eval_dataset = create_reranker_dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=False,
legacy_support=args.legacy_support
)
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
return train_dataset, eval_dataset
def setup_loss_function(args):
"""Setup cross-encoder loss function"""
if args.loss_type == "cross_entropy":
if args.label_smoothing > 0:
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
else:
loss_fn = torch.nn.CrossEntropyLoss()
elif args.loss_type == "binary_cross_entropy":
loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.loss_type == "listwise_ce":
# Custom listwise cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss()
else:
raise ValueError(f"Unknown loss type: {args.loss_type}")
logger.info(f"Using loss function: {args.loss_type}")
return loss_fn
def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args):
"""Train for one epoch"""
model.train()
total_loss = 0
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
# Move batch to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Calculate loss
if args.loss_type == "binary_cross_entropy":
# For binary classification
loss = loss_fn(logits.squeeze(), labels.float())
else:
# For multi-class classification
loss = loss_fn(logits, labels)
# Backward pass
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss += loss.item()
num_batches += 1
# Log progress
if batch_idx % 100 == 0:
logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
return total_loss / num_batches
def evaluate(model, dataloader, loss_fn, device, args):
"""Evaluate the model"""
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Calculate loss
if args.loss_type == "binary_cross_entropy":
loss = loss_fn(logits.squeeze(), labels.float())
# Calculate accuracy for binary classification
predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long()
else:
loss = loss_fn(logits, labels)
# Calculate accuracy for multi-class classification
predictions = torch.argmax(logits, dim=-1)
total_loss += loss.item()
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = correct / total if total > 0 else 0
return total_loss / len(dataloader), accuracy
def main():
"""Main training function"""
args = parse_args()
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=args.log_level,
log_to_file=True,
log_to_console=True
)
logger.info("🚀 Starting BGE-Reranker Training")
logger.info(f"Arguments: {vars(args)}")
# Set random seed
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu")
logger.info(f"Using device: {device}")
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args)
model.to(device)
# Setup datasets
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
# Setup data loaders
train_dataloader = DataLoader(
train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
collate_fn=collate_reranker_batch,
pin_memory=torch.cuda.is_available()
)
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
collate_fn=collate_reranker_batch,
pin_memory=torch.cuda.is_available()
)
# Setup optimizer and scheduler
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
total_steps = len(train_dataloader) * args.num_train_epochs
warmup_steps = int(total_steps * args.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
# Setup loss function
loss_fn = setup_loss_function(args)
# Training loop
logger.info("🎯 Starting training...")
best_eval_loss = float('inf')
for epoch in range(args.num_train_epochs):
logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}")
# Train
train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args)
logger.info(f"Training loss: {train_loss:.4f}")
# Evaluate
if eval_dataloader:
eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args)
logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
# Save best model
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
best_model_path = os.path.join(args.output_dir, "best_model")
os.makedirs(best_model_path, exist_ok=True)
model.save_pretrained(best_model_path)
tokenizer.save_pretrained(best_model_path)
logger.info(f"Saved best model to {best_model_path}")
# Save final model
final_model_path = os.path.join(args.output_dir, "final_model")
os.makedirs(final_model_path, exist_ok=True)
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
logger.info(f"✅ Training completed! Final model saved to {final_model_path}")
# Show training summary
logger.info("📊 Training Summary:")
logger.info(f" Total examples: {len(train_dataset)}")
logger.info(f" Format detected: {train_dataset.format_type}")
logger.info(f" Epochs: {args.num_train_epochs}")
logger.info(f" Learning rate: {args.learning_rate}")
logger.info(f" Batch size: {args.per_device_train_batch_size}")
logger.info(f" Loss function: {args.loss_type}")
if __name__ == "__main__":
main()

122
scripts/validate_m3.py Normal file
View File

@@ -0,0 +1,122 @@
"""
Quick validation script for the fine-tuned BGE-M3 model
"""
import os
import sys
import torch
import numpy as np
from transformers import AutoTokenizer
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
def test_model_embeddings(model_path, test_queries=None):
"""Test if the model can generate embeddings"""
if test_queries is None:
test_queries = [
"什么是深度学习?", # Original training query
"什么是机器学习?", # Related query
"今天天气怎么样?" # Unrelated query
]
print(f"🧪 Testing model from: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.eval()
print("✅ Model and tokenizer loaded successfully")
# Test embeddings generation
embeddings = []
for i, query in enumerate(test_queries):
print(f"\n📝 Testing query {i+1}: '{query}'")
# Tokenize
inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
# Get embeddings
with torch.no_grad():
outputs = model.forward(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
return_dense=True,
return_dict=True
)
if 'dense' in outputs:
embedding = outputs['dense'].cpu().numpy() # type: ignore[index]
embeddings.append(embedding)
print(f" ✅ Embedding shape: {embedding.shape}")
print(f" 📊 Embedding norm: {np.linalg.norm(embedding):.4f}")
print(f" 📈 Mean activation: {embedding.mean():.6f}")
print(f" 📉 Std activation: {embedding.std():.6f}")
else:
print(f" ❌ No dense embedding returned")
return False
# Compare embeddings
if len(embeddings) >= 2:
print(f"\n🔗 Similarity Analysis:")
for i in range(len(embeddings)):
for j in range(i+1, len(embeddings)):
sim = np.dot(embeddings[i].flatten(), embeddings[j].flatten()) / (
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
)
print(f" Query {i+1} vs Query {j+1}: {sim:.4f}")
print(f"\n🎉 Model validation successful!")
return True
except Exception as e:
print(f"❌ Model validation failed: {e}")
return False
def compare_models(original_path, finetuned_path):
"""Compare original vs fine-tuned model"""
print("🔄 Comparing original vs fine-tuned model...")
test_query = "什么是深度学习?"
# Test original model
print("\n📍 Original Model:")
orig_success = test_model_embeddings(original_path, [test_query])
# Test fine-tuned model
print("\n📍 Fine-tuned Model:")
ft_success = test_model_embeddings(finetuned_path, [test_query])
return orig_success and ft_success
if __name__ == "__main__":
print("🚀 BGE-M3 Model Validation")
print("=" * 50)
# Test fine-tuned model
finetuned_model_path = "./test_m3_output/final_model"
success = test_model_embeddings(finetuned_model_path)
# Optional: Compare with original model
original_model_path = "./models/bge-m3"
compare_models(original_model_path, finetuned_model_path)
if success:
print("\n✅ Fine-tuning validation: PASSED")
print(" The model can generate embeddings successfully!")
else:
print("\n❌ Fine-tuning validation: FAILED")
print(" The model has issues generating embeddings!")

View File

@@ -0,0 +1,150 @@
"""
Quick validation script for the fine-tuned BGE Reranker model
"""
import os
import sys
import torch
import numpy as np
from transformers import AutoTokenizer
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_reranker import BGERerankerModel
def test_reranker_scoring(model_path, test_queries=None):
"""Test if the reranker can score query-passage pairs"""
if test_queries is None:
# Test query with relevant and irrelevant passages
test_queries = [
{
"query": "什么是深度学习?",
"passages": [
"深度学习是机器学习的一个子领域", # Relevant (should score high)
"机器学习包含多种算法", # Somewhat relevant
"今天天气很好", # Irrelevant (should score low)
"深度学习使用神经网络进行学习" # Very relevant
]
}
]
print(f"🧪 Testing reranker from: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.eval()
print("✅ Reranker model and tokenizer loaded successfully")
for query_idx, test_case in enumerate(test_queries):
query = test_case["query"]
passages = test_case["passages"]
print(f"\n📝 Test Case {query_idx + 1}")
print(f"Query: '{query}'")
print(f"Passages to rank:")
scores = []
for i, passage in enumerate(passages):
print(f" {i+1}. '{passage}'")
# Create query-passage pair
sep_token = tokenizer.sep_token or '[SEP]'
pair_text = f"{query} {sep_token} {passage}"
# Tokenize
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
# Get relevance score
with torch.no_grad():
outputs = model(**inputs)
# Extract score (logits)
if hasattr(outputs, 'logits'):
score = outputs.logits.item()
elif isinstance(outputs, torch.Tensor):
score = outputs.item()
else:
score = outputs[0].item()
scores.append((i, passage, score))
# Sort by relevance score (highest first)
scores.sort(key=lambda x: x[2], reverse=True)
print(f"\n🏆 Ranking Results:")
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
# Check if most relevant passages ranked highest
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
top_passages = [item[1] for item in scores[:2]]
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
irrelevant_in_top = "今天天气很好" in top_passages
if relevant_in_top and not irrelevant_in_top:
print(f" ✅ Good ranking - relevant passages ranked higher")
elif relevant_in_top:
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
else:
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
print(f"\n🎉 Reranker validation successful!")
return True
except Exception as e:
print(f"❌ Reranker validation failed: {e}")
return False
def compare_rerankers(original_path, finetuned_path):
"""Compare original vs fine-tuned reranker"""
print("🔄 Comparing original vs fine-tuned reranker...")
test_case = {
"query": "什么是深度学习?",
"passages": [
"深度学习是机器学习的一个子领域",
"今天天气很好"
]
}
# Test original model
print("\n📍 Original Reranker:")
orig_success = test_reranker_scoring(original_path, [test_case])
# Test fine-tuned model
print("\n📍 Fine-tuned Reranker:")
ft_success = test_reranker_scoring(finetuned_path, [test_case])
return orig_success and ft_success
if __name__ == "__main__":
print("🚀 BGE Reranker Model Validation")
print("=" * 50)
# Test fine-tuned model
finetuned_model_path = "./test_reranker_output/final_model"
success = test_reranker_scoring(finetuned_model_path)
# Optional: Compare with original model
original_model_path = "./models/bge-reranker-base"
compare_rerankers(original_model_path, finetuned_model_path)
if success:
print("\n✅ Reranker validation: PASSED")
print(" The reranker can score query-passage pairs successfully!")
else:
print("\n❌ Reranker validation: FAILED")
print(" The reranker has issues with scoring!")

641
scripts/validation_utils.py Normal file
View File

@@ -0,0 +1,641 @@
#!/usr/bin/env python3
"""
Validation Utilities for BGE Fine-tuning
This module provides utilities for model validation including:
- Automatic model discovery
- Test data preparation and validation
- Results analysis and reporting
- Benchmark dataset creation
Usage:
python scripts/validation_utils.py --discover-models
python scripts/validation_utils.py --prepare-test-data
python scripts/validation_utils.py --analyze-results ./validation_results
"""
import os
import sys
import json
import glob
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any
from datetime import datetime
import pandas as pd
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class ModelDiscovery:
"""Discover trained models in the workspace"""
def __init__(self, workspace_root: str = "."):
self.workspace_root = Path(workspace_root)
def find_trained_models(self) -> Dict[str, Dict[str, str]]:
"""Find all trained BGE models in the workspace"""
models = {
'retrievers': {},
'rerankers': {},
'joint': {}
}
# Common output directories to search
search_patterns = [
"output/*/final_model",
"output/*/best_model",
"output/*/*/final_model",
"output/*/*/best_model",
"test_*_output/final_model",
"training_output/*/final_model"
]
for pattern in search_patterns:
for model_path in glob.glob(str(self.workspace_root / pattern)):
model_path = Path(model_path)
if self._is_valid_model_directory(model_path):
model_info = self._analyze_model_directory(model_path)
if model_info['type'] == 'retriever':
models['retrievers'][model_info['name']] = model_info
elif model_info['type'] == 'reranker':
models['rerankers'][model_info['name']] = model_info
elif model_info['type'] == 'joint':
models['joint'][model_info['name']] = model_info
return models
def _is_valid_model_directory(self, path: Path) -> bool:
"""Check if directory contains a valid trained model"""
required_files = ['config.json', 'pytorch_model.bin']
alternative_files = ['model.safetensors', 'tokenizer.json']
has_required = any((path / f).exists() for f in required_files)
has_alternative = any((path / f).exists() for f in alternative_files)
return has_required or has_alternative
def _analyze_model_directory(self, path: Path) -> Dict[str, str]:
"""Analyze model directory to determine type and metadata"""
info = {
'path': str(path),
'name': self._generate_model_name(path),
'type': 'unknown',
'created': 'unknown',
'training_info': {}
}
# Try to determine model type from path
path_str = str(path).lower()
if 'm3' in path_str or 'retriever' in path_str:
info['type'] = 'retriever'
elif 'reranker' in path_str:
info['type'] = 'reranker'
elif 'joint' in path_str:
info['type'] = 'joint'
# Get creation time
if path.exists():
info['created'] = datetime.fromtimestamp(path.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
# Try to read training info
training_summary_path = path.parent / 'training_summary.json'
if training_summary_path.exists():
try:
with open(training_summary_path, 'r') as f:
info['training_info'] = json.load(f)
except:
pass
# Try to read config
config_path = path / 'config.json'
if config_path.exists():
try:
with open(config_path, 'r') as f:
config = json.load(f)
if 'model_type' in config:
info['type'] = config['model_type']
except:
pass
return info
def _generate_model_name(self, path: Path) -> str:
"""Generate a readable name for the model"""
parts = path.parts
# Try to find meaningful parts
meaningful_parts = []
for part in parts[-3:]: # Look at last 3 parts
if part not in ['final_model', 'best_model', 'output']:
meaningful_parts.append(part)
if meaningful_parts:
return '_'.join(meaningful_parts)
else:
return f"model_{path.parent.name}"
def print_discovered_models(self):
"""Print discovered models in a formatted way"""
models = self.find_trained_models()
print("\n🔍 DISCOVERED TRAINED MODELS")
print("=" * 60)
for model_type, model_dict in models.items():
if model_dict:
print(f"\n📊 {model_type.upper()}:")
for name, info in model_dict.items():
print(f"{name}")
print(f" Path: {info['path']}")
print(f" Created: {info['created']}")
if info['training_info']:
training_info = info['training_info']
if 'total_steps' in training_info:
print(f" Training Steps: {training_info['total_steps']}")
if 'final_metrics' in training_info and training_info['final_metrics']:
print(f" Final Metrics: {training_info['final_metrics']}")
if not any(models.values()):
print("❌ No trained models found!")
print("💡 Make sure you have trained some models first.")
return models
class TestDataManager:
"""Manage test datasets for validation"""
def __init__(self, data_dir: str = "data/datasets"):
self.data_dir = Path(data_dir)
def find_test_datasets(self) -> Dict[str, List[str]]:
"""Find available test datasets"""
datasets = {
'retriever_datasets': [],
'reranker_datasets': [],
'general_datasets': []
}
# Search for JSONL files in data directory
for jsonl_file in self.data_dir.rglob("*.jsonl"):
file_path = str(jsonl_file)
# Categorize based on filename/path
if 'reranker' in file_path.lower():
datasets['reranker_datasets'].append(file_path)
elif 'embedding' in file_path.lower() or 'm3' in file_path.lower():
datasets['retriever_datasets'].append(file_path)
else:
datasets['general_datasets'].append(file_path)
# Also search for JSON files
for json_file in self.data_dir.rglob("*.json"):
file_path = str(json_file)
if 'test' in file_path.lower() or 'dev' in file_path.lower():
if 'reranker' in file_path.lower() or 'AFQMC' in file_path:
datasets['reranker_datasets'].append(file_path)
else:
datasets['general_datasets'].append(file_path)
return datasets
def validate_dataset(self, dataset_path: str) -> Dict[str, Any]:
"""Validate a dataset file and analyze its format"""
validation_result = {
'path': dataset_path,
'exists': False,
'format': 'unknown',
'sample_count': 0,
'has_queries': False,
'has_passages': False,
'has_labels': False,
'issues': [],
'sample_data': {}
}
if not os.path.exists(dataset_path):
validation_result['issues'].append(f"File does not exist: {dataset_path}")
return validation_result
validation_result['exists'] = True
try:
# Read first few lines to analyze format
samples = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= 5: # Only analyze first 5 lines
break
try:
sample = json.loads(line.strip())
samples.append(sample)
except json.JSONDecodeError:
validation_result['issues'].append(f"Invalid JSON at line {i+1}")
if samples:
validation_result['sample_count'] = len(samples)
validation_result['sample_data'] = samples[0] if samples else {}
# Analyze format
first_sample = samples[0]
# Check for common fields
if 'query' in first_sample:
validation_result['has_queries'] = True
if any(field in first_sample for field in ['passage', 'pos', 'neg', 'candidates']):
validation_result['has_passages'] = True
if any(field in first_sample for field in ['label', 'labels', 'score']):
validation_result['has_labels'] = True
# Determine format
if 'pos' in first_sample and 'neg' in first_sample:
validation_result['format'] = 'triplets'
elif 'candidates' in first_sample:
validation_result['format'] = 'candidates'
elif 'passage' in first_sample and 'label' in first_sample:
validation_result['format'] = 'pairs'
elif 'sentence1' in first_sample and 'sentence2' in first_sample:
validation_result['format'] = 'sentence_pairs'
else:
validation_result['format'] = 'custom'
except Exception as e:
validation_result['issues'].append(f"Error reading file: {e}")
return validation_result
def print_dataset_analysis(self):
"""Print analysis of available datasets"""
datasets = self.find_test_datasets()
print("\n📁 AVAILABLE TEST DATASETS")
print("=" * 60)
for category, dataset_list in datasets.items():
if dataset_list:
print(f"\n📊 {category.upper().replace('_', ' ')}:")
for dataset_path in dataset_list:
validation = self.validate_dataset(dataset_path)
status = "" if validation['exists'] and not validation['issues'] else "⚠️"
print(f" {status} {Path(dataset_path).name}")
print(f" Path: {dataset_path}")
print(f" Format: {validation['format']}")
if validation['issues']:
for issue in validation['issues']:
print(f" ⚠️ {issue}")
if not any(datasets.values()):
print("❌ No test datasets found!")
print("💡 Place your test data in the data/datasets/ directory")
return datasets
def create_sample_test_data(self, output_dir: str = "data/test_samples"):
"""Create sample test datasets for validation"""
os.makedirs(output_dir, exist_ok=True)
# Create sample retriever test data
retriever_test = [
{
"query": "什么是人工智能?",
"pos": [
"人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
"AI技术包括机器学习、深度学习、自然语言处理等多个领域"
],
"neg": [
"今天的天气非常好,阳光明媚",
"我喜欢听音乐和看电影"
]
},
{
"query": "如何学习编程?",
"pos": [
"学习编程需要选择合适的语言如Python、Java等然后通过实践项目提升技能",
"编程学习的关键是多写代码、多调试、多思考算法逻辑"
],
"neg": [
"烹饪是一门艺术,需要掌握火候和调味",
"运动对健康很重要每天应该锻炼30分钟"
]
}
]
# Create sample reranker test data
reranker_test = [
{
"query": "什么是人工智能?",
"passage": "人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
"label": 1
},
{
"query": "什么是人工智能?",
"passage": "今天的天气非常好,阳光明媚",
"label": 0
},
{
"query": "如何学习编程?",
"passage": "学习编程需要选择合适的语言如Python、Java等然后通过实践项目提升技能",
"label": 1
},
{
"query": "如何学习编程?",
"passage": "烹饪是一门艺术,需要掌握火候和调味",
"label": 0
}
]
# Save retriever test data
retriever_path = os.path.join(output_dir, "retriever_test.jsonl")
with open(retriever_path, 'w', encoding='utf-8') as f:
for item in retriever_test:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Save reranker test data
reranker_path = os.path.join(output_dir, "reranker_test.jsonl")
with open(reranker_path, 'w', encoding='utf-8') as f:
for item in reranker_test:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f"✅ Sample test datasets created in: {output_dir}")
print(f" 📁 Retriever test: {retriever_path}")
print(f" 📁 Reranker test: {reranker_path}")
return {
'retriever_test': retriever_path,
'reranker_test': reranker_path
}
class ValidationReportAnalyzer:
"""Analyze validation results and generate insights"""
def __init__(self, results_dir: str):
self.results_dir = Path(results_dir)
def analyze_results(self) -> Dict[str, Any]:
"""Analyze validation results"""
analysis = {
'overall_status': 'unknown',
'model_performance': {},
'recommendations': [],
'summary_stats': {}
}
# Look for results files
json_results = self.results_dir / "validation_results.json"
if not json_results.exists():
analysis['recommendations'].append("No validation results found. Run validation first.")
return analysis
try:
with open(json_results, 'r') as f:
results = json.load(f)
# Analyze each model type
improvements = 0
degradations = 0
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in results and 'comparisons' in results[model_type]:
model_analysis = self._analyze_model_results(results[model_type])
analysis['model_performance'][model_type] = model_analysis
improvements += model_analysis['total_improvements']
degradations += model_analysis['total_degradations']
# Overall status
if improvements > degradations * 1.5:
analysis['overall_status'] = 'excellent'
elif improvements > degradations:
analysis['overall_status'] = 'good'
elif improvements == degradations:
analysis['overall_status'] = 'mixed'
else:
analysis['overall_status'] = 'poor'
# Generate recommendations
analysis['recommendations'] = self._generate_recommendations(analysis)
# Summary statistics
analysis['summary_stats'] = {
'total_improvements': improvements,
'total_degradations': degradations,
'net_improvement': improvements - degradations,
'improvement_ratio': improvements / (improvements + degradations) if (improvements + degradations) > 0 else 0
}
except Exception as e:
analysis['recommendations'].append(f"Error analyzing results: {e}")
return analysis
def _analyze_model_results(self, model_results: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze results for a specific model type"""
analysis = {
'total_improvements': 0,
'total_degradations': 0,
'best_dataset': None,
'worst_dataset': None,
'avg_improvement': 0,
'issues': []
}
dataset_scores = {}
for dataset_name, comparison in model_results.get('comparisons', {}).items():
improvements = len(comparison.get('improvements', {}))
degradations = len(comparison.get('degradations', {}))
analysis['total_improvements'] += improvements
analysis['total_degradations'] += degradations
# Score dataset (improvements - degradations)
dataset_score = improvements - degradations
dataset_scores[dataset_name] = dataset_score
if dataset_scores:
analysis['best_dataset'] = max(dataset_scores.items(), key=lambda x: x[1])
analysis['worst_dataset'] = min(dataset_scores.items(), key=lambda x: x[1])
analysis['avg_improvement'] = sum(dataset_scores.values()) / len(dataset_scores)
# Identify issues
if analysis['total_degradations'] > analysis['total_improvements']:
analysis['issues'].append("More degradations than improvements")
if analysis['avg_improvement'] < 0:
analysis['issues'].append("Negative average improvement")
return analysis
def _generate_recommendations(self, analysis: Dict[str, Any]) -> List[str]:
"""Generate recommendations based on analysis"""
recommendations = []
overall_status = analysis['overall_status']
if overall_status == 'excellent':
recommendations.append("🎉 Great job! Your fine-tuned models show significant improvements.")
recommendations.append("Consider deploying these models to production.")
elif overall_status == 'good':
recommendations.append("✅ Good progress! Models show overall improvement.")
recommendations.append("Consider further fine-tuning on specific datasets where degradation occurred.")
elif overall_status == 'mixed':
recommendations.append("⚠️ Mixed results. Some improvements, some degradations.")
recommendations.append("Analyze which datasets/metrics are degrading and adjust training data.")
recommendations.append("Consider different hyperparameters or longer training.")
elif overall_status == 'poor':
recommendations.append("❌ Models show more degradations than improvements.")
recommendations.append("Review training data quality and format.")
recommendations.append("Check hyperparameters (learning rate might be too high).")
recommendations.append("Consider starting with different base models.")
# Specific model recommendations
for model_type, model_perf in analysis.get('model_performance', {}).items():
if model_perf['issues']:
recommendations.append(f"🔧 {model_type.title()} issues: {', '.join(model_perf['issues'])}")
if model_perf['best_dataset']:
best_dataset, best_score = model_perf['best_dataset']
recommendations.append(f"📊 {model_type.title()} performs best on: {best_dataset} (score: {best_score})")
if model_perf['worst_dataset']:
worst_dataset, worst_score = model_perf['worst_dataset']
recommendations.append(f"📉 {model_type.title()} needs improvement on: {worst_dataset} (score: {worst_score})")
return recommendations
def print_analysis(self):
"""Print comprehensive analysis of validation results"""
analysis = self.analyze_results()
print(f"\n📊 VALIDATION RESULTS ANALYSIS")
print("=" * 60)
# Overall status
status_emoji = {
'excellent': '🌟',
'good': '',
'mixed': '⚠️',
'poor': '',
'unknown': ''
}
print(f"{status_emoji.get(analysis['overall_status'], '')} Overall Status: {analysis['overall_status'].upper()}")
# Summary stats
if 'summary_stats' in analysis and analysis['summary_stats']:
stats = analysis['summary_stats']
print(f"\n📈 Summary Statistics:")
print(f" Total Improvements: {stats['total_improvements']}")
print(f" Total Degradations: {stats['total_degradations']}")
print(f" Net Improvement: {stats['net_improvement']:+d}")
print(f" Improvement Ratio: {stats['improvement_ratio']:.2%}")
# Model performance
if analysis['model_performance']:
print(f"\n📊 Model Performance:")
for model_type, perf in analysis['model_performance'].items():
print(f" {model_type.title()}:")
print(f" Improvements: {perf['total_improvements']}")
print(f" Degradations: {perf['total_degradations']}")
print(f" Average Score: {perf['avg_improvement']:.2f}")
if perf['best_dataset']:
print(f" Best Dataset: {perf['best_dataset'][0]} (score: {perf['best_dataset'][1]})")
if perf['issues']:
print(f" Issues: {', '.join(perf['issues'])}")
# Recommendations
if analysis['recommendations']:
print(f"\n💡 Recommendations:")
for i, rec in enumerate(analysis['recommendations'], 1):
print(f" {i}. {rec}")
return analysis
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Validation utilities for BGE models",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--discover-models", action="store_true",
help="Discover trained models in workspace")
parser.add_argument("--analyze-datasets", action="store_true",
help="Analyze available test datasets")
parser.add_argument("--create-sample-data", action="store_true",
help="Create sample test datasets")
parser.add_argument("--analyze-results", type=str, default=None,
help="Analyze validation results in specified directory")
parser.add_argument("--workspace-root", type=str, default=".",
help="Root directory of workspace")
parser.add_argument("--data-dir", type=str, default="data/datasets",
help="Directory containing test datasets")
parser.add_argument("--output-dir", type=str, default="data/test_samples",
help="Output directory for sample data")
return parser.parse_args()
def main():
"""Main utility function"""
args = parse_args()
print("🔧 BGE Validation Utilities")
print("=" * 50)
if args.discover_models:
print("\n🔍 Discovering trained models...")
discovery = ModelDiscovery(args.workspace_root)
models = discovery.print_discovered_models()
if any(models.values()):
print("\n💡 Use these model paths in your validation commands:")
for model_type, model_dict in models.items():
for name, info in model_dict.items():
print(f" --{model_type[:-1]}_model {info['path']}")
if args.analyze_datasets:
print("\n📁 Analyzing test datasets...")
data_manager = TestDataManager(args.data_dir)
datasets = data_manager.print_dataset_analysis()
if args.create_sample_data:
print("\n📝 Creating sample test data...")
data_manager = TestDataManager()
sample_paths = data_manager.create_sample_test_data(args.output_dir)
print("\n💡 Use these sample datasets for testing:")
print(f" --test_datasets {sample_paths['retriever_test']} {sample_paths['reranker_test']}")
if args.analyze_results:
print(f"\n📊 Analyzing results from: {args.analyze_results}")
analyzer = ValidationReportAnalyzer(args.analyze_results)
analysis = analyzer.print_analysis()
if not any([args.discover_models, args.analyze_datasets, args.create_sample_data, args.analyze_results]):
print("❓ No action specified. Use --help to see available options.")
return 1
return 0
if __name__ == "__main__":
exit(main())