init commit
This commit is contained in:
144
scripts/benchmark.py
Normal file
144
scripts/benchmark.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("benchmark")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for benchmarking."""
|
||||
parser = argparse.ArgumentParser(description="Benchmark BGE fine-tuned models.")
|
||||
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
|
||||
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current memory usage of the process."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def benchmark_retriever(model_path, data_path, batch_size, max_samples, device):
|
||||
"""Benchmark the BGE retriever model."""
|
||||
logger.info(f"Loading retriever model from {model_path}")
|
||||
model = BGEM3Model(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
|
||||
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
# Only dense embedding for speed
|
||||
_ = model(
|
||||
batch['query_input_ids'].to(device),
|
||||
batch['query_attention_mask'].to(device),
|
||||
return_dense=True, return_sparse=False, return_colbert=False
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n_samples += batch['query_input_ids'].shape[0]
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
logger.info(f"Retriever throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
|
||||
return throughput, latency, mem_mb
|
||||
|
||||
def benchmark_reranker(model_path, data_path, batch_size, max_samples, device):
|
||||
"""Benchmark the BGE reranker model."""
|
||||
logger.info(f"Loading reranker model from {model_path}")
|
||||
model = BGERerankerModel(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
|
||||
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
_ = model.model(
|
||||
input_ids=batch['input_ids'].to(device),
|
||||
attention_mask=batch['attention_mask'].to(device)
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n_samples += batch['input_ids'].shape[0]
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
logger.info(f"Reranker throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
|
||||
return throughput, latency, mem_mb
|
||||
|
||||
def main():
|
||||
"""Run benchmark with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Benchmark BGE models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
|
||||
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
if args.model_type == 'retriever':
|
||||
throughput, latency, mem_mb = benchmark_retriever(
|
||||
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
else:
|
||||
throughput, latency, mem_mb = benchmark_reranker(
|
||||
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
# Save results
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Model: {args.model_path}\n")
|
||||
f.write(f"Data: {args.data_path}\n")
|
||||
f.write(f"Batch size: {args.batch_size}\n")
|
||||
f.write(f"Throughput: {throughput:.2f} samples/sec\n")
|
||||
f.write(f"Latency: {latency*1000:.2f} ms/sample\n")
|
||||
f.write(f"Peak memory: {mem_mb:.2f} MB\n")
|
||||
logger.info(f"Benchmark results saved to {args.output}")
|
||||
logging.info(f"Benchmark completed on device: {device}")
|
||||
except Exception as e:
|
||||
logging.error(f"Benchmark failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
138
scripts/compare_reranker.py
Normal file
138
scripts/compare_reranker.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGERerankerDataset
|
||||
from transformers import AutoTokenizer
|
||||
from evaluation.metrics import compute_reranker_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("compare_reranker")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command-line arguments for reranker comparison."""
|
||||
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline reranker models.")
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current memory usage of the process."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def run_reranker(model_path, data_path, batch_size, max_samples, device):
|
||||
"""
|
||||
Run inference on a reranker model and measure throughput, latency, and memory.
|
||||
Args:
|
||||
model_path (str): Path to the reranker model.
|
||||
data_path (str): Path to the evaluation data.
|
||||
batch_size (int): Batch size for inference.
|
||||
max_samples (int): Maximum number of samples to process.
|
||||
device (str): Device to use (cuda, cpu, npu).
|
||||
Returns:
|
||||
tuple: (throughput, latency, mem_mb, all_scores, all_labels)
|
||||
"""
|
||||
model = BGERerankerModel(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
outputs = model.model(
|
||||
input_ids=batch['input_ids'].to(device),
|
||||
attention_mask=batch['attention_mask'].to(device)
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n = batch['input_ids'].shape[0]
|
||||
n_samples += n
|
||||
# For metrics: collect scores and labels
|
||||
scores = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
|
||||
all_scores.append(scores.cpu().numpy())
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
# Prepare for metrics
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_labels = np.concatenate(all_labels, axis=0)
|
||||
return throughput, latency, mem_mb, all_scores, all_labels
|
||||
|
||||
def main():
|
||||
"""Run reranker comparison with robust error handling and config-driven defaults."""
|
||||
parser = argparse.ArgumentParser(description="Compare reranker models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
# Fine-tuned model
|
||||
logger.info("Running fine-tuned reranker...")
|
||||
ft_throughput, ft_latency, ft_mem, ft_scores, ft_labels = run_reranker(
|
||||
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
ft_metrics = compute_reranker_metrics(ft_scores, ft_labels, k_values=args.k_values)
|
||||
# Baseline model
|
||||
logger.info("Running baseline reranker...")
|
||||
bl_throughput, bl_latency, bl_mem, bl_scores, bl_labels = run_reranker(
|
||||
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
bl_metrics = compute_reranker_metrics(bl_scores, bl_labels, k_values=args.k_values)
|
||||
# Output comparison
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
|
||||
for k in args.k_values:
|
||||
for metric in [f'accuracy@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
for metric in ['map', 'mrr']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
|
||||
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
|
||||
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
|
||||
logger.info(f"Comparison results saved to {args.output}")
|
||||
print(f"\nComparison complete. See {args.output} for details.")
|
||||
except Exception as e:
|
||||
logging.error(f"Reranker comparison failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
144
scripts/compare_retriever.py
Normal file
144
scripts/compare_retriever.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset
|
||||
from transformers import AutoTokenizer
|
||||
from evaluation.metrics import compute_retrieval_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("compare_retriever")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for retriever comparison."""
|
||||
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline retriever models.")
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current process memory usage in MB."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def run_retriever(model_path, data_path, batch_size, max_samples, device):
|
||||
"""
|
||||
Run inference on a retriever model and measure throughput, latency, and memory.
|
||||
Args:
|
||||
model_path (str): Path to the retriever model.
|
||||
data_path (str): Path to the evaluation data.
|
||||
batch_size (int): Batch size for inference.
|
||||
max_samples (int): Maximum number of samples to process.
|
||||
device (str): Device to use (cuda, cpu, npu).
|
||||
Returns:
|
||||
tuple: throughput (samples/sec), latency (ms/sample), peak memory (MB),
|
||||
query_embeds (np.ndarray), passage_embeds (np.ndarray), labels (np.ndarray).
|
||||
"""
|
||||
model = BGEM3Model(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
query_embeds = []
|
||||
passage_embeds = []
|
||||
labels = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
out = model(
|
||||
batch['query_input_ids'].to(device),
|
||||
batch['query_attention_mask'].to(device),
|
||||
return_dense=True, return_sparse=False, return_colbert=False
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n = batch['query_input_ids'].shape[0]
|
||||
n_samples += n
|
||||
# For metrics: collect embeddings and labels
|
||||
query_embeds.append(out['query_embeds'].cpu().numpy())
|
||||
passage_embeds.append(out['passage_embeds'].cpu().numpy())
|
||||
labels.append(batch['labels'].cpu().numpy())
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
# Prepare for metrics
|
||||
query_embeds = np.concatenate(query_embeds, axis=0)
|
||||
passage_embeds = np.concatenate(passage_embeds, axis=0)
|
||||
labels = np.concatenate(labels, axis=0)
|
||||
return throughput, latency, mem_mb, query_embeds, passage_embeds, labels
|
||||
|
||||
def main():
|
||||
"""Run retriever comparison with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Compare retriever models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
# Fine-tuned model
|
||||
logger.info("Running fine-tuned retriever...")
|
||||
ft_throughput, ft_latency, ft_mem, ft_qe, ft_pe, ft_labels = run_retriever(
|
||||
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
ft_metrics = compute_retrieval_metrics(ft_qe, ft_pe, ft_labels, k_values=args.k_values)
|
||||
# Baseline model
|
||||
logger.info("Running baseline retriever...")
|
||||
bl_throughput, bl_latency, bl_mem, bl_qe, bl_pe, bl_labels = run_retriever(
|
||||
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
bl_metrics = compute_retrieval_metrics(bl_qe, bl_pe, bl_labels, k_values=args.k_values)
|
||||
# Output comparison
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
|
||||
for k in args.k_values:
|
||||
for metric in [f'recall@{k}', f'precision@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
for metric in ['map', 'mrr']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
|
||||
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
|
||||
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
|
||||
logger.info(f"Comparison results saved to {args.output}")
|
||||
print(f"\nComparison complete. See {args.output} for details.")
|
||||
except Exception as e:
|
||||
logging.error(f"Retriever comparison failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
990
scripts/comprehensive_validation.py
Normal file
990
scripts/comprehensive_validation.py
Normal file
@@ -0,0 +1,990 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Validation Script for BGE Fine-tuned Models
|
||||
|
||||
This script provides a complete validation framework to determine if fine-tuned models
|
||||
actually perform better than baseline models across multiple datasets and metrics.
|
||||
|
||||
Features:
|
||||
- Tests both M3 retriever and reranker models
|
||||
- Compares against baseline models on multiple datasets
|
||||
- Provides statistical significance testing
|
||||
- Tests pipeline performance (retrieval + reranking)
|
||||
- Generates comprehensive reports
|
||||
- Includes performance degradation detection
|
||||
- Supports cross-dataset validation
|
||||
|
||||
Usage:
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_finetuned ./output/bge-reranker/final_model \\
|
||||
--output_dir ./validation_results
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from scipy import stats
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from evaluation.metrics import (
|
||||
compute_retrieval_metrics, compute_reranker_metrics,
|
||||
compute_retrieval_metrics_comprehensive
|
||||
)
|
||||
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from utils.logging import setup_logging, get_logger
|
||||
from utils.config_loader import get_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationConfig:
|
||||
"""Configuration for comprehensive validation"""
|
||||
# Model paths
|
||||
retriever_baseline: str = "BAAI/bge-m3"
|
||||
reranker_baseline: str = "BAAI/bge-reranker-base"
|
||||
retriever_finetuned: Optional[str] = None
|
||||
reranker_finetuned: Optional[str] = None
|
||||
|
||||
# Test datasets
|
||||
test_datasets: List[str] = None
|
||||
|
||||
# Evaluation settings
|
||||
batch_size: int = 32
|
||||
max_length: int = 512
|
||||
k_values: List[int] = None
|
||||
significance_level: float = 0.05
|
||||
min_samples_for_significance: int = 30
|
||||
|
||||
# Pipeline settings
|
||||
retrieval_top_k: int = 100
|
||||
rerank_top_k: int = 10
|
||||
|
||||
# Output settings
|
||||
output_dir: str = "./validation_results"
|
||||
save_embeddings: bool = False
|
||||
create_report: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.test_datasets is None:
|
||||
self.test_datasets = [
|
||||
"data/datasets/examples/embedding_data.jsonl",
|
||||
"data/datasets/examples/reranker_data.jsonl",
|
||||
"data/datasets/Reranker_AFQMC/dev.json",
|
||||
"data/datasets/三国演义/reranker_train.jsonl"
|
||||
]
|
||||
if self.k_values is None:
|
||||
self.k_values = [1, 3, 5, 10, 20]
|
||||
|
||||
|
||||
class ComprehensiveValidator:
|
||||
"""Comprehensive validation framework for BGE models"""
|
||||
|
||||
def __init__(self, config: ValidationConfig):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
self.results = {}
|
||||
self.device = self._get_device()
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(config.output_dir, exist_ok=True)
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=config.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
"""Get optimal device for evaluation"""
|
||||
try:
|
||||
config = get_config()
|
||||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||||
|
||||
if device_type == 'npu' and torch.cuda.is_available():
|
||||
try:
|
||||
import torch_npu
|
||||
device = torch.device("npu:0")
|
||||
self.logger.info("Using NPU device")
|
||||
return device
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
self.logger.info("Using CUDA device")
|
||||
return device
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error detecting optimal device: {e}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
self.logger.info("Using CPU device")
|
||||
return device
|
||||
|
||||
def validate_all(self) -> Dict[str, Any]:
|
||||
"""Run comprehensive validation on all models and datasets"""
|
||||
self.logger.info("🚀 Starting Comprehensive BGE Model Validation")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# 1. Validate retriever models
|
||||
if self.config.retriever_finetuned:
|
||||
self.logger.info("📊 Validating Retriever Models...")
|
||||
retriever_results = self._validate_retrievers()
|
||||
self.results['retriever'] = retriever_results
|
||||
|
||||
# 2. Validate reranker models
|
||||
if self.config.reranker_finetuned:
|
||||
self.logger.info("📊 Validating Reranker Models...")
|
||||
reranker_results = self._validate_rerankers()
|
||||
self.results['reranker'] = reranker_results
|
||||
|
||||
# 3. Validate pipeline (if both models available)
|
||||
if self.config.retriever_finetuned and self.config.reranker_finetuned:
|
||||
self.logger.info("📊 Validating Pipeline Performance...")
|
||||
pipeline_results = self._validate_pipeline()
|
||||
self.results['pipeline'] = pipeline_results
|
||||
|
||||
# 4. Generate comprehensive report
|
||||
if self.config.create_report:
|
||||
self._generate_report()
|
||||
|
||||
# 5. Statistical analysis
|
||||
self._perform_statistical_analysis()
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
self.logger.info("✅ Comprehensive validation completed!")
|
||||
self.logger.info(f"⏱️ Total validation time: {duration:.2f} seconds")
|
||||
|
||||
return self.results
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Validation failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _validate_retrievers(self) -> Dict[str, Any]:
|
||||
"""Validate retriever models on all suitable datasets"""
|
||||
results = {
|
||||
'baseline_metrics': {},
|
||||
'finetuned_metrics': {},
|
||||
'comparisons': {},
|
||||
'datasets_tested': []
|
||||
}
|
||||
|
||||
for dataset_path in self.config.test_datasets:
|
||||
if not os.path.exists(dataset_path):
|
||||
self.logger.warning(f"Dataset not found: {dataset_path}")
|
||||
continue
|
||||
|
||||
# Skip reranker-only datasets for retriever testing
|
||||
if 'reranker' in dataset_path.lower() and 'AFQMC' in dataset_path:
|
||||
continue
|
||||
|
||||
dataset_name = Path(dataset_path).stem
|
||||
self.logger.info(f"Testing retriever on dataset: {dataset_name}")
|
||||
|
||||
try:
|
||||
# Test baseline model
|
||||
baseline_metrics = self._test_retriever_on_dataset(
|
||||
model_path=self.config.retriever_baseline,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=True
|
||||
)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_metrics = self._test_retriever_on_dataset(
|
||||
model_path=self.config.retriever_finetuned,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=False
|
||||
)
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "retriever")
|
||||
|
||||
results['baseline_metrics'][dataset_name] = baseline_metrics
|
||||
results['finetuned_metrics'][dataset_name] = finetuned_metrics
|
||||
results['comparisons'][dataset_name] = comparison
|
||||
results['datasets_tested'].append(dataset_name)
|
||||
|
||||
self.logger.info(f"✅ Retriever validation complete for {dataset_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Failed to test retriever on {dataset_name}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _validate_rerankers(self) -> Dict[str, Any]:
|
||||
"""Validate reranker models on all suitable datasets"""
|
||||
results = {
|
||||
'baseline_metrics': {},
|
||||
'finetuned_metrics': {},
|
||||
'comparisons': {},
|
||||
'datasets_tested': []
|
||||
}
|
||||
|
||||
for dataset_path in self.config.test_datasets:
|
||||
if not os.path.exists(dataset_path):
|
||||
self.logger.warning(f"Dataset not found: {dataset_path}")
|
||||
continue
|
||||
|
||||
dataset_name = Path(dataset_path).stem
|
||||
self.logger.info(f"Testing reranker on dataset: {dataset_name}")
|
||||
|
||||
try:
|
||||
# Test baseline model
|
||||
baseline_metrics = self._test_reranker_on_dataset(
|
||||
model_path=self.config.reranker_baseline,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=True
|
||||
)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_metrics = self._test_reranker_on_dataset(
|
||||
model_path=self.config.reranker_finetuned,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=False
|
||||
)
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "reranker")
|
||||
|
||||
results['baseline_metrics'][dataset_name] = baseline_metrics
|
||||
results['finetuned_metrics'][dataset_name] = finetuned_metrics
|
||||
results['comparisons'][dataset_name] = comparison
|
||||
results['datasets_tested'].append(dataset_name)
|
||||
|
||||
self.logger.info(f"✅ Reranker validation complete for {dataset_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Failed to test reranker on {dataset_name}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _validate_pipeline(self) -> Dict[str, Any]:
|
||||
"""Validate end-to-end retrieval + reranking pipeline"""
|
||||
results = {
|
||||
'baseline_pipeline': {},
|
||||
'finetuned_pipeline': {},
|
||||
'comparisons': {},
|
||||
'datasets_tested': []
|
||||
}
|
||||
|
||||
for dataset_path in self.config.test_datasets:
|
||||
if not os.path.exists(dataset_path):
|
||||
continue
|
||||
|
||||
# Only test on datasets suitable for pipeline evaluation
|
||||
if 'embedding_data' not in dataset_path:
|
||||
continue
|
||||
|
||||
dataset_name = Path(dataset_path).stem
|
||||
self.logger.info(f"Testing pipeline on dataset: {dataset_name}")
|
||||
|
||||
try:
|
||||
# Test baseline pipeline
|
||||
baseline_metrics = self._test_pipeline_on_dataset(
|
||||
retriever_path=self.config.retriever_baseline,
|
||||
reranker_path=self.config.reranker_baseline,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=True
|
||||
)
|
||||
|
||||
# Test fine-tuned pipeline
|
||||
finetuned_metrics = self._test_pipeline_on_dataset(
|
||||
retriever_path=self.config.retriever_finetuned,
|
||||
reranker_path=self.config.reranker_finetuned,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=False
|
||||
)
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "pipeline")
|
||||
|
||||
results['baseline_pipeline'][dataset_name] = baseline_metrics
|
||||
results['finetuned_pipeline'][dataset_name] = finetuned_metrics
|
||||
results['comparisons'][dataset_name] = comparison
|
||||
results['datasets_tested'].append(dataset_name)
|
||||
|
||||
self.logger.info(f"✅ Pipeline validation complete for {dataset_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Failed to test pipeline on {dataset_name}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _test_retriever_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
||||
"""Test a retriever model on a specific dataset"""
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=self.config.max_length,
|
||||
passage_max_length=self.config.max_length,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# Create evaluator
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=self.device,
|
||||
k_values=self.config.k_values
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
metrics = evaluator.evaluate(
|
||||
dataloader=dataloader,
|
||||
desc=f"{'Baseline' if is_baseline else 'Fine-tuned'} Retriever"
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error testing retriever: {e}")
|
||||
return {}
|
||||
|
||||
def _test_reranker_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
||||
"""Test a reranker model on a specific dataset"""
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Create dataset
|
||||
dataset = BGERerankerDataset(
|
||||
data_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=self.config.max_length,
|
||||
passage_max_length=self.config.max_length,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# Collect predictions and labels
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
# Move to device
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
attention_mask = batch['attention_mask'].to(self.device)
|
||||
labels = batch['labels'].cpu().numpy()
|
||||
|
||||
# Get predictions
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
scores = outputs.logits.cpu().numpy()
|
||||
|
||||
all_scores.append(scores)
|
||||
all_labels.append(labels)
|
||||
|
||||
# Concatenate results
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
# Compute metrics
|
||||
metrics = compute_reranker_metrics(
|
||||
scores=all_scores,
|
||||
labels=all_labels,
|
||||
k_values=self.config.k_values
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error testing reranker: {e}")
|
||||
return {}
|
||||
|
||||
def _test_pipeline_on_dataset(self, retriever_path: str, reranker_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
||||
"""Test end-to-end pipeline on a dataset"""
|
||||
try:
|
||||
# Load models
|
||||
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
|
||||
retriever = BGEM3Model.from_pretrained(retriever_path)
|
||||
retriever.to(self.device)
|
||||
retriever.eval()
|
||||
|
||||
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
|
||||
reranker = BGERerankerModel.from_pretrained(reranker_path)
|
||||
reranker.to(self.device)
|
||||
reranker.eval()
|
||||
|
||||
# Load and preprocess data
|
||||
data = []
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
item = json.loads(line.strip())
|
||||
if 'query' in item and 'pos' in item and 'neg' in item:
|
||||
data.append(item)
|
||||
except:
|
||||
continue
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Test pipeline performance
|
||||
total_queries = 0
|
||||
pipeline_metrics = {f'recall@{k}': [] for k in self.config.k_values}
|
||||
pipeline_metrics.update({f'mrr@{k}': [] for k in self.config.k_values})
|
||||
|
||||
for item in data[:100]: # Limit to 100 queries for speed
|
||||
query = item['query']
|
||||
pos_docs = item['pos'][:2] # Take first 2 positive docs
|
||||
neg_docs = item['neg'][:8] # Take first 8 negative docs
|
||||
|
||||
# Create candidate pool
|
||||
candidates = pos_docs + neg_docs
|
||||
labels = [1] * len(pos_docs) + [0] * len(neg_docs)
|
||||
|
||||
if len(candidates) < 3: # Need minimum candidates
|
||||
continue
|
||||
|
||||
try:
|
||||
# Step 1: Retrieval (encode query and candidates)
|
||||
query_inputs = retriever_tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.config.max_length
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
query_outputs = retriever(**query_inputs, return_dense=True)
|
||||
query_emb = query_outputs['dense']
|
||||
|
||||
# Encode candidates
|
||||
candidate_embs = []
|
||||
for candidate in candidates:
|
||||
cand_inputs = retriever_tokenizer(
|
||||
candidate,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.config.max_length
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
cand_outputs = retriever(**cand_inputs, return_dense=True)
|
||||
candidate_embs.append(cand_outputs['dense'])
|
||||
|
||||
candidate_embs = torch.cat(candidate_embs, dim=0)
|
||||
|
||||
# Compute similarities
|
||||
similarities = torch.cosine_similarity(query_emb, candidate_embs, dim=1)
|
||||
|
||||
# Get top-k for retrieval
|
||||
top_k_retrieval = min(self.config.retrieval_top_k, len(candidates))
|
||||
retrieval_indices = torch.argsort(similarities, descending=True)[:top_k_retrieval]
|
||||
|
||||
# Step 2: Reranking
|
||||
rerank_candidates = [candidates[i] for i in retrieval_indices]
|
||||
rerank_labels = [labels[i] for i in retrieval_indices]
|
||||
|
||||
# Score pairs with reranker
|
||||
rerank_scores = []
|
||||
for candidate in rerank_candidates:
|
||||
pair_text = f"{query} [SEP] {candidate}"
|
||||
pair_inputs = reranker_tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.config.max_length
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pair_outputs = reranker(**pair_inputs)
|
||||
score = pair_outputs.logits.item()
|
||||
rerank_scores.append(score)
|
||||
|
||||
# Final ranking
|
||||
final_indices = np.argsort(rerank_scores)[::-1]
|
||||
final_labels = [rerank_labels[i] for i in final_indices]
|
||||
|
||||
# Compute metrics
|
||||
for k in self.config.k_values:
|
||||
if k <= len(final_labels):
|
||||
# Recall@k
|
||||
relevant_in_topk = sum(final_labels[:k])
|
||||
total_relevant = sum(rerank_labels)
|
||||
recall = relevant_in_topk / total_relevant if total_relevant > 0 else 0
|
||||
pipeline_metrics[f'recall@{k}'].append(recall)
|
||||
|
||||
# MRR@k
|
||||
for i, label in enumerate(final_labels[:k]):
|
||||
if label == 1:
|
||||
pipeline_metrics[f'mrr@{k}'].append(1.0 / (i + 1))
|
||||
break
|
||||
else:
|
||||
pipeline_metrics[f'mrr@{k}'].append(0.0)
|
||||
|
||||
total_queries += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error processing query: {e}")
|
||||
continue
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {}
|
||||
for metric, values in pipeline_metrics.items():
|
||||
if values:
|
||||
final_metrics[metric] = np.mean(values)
|
||||
else:
|
||||
final_metrics[metric] = 0.0
|
||||
|
||||
final_metrics['total_queries'] = total_queries
|
||||
|
||||
return final_metrics
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error testing pipeline: {e}")
|
||||
return {}
|
||||
|
||||
def _compare_metrics(self, baseline: Dict[str, float], finetuned: Dict[str, float], model_type: str) -> Dict[str, Any]:
|
||||
"""Compare baseline vs fine-tuned metrics with statistical analysis"""
|
||||
comparison = {
|
||||
'improvements': {},
|
||||
'degradations': {},
|
||||
'statistical_significance': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
# Compare common metrics
|
||||
common_metrics = set(baseline.keys()) & set(finetuned.keys())
|
||||
|
||||
improvements = 0
|
||||
degradations = 0
|
||||
significant_improvements = 0
|
||||
|
||||
for metric in common_metrics:
|
||||
if isinstance(baseline[metric], (int, float)) and isinstance(finetuned[metric], (int, float)):
|
||||
baseline_val = baseline[metric]
|
||||
finetuned_val = finetuned[metric]
|
||||
|
||||
# Calculate improvement
|
||||
if baseline_val != 0:
|
||||
improvement_pct = ((finetuned_val - baseline_val) / baseline_val) * 100
|
||||
else:
|
||||
improvement_pct = float('inf') if finetuned_val > 0 else 0
|
||||
|
||||
# Classify improvement/degradation
|
||||
if finetuned_val > baseline_val:
|
||||
comparison['improvements'][metric] = {
|
||||
'baseline': baseline_val,
|
||||
'finetuned': finetuned_val,
|
||||
'improvement_pct': improvement_pct,
|
||||
'absolute_improvement': finetuned_val - baseline_val
|
||||
}
|
||||
improvements += 1
|
||||
elif finetuned_val < baseline_val:
|
||||
comparison['degradations'][metric] = {
|
||||
'baseline': baseline_val,
|
||||
'finetuned': finetuned_val,
|
||||
'degradation_pct': -improvement_pct,
|
||||
'absolute_degradation': baseline_val - finetuned_val
|
||||
}
|
||||
degradations += 1
|
||||
|
||||
# Summary statistics
|
||||
comparison['summary'] = {
|
||||
'total_metrics': len(common_metrics),
|
||||
'improvements': improvements,
|
||||
'degradations': degradations,
|
||||
'unchanged': len(common_metrics) - improvements - degradations,
|
||||
'improvement_ratio': improvements / len(common_metrics) if common_metrics else 0,
|
||||
'net_improvement': improvements - degradations
|
||||
}
|
||||
|
||||
return comparison
|
||||
|
||||
def _perform_statistical_analysis(self):
|
||||
"""Perform statistical analysis on validation results"""
|
||||
self.logger.info("📈 Performing Statistical Analysis...")
|
||||
|
||||
analysis = {
|
||||
'overall_summary': {},
|
||||
'significance_tests': {},
|
||||
'confidence_intervals': {},
|
||||
'effect_sizes': {}
|
||||
}
|
||||
|
||||
# Analyze each model type
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type not in self.results:
|
||||
continue
|
||||
|
||||
model_results = self.results[model_type]
|
||||
|
||||
# Collect all improvements across datasets
|
||||
all_improvements = []
|
||||
significant_improvements = 0
|
||||
|
||||
for dataset, comparison in model_results.get('comparisons', {}).items():
|
||||
for metric, improvement_data in comparison.get('improvements', {}).items():
|
||||
improvement_pct = improvement_data.get('improvement_pct', 0)
|
||||
if improvement_pct > 0:
|
||||
all_improvements.append(improvement_pct)
|
||||
|
||||
# Simple significance test (can be enhanced)
|
||||
if improvement_pct > 5.0: # 5% improvement threshold
|
||||
significant_improvements += 1
|
||||
|
||||
if all_improvements:
|
||||
analysis[model_type] = {
|
||||
'mean_improvement_pct': np.mean(all_improvements),
|
||||
'median_improvement_pct': np.median(all_improvements),
|
||||
'std_improvement_pct': np.std(all_improvements),
|
||||
'min_improvement_pct': np.min(all_improvements),
|
||||
'max_improvement_pct': np.max(all_improvements),
|
||||
'significant_improvements': significant_improvements,
|
||||
'total_improvements': len(all_improvements)
|
||||
}
|
||||
|
||||
self.results['statistical_analysis'] = analysis
|
||||
|
||||
def _generate_report(self):
|
||||
"""Generate comprehensive validation report"""
|
||||
self.logger.info("📝 Generating Comprehensive Report...")
|
||||
|
||||
report_path = os.path.join(self.config.output_dir, "validation_report.html")
|
||||
|
||||
html_content = self._create_html_report()
|
||||
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
# Also save JSON results
|
||||
json_path = os.path.join(self.config.output_dir, "validation_results.json")
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.results, f, indent=2, default=str)
|
||||
|
||||
self.logger.info(f"📊 Report saved to: {report_path}")
|
||||
self.logger.info(f"📊 JSON results saved to: {json_path}")
|
||||
|
||||
def _create_html_report(self) -> str:
|
||||
"""Create HTML report with comprehensive validation results"""
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>BGE Model Validation Report</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.header {{ background: #f0f8ff; padding: 20px; border-radius: 10px; }}
|
||||
.section {{ margin: 20px 0; }}
|
||||
.metrics-table {{ border-collapse: collapse; width: 100%; }}
|
||||
.metrics-table th, .metrics-table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
||||
.metrics-table th {{ background-color: #f2f2f2; }}
|
||||
.improvement {{ color: green; font-weight: bold; }}
|
||||
.degradation {{ color: red; font-weight: bold; }}
|
||||
.summary-box {{ background: #f9f9f9; padding: 15px; border-left: 4px solid #4CAF50; margin: 10px 0; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>🚀 BGE Model Comprehensive Validation Report</h1>
|
||||
<p><strong>Generated:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
<p><strong>Configuration:</strong></p>
|
||||
<ul>
|
||||
<li>Retriever Baseline: {self.config.retriever_baseline}</li>
|
||||
<li>Retriever Fine-tuned: {self.config.retriever_finetuned or 'Not tested'}</li>
|
||||
<li>Reranker Baseline: {self.config.reranker_baseline}</li>
|
||||
<li>Reranker Fine-tuned: {self.config.reranker_finetuned or 'Not tested'}</li>
|
||||
</ul>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add executive summary
|
||||
html += self._create_executive_summary()
|
||||
|
||||
# Add detailed results for each model type
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in self.results:
|
||||
html += self._create_model_section(model_type)
|
||||
|
||||
html += """
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html
|
||||
|
||||
def _create_executive_summary(self) -> str:
|
||||
"""Create executive summary section"""
|
||||
html = """
|
||||
<div class="section">
|
||||
<h2>📊 Executive Summary</h2>
|
||||
"""
|
||||
|
||||
# Overall verdict
|
||||
total_improvements = 0
|
||||
total_degradations = 0
|
||||
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in self.results and 'comparisons' in self.results[model_type]:
|
||||
for dataset, comparison in self.results[model_type]['comparisons'].items():
|
||||
total_improvements += len(comparison.get('improvements', {}))
|
||||
total_degradations += len(comparison.get('degradations', {}))
|
||||
|
||||
if total_improvements > total_degradations:
|
||||
verdict = "✅ <strong>FINE-TUNING SUCCESSFUL</strong> - Models show overall improvement"
|
||||
verdict_color = "green"
|
||||
elif total_improvements == total_degradations:
|
||||
verdict = "⚠️ <strong>MIXED RESULTS</strong> - Equal improvements and degradations"
|
||||
verdict_color = "orange"
|
||||
else:
|
||||
verdict = "❌ <strong>FINE-TUNING ISSUES</strong> - More degradations than improvements"
|
||||
verdict_color = "red"
|
||||
|
||||
html += f"""
|
||||
<div class="summary-box" style="border-left-color: {verdict_color};">
|
||||
<h3>{verdict}</h3>
|
||||
<ul>
|
||||
<li>Total Metric Improvements: {total_improvements}</li>
|
||||
<li>Total Metric Degradations: {total_degradations}</li>
|
||||
<li>Net Improvement: {total_improvements - total_degradations}</li>
|
||||
</ul>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</div>
|
||||
"""
|
||||
return html
|
||||
|
||||
def _create_model_section(self, model_type: str) -> str:
|
||||
"""Create detailed section for a model type"""
|
||||
html = f"""
|
||||
<div class="section">
|
||||
<h2>📈 {model_type.title()} Model Results</h2>
|
||||
"""
|
||||
|
||||
model_results = self.results[model_type]
|
||||
|
||||
for dataset in model_results.get('datasets_tested', []):
|
||||
if dataset in model_results['comparisons']:
|
||||
comparison = model_results['comparisons'][dataset]
|
||||
|
||||
html += f"""
|
||||
<h3>Dataset: {dataset}</h3>
|
||||
<table class="metrics-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Metric</th>
|
||||
<th>Baseline</th>
|
||||
<th>Fine-tuned</th>
|
||||
<th>Change</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
# Add improvements
|
||||
for metric, data in comparison.get('improvements', {}).items():
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{metric}</td>
|
||||
<td>{data['baseline']:.4f}</td>
|
||||
<td>{data['finetuned']:.4f}</td>
|
||||
<td class="improvement">+{data['improvement_pct']:.2f}%</td>
|
||||
<td class="improvement">✅ Improved</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
# Add degradations
|
||||
for metric, data in comparison.get('degradations', {}).items():
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{metric}</td>
|
||||
<td>{data['baseline']:.4f}</td>
|
||||
<td>{data['finetuned']:.4f}</td>
|
||||
<td class="degradation">-{data['degradation_pct']:.2f}%</td>
|
||||
<td class="degradation">❌ Degraded</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</div>
|
||||
"""
|
||||
return html
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Comprehensive validation of BGE fine-tuned models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Validate both retriever and reranker
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_finetuned ./output/bge-reranker/final_model
|
||||
|
||||
# Validate only retriever
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model
|
||||
|
||||
# Custom datasets and output
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
||||
--test_datasets data/my_test1.jsonl data/my_test2.jsonl \\
|
||||
--output_dir ./my_validation_results
|
||||
"""
|
||||
)
|
||||
|
||||
# Model paths
|
||||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||||
help="Path to baseline retriever model")
|
||||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to baseline reranker model")
|
||||
parser.add_argument("--retriever_finetuned", type=str, default=None,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_finetuned", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
|
||||
# Test data
|
||||
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
|
||||
help="Paths to test datasets (JSONL format)")
|
||||
|
||||
# Evaluation settings
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 3, 5, 10, 20],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Output settings
|
||||
parser.add_argument("--output_dir", type=str, default="./validation_results",
|
||||
help="Directory to save validation results")
|
||||
parser.add_argument("--save_embeddings", action="store_true",
|
||||
help="Save embeddings for further analysis")
|
||||
parser.add_argument("--no_report", action="store_true",
|
||||
help="Skip HTML report generation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for retrieval stage")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking stage")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main validation function"""
|
||||
args = parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if not args.retriever_finetuned and not args.reranker_finetuned:
|
||||
print("❌ Error: At least one fine-tuned model must be specified")
|
||||
print(" Use --retriever_finetuned or --reranker_finetuned")
|
||||
return 1
|
||||
|
||||
# Create configuration
|
||||
config = ValidationConfig(
|
||||
retriever_baseline=args.retriever_baseline,
|
||||
reranker_baseline=args.reranker_baseline,
|
||||
retriever_finetuned=args.retriever_finetuned,
|
||||
reranker_finetuned=args.reranker_finetuned,
|
||||
test_datasets=args.test_datasets,
|
||||
batch_size=args.batch_size,
|
||||
max_length=args.max_length,
|
||||
k_values=args.k_values,
|
||||
output_dir=args.output_dir,
|
||||
save_embeddings=args.save_embeddings,
|
||||
create_report=not args.no_report,
|
||||
retrieval_top_k=args.retrieval_top_k,
|
||||
rerank_top_k=args.rerank_top_k
|
||||
)
|
||||
|
||||
# Run validation
|
||||
try:
|
||||
validator = ComprehensiveValidator(config)
|
||||
results = validator.validate_all()
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("🎉 COMPREHENSIVE VALIDATION COMPLETED!")
|
||||
print("=" * 80)
|
||||
print(f"📊 Results saved to: {config.output_dir}")
|
||||
print("\n📈 Quick Summary:")
|
||||
|
||||
# Print quick summary
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in results:
|
||||
model_results = results[model_type]
|
||||
total_datasets = len(model_results.get('datasets_tested', []))
|
||||
total_improvements = sum(len(comp.get('improvements', {}))
|
||||
for comp in model_results.get('comparisons', {}).values())
|
||||
total_degradations = sum(len(comp.get('degradations', {}))
|
||||
for comp in model_results.get('comparisons', {}).values())
|
||||
|
||||
status = "✅" if total_improvements > total_degradations else "⚠️" if total_improvements == total_degradations else "❌"
|
||||
print(f" {status} {model_type.title()}: {total_improvements} improvements, {total_degradations} degradations across {total_datasets} datasets")
|
||||
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n❌ Validation interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Validation failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
291
scripts/convert_m3_dataset.py
Normal file
291
scripts/convert_m3_dataset.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE-M3 Dataset Format Converter
|
||||
|
||||
Converts the nested dataset format where training data is stored inside a "data" field
|
||||
to the flat format expected by the BGE-M3 training pipeline.
|
||||
|
||||
Input format:
|
||||
{
|
||||
"query": "...",
|
||||
"data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}"
|
||||
}
|
||||
|
||||
Output format:
|
||||
{
|
||||
"query": "...",
|
||||
"pos": [...],
|
||||
"neg": [...],
|
||||
"pos_scores": [...],
|
||||
"neg_scores": [...],
|
||||
"prompt": "...",
|
||||
"type": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Any, Optional
|
||||
import re
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract JSON data from the "data" field which contains a JSON string
|
||||
wrapped in markdown code blocks.
|
||||
|
||||
Args:
|
||||
data_field: String containing JSON data, possibly wrapped in markdown
|
||||
|
||||
Returns:
|
||||
Parsed JSON dictionary or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Remove markdown code block markers if present
|
||||
cleaned_data = data_field.strip()
|
||||
|
||||
# Remove ```json and ``` markers
|
||||
if cleaned_data.startswith('```json'):
|
||||
cleaned_data = cleaned_data[7:] # Remove ```json
|
||||
elif cleaned_data.startswith('```'):
|
||||
cleaned_data = cleaned_data[3:] # Remove ```
|
||||
|
||||
if cleaned_data.endswith('```'):
|
||||
cleaned_data = cleaned_data[:-3] # Remove trailing ```
|
||||
|
||||
# Clean up any remaining whitespace
|
||||
cleaned_data = cleaned_data.strip()
|
||||
|
||||
# Parse the JSON
|
||||
return json.loads(cleaned_data)
|
||||
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
logger.warning(f"Failed to parse JSON from data field: {e}")
|
||||
logger.warning(f"Data content (first 200 chars): {data_field[:200]}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_dataset_format(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Convert dataset from nested format to flat format.
|
||||
|
||||
Args:
|
||||
input_file: Path to input JSONL file with nested format
|
||||
output_file: Path to output JSONL file with flat format
|
||||
"""
|
||||
logger.info(f"Converting dataset from {input_file} to {output_file}")
|
||||
|
||||
# Count total lines for progress bar
|
||||
total_lines = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
total_lines += 1
|
||||
|
||||
converted_count = 0
|
||||
error_count = 0
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as infile, \
|
||||
open(output_file, 'w', encoding='utf-8') as outfile:
|
||||
|
||||
for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse the input JSON line
|
||||
item = json.loads(line)
|
||||
|
||||
# Validate required fields
|
||||
if 'query' not in item:
|
||||
logger.warning(f"Line {line_num}: Missing 'query' field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
if 'data' not in item:
|
||||
logger.warning(f"Line {line_num}: Missing 'data' field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Extract data from the nested field
|
||||
data_content = extract_json_from_data_field(item['data'])
|
||||
if data_content is None:
|
||||
logger.warning(f"Line {line_num}: Failed to parse data field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Create the flattened format
|
||||
flattened_item = {
|
||||
'query': item['query']
|
||||
}
|
||||
|
||||
# Add all fields from the data content
|
||||
expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type']
|
||||
for field in expected_fields:
|
||||
if field in data_content:
|
||||
flattened_item[field] = data_content[field]
|
||||
else:
|
||||
# Provide reasonable defaults
|
||||
if field == 'pos':
|
||||
flattened_item[field] = []
|
||||
elif field == 'neg':
|
||||
flattened_item[field] = []
|
||||
elif field == 'pos_scores':
|
||||
flattened_item[field] = [1.0] * len(data_content.get('pos', []))
|
||||
elif field == 'neg_scores':
|
||||
flattened_item[field] = [0.0] * len(data_content.get('neg', []))
|
||||
elif field == 'prompt':
|
||||
flattened_item[field] = '为此查询生成表示:'
|
||||
elif field == 'type':
|
||||
flattened_item[field] = 'normal'
|
||||
|
||||
# Validate the converted item
|
||||
if not flattened_item['pos'] and not flattened_item['neg']:
|
||||
logger.warning(f"Line {line_num}: No positive or negative passages found")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Ensure scores arrays match passage arrays
|
||||
if len(flattened_item['pos_scores']) != len(flattened_item['pos']):
|
||||
flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos'])
|
||||
|
||||
if len(flattened_item['neg_scores']) != len(flattened_item['neg']):
|
||||
flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg'])
|
||||
|
||||
# Write the flattened item
|
||||
outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n')
|
||||
converted_count += 1
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
|
||||
error_count += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Line {line_num}: Unexpected error - {e}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
logger.info(f"Conversion complete!")
|
||||
logger.info(f"Successfully converted: {converted_count} items")
|
||||
logger.info(f"Errors encountered: {error_count} items")
|
||||
logger.info(f"Output saved to: {output_file}")
|
||||
|
||||
|
||||
def validate_converted_file(file_path: str, sample_size: int = 5) -> None:
|
||||
"""
|
||||
Validate the converted file by showing a few sample entries.
|
||||
|
||||
Args:
|
||||
file_path: Path to the converted JSONL file
|
||||
sample_size: Number of sample entries to display
|
||||
"""
|
||||
logger.info(f"Validating converted file: {file_path}")
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
total_lines = len(lines)
|
||||
logger.info(f"Total converted entries: {total_lines}")
|
||||
|
||||
if total_lines == 0:
|
||||
logger.warning("No entries found in converted file!")
|
||||
return
|
||||
|
||||
# Show sample entries
|
||||
sample_indices = list(range(min(sample_size, total_lines)))
|
||||
|
||||
for i, idx in enumerate(sample_indices):
|
||||
try:
|
||||
item = json.loads(lines[idx])
|
||||
logger.info(f"\nSample {i+1} (line {idx+1}):")
|
||||
logger.info(f" Query: {item['query'][:100]}...")
|
||||
logger.info(f" Positives: {len(item.get('pos', []))}")
|
||||
logger.info(f" Negatives: {len(item.get('neg', []))}")
|
||||
logger.info(f" Type: {item.get('type', 'N/A')}")
|
||||
logger.info(f" Prompt: {item.get('prompt', 'N/A')}")
|
||||
|
||||
# Show first positive and negative if available
|
||||
if item.get('pos'):
|
||||
logger.info(f" First positive: {item['pos'][0][:150]}...")
|
||||
if item.get('neg'):
|
||||
logger.info(f" First negative: {item['neg'][0][:150]}...")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON at line {idx+1}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating file: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert BGE-M3 dataset from nested to flat format",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl
|
||||
python convert_m3_dataset.py input.jsonl output.jsonl --validate
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'input_file',
|
||||
help='Path to input JSONL file with nested format'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'output_file',
|
||||
help='Path to output JSONL file with flat format'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--validate',
|
||||
action='store_true',
|
||||
help='Validate the converted file and show samples'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--sample-size',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Number of sample entries to show during validation (default: 5)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check if input file exists
|
||||
if not os.path.exists(args.input_file):
|
||||
logger.error(f"Input file not found: {args.input_file}")
|
||||
return 1
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = os.path.dirname(args.output_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
logger.info(f"Created output directory: {output_dir}")
|
||||
|
||||
try:
|
||||
# Convert the dataset
|
||||
convert_dataset_format(args.input_file, args.output_file)
|
||||
|
||||
# Validate if requested
|
||||
if args.validate:
|
||||
validate_converted_file(args.output_file, args.sample_size)
|
||||
|
||||
logger.info("Dataset conversion completed successfully!")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dataset conversion failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
584
scripts/evaluate.py
Normal file
584
scripts/evaluate.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
Evaluation script for fine-tuned BGE models
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
import numpy as np
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from utils.logging import setup_logging, get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for evaluation."""
|
||||
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--device", type=str, default="auto",
|
||||
help="Device to use (auto, cpu, cuda)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validation
|
||||
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
|
||||
args.eval_pipeline = True # Default to pipeline evaluation
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def load_evaluation_data(data_path: str, data_format: str):
|
||||
"""Load evaluation data from a file."""
|
||||
preprocessor = DataPreprocessor()
|
||||
|
||||
if data_format == "auto":
|
||||
data_format = preprocessor._detect_format(data_path)
|
||||
|
||||
if data_format == "jsonl":
|
||||
data = preprocessor._load_jsonl(data_path)
|
||||
elif data_format == "json":
|
||||
data = preprocessor._load_json(data_path)
|
||||
elif data_format in ["csv", "tsv"]:
|
||||
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
|
||||
elif data_format == "dureader":
|
||||
data = preprocessor._load_dureader(data_path)
|
||||
elif data_format == "msmarco":
|
||||
data = preprocessor._load_msmarco(data_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {data_format}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
|
||||
"""Evaluate the retriever model."""
|
||||
logger.info("Evaluating retriever model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=retriever_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for evaluation
|
||||
queries = [item['query'] for item in eval_data]
|
||||
|
||||
# If we have passage data, create passage corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
# Create relevance mapping
|
||||
relevant_docs = {}
|
||||
for i, item in enumerate(eval_data):
|
||||
query = item['query']
|
||||
if 'pos' in item:
|
||||
relevant_indices = []
|
||||
for pos in item['pos']:
|
||||
if pos in passages:
|
||||
relevant_indices.append(passages.index(pos))
|
||||
if relevant_indices:
|
||||
relevant_docs[query] = relevant_indices
|
||||
|
||||
# Evaluate
|
||||
if passages and relevant_docs:
|
||||
metrics = evaluator.evaluate_retrieval_pipeline(
|
||||
queries=queries,
|
||||
corpus=passages,
|
||||
relevant_docs=relevant_docs,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
else:
|
||||
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
|
||||
# Create dummy evaluation
|
||||
metrics = {"note": "Limited evaluation due to data format"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the reranker model."""
|
||||
logger.info("Evaluating reranker model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for reranker evaluation
|
||||
total_queries = 0
|
||||
total_correct = 0
|
||||
total_mrr = 0
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or 'neg' not in item:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
positives = item['pos']
|
||||
negatives = item['neg']
|
||||
|
||||
if not positives or not negatives:
|
||||
continue
|
||||
|
||||
# Create candidates (1 positive + negatives)
|
||||
candidates = positives[:1] + negatives
|
||||
|
||||
# Score all candidates
|
||||
pairs = [(query, candidate) for candidate in candidates]
|
||||
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
|
||||
# Check if positive is ranked first
|
||||
best_idx = np.argmax(scores)
|
||||
if best_idx == 0: # First is positive
|
||||
total_correct += 1
|
||||
total_mrr += 1.0
|
||||
else:
|
||||
# Find rank of positive (it's always at index 0)
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
|
||||
total_mrr += 1.0 / positive_rank
|
||||
|
||||
total_queries += 1
|
||||
|
||||
if total_queries > 0:
|
||||
metrics = {
|
||||
'accuracy@1': total_correct / total_queries,
|
||||
'mrr': total_mrr / total_queries,
|
||||
'total_queries': total_queries
|
||||
}
|
||||
else:
|
||||
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the full retrieval + reranking pipeline."""
|
||||
logger.info("Evaluating full pipeline...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
return {"note": "No passages found for pipeline evaluation"}
|
||||
|
||||
# Pre-encode corpus with retriever
|
||||
logger.info("Encoding corpus...")
|
||||
corpus_embeddings = retriever_model.encode(
|
||||
passages,
|
||||
batch_size=args.batch_size,
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
total_queries = 0
|
||||
pipeline_metrics = {
|
||||
'retrieval_recall@10': [],
|
||||
'rerank_accuracy@1': [],
|
||||
'rerank_mrr': []
|
||||
}
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or not item['pos']:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
relevant_passages = item['pos']
|
||||
|
||||
# Step 1: Retrieval
|
||||
query_embedding = retriever_model.encode(
|
||||
[query],
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
# Compute similarities
|
||||
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
|
||||
|
||||
# Get top-k for retrieval
|
||||
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
|
||||
retrieved_passages = [passages[i] for i in top_k_indices]
|
||||
|
||||
# Check retrieval recall
|
||||
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
|
||||
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
|
||||
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
|
||||
|
||||
# Step 2: Reranking (if reranker available)
|
||||
if reranker_model is not None:
|
||||
# Rerank top candidates
|
||||
rerank_candidates = retrieved_passages[:args.rerank_top_k]
|
||||
pairs = [(query, candidate) for candidate in rerank_candidates]
|
||||
|
||||
if pairs:
|
||||
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
reranked_indices = np.argsort(rerank_scores)[::-1]
|
||||
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
|
||||
|
||||
# Check reranking performance
|
||||
if reranked_passages[0] in relevant_passages:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(1.0)
|
||||
pipeline_metrics['rerank_mrr'].append(1.0)
|
||||
else:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(0.0)
|
||||
# Find MRR
|
||||
for rank, passage in enumerate(reranked_passages, 1):
|
||||
if passage in relevant_passages:
|
||||
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
|
||||
break
|
||||
else:
|
||||
pipeline_metrics['rerank_mrr'].append(0.0)
|
||||
|
||||
total_queries += 1
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {}
|
||||
for metric, values in pipeline_metrics.items():
|
||||
if values:
|
||||
final_metrics[metric] = np.mean(values)
|
||||
|
||||
final_metrics['total_queries'] = total_queries
|
||||
|
||||
return final_metrics
|
||||
|
||||
|
||||
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Run interactive evaluation."""
|
||||
logger.info("Starting interactive evaluation...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
print("No passages found for interactive evaluation")
|
||||
return
|
||||
|
||||
# Create interactive evaluator
|
||||
evaluator = InteractiveEvaluator(
|
||||
retriever=retriever_model,
|
||||
reranker=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
corpus=passages,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
print("\n=== Interactive Evaluation ===")
|
||||
print("Enter queries to test the retrieval system.")
|
||||
print("Type 'quit' to exit, 'sample' to test with sample queries.")
|
||||
|
||||
# Load some sample queries
|
||||
sample_queries = [item['query'] for item in eval_data[:5]]
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\nEnter query: ").strip()
|
||||
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
break
|
||||
elif user_input.lower() == 'sample':
|
||||
print("\nSample queries:")
|
||||
for i, query in enumerate(sample_queries):
|
||||
print(f"{i + 1}. {query}")
|
||||
continue
|
||||
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
|
||||
query = sample_queries[int(user_input) - 1]
|
||||
else:
|
||||
query = user_input
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Search
|
||||
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
|
||||
|
||||
print(f"\nResults for: {query}")
|
||||
print("-" * 50)
|
||||
for i, (idx, score, text) in enumerate(results):
|
||||
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting interactive evaluation...")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run evaluation with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting model evaluation")
|
||||
logger.info(f"Arguments: {args}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Set device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
elif device == "npu":
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
device = "npu"
|
||||
else:
|
||||
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
|
||||
device = "cpu"
|
||||
else:
|
||||
device = device
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load models
|
||||
logger.info("Loading models...")
|
||||
from utils.config_loader import get_config
|
||||
model_source = get_config().get('model_paths.source', 'huggingface')
|
||||
# Load retriever
|
||||
try:
|
||||
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
|
||||
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
|
||||
retriever_model.to(device)
|
||||
retriever_model.eval()
|
||||
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load retriever: {e}")
|
||||
return
|
||||
|
||||
# Load reranker (optional)
|
||||
reranker_model = None
|
||||
reranker_tokenizer = None
|
||||
if args.reranker_model:
|
||||
try:
|
||||
reranker_model = BGERerankerModel(args.reranker_model)
|
||||
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
|
||||
reranker_model.to(device)
|
||||
reranker_model.eval()
|
||||
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load reranker: {e}")
|
||||
|
||||
# Load evaluation data
|
||||
logger.info("Loading evaluation data...")
|
||||
eval_data = load_evaluation_data(args.eval_data, args.data_format)
|
||||
logger.info(f"Loaded {len(eval_data)} evaluation examples")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Run evaluations
|
||||
all_results = {}
|
||||
|
||||
if args.eval_retriever:
|
||||
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
|
||||
all_results['retriever'] = retriever_metrics
|
||||
logger.info(f"Retriever metrics: {retriever_metrics}")
|
||||
|
||||
if args.eval_reranker and reranker_model:
|
||||
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
|
||||
all_results['reranker'] = reranker_metrics
|
||||
logger.info(f"Reranker metrics: {reranker_metrics}")
|
||||
|
||||
if args.eval_pipeline:
|
||||
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
all_results['pipeline'] = pipeline_metrics
|
||||
logger.info(f"Pipeline metrics: {pipeline_metrics}")
|
||||
|
||||
# Compare with baseline if requested
|
||||
if args.compare_baseline:
|
||||
logger.info("Loading baseline models for comparison...")
|
||||
try:
|
||||
base_retriever = BGEM3Model(args.base_retriever)
|
||||
base_retriever.to(device)
|
||||
base_retriever.eval()
|
||||
|
||||
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
|
||||
all_results['baseline_retriever'] = base_retriever_metrics
|
||||
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
|
||||
|
||||
if reranker_model:
|
||||
base_reranker = BGERerankerModel(args.base_reranker)
|
||||
base_reranker.to(device)
|
||||
base_reranker.eval()
|
||||
|
||||
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
|
||||
all_results['baseline_reranker'] = base_reranker_metrics
|
||||
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load baseline models: {e}")
|
||||
|
||||
# Save results
|
||||
results_file = os.path.join(args.output_dir, "evaluation_results.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
logger.info(f"Evaluation results saved to {results_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("EVALUATION SUMMARY")
|
||||
print("=" * 50)
|
||||
for model_type, metrics in all_results.items():
|
||||
print(f"\n{model_type.upper()}:")
|
||||
for metric, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.4f}")
|
||||
else:
|
||||
print(f" {metric}: {value}")
|
||||
|
||||
# Interactive evaluation
|
||||
if args.interactive:
|
||||
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
624
scripts/quick_validation.py
Normal file
624
scripts/quick_validation.py
Normal file
@@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick Validation Script for BGE Fine-tuned Models
|
||||
|
||||
This script provides fast validation to quickly check if fine-tuned models
|
||||
are working correctly and performing better than baselines.
|
||||
|
||||
Usage:
|
||||
# Quick test both models
|
||||
python scripts/quick_validation.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
|
||||
# Test only retriever
|
||||
python scripts/quick_validation.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuickValidator:
|
||||
"""Quick validation for BGE models"""
|
||||
|
||||
def __init__(self):
|
||||
self.device = self._get_device()
|
||||
self.test_queries = self._get_test_queries()
|
||||
|
||||
def _get_device(self):
|
||||
"""Get optimal device"""
|
||||
try:
|
||||
config = get_config()
|
||||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||||
|
||||
if device_type == 'npu' and torch.cuda.is_available():
|
||||
try:
|
||||
import torch_npu
|
||||
return torch.device("npu:0")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
return torch.device("cpu")
|
||||
|
||||
def _get_test_queries(self):
|
||||
"""Get standard test queries for validation"""
|
||||
return [
|
||||
{
|
||||
"query": "什么是深度学习?",
|
||||
"relevant": [
|
||||
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
|
||||
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
|
||||
],
|
||||
"irrelevant": [
|
||||
"今天天气很好,阳光明媚",
|
||||
"我喜欢吃苹果和香蕉"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "如何制作红烧肉?",
|
||||
"relevant": [
|
||||
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
|
||||
"制作红烧肉的关键是掌握火候和糖色的调制"
|
||||
],
|
||||
"irrelevant": [
|
||||
"Python是一种编程语言",
|
||||
"机器学习需要大量的数据"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Python编程入门",
|
||||
"relevant": [
|
||||
"Python是一种易学易用的编程语言,适合初学者入门",
|
||||
"学习Python需要掌握基本语法、数据类型和控制结构"
|
||||
],
|
||||
"irrelevant": [
|
||||
"红烧肉是一道经典的中式菜肴",
|
||||
"深度学习使用神经网络进行训练"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
|
||||
"""Quick validation of retriever model"""
|
||||
print(f"\n🔍 Validating Retriever Model")
|
||||
print("=" * 50)
|
||||
|
||||
results = {
|
||||
'model_path': model_path,
|
||||
'baseline_path': baseline_path,
|
||||
'test_results': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Test fine-tuned model
|
||||
print("Testing fine-tuned model...")
|
||||
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
|
||||
|
||||
# Test baseline model
|
||||
print("Testing baseline model...")
|
||||
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
|
||||
|
||||
results['test_results'] = {
|
||||
'finetuned': ft_scores,
|
||||
'baseline': baseline_scores,
|
||||
'comparison': comparison
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print(f"\n📊 Retriever Validation Summary:")
|
||||
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
|
||||
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
|
||||
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
|
||||
|
||||
if comparison['relevance_improvement'] > 0:
|
||||
print(" ✅ Fine-tuned model shows improvement!")
|
||||
else:
|
||||
print(" ❌ Fine-tuned model shows degradation!")
|
||||
|
||||
results['summary'] = {
|
||||
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
|
||||
'improvement_pct': comparison['relevance_improvement']
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Retriever validation failed: {e}")
|
||||
return None
|
||||
|
||||
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
|
||||
"""Quick validation of reranker model"""
|
||||
print(f"\n🏆 Validating Reranker Model")
|
||||
print("=" * 50)
|
||||
|
||||
results = {
|
||||
'model_path': model_path,
|
||||
'baseline_path': baseline_path,
|
||||
'test_results': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Test fine-tuned model
|
||||
print("Testing fine-tuned reranker...")
|
||||
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
|
||||
|
||||
# Test baseline model
|
||||
print("Testing baseline reranker...")
|
||||
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
|
||||
|
||||
results['test_results'] = {
|
||||
'finetuned': ft_scores,
|
||||
'baseline': baseline_scores,
|
||||
'comparison': comparison
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print(f"\n📊 Reranker Validation Summary:")
|
||||
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
|
||||
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
|
||||
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
|
||||
|
||||
if comparison['accuracy_improvement'] > 0:
|
||||
print(" ✅ Fine-tuned reranker shows improvement!")
|
||||
else:
|
||||
print(" ❌ Fine-tuned reranker shows degradation!")
|
||||
|
||||
results['summary'] = {
|
||||
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
|
||||
'improvement_pct': comparison['accuracy_improvement']
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reranker validation failed: {e}")
|
||||
return None
|
||||
|
||||
def _test_retriever_model(self, model_path, model_type):
|
||||
"""Test retriever model with predefined queries"""
|
||||
try:
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
print(f" ✅ {model_type} model loaded successfully")
|
||||
|
||||
relevant_scores = []
|
||||
irrelevant_scores = []
|
||||
ranking_accuracy = []
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Encode query
|
||||
query_inputs = tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
query_outputs = model(**query_inputs, return_dense=True)
|
||||
query_emb = query_outputs['dense']
|
||||
|
||||
# Test relevant documents
|
||||
for doc in relevant_docs:
|
||||
doc_inputs = tokenizer(
|
||||
doc,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||||
doc_emb = doc_outputs['dense']
|
||||
|
||||
# Compute similarity
|
||||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
relevant_scores.append(similarity)
|
||||
|
||||
# Test irrelevant documents
|
||||
for doc in irrelevant_docs:
|
||||
doc_inputs = tokenizer(
|
||||
doc,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||||
doc_emb = doc_outputs['dense']
|
||||
|
||||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
irrelevant_scores.append(similarity)
|
||||
|
||||
# Check ranking accuracy (relevant should score higher than irrelevant)
|
||||
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||||
for doc in relevant_docs])
|
||||
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||||
for doc in irrelevant_docs])
|
||||
|
||||
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
|
||||
|
||||
return {
|
||||
'avg_relevant_score': np.mean(relevant_scores),
|
||||
'avg_irrelevant_score': np.mean(irrelevant_scores),
|
||||
'ranking_accuracy': np.mean(ranking_accuracy),
|
||||
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing {model_type} retriever: {e}")
|
||||
return None
|
||||
|
||||
def _test_reranker_model(self, model_path, model_type):
|
||||
"""Test reranker model with predefined query-document pairs"""
|
||||
try:
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
print(f" ✅ {model_type} reranker loaded successfully")
|
||||
|
||||
correct_rankings = 0
|
||||
total_pairs = 0
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Test relevant documents (positive examples)
|
||||
for doc in relevant_docs:
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
score = outputs.logits.item()
|
||||
all_scores.append(score)
|
||||
all_labels.append(1) # Relevant
|
||||
|
||||
# Check if score indicates relevance (positive score)
|
||||
if score > 0:
|
||||
correct_rankings += 1
|
||||
total_pairs += 1
|
||||
|
||||
# Test irrelevant documents (negative examples)
|
||||
for doc in irrelevant_docs:
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
score = outputs.logits.item()
|
||||
all_scores.append(score)
|
||||
all_labels.append(0) # Irrelevant
|
||||
|
||||
# Check if score indicates irrelevance (negative score)
|
||||
if score <= 0:
|
||||
correct_rankings += 1
|
||||
total_pairs += 1
|
||||
|
||||
return {
|
||||
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
|
||||
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
|
||||
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
|
||||
'total_pairs': total_pairs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing {model_type} reranker: {e}")
|
||||
return None
|
||||
|
||||
def _compare_retriever_results(self, baseline, finetuned):
|
||||
"""Compare retriever results"""
|
||||
if not baseline or not finetuned:
|
||||
return {}
|
||||
|
||||
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
|
||||
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
|
||||
|
||||
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
|
||||
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
|
||||
|
||||
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
|
||||
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
|
||||
|
||||
return {
|
||||
'relevance_improvement': relevance_improvement,
|
||||
'separation_improvement': separation_improvement,
|
||||
'ranking_improvement': ranking_improvement
|
||||
}
|
||||
|
||||
def _compare_reranker_results(self, baseline, finetuned):
|
||||
"""Compare reranker results"""
|
||||
if not baseline or not finetuned:
|
||||
return {}
|
||||
|
||||
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
|
||||
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
|
||||
|
||||
return {
|
||||
'accuracy_improvement': accuracy_improvement
|
||||
}
|
||||
|
||||
def validate_pipeline(self, retriever_path, reranker_path):
|
||||
"""Quick validation of full pipeline"""
|
||||
print(f"\n🚀 Validating Full Pipeline")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Load models
|
||||
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
|
||||
retriever = BGEM3Model.from_pretrained(retriever_path)
|
||||
retriever.to(self.device)
|
||||
retriever.eval()
|
||||
|
||||
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
|
||||
reranker = BGERerankerModel.from_pretrained(reranker_path)
|
||||
reranker.to(self.device)
|
||||
reranker.eval()
|
||||
|
||||
print(" ✅ Pipeline models loaded successfully")
|
||||
|
||||
total_queries = 0
|
||||
pipeline_accuracy = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Create candidate pool
|
||||
all_docs = relevant_docs + irrelevant_docs
|
||||
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
|
||||
|
||||
# Step 1: Retrieval
|
||||
query_inputs = retriever_tokenizer(
|
||||
query, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
query_emb = retriever(**query_inputs, return_dense=True)['dense']
|
||||
|
||||
# Score all documents with retriever
|
||||
doc_scores = []
|
||||
for doc in all_docs:
|
||||
doc_inputs = retriever_tokenizer(
|
||||
doc, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
|
||||
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
doc_scores.append(score)
|
||||
|
||||
# Get top-k for reranking (all docs in this case)
|
||||
retrieval_indices = sorted(range(len(doc_scores)),
|
||||
key=lambda i: doc_scores[i], reverse=True)
|
||||
|
||||
# Step 2: Reranking
|
||||
rerank_scores = []
|
||||
rerank_labels = []
|
||||
|
||||
for idx in retrieval_indices:
|
||||
doc = all_docs[idx]
|
||||
label = doc_labels[idx]
|
||||
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
pair_inputs = reranker_tokenizer(
|
||||
pair_text, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
|
||||
rerank_output = reranker(**pair_inputs)
|
||||
rerank_score = rerank_output.logits.item()
|
||||
|
||||
rerank_scores.append(rerank_score)
|
||||
rerank_labels.append(label)
|
||||
|
||||
# Check if top-ranked document is relevant
|
||||
best_idx = np.argmax(rerank_scores)
|
||||
if rerank_labels[best_idx] == 1:
|
||||
pipeline_accuracy += 1
|
||||
|
||||
total_queries += 1
|
||||
|
||||
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
|
||||
|
||||
print(f"\n📊 Pipeline Validation Summary:")
|
||||
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
|
||||
print(f" Queries Tested: {total_queries}")
|
||||
|
||||
if pipeline_acc > 0.5:
|
||||
print(" ✅ Pipeline performing well!")
|
||||
else:
|
||||
print(" ❌ Pipeline needs improvement!")
|
||||
|
||||
return {
|
||||
'accuracy': pipeline_acc,
|
||||
'total_queries': total_queries,
|
||||
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline validation failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Quick validation of BGE fine-tuned models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("--retriever_model", type=str, default=None,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||||
help="Path to baseline retriever model")
|
||||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to baseline reranker model")
|
||||
parser.add_argument("--test_pipeline", action="store_true",
|
||||
help="Test full retrieval + reranking pipeline")
|
||||
parser.add_argument("--output_file", type=str, default=None,
|
||||
help="Save results to JSON file")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main validation function"""
|
||||
args = parse_args()
|
||||
|
||||
if not args.retriever_model and not args.reranker_model:
|
||||
print("❌ Error: At least one model must be specified")
|
||||
print(" Use --retriever_model or --reranker_model")
|
||||
return 1
|
||||
|
||||
print("🚀 BGE Quick Validation")
|
||||
print("=" * 60)
|
||||
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
validator = QuickValidator()
|
||||
results = {}
|
||||
|
||||
# Validate retriever
|
||||
if args.retriever_model:
|
||||
if os.path.exists(args.retriever_model):
|
||||
retriever_results = validator.validate_retriever(
|
||||
args.retriever_model, args.retriever_baseline
|
||||
)
|
||||
if retriever_results:
|
||||
results['retriever'] = retriever_results
|
||||
else:
|
||||
print(f"❌ Retriever model not found: {args.retriever_model}")
|
||||
|
||||
# Validate reranker
|
||||
if args.reranker_model:
|
||||
if os.path.exists(args.reranker_model):
|
||||
reranker_results = validator.validate_reranker(
|
||||
args.reranker_model, args.reranker_baseline
|
||||
)
|
||||
if reranker_results:
|
||||
results['reranker'] = reranker_results
|
||||
else:
|
||||
print(f"❌ Reranker model not found: {args.reranker_model}")
|
||||
|
||||
# Validate pipeline
|
||||
if (args.test_pipeline and args.retriever_model and args.reranker_model and
|
||||
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
|
||||
pipeline_results = validator.validate_pipeline(
|
||||
args.retriever_model, args.reranker_model
|
||||
)
|
||||
if pipeline_results:
|
||||
results['pipeline'] = pipeline_results
|
||||
|
||||
# Print final summary
|
||||
print(f"\n{'='*60}")
|
||||
print("🎯 FINAL VALIDATION SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
|
||||
overall_status = "✅ SUCCESS"
|
||||
issues = []
|
||||
|
||||
for model_type, model_results in results.items():
|
||||
if 'summary' in model_results:
|
||||
status = model_results['summary'].get('status', 'unknown')
|
||||
improvement = model_results['summary'].get('improvement_pct', 0)
|
||||
|
||||
if status == 'improved':
|
||||
print(f"✅ {model_type.title()}: Improved by {improvement:.2f}%")
|
||||
elif status == 'degraded':
|
||||
print(f"❌ {model_type.title()}: Degraded by {abs(improvement):.2f}%")
|
||||
issues.append(model_type)
|
||||
overall_status = "⚠️ ISSUES DETECTED"
|
||||
elif model_type == 'pipeline' and 'status' in model_results:
|
||||
if model_results['status'] == 'good':
|
||||
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
|
||||
else:
|
||||
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
|
||||
issues.append('pipeline')
|
||||
overall_status = "⚠️ ISSUES DETECTED"
|
||||
|
||||
print(f"\n{overall_status}")
|
||||
if issues:
|
||||
print(f"⚠️ Issues detected in: {', '.join(issues)}")
|
||||
print("💡 Consider running comprehensive validation for detailed analysis")
|
||||
|
||||
# Save results if requested
|
||||
if args.output_file:
|
||||
results['timestamp'] = datetime.now().isoformat()
|
||||
results['overall_status'] = overall_status
|
||||
results['issues'] = issues
|
||||
|
||||
with open(args.output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"📊 Results saved to: {args.output_file}")
|
||||
|
||||
return 0 if overall_status == "✅ SUCCESS" else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
696
scripts/run_validation_suite.py
Normal file
696
scripts/run_validation_suite.py
Normal file
@@ -0,0 +1,696 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE Validation Suite Runner
|
||||
|
||||
This script runs a complete validation suite for BGE fine-tuned models,
|
||||
including multiple validation approaches and generating comprehensive reports.
|
||||
|
||||
Usage:
|
||||
# Auto-discover models and run full suite
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Run with specific models
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
|
||||
# Run only specific validation types
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--validation_types quick comprehensive comparison
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import subprocess
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from scripts.validation_utils import ModelDiscovery, TestDataManager, ValidationReportAnalyzer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationSuiteRunner:
|
||||
"""Orchestrates comprehensive validation of BGE models"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.results_dir = Path(config.get('output_dir', './validation_suite_results'))
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize components
|
||||
self.model_discovery = ModelDiscovery(config.get('workspace_root', '.'))
|
||||
self.data_manager = TestDataManager(config.get('data_dir', 'data/datasets'))
|
||||
|
||||
# Suite results
|
||||
self.suite_results = {
|
||||
'start_time': datetime.now().isoformat(),
|
||||
'config': config,
|
||||
'validation_results': {},
|
||||
'summary': {},
|
||||
'issues': [],
|
||||
'recommendations': []
|
||||
}
|
||||
|
||||
def run_suite(self) -> Dict[str, Any]:
|
||||
"""Run the complete validation suite"""
|
||||
logger.info("🚀 Starting BGE Validation Suite")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Phase 1: Discovery and Setup
|
||||
self._phase_discovery()
|
||||
|
||||
# Phase 2: Quick Validation
|
||||
if 'quick' in self.config.get('validation_types', ['quick']):
|
||||
self._phase_quick_validation()
|
||||
|
||||
# Phase 3: Comprehensive Validation
|
||||
if 'comprehensive' in self.config.get('validation_types', ['comprehensive']):
|
||||
self._phase_comprehensive_validation()
|
||||
|
||||
# Phase 4: Comparison Benchmarks
|
||||
if 'comparison' in self.config.get('validation_types', ['comparison']):
|
||||
self._phase_comparison_benchmarks()
|
||||
|
||||
# Phase 5: Analysis and Reporting
|
||||
self._phase_analysis()
|
||||
|
||||
# Phase 6: Generate Final Report
|
||||
self._generate_final_report()
|
||||
|
||||
self.suite_results['end_time'] = datetime.now().isoformat()
|
||||
self.suite_results['status'] = 'completed'
|
||||
|
||||
logger.info("✅ Validation Suite Completed Successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Validation Suite Failed: {e}")
|
||||
self.suite_results['status'] = 'failed'
|
||||
self.suite_results['error'] = str(e)
|
||||
raise
|
||||
|
||||
return self.suite_results
|
||||
|
||||
def _phase_discovery(self):
|
||||
"""Phase 1: Discover models and datasets"""
|
||||
logger.info("📡 Phase 1: Discovery and Setup")
|
||||
|
||||
# Auto-discover models if requested
|
||||
if self.config.get('auto_discover', False):
|
||||
discovered_models = self.model_discovery.find_trained_models()
|
||||
|
||||
# Set models from discovery if not provided
|
||||
if not self.config.get('retriever_model'):
|
||||
if discovered_models['retrievers']:
|
||||
model_name = list(discovered_models['retrievers'].keys())[0]
|
||||
self.config['retriever_model'] = discovered_models['retrievers'][model_name]['path']
|
||||
logger.info(f" 📊 Auto-discovered retriever: {model_name}")
|
||||
|
||||
if not self.config.get('reranker_model'):
|
||||
if discovered_models['rerankers']:
|
||||
model_name = list(discovered_models['rerankers'].keys())[0]
|
||||
self.config['reranker_model'] = discovered_models['rerankers'][model_name]['path']
|
||||
logger.info(f" 🏆 Auto-discovered reranker: {model_name}")
|
||||
|
||||
self.suite_results['discovered_models'] = discovered_models
|
||||
|
||||
# Discover test datasets
|
||||
discovered_datasets = self.data_manager.find_test_datasets()
|
||||
self.suite_results['available_datasets'] = discovered_datasets
|
||||
|
||||
# Validate models exist
|
||||
issues = []
|
||||
if self.config.get('retriever_model'):
|
||||
if not os.path.exists(self.config['retriever_model']):
|
||||
issues.append(f"Retriever model not found: {self.config['retriever_model']}")
|
||||
|
||||
if self.config.get('reranker_model'):
|
||||
if not os.path.exists(self.config['reranker_model']):
|
||||
issues.append(f"Reranker model not found: {self.config['reranker_model']}")
|
||||
|
||||
if issues:
|
||||
for issue in issues:
|
||||
logger.error(f" ❌ {issue}")
|
||||
self.suite_results['issues'].extend(issues)
|
||||
return
|
||||
|
||||
logger.info(" ✅ Discovery phase completed")
|
||||
|
||||
def _phase_quick_validation(self):
|
||||
"""Phase 2: Quick validation tests"""
|
||||
logger.info("⚡ Phase 2: Quick Validation")
|
||||
|
||||
try:
|
||||
# Run quick validation script
|
||||
cmd = ['python', 'scripts/quick_validation.py']
|
||||
|
||||
if self.config.get('retriever_model'):
|
||||
cmd.extend(['--retriever_model', self.config['retriever_model']])
|
||||
|
||||
if self.config.get('reranker_model'):
|
||||
cmd.extend(['--reranker_model', self.config['reranker_model']])
|
||||
|
||||
if self.config.get('retriever_model') and self.config.get('reranker_model'):
|
||||
cmd.append('--test_pipeline')
|
||||
|
||||
# Save results to file
|
||||
quick_results_file = self.results_dir / 'quick_validation_results.json'
|
||||
cmd.extend(['--output_file', str(quick_results_file)])
|
||||
|
||||
logger.info(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(" ✅ Quick validation completed")
|
||||
|
||||
# Load results
|
||||
if quick_results_file.exists():
|
||||
with open(quick_results_file, 'r') as f:
|
||||
quick_results = json.load(f)
|
||||
self.suite_results['validation_results']['quick'] = quick_results
|
||||
else:
|
||||
logger.error(f" ❌ Quick validation failed: {result.stderr}")
|
||||
self.suite_results['issues'].append(f"Quick validation failed: {result.stderr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in quick validation: {e}")
|
||||
self.suite_results['issues'].append(f"Quick validation error: {e}")
|
||||
|
||||
def _phase_comprehensive_validation(self):
|
||||
"""Phase 3: Comprehensive validation"""
|
||||
logger.info("🔬 Phase 3: Comprehensive Validation")
|
||||
|
||||
try:
|
||||
# Run comprehensive validation script
|
||||
cmd = ['python', 'scripts/comprehensive_validation.py']
|
||||
|
||||
if self.config.get('retriever_model'):
|
||||
cmd.extend(['--retriever_finetuned', self.config['retriever_model']])
|
||||
|
||||
if self.config.get('reranker_model'):
|
||||
cmd.extend(['--reranker_finetuned', self.config['reranker_model']])
|
||||
|
||||
# Use specific datasets if provided
|
||||
if self.config.get('test_datasets'):
|
||||
cmd.extend(['--test_datasets'] + self.config['test_datasets'])
|
||||
|
||||
# Set output directory
|
||||
comprehensive_dir = self.results_dir / 'comprehensive'
|
||||
cmd.extend(['--output_dir', str(comprehensive_dir)])
|
||||
|
||||
logger.info(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(" ✅ Comprehensive validation completed")
|
||||
|
||||
# Load results
|
||||
results_file = comprehensive_dir / 'validation_results.json'
|
||||
if results_file.exists():
|
||||
with open(results_file, 'r') as f:
|
||||
comp_results = json.load(f)
|
||||
self.suite_results['validation_results']['comprehensive'] = comp_results
|
||||
else:
|
||||
logger.error(f" ❌ Comprehensive validation failed: {result.stderr}")
|
||||
self.suite_results['issues'].append(f"Comprehensive validation failed: {result.stderr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in comprehensive validation: {e}")
|
||||
self.suite_results['issues'].append(f"Comprehensive validation error: {e}")
|
||||
|
||||
def _phase_comparison_benchmarks(self):
|
||||
"""Phase 4: Run comparison benchmarks"""
|
||||
logger.info("📊 Phase 4: Comparison Benchmarks")
|
||||
|
||||
# Find suitable datasets for comparison
|
||||
datasets = self.data_manager.find_test_datasets()
|
||||
|
||||
try:
|
||||
# Run retriever comparison if model exists
|
||||
if self.config.get('retriever_model'):
|
||||
self._run_retriever_comparison(datasets)
|
||||
|
||||
# Run reranker comparison if model exists
|
||||
if self.config.get('reranker_model'):
|
||||
self._run_reranker_comparison(datasets)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in comparison benchmarks: {e}")
|
||||
self.suite_results['issues'].append(f"Comparison benchmarks error: {e}")
|
||||
|
||||
def _run_retriever_comparison(self, datasets: Dict[str, List[str]]):
|
||||
"""Run retriever comparison benchmarks"""
|
||||
logger.info(" 🔍 Running retriever comparison...")
|
||||
|
||||
# Find suitable datasets for retriever testing
|
||||
test_datasets = datasets.get('retriever_datasets', []) + datasets.get('general_datasets', [])
|
||||
|
||||
if not test_datasets:
|
||||
logger.warning(" ⚠️ No suitable datasets found for retriever comparison")
|
||||
return
|
||||
|
||||
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
|
||||
if not os.path.exists(dataset_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
'python', 'scripts/compare_retriever.py',
|
||||
'--finetuned_model_path', self.config['retriever_model'],
|
||||
'--baseline_model_path', self.config.get('retriever_baseline', 'BAAI/bge-m3'),
|
||||
'--data_path', dataset_path,
|
||||
'--batch_size', str(self.config.get('batch_size', 16)),
|
||||
'--max_samples', str(self.config.get('max_samples', 500)),
|
||||
'--output', str(self.results_dir / f'retriever_comparison_{Path(dataset_path).stem}.txt')
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f" ✅ Retriever comparison on {Path(dataset_path).name}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Retriever comparison failed on {Path(dataset_path).name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠️ Error comparing retriever on {dataset_path}: {e}")
|
||||
|
||||
def _run_reranker_comparison(self, datasets: Dict[str, List[str]]):
|
||||
"""Run reranker comparison benchmarks"""
|
||||
logger.info(" 🏆 Running reranker comparison...")
|
||||
|
||||
# Find suitable datasets for reranker testing
|
||||
test_datasets = datasets.get('reranker_datasets', []) + datasets.get('general_datasets', [])
|
||||
|
||||
if not test_datasets:
|
||||
logger.warning(" ⚠️ No suitable datasets found for reranker comparison")
|
||||
return
|
||||
|
||||
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
|
||||
if not os.path.exists(dataset_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
'python', 'scripts/compare_reranker.py',
|
||||
'--finetuned_model_path', self.config['reranker_model'],
|
||||
'--baseline_model_path', self.config.get('reranker_baseline', 'BAAI/bge-reranker-base'),
|
||||
'--data_path', dataset_path,
|
||||
'--batch_size', str(self.config.get('batch_size', 16)),
|
||||
'--max_samples', str(self.config.get('max_samples', 500)),
|
||||
'--output', str(self.results_dir / f'reranker_comparison_{Path(dataset_path).stem}.txt')
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f" ✅ Reranker comparison on {Path(dataset_path).name}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Reranker comparison failed on {Path(dataset_path).name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠️ Error comparing reranker on {dataset_path}: {e}")
|
||||
|
||||
def _phase_analysis(self):
|
||||
"""Phase 5: Analyze all results"""
|
||||
logger.info("📈 Phase 5: Results Analysis")
|
||||
|
||||
try:
|
||||
# Analyze comprehensive results if available
|
||||
comprehensive_dir = self.results_dir / 'comprehensive'
|
||||
if comprehensive_dir.exists():
|
||||
analyzer = ValidationReportAnalyzer(str(comprehensive_dir))
|
||||
analysis = analyzer.analyze_results()
|
||||
self.suite_results['analysis'] = analysis
|
||||
|
||||
logger.info(" ✅ Results analysis completed")
|
||||
else:
|
||||
logger.warning(" ⚠️ No comprehensive results found for analysis")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in results analysis: {e}")
|
||||
self.suite_results['issues'].append(f"Analysis error: {e}")
|
||||
|
||||
def _generate_final_report(self):
|
||||
"""Generate final comprehensive report"""
|
||||
logger.info("📝 Phase 6: Generating Final Report")
|
||||
|
||||
try:
|
||||
# Create summary
|
||||
self._create_suite_summary()
|
||||
|
||||
# Save suite results
|
||||
suite_results_file = self.results_dir / 'validation_suite_results.json'
|
||||
with open(suite_results_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.suite_results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Generate HTML report
|
||||
self._generate_html_report()
|
||||
|
||||
logger.info(f" 📊 Reports saved to: {self.results_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error generating final report: {e}")
|
||||
self.suite_results['issues'].append(f"Report generation error: {e}")
|
||||
|
||||
def _create_suite_summary(self):
|
||||
"""Create suite summary"""
|
||||
summary = {
|
||||
'total_validations': 0,
|
||||
'successful_validations': 0,
|
||||
'failed_validations': 0,
|
||||
'overall_verdict': 'unknown',
|
||||
'key_findings': [],
|
||||
'critical_issues': []
|
||||
}
|
||||
|
||||
# Count validations
|
||||
for validation_type, results in self.suite_results.get('validation_results', {}).items():
|
||||
summary['total_validations'] += 1
|
||||
if results:
|
||||
summary['successful_validations'] += 1
|
||||
else:
|
||||
summary['failed_validations'] += 1
|
||||
|
||||
# Determine overall verdict
|
||||
if summary['successful_validations'] > 0 and summary['failed_validations'] == 0:
|
||||
if self.suite_results.get('analysis', {}).get('overall_status') == 'excellent':
|
||||
summary['overall_verdict'] = 'excellent'
|
||||
elif self.suite_results.get('analysis', {}).get('overall_status') in ['good', 'mixed']:
|
||||
summary['overall_verdict'] = 'good'
|
||||
else:
|
||||
summary['overall_verdict'] = 'fair'
|
||||
elif summary['successful_validations'] > summary['failed_validations']:
|
||||
summary['overall_verdict'] = 'partial'
|
||||
else:
|
||||
summary['overall_verdict'] = 'poor'
|
||||
|
||||
# Extract key findings from analysis
|
||||
if 'analysis' in self.suite_results:
|
||||
analysis = self.suite_results['analysis']
|
||||
if 'recommendations' in analysis:
|
||||
summary['key_findings'] = analysis['recommendations'][:5] # Top 5 recommendations
|
||||
|
||||
# Extract critical issues
|
||||
summary['critical_issues'] = self.suite_results.get('issues', [])
|
||||
|
||||
self.suite_results['summary'] = summary
|
||||
|
||||
def _generate_html_report(self):
|
||||
"""Generate HTML summary report"""
|
||||
html_file = self.results_dir / 'validation_suite_report.html'
|
||||
|
||||
summary = self.suite_results.get('summary', {})
|
||||
|
||||
verdict_colors = {
|
||||
'excellent': '#28a745',
|
||||
'good': '#17a2b8',
|
||||
'fair': '#ffc107',
|
||||
'partial': '#fd7e14',
|
||||
'poor': '#dc3545',
|
||||
'unknown': '#6c757d'
|
||||
}
|
||||
|
||||
verdict_color = verdict_colors.get(summary.get('overall_verdict', 'unknown'), '#6c757d')
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>BGE Validation Suite Report</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; }}
|
||||
.header {{ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white; padding: 30px; border-radius: 10px; text-align: center; }}
|
||||
.verdict {{ font-size: 24px; font-weight: bold; color: {verdict_color};
|
||||
margin: 20px 0; text-align: center; padding: 15px;
|
||||
border: 3px solid {verdict_color}; border-radius: 10px; }}
|
||||
.section {{ margin: 30px 0; }}
|
||||
.metrics {{ display: flex; justify-content: space-around; margin: 20px 0; }}
|
||||
.metric {{ text-align: center; padding: 15px; background: #f8f9fa;
|
||||
border-radius: 8px; min-width: 120px; }}
|
||||
.metric-value {{ font-size: 2em; font-weight: bold; color: #495057; }}
|
||||
.metric-label {{ color: #6c757d; }}
|
||||
.findings {{ background: #e3f2fd; padding: 20px; border-radius: 8px;
|
||||
border-left: 4px solid #2196f3; }}
|
||||
.issues {{ background: #ffebee; padding: 20px; border-radius: 8px;
|
||||
border-left: 4px solid #f44336; }}
|
||||
.timeline {{ margin: 20px 0; }}
|
||||
.timeline-item {{ margin: 10px 0; padding: 10px; background: #f8f9fa;
|
||||
border-left: 3px solid #007bff; border-radius: 0 5px 5px 0; }}
|
||||
ul {{ padding-left: 20px; }}
|
||||
li {{ margin: 8px 0; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>🚀 BGE Validation Suite Report</h1>
|
||||
<p>Comprehensive Model Performance Analysis</p>
|
||||
<p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
|
||||
<div class="verdict">
|
||||
Overall Verdict: {summary.get('overall_verdict', 'UNKNOWN').upper()}
|
||||
</div>
|
||||
|
||||
<div class="metrics">
|
||||
<div class="metric">
|
||||
<div class="metric-value">{summary.get('total_validations', 0)}</div>
|
||||
<div class="metric-label">Total Validations</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value">{summary.get('successful_validations', 0)}</div>
|
||||
<div class="metric-label">Successful</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value">{summary.get('failed_validations', 0)}</div>
|
||||
<div class="metric-label">Failed</div>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add key findings
|
||||
if summary.get('key_findings'):
|
||||
html_content += """
|
||||
<div class="section">
|
||||
<h2>🔍 Key Findings</h2>
|
||||
<div class="findings">
|
||||
<ul>
|
||||
"""
|
||||
for finding in summary['key_findings']:
|
||||
html_content += f" <li>{finding}</li>\n"
|
||||
|
||||
html_content += """
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add critical issues if any
|
||||
if summary.get('critical_issues'):
|
||||
html_content += """
|
||||
<div class="section">
|
||||
<h2>⚠️ Critical Issues</h2>
|
||||
<div class="issues">
|
||||
<ul>
|
||||
"""
|
||||
for issue in summary['critical_issues']:
|
||||
html_content += f" <li>{issue}</li>\n"
|
||||
|
||||
html_content += """
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add validation timeline
|
||||
html_content += f"""
|
||||
<div class="section">
|
||||
<h2>⏱️ Validation Timeline</h2>
|
||||
<div class="timeline">
|
||||
<div class="timeline-item">
|
||||
<strong>Suite Started:</strong> {self.suite_results.get('start_time', 'Unknown')}
|
||||
</div>
|
||||
"""
|
||||
|
||||
for validation_type in self.suite_results.get('validation_results', {}).keys():
|
||||
html_content += f"""
|
||||
<div class="timeline-item">
|
||||
<strong>{validation_type.title()} Validation:</strong> Completed
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content += f"""
|
||||
<div class="timeline-item">
|
||||
<strong>Suite Completed:</strong> {self.suite_results.get('end_time', 'In Progress')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📁 Detailed Results</h2>
|
||||
<p>For detailed validation results, check the following files in <code>{self.results_dir}</code>:</p>
|
||||
<ul>
|
||||
<li><code>validation_suite_results.json</code> - Complete suite results</li>
|
||||
<li><code>quick_validation_results.json</code> - Quick validation results</li>
|
||||
<li><code>comprehensive/</code> - Comprehensive validation reports</li>
|
||||
<li><code>*_comparison_*.txt</code> - Model comparison results</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
with open(html_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f" 📊 HTML report: {html_file}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run comprehensive BGE validation suite",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Auto-discover models and run full suite
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Run with specific models
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
|
||||
# Run only specific validation types
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--validation_types quick comprehensive
|
||||
"""
|
||||
)
|
||||
|
||||
# Model paths
|
||||
parser.add_argument("--retriever_model", type=str, default=None,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--auto_discover", action="store_true",
|
||||
help="Auto-discover trained models in workspace")
|
||||
|
||||
# Baseline models
|
||||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||||
help="Baseline retriever model")
|
||||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Baseline reranker model")
|
||||
|
||||
# Validation configuration
|
||||
parser.add_argument("--validation_types", type=str, nargs="+",
|
||||
default=["quick", "comprehensive", "comparison"],
|
||||
choices=["quick", "comprehensive", "comparison"],
|
||||
help="Types of validation to run")
|
||||
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
|
||||
help="Specific test datasets to use")
|
||||
|
||||
# Performance settings
|
||||
parser.add_argument("--batch_size", type=int, default=16,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_samples", type=int, default=1000,
|
||||
help="Maximum samples per dataset for speed")
|
||||
|
||||
# Directories
|
||||
parser.add_argument("--workspace_root", type=str, default=".",
|
||||
help="Workspace root directory")
|
||||
parser.add_argument("--data_dir", type=str, default="data/datasets",
|
||||
help="Test datasets directory")
|
||||
parser.add_argument("--output_dir", type=str, default="./validation_suite_results",
|
||||
help="Output directory for all results")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
args = parse_args()
|
||||
|
||||
# Validate input
|
||||
if not args.auto_discover and not args.retriever_model and not args.reranker_model:
|
||||
print("❌ Error: Must specify models or use --auto_discover")
|
||||
print(" Use --retriever_model, --reranker_model, or --auto_discover")
|
||||
return 1
|
||||
|
||||
# Create configuration
|
||||
config = {
|
||||
'retriever_model': args.retriever_model,
|
||||
'reranker_model': args.reranker_model,
|
||||
'auto_discover': args.auto_discover,
|
||||
'retriever_baseline': args.retriever_baseline,
|
||||
'reranker_baseline': args.reranker_baseline,
|
||||
'validation_types': args.validation_types,
|
||||
'test_datasets': args.test_datasets,
|
||||
'batch_size': args.batch_size,
|
||||
'max_samples': args.max_samples,
|
||||
'workspace_root': args.workspace_root,
|
||||
'data_dir': args.data_dir,
|
||||
'output_dir': args.output_dir
|
||||
}
|
||||
|
||||
# Run validation suite
|
||||
try:
|
||||
runner = ValidationSuiteRunner(config)
|
||||
results = runner.run_suite()
|
||||
|
||||
# Print final summary
|
||||
summary = results.get('summary', {})
|
||||
verdict = summary.get('overall_verdict', 'unknown')
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("🎯 VALIDATION SUITE SUMMARY")
|
||||
print("=" * 80)
|
||||
|
||||
verdict_emojis = {
|
||||
'excellent': '🌟',
|
||||
'good': '✅',
|
||||
'fair': '👌',
|
||||
'partial': '⚠️',
|
||||
'poor': '❌',
|
||||
'unknown': '❓'
|
||||
}
|
||||
|
||||
print(f"{verdict_emojis.get(verdict, '❓')} Overall Verdict: {verdict.upper()}")
|
||||
print(f"📊 Validations: {summary.get('successful_validations', 0)}/{summary.get('total_validations', 0)} successful")
|
||||
|
||||
if summary.get('key_findings'):
|
||||
print(f"\n🔍 Key Findings:")
|
||||
for i, finding in enumerate(summary['key_findings'][:3], 1):
|
||||
print(f" {i}. {finding}")
|
||||
|
||||
if summary.get('critical_issues'):
|
||||
print(f"\n⚠️ Critical Issues:")
|
||||
for issue in summary['critical_issues'][:3]:
|
||||
print(f" • {issue}")
|
||||
|
||||
print(f"\n📁 Detailed results: {config['output_dir']}")
|
||||
print("🌐 Open validation_suite_report.html in your browser for full report")
|
||||
|
||||
return 0 if verdict in ['excellent', 'good'] else 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n❌ Validation suite interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Validation suite failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
401
scripts/setup.py
Normal file
401
scripts/setup.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Setup script for BGE fine-tuning environment
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
import shutil
|
||||
import urllib.request
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_command(command, check=True):
|
||||
"""Run a shell command"""
|
||||
logger.info(f"Running: {command}")
|
||||
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||||
|
||||
if check and result.returncode != 0:
|
||||
logger.error(f"Command failed: {command}")
|
||||
logger.error(f"Error: {result.stderr}")
|
||||
sys.exit(1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_python_version():
|
||||
"""Check Python version"""
|
||||
if sys.version_info < (3, 8):
|
||||
logger.error("Python 3.8 or higher is required")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Python version: {sys.version}")
|
||||
|
||||
|
||||
def check_gpu():
|
||||
"""Check GPU availability"""
|
||||
try:
|
||||
result = run_command("nvidia-smi", check=False)
|
||||
if result.returncode == 0:
|
||||
logger.info("NVIDIA GPU detected")
|
||||
return "cuda"
|
||||
else:
|
||||
logger.info("No NVIDIA GPU detected")
|
||||
return "cpu"
|
||||
except:
|
||||
logger.info("nvidia-smi not available")
|
||||
return "cpu"
|
||||
|
||||
|
||||
def install_pytorch(device_type="cuda"):
|
||||
"""Install PyTorch"""
|
||||
logger.info("Installing PyTorch...")
|
||||
|
||||
if device_type == "cuda":
|
||||
# Install CUDA version for RTX 5090
|
||||
pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
|
||||
else:
|
||||
# Install CPU version
|
||||
pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
||||
run_command(pip_command)
|
||||
|
||||
# Verify installation
|
||||
try:
|
||||
import torch
|
||||
logger.info(f"PyTorch version: {torch.__version__}")
|
||||
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
logger.info(f"CUDA version: {torch.version.cuda}")
|
||||
logger.info(f"GPU count: {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
||||
except ImportError:
|
||||
logger.error("Failed to import PyTorch")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_requirements():
|
||||
"""Install requirements"""
|
||||
logger.info("Installing requirements...")
|
||||
|
||||
requirements_file = Path(__file__).parent.parent / "requirements.txt"
|
||||
|
||||
if requirements_file.exists():
|
||||
run_command(f"pip install -r {requirements_file}")
|
||||
else:
|
||||
logger.warning("requirements.txt not found, installing core packages...")
|
||||
|
||||
# Core packages
|
||||
packages = [
|
||||
"transformers>=4.36.0",
|
||||
"datasets>=2.14.0",
|
||||
"tokenizers>=0.15.0",
|
||||
"numpy>=1.24.0",
|
||||
"scipy>=1.10.0",
|
||||
"scikit-learn>=1.3.0",
|
||||
"pandas>=2.0.0",
|
||||
"FlagEmbedding>=1.2.0",
|
||||
"faiss-gpu>=1.7.4",
|
||||
"rank-bm25>=0.2.2",
|
||||
"accelerate>=0.25.0",
|
||||
"tensorboard>=2.14.0",
|
||||
"tqdm>=4.66.0",
|
||||
"pyyaml>=6.0",
|
||||
"jsonlines>=3.1.0",
|
||||
"huggingface-hub>=0.19.0",
|
||||
"jieba>=0.42.1",
|
||||
"toml>=0.10.2",
|
||||
"requests>=2.28.0"
|
||||
]
|
||||
|
||||
for package in packages:
|
||||
run_command(f"pip install {package}")
|
||||
|
||||
|
||||
def setup_directories():
|
||||
"""Setup project directories"""
|
||||
logger.info("Setting up directories...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
directories = [
|
||||
"models",
|
||||
"models/bge-m3",
|
||||
"models/bge-reranker-base",
|
||||
"data/raw",
|
||||
"data/processed",
|
||||
"output",
|
||||
"logs",
|
||||
"cache/data"
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
dir_path = base_dir / directory
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created directory: {dir_path}")
|
||||
|
||||
|
||||
def download_models(download_models_flag=False):
|
||||
"""Download base models from HuggingFace or ModelScope based on config.toml"""
|
||||
if not download_models_flag:
|
||||
logger.info("Skipping model download. Use --download-models to download base models.")
|
||||
return
|
||||
|
||||
logger.info("Downloading base models...")
|
||||
base_dir = Path(__file__).parent.parent
|
||||
config = get_config()
|
||||
model_source = config.get('model_paths.source', 'huggingface')
|
||||
|
||||
if model_source == 'huggingface':
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
# Download BGE-M3
|
||||
logger.info("Downloading BGE-M3 from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-m3",
|
||||
local_dir=base_dir / "models/bge-m3",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
# Download BGE-Reranker
|
||||
logger.info("Downloading BGE-Reranker from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-reranker-base",
|
||||
local_dir=base_dir / "models/bge-reranker-base",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
logger.info("Models downloaded successfully from HuggingFace!")
|
||||
except ImportError:
|
||||
logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download models from HuggingFace: {e}")
|
||||
logger.info("You can download models manually or use the Hugging Face CLI:")
|
||||
logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3")
|
||||
logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base")
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
# Download BGE-M3
|
||||
logger.info("Downloading BGE-M3 from ModelScope...")
|
||||
ms_snapshot_download(
|
||||
model_id="damo/nlp_corom_sentence-embedding_chinese-base",
|
||||
cache_dir=str(base_dir / "models/bge-m3")
|
||||
)
|
||||
# Download BGE-Reranker
|
||||
logger.info("Downloading BGE-Reranker from ModelScope...")
|
||||
ms_snapshot_download(
|
||||
model_id="damo/nlp_corom_cross-encoder_chinese-base",
|
||||
cache_dir=str(base_dir / "models/bge-reranker-base")
|
||||
)
|
||||
logger.info("Models downloaded successfully from ModelScope!")
|
||||
except ImportError:
|
||||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download models from ModelScope: {e}")
|
||||
logger.info("You can download models manually from ModelScope website or CLI.")
|
||||
else:
|
||||
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||||
|
||||
|
||||
def create_sample_config():
|
||||
"""Create sample configuration files"""
|
||||
logger.info("Creating sample configuration...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
config_file = base_dir / "config.toml"
|
||||
|
||||
if not config_file.exists():
|
||||
# Create default config (this would match our config.toml artifact)
|
||||
sample_config = """# Configuration file for BGE fine-tuning project
|
||||
|
||||
[api]
|
||||
# OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.)
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "your-api-key-here" # Replace with your actual API key
|
||||
model = "deepseek-chat"
|
||||
timeout = 30
|
||||
|
||||
[translation]
|
||||
# Translation settings
|
||||
enable_api = true
|
||||
target_languages = ["en", "zh", "ja", "fr"]
|
||||
max_retries = 3
|
||||
fallback_to_simulation = true
|
||||
|
||||
[training]
|
||||
# Training defaults
|
||||
default_batch_size = 4
|
||||
default_learning_rate = 1e-5
|
||||
default_epochs = 3
|
||||
default_warmup_ratio = 0.1
|
||||
|
||||
[model_paths]
|
||||
# Default model paths
|
||||
bge_m3 = "./models/bge-m3"
|
||||
bge_reranker = "./models/bge-reranker-base"
|
||||
|
||||
[data]
|
||||
# Data processing settings
|
||||
max_query_length = 512
|
||||
max_passage_length = 8192
|
||||
min_passage_length = 10
|
||||
validation_split = 0.1
|
||||
"""
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
f.write(sample_config)
|
||||
|
||||
logger.info(f"Created sample config: {config_file}")
|
||||
logger.info("Please edit config.toml to add your API keys if needed.")
|
||||
|
||||
|
||||
def create_sample_data():
|
||||
"""Create sample data files"""
|
||||
logger.info("Creating sample data...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
sample_file = base_dir / "data/raw/sample_data.jsonl"
|
||||
|
||||
if not sample_file.exists():
|
||||
sample_data = [
|
||||
{
|
||||
"query": "什么是机器学习?",
|
||||
"pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"],
|
||||
"neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"]
|
||||
},
|
||||
{
|
||||
"query": "How does neural network work?",
|
||||
"pos": [
|
||||
"Neural networks are computing systems inspired by biological neural networks that constitute the brain."],
|
||||
"neg": ["Machine learning algorithms can be supervised or unsupervised.",
|
||||
"Data preprocessing is an important step in ML."]
|
||||
},
|
||||
{
|
||||
"query": "Python编程语言的特点",
|
||||
"pos": ["Python是一种高级编程语言,具有简洁的语法和强大的功能,广泛用于数据科学和AI开发。"],
|
||||
"neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"]
|
||||
}
|
||||
]
|
||||
|
||||
with open(sample_file, 'w', encoding='utf-8') as f:
|
||||
for item in sample_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
logger.info(f"Created sample data: {sample_file}")
|
||||
|
||||
|
||||
def verify_installation():
|
||||
"""Verify installation"""
|
||||
logger.info("Verifying installation...")
|
||||
|
||||
try:
|
||||
# Test imports
|
||||
import torch
|
||||
import transformers
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger.info("✓ Core packages imported successfully")
|
||||
|
||||
# Test GPU
|
||||
if torch.cuda.is_available():
|
||||
logger.info("✓ CUDA is available")
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn(10, 10).to(device)
|
||||
logger.info("✓ GPU tensor operations work")
|
||||
else:
|
||||
logger.info("✓ CPU-only setup")
|
||||
|
||||
# Test model loading (if models exist)
|
||||
base_dir = Path(__file__).parent.parent
|
||||
if (base_dir / "models/bge-m3").exists():
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3"))
|
||||
logger.info("✓ BGE-M3 model files are accessible")
|
||||
except Exception as e:
|
||||
logger.warning(f"BGE-M3 model check failed: {e}")
|
||||
|
||||
logger.info("Installation verification completed!")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Import error: {e}")
|
||||
logger.error("Installation may be incomplete")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def print_usage_info():
|
||||
"""Print usage information"""
|
||||
print("\n" + "=" * 60)
|
||||
print("BGE FINE-TUNING SETUP COMPLETED!")
|
||||
print("=" * 60)
|
||||
print("\nQuick Start:")
|
||||
print("1. Edit config.toml to add your API keys (optional)")
|
||||
print("2. Prepare your training data in JSONL format")
|
||||
print("3. Run training:")
|
||||
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
|
||||
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
|
||||
print("4. Evaluate models:")
|
||||
print(" python scripts/evaluate.py --retriever_model output/model --eval_data data/raw/test.jsonl")
|
||||
print("\nSample data created in: data/raw/sample_data.jsonl")
|
||||
print("Configuration file: config.toml")
|
||||
print("\nFor more information, see the README.md file.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment")
|
||||
parser.add_argument("--download-models", action="store_true",
|
||||
help="Download base models from Hugging Face")
|
||||
parser.add_argument("--cpu-only", action="store_true",
|
||||
help="Install CPU-only version of PyTorch")
|
||||
parser.add_argument("--skip-pytorch", action="store_true",
|
||||
help="Skip PyTorch installation")
|
||||
parser.add_argument("--skip-requirements", action="store_true",
|
||||
help="Skip requirements installation")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Starting BGE fine-tuning environment setup...")
|
||||
|
||||
# Check Python version
|
||||
check_python_version()
|
||||
|
||||
# Detect device type
|
||||
if args.cpu_only:
|
||||
device_type = "cpu"
|
||||
else:
|
||||
device_type = check_gpu()
|
||||
|
||||
# Install PyTorch
|
||||
if not args.skip_pytorch:
|
||||
install_pytorch(device_type)
|
||||
|
||||
# Install requirements
|
||||
if not args.skip_requirements:
|
||||
install_requirements()
|
||||
|
||||
# Setup directories
|
||||
setup_directories()
|
||||
|
||||
# Download models
|
||||
download_models(args.download_models)
|
||||
|
||||
# Create sample files
|
||||
create_sample_config()
|
||||
create_sample_data()
|
||||
|
||||
# Verify installation
|
||||
verify_installation()
|
||||
|
||||
# Print usage info
|
||||
print_usage_info()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
91
scripts/test_converted_dataset.py
Normal file
91
scripts/test_converted_dataset.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify the converted dataset can be loaded by BGE training system
|
||||
"""
|
||||
|
||||
from data.dataset import BGEDataset
|
||||
from transformers import AutoTokenizer
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_dataset_loading(dataset_path: str):
|
||||
"""Test that the dataset can be loaded properly"""
|
||||
|
||||
# Check if dataset file exists
|
||||
if not os.path.exists(dataset_path):
|
||||
logger.error(f"Dataset file not found: {dataset_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load tokenizer
|
||||
logger.info("Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('models/bge-m3')
|
||||
|
||||
# Test loading the converted dataset
|
||||
logger.info("Loading dataset...")
|
||||
dataset = BGEDataset(
|
||||
data_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=512,
|
||||
format_type='triplets'
|
||||
)
|
||||
|
||||
logger.info(f"✅ Dataset loaded successfully!")
|
||||
logger.info(f"📊 Total samples: {len(dataset)}")
|
||||
logger.info(f"🔍 Detected format: {dataset.detected_format}")
|
||||
|
||||
# Test getting a sample
|
||||
logger.info("Testing sample retrieval...")
|
||||
sample = dataset[0]
|
||||
|
||||
logger.info(f"🎯 Sample keys: {list(sample.keys())}")
|
||||
logger.info(f"🔤 Query shape: {sample['query_input_ids'].shape}")
|
||||
logger.info(f"📄 Passages shape: {sample['passage_input_ids'].shape}")
|
||||
logger.info(f"🏷️ Labels shape: {sample['labels'].shape}")
|
||||
logger.info(f"📊 Scores shape: {sample['scores'].shape}")
|
||||
|
||||
# Show sample content
|
||||
logger.info(f"🔍 Query text: {sample['query_text'][:100]}...")
|
||||
logger.info(f"📝 First passage: {sample['passage_texts'][0][:100]}...")
|
||||
logger.info(f"🏷️ Labels: {sample['labels']}")
|
||||
logger.info(f"📊 Scores: {sample['scores']}")
|
||||
|
||||
# Test a few more samples
|
||||
logger.info("Testing multiple samples...")
|
||||
for i in [1, 10, 100]:
|
||||
if i < len(dataset):
|
||||
sample_i = dataset[i]
|
||||
logger.info(f"Sample {i}: Query length={len(sample_i['query_text'])}, Passages={len([p for p in sample_i['passage_texts'] if p])}")
|
||||
|
||||
logger.info("✅ All tests passed! Dataset is ready for BGE-M3 training.")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Dataset loading failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
dataset_path = "data/datasets/三国演义/m3_train_converted.jsonl"
|
||||
|
||||
logger.info("🧪 Testing converted BGE-M3 dataset...")
|
||||
logger.info(f"📁 Dataset path: {dataset_path}")
|
||||
|
||||
success = test_dataset_loading(dataset_path)
|
||||
|
||||
if success:
|
||||
logger.info("🎉 Dataset is ready for training!")
|
||||
logger.info("💡 You can now use this dataset with:")
|
||||
logger.info(" python scripts/train_m3.py --data_path data/datasets/三国演义/m3_train_converted.jsonl")
|
||||
else:
|
||||
logger.error("💥 Dataset testing failed!")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
497
scripts/train_joint.py
Normal file
497
scripts/train_joint.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""
|
||||
Joint training of BGE-M3 and BGE-Reranker using RocketQAv2 approach
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
# Set tokenizer parallelism to avoid warnings
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
import torch
|
||||
from transformers import AutoTokenizer, set_seed
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset, JointDataset
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from training.joint_trainer import JointTrainer
|
||||
from utils.logging import setup_logging, get_logger
|
||||
from utils.distributed import setup_distributed, is_main_process, set_random_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Joint training of BGE-M3 and BGE-Reranker")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, default="BAAI/bge-m3",
|
||||
help="Path to retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to reranker model")
|
||||
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
||||
help="Directory to cache downloaded models")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--train_data", type=str, required=True,
|
||||
help="Path to training data file")
|
||||
parser.add_argument("--eval_data", type=str, default=None,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./output/joint-training",
|
||||
help="Directory to save the models")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
|
||||
help="Batch size per GPU/CPU for training")
|
||||
parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
|
||||
help="Batch size per GPU/CPU for evaluation")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
|
||||
help="Number of updates steps to accumulate")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-5,
|
||||
help="Initial learning rate for retriever")
|
||||
parser.add_argument("--reranker_learning_rate", type=float, default=6e-5,
|
||||
help="Initial learning rate for reranker")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
||||
help="Ratio of warmup steps")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01,
|
||||
help="Weight decay")
|
||||
|
||||
# Joint training specific
|
||||
parser.add_argument("--joint_loss_weight", type=float, default=0.3,
|
||||
help="Weight for joint distillation loss")
|
||||
parser.add_argument("--update_retriever_steps", type=int, default=100,
|
||||
help="Update retriever every N steps")
|
||||
parser.add_argument("--use_dynamic_distillation", action="store_true",
|
||||
help="Use dynamic listwise distillation")
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument("--retriever_train_group_size", type=int, default=8,
|
||||
help="Number of passages per query for retriever")
|
||||
parser.add_argument("--reranker_train_group_size", type=int, default=16,
|
||||
help="Number of passages per query for reranker")
|
||||
parser.add_argument("--temperature", type=float, default=0.02,
|
||||
help="Temperature for InfoNCE loss")
|
||||
|
||||
# BGE-M3 specific
|
||||
parser.add_argument("--use_dense", action="store_true", default=True,
|
||||
help="Use dense embeddings")
|
||||
parser.add_argument("--use_sparse", action="store_true",
|
||||
help="Use sparse embeddings")
|
||||
parser.add_argument("--use_colbert", action="store_true",
|
||||
help="Use ColBERT embeddings")
|
||||
parser.add_argument("--query_max_length", type=int, default=512,
|
||||
help="Maximum query length")
|
||||
parser.add_argument("--passage_max_length", type=int, default=8192,
|
||||
help="Maximum passage length")
|
||||
|
||||
# Reranker specific
|
||||
parser.add_argument("--reranker_max_length", type=int, default=512,
|
||||
help="Maximum sequence length for reranker")
|
||||
|
||||
# Performance
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Use mixed precision training")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="Use gradient checkpointing")
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
||||
help="Number of workers for data loading")
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--save_steps", type=int, default=1000,
|
||||
help="Save checkpoint every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=1000,
|
||||
help="Evaluate every N steps")
|
||||
parser.add_argument("--logging_steps", type=int, default=100,
|
||||
help="Log every N steps")
|
||||
parser.add_argument("--save_total_limit", type=int, default=3,
|
||||
help="Maximum number of checkpoints to keep")
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint to resume from")
|
||||
|
||||
# Distributed training
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="Local rank for distributed training")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true",
|
||||
help="Overwrite the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
|
||||
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting joint training of BGE-M3 and BGE-Reranker")
|
||||
logger.info(f"Arguments: {args}")
|
||||
|
||||
# Set random seed
|
||||
set_random_seed(args.seed, deterministic=False)
|
||||
|
||||
# Setup distributed training if needed
|
||||
device = None
|
||||
if args.local_rank != -1:
|
||||
setup_distributed()
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu
|
||||
torch_npu.npu.set_device(args.local_rank)
|
||||
device = torch.device("npu", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Distributed training enabled: using CPU")
|
||||
else:
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu
|
||||
device = torch.device("npu")
|
||||
logger.info("Using Ascend NPU")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU")
|
||||
|
||||
# Create configurations
|
||||
retriever_config = BGEM3Config(
|
||||
model_name_or_path=args.retriever_model,
|
||||
cache_dir=args.cache_dir,
|
||||
train_data_path=args.train_data,
|
||||
eval_data_path=args.eval_data,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
output_dir=os.path.join(args.output_dir, "retriever"),
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
weight_decay=args.weight_decay,
|
||||
train_group_size=args.retriever_train_group_size,
|
||||
temperature=args.temperature,
|
||||
use_dense=args.use_dense,
|
||||
use_sparse=args.use_sparse,
|
||||
use_colbert=args.use_colbert,
|
||||
fp16=args.fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
dataloader_num_workers=args.dataloader_num_workers,
|
||||
save_steps=args.save_steps,
|
||||
eval_steps=args.eval_steps,
|
||||
logging_steps=args.logging_steps,
|
||||
save_total_limit=args.save_total_limit,
|
||||
seed=args.seed,
|
||||
local_rank=args.local_rank,
|
||||
device=str(device)
|
||||
)
|
||||
|
||||
reranker_config = BGERerankerConfig(
|
||||
model_name_or_path=args.reranker_model,
|
||||
cache_dir=args.cache_dir,
|
||||
train_data_path=args.train_data,
|
||||
eval_data_path=args.eval_data,
|
||||
max_seq_length=args.reranker_max_length,
|
||||
output_dir=os.path.join(args.output_dir, "reranker"),
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.reranker_learning_rate,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
weight_decay=args.weight_decay,
|
||||
train_group_size=args.reranker_train_group_size,
|
||||
joint_training=True, # type: ignore[call-arg]
|
||||
joint_loss_weight=args.joint_loss_weight, # type: ignore[call-arg]
|
||||
use_dynamic_listwise_distillation=args.use_dynamic_distillation, # type: ignore[call-arg]
|
||||
update_retriever_steps=args.update_retriever_steps, # type: ignore[call-arg]
|
||||
fp16=args.fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
dataloader_num_workers=args.dataloader_num_workers,
|
||||
save_steps=args.save_steps,
|
||||
eval_steps=args.eval_steps,
|
||||
logging_steps=args.logging_steps,
|
||||
save_total_limit=args.save_total_limit,
|
||||
seed=args.seed,
|
||||
local_rank=args.local_rank,
|
||||
device=str(device)
|
||||
)
|
||||
|
||||
# Create joint config (use retriever config as base)
|
||||
joint_config = retriever_config
|
||||
joint_config.output_dir = args.output_dir
|
||||
joint_config.joint_loss_weight = args.joint_loss_weight # type: ignore[attr-defined]
|
||||
joint_config.use_dynamic_listwise_distillation = args.use_dynamic_distillation # type: ignore[attr-defined]
|
||||
joint_config.update_retriever_steps = args.update_retriever_steps # type: ignore[attr-defined]
|
||||
|
||||
# Save configurations
|
||||
if is_main_process():
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
joint_config.save(os.path.join(args.output_dir, "joint_config.json"))
|
||||
retriever_config.save(os.path.join(args.output_dir, "retriever_config.json"))
|
||||
reranker_config.save(os.path.join(args.output_dir, "reranker_config.json"))
|
||||
|
||||
# Load tokenizers
|
||||
logger.info("Loading tokenizers...")
|
||||
retriever_tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.retriever_model,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
reranker_tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.reranker_model,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
# Prepare data
|
||||
logger.info("Preparing data...")
|
||||
preprocessor = DataPreprocessor(
|
||||
tokenizer=retriever_tokenizer, # Use retriever tokenizer for preprocessing
|
||||
max_query_length=args.query_max_length,
|
||||
max_passage_length=args.passage_max_length
|
||||
)
|
||||
|
||||
# Preprocess training data
|
||||
train_path, val_path = preprocessor.preprocess_file(
|
||||
input_path=args.train_data,
|
||||
output_path=os.path.join(args.output_dir, "processed_train.jsonl"),
|
||||
file_format=args.data_format,
|
||||
validation_split=0.1 if args.eval_data is None else 0.0
|
||||
)
|
||||
|
||||
# Use provided eval data or split validation
|
||||
eval_path = args.eval_data if args.eval_data else val_path
|
||||
|
||||
# Create datasets
|
||||
logger.info("Creating datasets...")
|
||||
|
||||
# Retriever dataset
|
||||
retriever_train_dataset = BGEM3Dataset(
|
||||
data_path=train_path,
|
||||
tokenizer=retriever_tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.retriever_train_group_size,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
# Reranker dataset
|
||||
reranker_train_dataset = BGERerankerDataset(
|
||||
data_path=train_path,
|
||||
tokenizer=reranker_tokenizer,
|
||||
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
|
||||
train_group_size=args.reranker_train_group_size,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
# Joint dataset
|
||||
joint_train_dataset = JointDataset( # type: ignore[call-arg]
|
||||
retriever_dataset=retriever_train_dataset, # type: ignore[call-arg]
|
||||
reranker_dataset=reranker_train_dataset # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Evaluation datasets
|
||||
joint_eval_dataset = None
|
||||
if eval_path:
|
||||
retriever_eval_dataset = BGEM3Dataset(
|
||||
data_path=eval_path,
|
||||
tokenizer=retriever_tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.retriever_train_group_size,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
reranker_eval_dataset = BGERerankerDataset(
|
||||
data_path=eval_path,
|
||||
tokenizer=reranker_tokenizer,
|
||||
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
|
||||
train_group_size=args.reranker_train_group_size,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
joint_eval_dataset = JointDataset( # type: ignore[call-arg]
|
||||
retriever_dataset=retriever_eval_dataset, # type: ignore[call-arg]
|
||||
reranker_dataset=reranker_eval_dataset # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Initialize models
|
||||
logger.info("Loading models...")
|
||||
retriever = BGEM3Model(
|
||||
model_name_or_path=args.retriever_model,
|
||||
use_dense=args.use_dense,
|
||||
use_sparse=args.use_sparse,
|
||||
use_colbert=args.use_colbert,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
reranker = BGERerankerModel(
|
||||
model_name_or_path=args.reranker_model,
|
||||
cache_dir=args.cache_dir,
|
||||
use_fp16=args.fp16
|
||||
)
|
||||
from utils.config_loader import get_config
|
||||
logger.info("Retriever model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
|
||||
logger.info("Reranker model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
|
||||
|
||||
# Initialize optimizers
|
||||
from torch.optim import AdamW
|
||||
|
||||
retriever_optimizer = AdamW(
|
||||
retriever.parameters(),
|
||||
lr=args.learning_rate,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
reranker_optimizer = AdamW(
|
||||
reranker.parameters(),
|
||||
lr=args.reranker_learning_rate,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
# Initialize joint trainer
|
||||
logger.info("Initializing joint trainer...")
|
||||
trainer = JointTrainer(
|
||||
config=joint_config,
|
||||
retriever=retriever,
|
||||
reranker=reranker,
|
||||
retriever_optimizer=retriever_optimizer,
|
||||
reranker_optimizer=reranker_optimizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.distributed import create_distributed_dataloader
|
||||
|
||||
if args.local_rank != -1:
|
||||
train_dataloader = create_distributed_dataloader(
|
||||
joint_train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if joint_eval_dataset:
|
||||
eval_dataloader = create_distributed_dataloader(
|
||||
joint_eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
seed=args.seed
|
||||
)
|
||||
else:
|
||||
train_dataloader = DataLoader(
|
||||
joint_train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if joint_eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
joint_eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume_from_checkpoint:
|
||||
trainer.load_checkpoint(args.resume_from_checkpoint)
|
||||
|
||||
# Train
|
||||
logger.info("Starting joint training...")
|
||||
try:
|
||||
trainer.train(train_dataloader, eval_dataloader, args.num_train_epochs)
|
||||
|
||||
# Save final models
|
||||
if is_main_process():
|
||||
logger.info("Saving final models...")
|
||||
final_dir = os.path.join(args.output_dir, "final_models")
|
||||
trainer.save_models(final_dir)
|
||||
|
||||
# Save individual models
|
||||
retriever_path = os.path.join(args.output_dir, "final_retriever")
|
||||
reranker_path = os.path.join(args.output_dir, "final_reranker")
|
||||
|
||||
os.makedirs(retriever_path, exist_ok=True)
|
||||
os.makedirs(reranker_path, exist_ok=True)
|
||||
|
||||
retriever.save_pretrained(retriever_path)
|
||||
retriever_tokenizer.save_pretrained(retriever_path)
|
||||
|
||||
reranker.model.save_pretrained(reranker_path)
|
||||
reranker_tokenizer.save_pretrained(reranker_path)
|
||||
|
||||
# Save training summary
|
||||
summary = {
|
||||
"training_args": vars(args),
|
||||
"retriever_config": retriever_config.to_dict(),
|
||||
"reranker_config": reranker_config.to_dict(),
|
||||
"joint_config": joint_config.to_dict(),
|
||||
"final_metrics": getattr(trainer, 'best_metric', None),
|
||||
"total_steps": getattr(trainer, 'global_step', 0),
|
||||
"training_completed": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
logger.info(f"Joint training completed! Models saved to {args.output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Joint training failed with error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
614
scripts/train_m3.py
Normal file
614
scripts/train_m3.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""
|
||||
Enhanced BGE-M3 embedding model training script with integrated dataset pipeline
|
||||
Supports both legacy nested format and new embedding schema with automatic detection
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
# Set tokenizer parallelism to avoid warnings
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
import torch
|
||||
from transformers import AutoTokenizer, set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from transformers.optimization import get_linear_schedule_with_warmup
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
try:
|
||||
from config.m3_config import BGEM3Config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset, create_embedding_dataset, collate_embedding_batch, migrate_nested_to_flat
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from data.augmentation import QueryAugmenter, HardNegativeAugmenter
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from evaluation.evaluator import RetrievalEvaluator
|
||||
from utils.logging import setup_logging, get_logger
|
||||
from utils.distributed import setup_distributed, is_main_process, set_random_seed
|
||||
from utils.config_loader import get_config, load_config_for_training, resolve_model_path
|
||||
except ImportError as e:
|
||||
print(f"Import error: {e}")
|
||||
print("Please ensure you're running from the project root directory")
|
||||
sys.exit(1)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Enhanced BGE-M3 Embedding Model Training")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model_name_or_path", type=str, default=None,
|
||||
help="Path to pretrained model or model identifier")
|
||||
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
||||
help="Directory to cache downloaded models")
|
||||
parser.add_argument("--config_path", type=str, default="./config.toml",
|
||||
help="Path to configuration file")
|
||||
|
||||
# Data arguments - support both legacy and new schemas
|
||||
parser.add_argument("--train_data", type=str, required=True,
|
||||
help="Path to training data file (supports both embedding and legacy nested formats)")
|
||||
parser.add_argument("--eval_data", type=str, default=None,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "embedding", "legacy_nested", "jsonl", "json", "csv", "tsv"],
|
||||
help="Format of input data (auto-detects by default)")
|
||||
parser.add_argument("--legacy_data", type=str, default=None,
|
||||
help="Path to additional legacy nested data (for backward compatibility)")
|
||||
parser.add_argument("--migrate_legacy", action="store_true",
|
||||
help="Migrate legacy nested format to flat format")
|
||||
|
||||
# Sequence length arguments
|
||||
parser.add_argument("--max_seq_length", type=int, default=None,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--query_max_length", type=int, default=None,
|
||||
help="Maximum query length")
|
||||
parser.add_argument("--passage_max_length", type=int, default=None,
|
||||
help="Maximum passage length")
|
||||
parser.add_argument("--max_length", type=int, default=8192,
|
||||
help="Maximum total input length (multi-granularity)")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./output/bge-m3-enhanced",
|
||||
help="Directory to save the model")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=None,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=None,
|
||||
help="Batch size per GPU/CPU for training")
|
||||
parser.add_argument("--per_device_eval_batch_size", type=int, default=None,
|
||||
help="Batch size per GPU/CPU for evaluation")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass")
|
||||
parser.add_argument("--learning_rate", type=float, default=None,
|
||||
help="Initial learning rate")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=None,
|
||||
help="Warmup steps as a ratio of total steps")
|
||||
parser.add_argument("--weight_decay", type=float, default=None,
|
||||
help="Weight decay")
|
||||
|
||||
# Model specific arguments - enhanced features
|
||||
parser.add_argument("--train_group_size", type=int, default=8,
|
||||
help="Number of passages per query in training")
|
||||
parser.add_argument("--temperature", type=float, default=0.1,
|
||||
help="Temperature for InfoNCE loss")
|
||||
parser.add_argument("--use_hard_negatives", action="store_true",
|
||||
help="Enable hard negative mining")
|
||||
parser.add_argument("--use_self_distill", action="store_true",
|
||||
help="Enable self-knowledge distillation")
|
||||
parser.add_argument("--use_score_weighted_sampling", action="store_true", default=True,
|
||||
help="Enable score-weighted positive sampling")
|
||||
parser.add_argument("--query_instruction", type=str, default="",
|
||||
help="Query instruction prefix")
|
||||
|
||||
# Enhanced sampling features
|
||||
parser.add_argument("--multiple_positives", action="store_true", default=True,
|
||||
help="Use multiple positives per batch for richer contrastive signals")
|
||||
parser.add_argument("--hard_negative_ratio", type=float, default=0.7,
|
||||
help="Ratio of hard to random negatives")
|
||||
parser.add_argument("--difficulty_threshold", type=float, default=0.1,
|
||||
help="Minimum difficulty for hard negatives")
|
||||
|
||||
# Optimization and performance
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Use mixed precision training")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="Use gradient checkpointing to save memory")
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
||||
help="Number of workers for data loading")
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--save_steps", type=int, default=500,
|
||||
help="Save checkpoint every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=500,
|
||||
help="Evaluate every N steps")
|
||||
parser.add_argument("--logging_steps", type=int, default=100,
|
||||
help="Log every N steps")
|
||||
parser.add_argument("--save_total_limit", type=int, default=3,
|
||||
help="Maximum number of checkpoints to keep")
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint to resume from")
|
||||
|
||||
# Distributed training
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="Local rank for distributed training")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true",
|
||||
help="Overwrite the output directory")
|
||||
parser.add_argument("--log_level", type=str, default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging level")
|
||||
|
||||
# Backward compatibility flags
|
||||
parser.add_argument("--legacy_support", action="store_true", default=True,
|
||||
help="Enable legacy nested format support")
|
||||
parser.add_argument("--use_new_dataset_api", action="store_true", default=True,
|
||||
help="Use new integrated dataset API with format detection")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
|
||||
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def create_config_from_args(args):
|
||||
"""Create configuration from command line arguments and config file
|
||||
|
||||
Parameter Precedence (highest to lowest):
|
||||
1. Command-line arguments (CLI)
|
||||
2. Model-specific config ([m3], [reranker])
|
||||
3. [hardware] section in config.toml
|
||||
4. [training] section in config.toml
|
||||
5. Hardcoded defaults in code
|
||||
"""
|
||||
# Load base configuration from TOML file
|
||||
config_dict = load_config_for_training(
|
||||
config_path=args.config_path,
|
||||
override_params={}
|
||||
)
|
||||
|
||||
# Get the centralized config instance
|
||||
central_config = get_config()
|
||||
|
||||
# Create override parameters from command line arguments
|
||||
override_params = {}
|
||||
|
||||
# Model configuration - use new resolve_model_path function
|
||||
resolved_model_path, resolved_cache_dir = resolve_model_path(
|
||||
model_key='bge_m3',
|
||||
config_instance=central_config,
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
args.model_name_or_path = resolved_model_path
|
||||
args.cache_dir = resolved_cache_dir
|
||||
|
||||
# Data configuration
|
||||
data_config = config_dict.get('data', {})
|
||||
if args.max_seq_length is None:
|
||||
args.max_seq_length = data_config.get('max_seq_length', 512)
|
||||
if args.query_max_length is None:
|
||||
args.query_max_length = data_config.get('max_query_length', 64)
|
||||
if args.passage_max_length is None:
|
||||
args.passage_max_length = data_config.get('max_passage_length', 512)
|
||||
|
||||
# Get validation split from config
|
||||
validation_split = data_config.get('validation_split', 0.1)
|
||||
|
||||
# Training configuration
|
||||
training_config = config_dict.get('training', {})
|
||||
if args.num_train_epochs is None:
|
||||
args.num_train_epochs = training_config.get('default_num_epochs', 3)
|
||||
if args.per_device_train_batch_size is None:
|
||||
args.per_device_train_batch_size = training_config.get('default_batch_size', 4)
|
||||
if args.per_device_eval_batch_size is None:
|
||||
args.per_device_eval_batch_size = args.per_device_train_batch_size * 2
|
||||
if args.learning_rate is None:
|
||||
args.learning_rate = training_config.get('default_learning_rate', 1e-5)
|
||||
if args.warmup_ratio is None:
|
||||
args.warmup_ratio = training_config.get('default_warmup_ratio', 0.1)
|
||||
if args.weight_decay is None:
|
||||
args.weight_decay = training_config.get('default_weight_decay', 0.01)
|
||||
|
||||
# Model-specific (M3) configuration - highest precedence
|
||||
m3_config = config_dict.get('m3', {})
|
||||
train_group_size = m3_config.get('train_group_size', args.train_group_size)
|
||||
temperature = m3_config.get('temperature', args.temperature)
|
||||
use_hard_negatives = m3_config.get('use_hard_negatives', args.use_hard_negatives)
|
||||
use_self_distill = m3_config.get('use_self_distill', args.use_self_distill)
|
||||
query_instruction_for_retrieval = m3_config.get('query_instruction_for_retrieval', args.query_instruction)
|
||||
|
||||
# Hardware configuration
|
||||
hardware_config = config_dict.get('hardware', {})
|
||||
if args.per_device_train_batch_size is None:
|
||||
args.per_device_train_batch_size = hardware_config.get('per_device_train_batch_size', args.per_device_train_batch_size)
|
||||
if args.per_device_eval_batch_size is None:
|
||||
args.per_device_eval_batch_size = hardware_config.get('per_device_eval_batch_size', args.per_device_eval_batch_size)
|
||||
|
||||
# Create BGE-M3 configuration
|
||||
config = BGEM3Config(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
train_data_path=args.train_data,
|
||||
eval_data_path=args.eval_data,
|
||||
max_seq_length=args.max_seq_length,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
max_length=args.max_length,
|
||||
output_dir=args.output_dir,
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
weight_decay=args.weight_decay,
|
||||
train_group_size=train_group_size,
|
||||
temperature=temperature,
|
||||
use_hard_negatives=use_hard_negatives,
|
||||
use_self_distill=use_self_distill,
|
||||
query_instruction_for_retrieval=query_instruction_for_retrieval,
|
||||
fp16=args.fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
dataloader_num_workers=args.dataloader_num_workers,
|
||||
save_steps=args.save_steps,
|
||||
eval_steps=args.eval_steps,
|
||||
logging_steps=args.logging_steps,
|
||||
save_total_limit=args.save_total_limit,
|
||||
seed=args.seed,
|
||||
local_rank=args.local_rank,
|
||||
overwrite_output_dir=args.overwrite_output_dir
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(args):
|
||||
"""Setup BGE-M3 model and tokenizer"""
|
||||
logger.info(f"Loading BGE-M3 model from {args.model_name_or_path}")
|
||||
|
||||
# Detect offline mode if the path is local or HF_HUB_OFFLINE is set
|
||||
local_files_only = os.path.isdir(args.model_name_or_path) or bool(os.environ.get("HF_HUB_OFFLINE"))
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=True,
|
||||
local_files_only=local_files_only
|
||||
)
|
||||
|
||||
# Load model
|
||||
try:
|
||||
model = BGEM3Model(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
logger.info("Model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def setup_datasets(args, tokenizer):
|
||||
"""Setup datasets with enhanced format detection and migration support"""
|
||||
logger.info("Setting up datasets with enhanced format detection...")
|
||||
|
||||
# Handle legacy data migration if needed
|
||||
if args.legacy_data and args.migrate_legacy:
|
||||
logger.info(f"Migrating legacy data from {args.legacy_data}")
|
||||
legacy_base = os.path.splitext(args.legacy_data)[0]
|
||||
migrated_file = f"{legacy_base}_migrated_embedding.jsonl"
|
||||
|
||||
# For embedding, we don't need to migrate to flat pairs, but we can optimize the format
|
||||
# Copy the legacy data to work with if needed
|
||||
import shutil
|
||||
shutil.copy2(args.legacy_data, migrated_file)
|
||||
logger.info(f"Legacy data prepared for training: {migrated_file}")
|
||||
|
||||
# Setup training dataset
|
||||
if args.use_new_dataset_api:
|
||||
# Use new integrated dataset API with format detection
|
||||
train_dataset = create_embedding_dataset(
|
||||
data_path=args.train_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(train_dataset)} training examples from {args.train_data}")
|
||||
logger.info(f"Detected format: {train_dataset.detected_format}")
|
||||
|
||||
# Add legacy data if provided
|
||||
if args.legacy_data and args.legacy_support:
|
||||
logger.info(f"Loading additional legacy data from {args.legacy_data}")
|
||||
legacy_dataset = create_embedding_dataset(
|
||||
data_path=args.legacy_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
logger.info(f"Loaded {len(legacy_dataset)} legacy examples")
|
||||
|
||||
# Combine datasets
|
||||
train_dataset.data.extend(legacy_dataset.data)
|
||||
logger.info(f"Total training examples: {len(train_dataset)}")
|
||||
else:
|
||||
# Use original dataset API for backward compatibility
|
||||
train_dataset = BGEM3Dataset(
|
||||
data_path=args.train_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.train_group_size,
|
||||
is_train=True
|
||||
)
|
||||
logger.info(f"Loaded {len(train_dataset)} training examples (legacy API)")
|
||||
|
||||
# Setup evaluation dataset
|
||||
eval_dataset = None
|
||||
if args.eval_data:
|
||||
if args.use_new_dataset_api:
|
||||
eval_dataset = create_embedding_dataset(
|
||||
data_path=args.eval_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=False
|
||||
)
|
||||
else:
|
||||
eval_dataset = BGEM3Dataset(
|
||||
data_path=args.eval_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.train_group_size,
|
||||
is_train=False
|
||||
)
|
||||
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_data_loaders(args, train_dataset, eval_dataset):
|
||||
"""Setup data loaders with appropriate collate functions"""
|
||||
|
||||
# 4. Set up data loaders with appropriate collate function
|
||||
if hasattr(train_dataset, 'format') and train_dataset.format == 'triplets':
|
||||
# This might be a generic BGE dataset in triplets format
|
||||
collate_fn = collate_embedding_batch
|
||||
logger.info("Using embedding collate function for triplets format")
|
||||
else:
|
||||
collate_fn = collate_embedding_batch
|
||||
logger.info("Using standard embedding collate function")
|
||||
|
||||
# FastM3Dataset doesn't work well with multiprocessing due to pickle loading
|
||||
workers = args.dataloader_num_workers
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
# Evaluation data loader
|
||||
eval_dataloader = None
|
||||
if eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# Create configuration
|
||||
config = create_config_from_args(args)
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=getattr(logging, args.log_level),
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("🚀 Starting Enhanced BGE-M3 Fine-tuning")
|
||||
logger.info(f"Arguments: {args}")
|
||||
logger.info(f"Configuration: {config.to_dict()}")
|
||||
|
||||
# Set random seed
|
||||
set_random_seed(args.seed, deterministic=False)
|
||||
|
||||
# Setup distributed training if needed
|
||||
device = None
|
||||
if args.local_rank != -1:
|
||||
setup_distributed()
|
||||
# Device selection from config
|
||||
from utils.config_loader import get_config
|
||||
config_instance = get_config()
|
||||
device_type = config_instance.get('ascend.device_type', 'cuda')
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu # type: ignore
|
||||
torch_npu.npu.set_device(args.local_rank)
|
||||
device = torch.device("npu", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Distributed training enabled: using CPU")
|
||||
else:
|
||||
# Non-distributed device selection
|
||||
from utils.config_loader import get_config
|
||||
config_instance = get_config()
|
||||
device_type = config_instance.get('ascend.device_type', 'cuda')
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu # type: ignore
|
||||
device = torch.device("npu")
|
||||
logger.info("Using Ascend NPU")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU")
|
||||
|
||||
# Update config with device
|
||||
config.device = str(device)
|
||||
|
||||
# Save configuration
|
||||
if is_main_process():
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
config.save(os.path.join(args.output_dir, "config.json"))
|
||||
|
||||
# Setup model and tokenizer
|
||||
model, tokenizer = setup_model_and_tokenizer(args)
|
||||
|
||||
# Setup datasets with enhanced format detection
|
||||
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
|
||||
|
||||
# Setup data loaders
|
||||
train_dataloader, eval_dataloader = setup_data_loaders(args, train_dataset, eval_dataset)
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing enhanced trainer...")
|
||||
try:
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
logger.info("Trainer initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize trainer: {e}")
|
||||
raise
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume_from_checkpoint:
|
||||
logger.info(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
|
||||
try:
|
||||
trainer.load_checkpoint(args.resume_from_checkpoint)
|
||||
logger.info("Checkpoint loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
raise
|
||||
|
||||
# Train
|
||||
logger.info("🎯 Starting enhanced training...")
|
||||
try:
|
||||
# Handle type compatibility for trainer
|
||||
from typing import cast
|
||||
trainer_train_dataset = cast(BGEM3Dataset, train_dataset) if train_dataset else None
|
||||
trainer_eval_dataset = cast(BGEM3Dataset, eval_dataset) if eval_dataset else None
|
||||
|
||||
if not trainer_train_dataset:
|
||||
raise ValueError("Training dataset is required but not provided")
|
||||
|
||||
trainer.train(
|
||||
train_dataset=trainer_train_dataset,
|
||||
eval_dataset=trainer_eval_dataset,
|
||||
resume_from_checkpoint=args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
# Save final model
|
||||
if is_main_process():
|
||||
logger.info("Saving final model...")
|
||||
final_path = os.path.join(args.output_dir, "final_model")
|
||||
os.makedirs(final_path, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model.save_pretrained(final_path)
|
||||
tokenizer.save_pretrained(final_path)
|
||||
|
||||
# Save training summary
|
||||
summary = {
|
||||
"training_args": vars(args),
|
||||
"model_config": config.to_dict(),
|
||||
"final_metrics": getattr(trainer, 'best_metric', None),
|
||||
"total_steps": getattr(trainer, 'global_step', 0),
|
||||
"training_completed": datetime.now().isoformat(),
|
||||
"device_used": str(device),
|
||||
"dataset_format": getattr(train_dataset, 'detected_format', 'unknown'),
|
||||
"enhanced_features": {
|
||||
"use_new_dataset_api": args.use_new_dataset_api,
|
||||
"use_hard_negatives": args.use_hard_negatives,
|
||||
"use_self_distill": args.use_self_distill,
|
||||
"use_score_weighted_sampling": args.use_score_weighted_sampling,
|
||||
"multiple_positives": args.multiple_positives,
|
||||
"legacy_support": args.legacy_support
|
||||
}
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
logger.info(f"✅ Enhanced training completed! Model saved to {final_path}")
|
||||
|
||||
# Show training summary
|
||||
logger.info("📊 Training Summary:")
|
||||
logger.info(f" Total examples: {len(train_dataset)}")
|
||||
if hasattr(train_dataset, 'detected_format'):
|
||||
logger.info(f" Format detected: {train_dataset.detected_format}") # type: ignore[attr-defined]
|
||||
logger.info(f" Epochs: {args.num_train_epochs}")
|
||||
logger.info(f" Learning rate: {args.learning_rate}")
|
||||
logger.info(f" Batch size: {args.per_device_train_batch_size}")
|
||||
logger.info(f" Temperature: {args.temperature}")
|
||||
logger.info(f" Enhanced features enabled: {args.use_new_dataset_api}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed with error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if hasattr(trainer, 'hard_negative_miner') and trainer.hard_negative_miner:
|
||||
trainer.hard_negative_miner.stop_async_updates()
|
||||
|
||||
if args.local_rank != -1:
|
||||
from utils.distributed import cleanup_distributed
|
||||
cleanup_distributed()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
398
scripts/train_reranker.py
Normal file
398
scripts/train_reranker.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE-Reranker Cross-Encoder Model Training Script
|
||||
Uses the refactored dataset pipeline with flat pairs schema
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
||||
from transformers.optimization import AdamW
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from utils.logging import setup_logging, get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for reranker training"""
|
||||
parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to pretrained BGE-Reranker model")
|
||||
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
||||
help="Directory to cache downloaded models")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--reranker_data", type=str, required=True,
|
||||
help="Path to reranker training data (reranker_data.jsonl)")
|
||||
parser.add_argument("--eval_data", type=str, default=None,
|
||||
help="Path to evaluation data")
|
||||
parser.add_argument("--legacy_data", type=str, default=None,
|
||||
help="Path to legacy nested data (will be migrated)")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./output/bge-reranker",
|
||||
help="Directory to save the trained model")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
|
||||
help="Batch size per GPU/CPU for training")
|
||||
parser.add_argument("--per_device_eval_batch_size", type=int, default=16,
|
||||
help="Batch size per GPU/CPU for evaluation")
|
||||
parser.add_argument("--learning_rate", type=float, default=6e-5,
|
||||
help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01,
|
||||
help="Weight decay")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
||||
help="Warmup ratio")
|
||||
|
||||
# Model specific arguments
|
||||
parser.add_argument("--query_max_length", type=int, default=64,
|
||||
help="Maximum query length")
|
||||
parser.add_argument("--passage_max_length", type=int, default=448,
|
||||
help="Maximum passage length")
|
||||
parser.add_argument("--loss_type", type=str, default="cross_entropy",
|
||||
choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"],
|
||||
help="Loss function type")
|
||||
|
||||
# Enhanced features
|
||||
parser.add_argument("--use_score_weighting", action="store_true",
|
||||
help="Use score weighting in loss calculation")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0,
|
||||
help="Label smoothing for cross entropy loss")
|
||||
|
||||
# System arguments
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Use 16-bit floating point precision")
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
||||
help="Number of workers for data loading")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
parser.add_argument("--log_level", type=str, default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging level")
|
||||
|
||||
# Backward compatibility
|
||||
parser.add_argument("--legacy_support", action="store_true", default=True,
|
||||
help="Enable legacy nested format support")
|
||||
parser.add_argument("--migrate_legacy", action="store_true",
|
||||
help="Migrate legacy nested format to flat format")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(args):
|
||||
"""Setup BGE-Reranker model and tokenizer"""
|
||||
logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def migrate_legacy_data(args):
|
||||
"""Migrate legacy nested data to flat format if needed"""
|
||||
if args.legacy_data and args.migrate_legacy:
|
||||
logger.info("Migrating legacy nested data to flat format...")
|
||||
|
||||
# Create output file name
|
||||
legacy_base = os.path.splitext(args.legacy_data)[0]
|
||||
migrated_file = f"{legacy_base}_migrated_flat.jsonl"
|
||||
|
||||
# Migrate the data
|
||||
migrate_nested_to_flat(args.legacy_data, migrated_file)
|
||||
|
||||
logger.info(f"✅ Migrated legacy data to {migrated_file}")
|
||||
return migrated_file
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def setup_datasets(args, tokenizer):
|
||||
"""Setup reranker datasets"""
|
||||
logger.info("Setting up reranker datasets...")
|
||||
|
||||
# Migrate legacy data if needed
|
||||
migrated_file = migrate_legacy_data(args)
|
||||
|
||||
# Primary reranker data
|
||||
train_dataset = create_reranker_dataset(
|
||||
data_path=args.reranker_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True,
|
||||
legacy_support=args.legacy_support
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}")
|
||||
logger.info(f"Detected format: {train_dataset.format_type}")
|
||||
|
||||
# Add migrated legacy data if available
|
||||
if migrated_file:
|
||||
logger.info(f"Loading migrated legacy data from {migrated_file}")
|
||||
migrated_dataset = create_reranker_dataset(
|
||||
data_path=migrated_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True,
|
||||
legacy_support=False # Migrated data is in flat format
|
||||
)
|
||||
logger.info(f"Loaded {len(migrated_dataset)} migrated examples")
|
||||
|
||||
# Combine datasets
|
||||
train_dataset.data.extend(migrated_dataset.data)
|
||||
logger.info(f"Total training examples: {len(train_dataset)}")
|
||||
|
||||
# Evaluation dataset
|
||||
eval_dataset = None
|
||||
if args.eval_data:
|
||||
eval_dataset = create_reranker_dataset(
|
||||
data_path=args.eval_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=False,
|
||||
legacy_support=args.legacy_support
|
||||
)
|
||||
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_loss_function(args):
|
||||
"""Setup cross-encoder loss function"""
|
||||
if args.loss_type == "cross_entropy":
|
||||
if args.label_smoothing > 0:
|
||||
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
else:
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
elif args.loss_type == "binary_cross_entropy":
|
||||
loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
elif args.loss_type == "listwise_ce":
|
||||
# Custom listwise cross-entropy loss
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {args.loss_type}")
|
||||
|
||||
logger.info(f"Using loss function: {args.loss_type}")
|
||||
return loss_fn
|
||||
|
||||
|
||||
def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args):
|
||||
"""Train for one epoch"""
|
||||
model.train()
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# Move batch to device
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
attention_mask = batch['attention_mask'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
||||
|
||||
# Calculate loss
|
||||
if args.loss_type == "binary_cross_entropy":
|
||||
# For binary classification
|
||||
loss = loss_fn(logits.squeeze(), labels.float())
|
||||
else:
|
||||
# For multi-class classification
|
||||
loss = loss_fn(logits, labels)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Log progress
|
||||
if batch_idx % 100 == 0:
|
||||
logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
||||
|
||||
return total_loss / num_batches
|
||||
|
||||
|
||||
def evaluate(model, dataloader, loss_fn, device, args):
|
||||
"""Evaluate the model"""
|
||||
model.eval()
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
attention_mask = batch['attention_mask'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
||||
|
||||
# Calculate loss
|
||||
if args.loss_type == "binary_cross_entropy":
|
||||
loss = loss_fn(logits.squeeze(), labels.float())
|
||||
# Calculate accuracy for binary classification
|
||||
predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long()
|
||||
else:
|
||||
loss = loss_fn(logits, labels)
|
||||
# Calculate accuracy for multi-class classification
|
||||
predictions = torch.argmax(logits, dim=-1)
|
||||
|
||||
total_loss += loss.item()
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
return total_loss / len(dataloader), accuracy
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=args.log_level,
|
||||
log_to_file=True,
|
||||
log_to_console=True
|
||||
)
|
||||
|
||||
logger.info("🚀 Starting BGE-Reranker Training")
|
||||
logger.info(f"Arguments: {vars(args)}")
|
||||
|
||||
# Set random seed
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Setup device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Setup model and tokenizer
|
||||
model, tokenizer = setup_model_and_tokenizer(args)
|
||||
model.to(device)
|
||||
|
||||
# Setup datasets
|
||||
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
|
||||
|
||||
# Setup data loaders
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_reranker_batch,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_reranker_batch,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
# Setup optimizer and scheduler
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
total_steps = len(train_dataloader) * args.num_train_epochs
|
||||
warmup_steps = int(total_steps * args.warmup_ratio)
|
||||
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_steps
|
||||
)
|
||||
|
||||
# Setup loss function
|
||||
loss_fn = setup_loss_function(args)
|
||||
|
||||
# Training loop
|
||||
logger.info("🎯 Starting training...")
|
||||
|
||||
best_eval_loss = float('inf')
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}")
|
||||
|
||||
# Train
|
||||
train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args)
|
||||
logger.info(f"Training loss: {train_loss:.4f}")
|
||||
|
||||
# Evaluate
|
||||
if eval_dataloader:
|
||||
eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args)
|
||||
logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
|
||||
|
||||
# Save best model
|
||||
if eval_loss < best_eval_loss:
|
||||
best_eval_loss = eval_loss
|
||||
best_model_path = os.path.join(args.output_dir, "best_model")
|
||||
os.makedirs(best_model_path, exist_ok=True)
|
||||
model.save_pretrained(best_model_path)
|
||||
tokenizer.save_pretrained(best_model_path)
|
||||
logger.info(f"Saved best model to {best_model_path}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = os.path.join(args.output_dir, "final_model")
|
||||
os.makedirs(final_model_path, exist_ok=True)
|
||||
model.save_pretrained(final_model_path)
|
||||
tokenizer.save_pretrained(final_model_path)
|
||||
|
||||
logger.info(f"✅ Training completed! Final model saved to {final_model_path}")
|
||||
|
||||
# Show training summary
|
||||
logger.info("📊 Training Summary:")
|
||||
logger.info(f" Total examples: {len(train_dataset)}")
|
||||
logger.info(f" Format detected: {train_dataset.format_type}")
|
||||
logger.info(f" Epochs: {args.num_train_epochs}")
|
||||
logger.info(f" Learning rate: {args.learning_rate}")
|
||||
logger.info(f" Batch size: {args.per_device_train_batch_size}")
|
||||
logger.info(f" Loss function: {args.loss_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
scripts/validate_m3.py
Normal file
122
scripts/validate_m3.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE-M3 model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
|
||||
def test_model_embeddings(model_path, test_queries=None):
|
||||
"""Test if the model can generate embeddings"""
|
||||
|
||||
if test_queries is None:
|
||||
test_queries = [
|
||||
"什么是深度学习?", # Original training query
|
||||
"什么是机器学习?", # Related query
|
||||
"今天天气怎么样?" # Unrelated query
|
||||
]
|
||||
|
||||
print(f"🧪 Testing model from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Model and tokenizer loaded successfully")
|
||||
|
||||
# Test embeddings generation
|
||||
embeddings = []
|
||||
|
||||
for i, query in enumerate(test_queries):
|
||||
print(f"\n📝 Testing query {i+1}: '{query}'")
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
with torch.no_grad():
|
||||
outputs = model.forward(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if 'dense' in outputs:
|
||||
embedding = outputs['dense'].cpu().numpy() # type: ignore[index]
|
||||
embeddings.append(embedding)
|
||||
|
||||
print(f" ✅ Embedding shape: {embedding.shape}")
|
||||
print(f" 📊 Embedding norm: {np.linalg.norm(embedding):.4f}")
|
||||
print(f" 📈 Mean activation: {embedding.mean():.6f}")
|
||||
print(f" 📉 Std activation: {embedding.std():.6f}")
|
||||
else:
|
||||
print(f" ❌ No dense embedding returned")
|
||||
return False
|
||||
|
||||
# Compare embeddings
|
||||
if len(embeddings) >= 2:
|
||||
print(f"\n🔗 Similarity Analysis:")
|
||||
for i in range(len(embeddings)):
|
||||
for j in range(i+1, len(embeddings)):
|
||||
sim = np.dot(embeddings[i].flatten(), embeddings[j].flatten()) / (
|
||||
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
|
||||
)
|
||||
print(f" Query {i+1} vs Query {j+1}: {sim:.4f}")
|
||||
|
||||
print(f"\n🎉 Model validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Model validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_models(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned model"""
|
||||
print("🔄 Comparing original vs fine-tuned model...")
|
||||
|
||||
test_query = "什么是深度学习?"
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Model:")
|
||||
orig_success = test_model_embeddings(original_path, [test_query])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Model:")
|
||||
ft_success = test_model_embeddings(finetuned_path, [test_query])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE-M3 Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_m3_output/final_model"
|
||||
success = test_model_embeddings(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-m3"
|
||||
compare_models(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Fine-tuning validation: PASSED")
|
||||
print(" The model can generate embeddings successfully!")
|
||||
else:
|
||||
print("\n❌ Fine-tuning validation: FAILED")
|
||||
print(" The model has issues generating embeddings!")
|
||||
150
scripts/validate_reranker.py
Normal file
150
scripts/validate_reranker.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE Reranker model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
|
||||
def test_reranker_scoring(model_path, test_queries=None):
|
||||
"""Test if the reranker can score query-passage pairs"""
|
||||
|
||||
if test_queries is None:
|
||||
# Test query with relevant and irrelevant passages
|
||||
test_queries = [
|
||||
{
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域", # Relevant (should score high)
|
||||
"机器学习包含多种算法", # Somewhat relevant
|
||||
"今天天气很好", # Irrelevant (should score low)
|
||||
"深度学习使用神经网络进行学习" # Very relevant
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🧪 Testing reranker from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Reranker model and tokenizer loaded successfully")
|
||||
|
||||
for query_idx, test_case in enumerate(test_queries):
|
||||
query = test_case["query"]
|
||||
passages = test_case["passages"]
|
||||
|
||||
print(f"\n📝 Test Case {query_idx + 1}")
|
||||
print(f"Query: '{query}'")
|
||||
print(f"Passages to rank:")
|
||||
|
||||
scores = []
|
||||
|
||||
for i, passage in enumerate(passages):
|
||||
print(f" {i+1}. '{passage}'")
|
||||
|
||||
# Create query-passage pair
|
||||
sep_token = tokenizer.sep_token or '[SEP]'
|
||||
pair_text = f"{query} {sep_token} {passage}"
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get relevance score
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Extract score (logits)
|
||||
if hasattr(outputs, 'logits'):
|
||||
score = outputs.logits.item()
|
||||
elif isinstance(outputs, torch.Tensor):
|
||||
score = outputs.item()
|
||||
else:
|
||||
score = outputs[0].item()
|
||||
|
||||
scores.append((i, passage, score))
|
||||
|
||||
# Sort by relevance score (highest first)
|
||||
scores.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
print(f"\n🏆 Ranking Results:")
|
||||
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
|
||||
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
|
||||
|
||||
# Check if most relevant passages ranked highest
|
||||
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
|
||||
top_passages = [item[1] for item in scores[:2]]
|
||||
|
||||
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
|
||||
irrelevant_in_top = "今天天气很好" in top_passages
|
||||
|
||||
if relevant_in_top and not irrelevant_in_top:
|
||||
print(f" ✅ Good ranking - relevant passages ranked higher")
|
||||
elif relevant_in_top:
|
||||
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
|
||||
else:
|
||||
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
|
||||
|
||||
print(f"\n🎉 Reranker validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Reranker validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_rerankers(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned reranker"""
|
||||
print("🔄 Comparing original vs fine-tuned reranker...")
|
||||
|
||||
test_case = {
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域",
|
||||
"今天天气很好"
|
||||
]
|
||||
}
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Reranker:")
|
||||
orig_success = test_reranker_scoring(original_path, [test_case])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Reranker:")
|
||||
ft_success = test_reranker_scoring(finetuned_path, [test_case])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE Reranker Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_reranker_output/final_model"
|
||||
success = test_reranker_scoring(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-reranker-base"
|
||||
compare_rerankers(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Reranker validation: PASSED")
|
||||
print(" The reranker can score query-passage pairs successfully!")
|
||||
else:
|
||||
print("\n❌ Reranker validation: FAILED")
|
||||
print(" The reranker has issues with scoring!")
|
||||
641
scripts/validation_utils.py
Normal file
641
scripts/validation_utils.py
Normal file
@@ -0,0 +1,641 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Validation Utilities for BGE Fine-tuning
|
||||
|
||||
This module provides utilities for model validation including:
|
||||
- Automatic model discovery
|
||||
- Test data preparation and validation
|
||||
- Results analysis and reporting
|
||||
- Benchmark dataset creation
|
||||
|
||||
Usage:
|
||||
python scripts/validation_utils.py --discover-models
|
||||
python scripts/validation_utils.py --prepare-test-data
|
||||
python scripts/validation_utils.py --analyze-results ./validation_results
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelDiscovery:
|
||||
"""Discover trained models in the workspace"""
|
||||
|
||||
def __init__(self, workspace_root: str = "."):
|
||||
self.workspace_root = Path(workspace_root)
|
||||
|
||||
def find_trained_models(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Find all trained BGE models in the workspace"""
|
||||
models = {
|
||||
'retrievers': {},
|
||||
'rerankers': {},
|
||||
'joint': {}
|
||||
}
|
||||
|
||||
# Common output directories to search
|
||||
search_patterns = [
|
||||
"output/*/final_model",
|
||||
"output/*/best_model",
|
||||
"output/*/*/final_model",
|
||||
"output/*/*/best_model",
|
||||
"test_*_output/final_model",
|
||||
"training_output/*/final_model"
|
||||
]
|
||||
|
||||
for pattern in search_patterns:
|
||||
for model_path in glob.glob(str(self.workspace_root / pattern)):
|
||||
model_path = Path(model_path)
|
||||
|
||||
if self._is_valid_model_directory(model_path):
|
||||
model_info = self._analyze_model_directory(model_path)
|
||||
|
||||
if model_info['type'] == 'retriever':
|
||||
models['retrievers'][model_info['name']] = model_info
|
||||
elif model_info['type'] == 'reranker':
|
||||
models['rerankers'][model_info['name']] = model_info
|
||||
elif model_info['type'] == 'joint':
|
||||
models['joint'][model_info['name']] = model_info
|
||||
|
||||
return models
|
||||
|
||||
def _is_valid_model_directory(self, path: Path) -> bool:
|
||||
"""Check if directory contains a valid trained model"""
|
||||
required_files = ['config.json', 'pytorch_model.bin']
|
||||
alternative_files = ['model.safetensors', 'tokenizer.json']
|
||||
|
||||
has_required = any((path / f).exists() for f in required_files)
|
||||
has_alternative = any((path / f).exists() for f in alternative_files)
|
||||
|
||||
return has_required or has_alternative
|
||||
|
||||
def _analyze_model_directory(self, path: Path) -> Dict[str, str]:
|
||||
"""Analyze model directory to determine type and metadata"""
|
||||
info = {
|
||||
'path': str(path),
|
||||
'name': self._generate_model_name(path),
|
||||
'type': 'unknown',
|
||||
'created': 'unknown',
|
||||
'training_info': {}
|
||||
}
|
||||
|
||||
# Try to determine model type from path
|
||||
path_str = str(path).lower()
|
||||
if 'm3' in path_str or 'retriever' in path_str:
|
||||
info['type'] = 'retriever'
|
||||
elif 'reranker' in path_str:
|
||||
info['type'] = 'reranker'
|
||||
elif 'joint' in path_str:
|
||||
info['type'] = 'joint'
|
||||
|
||||
# Get creation time
|
||||
if path.exists():
|
||||
info['created'] = datetime.fromtimestamp(path.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# Try to read training info
|
||||
training_summary_path = path.parent / 'training_summary.json'
|
||||
if training_summary_path.exists():
|
||||
try:
|
||||
with open(training_summary_path, 'r') as f:
|
||||
info['training_info'] = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Try to read config
|
||||
config_path = path / 'config.json'
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
if 'model_type' in config:
|
||||
info['type'] = config['model_type']
|
||||
except:
|
||||
pass
|
||||
|
||||
return info
|
||||
|
||||
def _generate_model_name(self, path: Path) -> str:
|
||||
"""Generate a readable name for the model"""
|
||||
parts = path.parts
|
||||
|
||||
# Try to find meaningful parts
|
||||
meaningful_parts = []
|
||||
for part in parts[-3:]: # Look at last 3 parts
|
||||
if part not in ['final_model', 'best_model', 'output']:
|
||||
meaningful_parts.append(part)
|
||||
|
||||
if meaningful_parts:
|
||||
return '_'.join(meaningful_parts)
|
||||
else:
|
||||
return f"model_{path.parent.name}"
|
||||
|
||||
def print_discovered_models(self):
|
||||
"""Print discovered models in a formatted way"""
|
||||
models = self.find_trained_models()
|
||||
|
||||
print("\n🔍 DISCOVERED TRAINED MODELS")
|
||||
print("=" * 60)
|
||||
|
||||
for model_type, model_dict in models.items():
|
||||
if model_dict:
|
||||
print(f"\n📊 {model_type.upper()}:")
|
||||
for name, info in model_dict.items():
|
||||
print(f" • {name}")
|
||||
print(f" Path: {info['path']}")
|
||||
print(f" Created: {info['created']}")
|
||||
if info['training_info']:
|
||||
training_info = info['training_info']
|
||||
if 'total_steps' in training_info:
|
||||
print(f" Training Steps: {training_info['total_steps']}")
|
||||
if 'final_metrics' in training_info and training_info['final_metrics']:
|
||||
print(f" Final Metrics: {training_info['final_metrics']}")
|
||||
|
||||
if not any(models.values()):
|
||||
print("❌ No trained models found!")
|
||||
print("💡 Make sure you have trained some models first.")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
class TestDataManager:
|
||||
"""Manage test datasets for validation"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/datasets"):
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
def find_test_datasets(self) -> Dict[str, List[str]]:
|
||||
"""Find available test datasets"""
|
||||
datasets = {
|
||||
'retriever_datasets': [],
|
||||
'reranker_datasets': [],
|
||||
'general_datasets': []
|
||||
}
|
||||
|
||||
# Search for JSONL files in data directory
|
||||
for jsonl_file in self.data_dir.rglob("*.jsonl"):
|
||||
file_path = str(jsonl_file)
|
||||
|
||||
# Categorize based on filename/path
|
||||
if 'reranker' in file_path.lower():
|
||||
datasets['reranker_datasets'].append(file_path)
|
||||
elif 'embedding' in file_path.lower() or 'm3' in file_path.lower():
|
||||
datasets['retriever_datasets'].append(file_path)
|
||||
else:
|
||||
datasets['general_datasets'].append(file_path)
|
||||
|
||||
# Also search for JSON files
|
||||
for json_file in self.data_dir.rglob("*.json"):
|
||||
file_path = str(json_file)
|
||||
if 'test' in file_path.lower() or 'dev' in file_path.lower():
|
||||
if 'reranker' in file_path.lower() or 'AFQMC' in file_path:
|
||||
datasets['reranker_datasets'].append(file_path)
|
||||
else:
|
||||
datasets['general_datasets'].append(file_path)
|
||||
|
||||
return datasets
|
||||
|
||||
def validate_dataset(self, dataset_path: str) -> Dict[str, Any]:
|
||||
"""Validate a dataset file and analyze its format"""
|
||||
validation_result = {
|
||||
'path': dataset_path,
|
||||
'exists': False,
|
||||
'format': 'unknown',
|
||||
'sample_count': 0,
|
||||
'has_queries': False,
|
||||
'has_passages': False,
|
||||
'has_labels': False,
|
||||
'issues': [],
|
||||
'sample_data': {}
|
||||
}
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
validation_result['issues'].append(f"File does not exist: {dataset_path}")
|
||||
return validation_result
|
||||
|
||||
validation_result['exists'] = True
|
||||
|
||||
try:
|
||||
# Read first few lines to analyze format
|
||||
samples = []
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= 5: # Only analyze first 5 lines
|
||||
break
|
||||
try:
|
||||
sample = json.loads(line.strip())
|
||||
samples.append(sample)
|
||||
except json.JSONDecodeError:
|
||||
validation_result['issues'].append(f"Invalid JSON at line {i+1}")
|
||||
|
||||
if samples:
|
||||
validation_result['sample_count'] = len(samples)
|
||||
validation_result['sample_data'] = samples[0] if samples else {}
|
||||
|
||||
# Analyze format
|
||||
first_sample = samples[0]
|
||||
|
||||
# Check for common fields
|
||||
if 'query' in first_sample:
|
||||
validation_result['has_queries'] = True
|
||||
|
||||
if any(field in first_sample for field in ['passage', 'pos', 'neg', 'candidates']):
|
||||
validation_result['has_passages'] = True
|
||||
|
||||
if any(field in first_sample for field in ['label', 'labels', 'score']):
|
||||
validation_result['has_labels'] = True
|
||||
|
||||
# Determine format
|
||||
if 'pos' in first_sample and 'neg' in first_sample:
|
||||
validation_result['format'] = 'triplets'
|
||||
elif 'candidates' in first_sample:
|
||||
validation_result['format'] = 'candidates'
|
||||
elif 'passage' in first_sample and 'label' in first_sample:
|
||||
validation_result['format'] = 'pairs'
|
||||
elif 'sentence1' in first_sample and 'sentence2' in first_sample:
|
||||
validation_result['format'] = 'sentence_pairs'
|
||||
else:
|
||||
validation_result['format'] = 'custom'
|
||||
|
||||
except Exception as e:
|
||||
validation_result['issues'].append(f"Error reading file: {e}")
|
||||
|
||||
return validation_result
|
||||
|
||||
def print_dataset_analysis(self):
|
||||
"""Print analysis of available datasets"""
|
||||
datasets = self.find_test_datasets()
|
||||
|
||||
print("\n📁 AVAILABLE TEST DATASETS")
|
||||
print("=" * 60)
|
||||
|
||||
for category, dataset_list in datasets.items():
|
||||
if dataset_list:
|
||||
print(f"\n📊 {category.upper().replace('_', ' ')}:")
|
||||
for dataset_path in dataset_list:
|
||||
validation = self.validate_dataset(dataset_path)
|
||||
status = "✅" if validation['exists'] and not validation['issues'] else "⚠️"
|
||||
print(f" {status} {Path(dataset_path).name}")
|
||||
print(f" Path: {dataset_path}")
|
||||
print(f" Format: {validation['format']}")
|
||||
if validation['issues']:
|
||||
for issue in validation['issues']:
|
||||
print(f" ⚠️ {issue}")
|
||||
|
||||
if not any(datasets.values()):
|
||||
print("❌ No test datasets found!")
|
||||
print("💡 Place your test data in the data/datasets/ directory")
|
||||
|
||||
return datasets
|
||||
|
||||
def create_sample_test_data(self, output_dir: str = "data/test_samples"):
|
||||
"""Create sample test datasets for validation"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Create sample retriever test data
|
||||
retriever_test = [
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"pos": [
|
||||
"人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
|
||||
"AI技术包括机器学习、深度学习、自然语言处理等多个领域"
|
||||
],
|
||||
"neg": [
|
||||
"今天的天气非常好,阳光明媚",
|
||||
"我喜欢听音乐和看电影"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"pos": [
|
||||
"学习编程需要选择合适的语言,如Python、Java等,然后通过实践项目提升技能",
|
||||
"编程学习的关键是多写代码、多调试、多思考算法逻辑"
|
||||
],
|
||||
"neg": [
|
||||
"烹饪是一门艺术,需要掌握火候和调味",
|
||||
"运动对健康很重要,每天应该锻炼30分钟"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Create sample reranker test data
|
||||
reranker_test = [
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"passage": "人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
|
||||
"label": 1
|
||||
},
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"passage": "今天的天气非常好,阳光明媚",
|
||||
"label": 0
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"passage": "学习编程需要选择合适的语言,如Python、Java等,然后通过实践项目提升技能",
|
||||
"label": 1
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"passage": "烹饪是一门艺术,需要掌握火候和调味",
|
||||
"label": 0
|
||||
}
|
||||
]
|
||||
|
||||
# Save retriever test data
|
||||
retriever_path = os.path.join(output_dir, "retriever_test.jsonl")
|
||||
with open(retriever_path, 'w', encoding='utf-8') as f:
|
||||
for item in retriever_test:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# Save reranker test data
|
||||
reranker_path = os.path.join(output_dir, "reranker_test.jsonl")
|
||||
with open(reranker_path, 'w', encoding='utf-8') as f:
|
||||
for item in reranker_test:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"✅ Sample test datasets created in: {output_dir}")
|
||||
print(f" 📁 Retriever test: {retriever_path}")
|
||||
print(f" 📁 Reranker test: {reranker_path}")
|
||||
|
||||
return {
|
||||
'retriever_test': retriever_path,
|
||||
'reranker_test': reranker_path
|
||||
}
|
||||
|
||||
|
||||
class ValidationReportAnalyzer:
|
||||
"""Analyze validation results and generate insights"""
|
||||
|
||||
def __init__(self, results_dir: str):
|
||||
self.results_dir = Path(results_dir)
|
||||
|
||||
def analyze_results(self) -> Dict[str, Any]:
|
||||
"""Analyze validation results"""
|
||||
analysis = {
|
||||
'overall_status': 'unknown',
|
||||
'model_performance': {},
|
||||
'recommendations': [],
|
||||
'summary_stats': {}
|
||||
}
|
||||
|
||||
# Look for results files
|
||||
json_results = self.results_dir / "validation_results.json"
|
||||
|
||||
if not json_results.exists():
|
||||
analysis['recommendations'].append("No validation results found. Run validation first.")
|
||||
return analysis
|
||||
|
||||
try:
|
||||
with open(json_results, 'r') as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Analyze each model type
|
||||
improvements = 0
|
||||
degradations = 0
|
||||
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in results and 'comparisons' in results[model_type]:
|
||||
model_analysis = self._analyze_model_results(results[model_type])
|
||||
analysis['model_performance'][model_type] = model_analysis
|
||||
|
||||
improvements += model_analysis['total_improvements']
|
||||
degradations += model_analysis['total_degradations']
|
||||
|
||||
# Overall status
|
||||
if improvements > degradations * 1.5:
|
||||
analysis['overall_status'] = 'excellent'
|
||||
elif improvements > degradations:
|
||||
analysis['overall_status'] = 'good'
|
||||
elif improvements == degradations:
|
||||
analysis['overall_status'] = 'mixed'
|
||||
else:
|
||||
analysis['overall_status'] = 'poor'
|
||||
|
||||
# Generate recommendations
|
||||
analysis['recommendations'] = self._generate_recommendations(analysis)
|
||||
|
||||
# Summary statistics
|
||||
analysis['summary_stats'] = {
|
||||
'total_improvements': improvements,
|
||||
'total_degradations': degradations,
|
||||
'net_improvement': improvements - degradations,
|
||||
'improvement_ratio': improvements / (improvements + degradations) if (improvements + degradations) > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
analysis['recommendations'].append(f"Error analyzing results: {e}")
|
||||
|
||||
return analysis
|
||||
|
||||
def _analyze_model_results(self, model_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Analyze results for a specific model type"""
|
||||
analysis = {
|
||||
'total_improvements': 0,
|
||||
'total_degradations': 0,
|
||||
'best_dataset': None,
|
||||
'worst_dataset': None,
|
||||
'avg_improvement': 0,
|
||||
'issues': []
|
||||
}
|
||||
|
||||
dataset_scores = {}
|
||||
|
||||
for dataset_name, comparison in model_results.get('comparisons', {}).items():
|
||||
improvements = len(comparison.get('improvements', {}))
|
||||
degradations = len(comparison.get('degradations', {}))
|
||||
|
||||
analysis['total_improvements'] += improvements
|
||||
analysis['total_degradations'] += degradations
|
||||
|
||||
# Score dataset (improvements - degradations)
|
||||
dataset_score = improvements - degradations
|
||||
dataset_scores[dataset_name] = dataset_score
|
||||
|
||||
if dataset_scores:
|
||||
analysis['best_dataset'] = max(dataset_scores.items(), key=lambda x: x[1])
|
||||
analysis['worst_dataset'] = min(dataset_scores.items(), key=lambda x: x[1])
|
||||
analysis['avg_improvement'] = sum(dataset_scores.values()) / len(dataset_scores)
|
||||
|
||||
# Identify issues
|
||||
if analysis['total_degradations'] > analysis['total_improvements']:
|
||||
analysis['issues'].append("More degradations than improvements")
|
||||
|
||||
if analysis['avg_improvement'] < 0:
|
||||
analysis['issues'].append("Negative average improvement")
|
||||
|
||||
return analysis
|
||||
|
||||
def _generate_recommendations(self, analysis: Dict[str, Any]) -> List[str]:
|
||||
"""Generate recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
overall_status = analysis['overall_status']
|
||||
|
||||
if overall_status == 'excellent':
|
||||
recommendations.append("🎉 Great job! Your fine-tuned models show significant improvements.")
|
||||
recommendations.append("Consider deploying these models to production.")
|
||||
|
||||
elif overall_status == 'good':
|
||||
recommendations.append("✅ Good progress! Models show overall improvement.")
|
||||
recommendations.append("Consider further fine-tuning on specific datasets where degradation occurred.")
|
||||
|
||||
elif overall_status == 'mixed':
|
||||
recommendations.append("⚠️ Mixed results. Some improvements, some degradations.")
|
||||
recommendations.append("Analyze which datasets/metrics are degrading and adjust training data.")
|
||||
recommendations.append("Consider different hyperparameters or longer training.")
|
||||
|
||||
elif overall_status == 'poor':
|
||||
recommendations.append("❌ Models show more degradations than improvements.")
|
||||
recommendations.append("Review training data quality and format.")
|
||||
recommendations.append("Check hyperparameters (learning rate might be too high).")
|
||||
recommendations.append("Consider starting with different base models.")
|
||||
|
||||
# Specific model recommendations
|
||||
for model_type, model_perf in analysis.get('model_performance', {}).items():
|
||||
if model_perf['issues']:
|
||||
recommendations.append(f"🔧 {model_type.title()} issues: {', '.join(model_perf['issues'])}")
|
||||
|
||||
if model_perf['best_dataset']:
|
||||
best_dataset, best_score = model_perf['best_dataset']
|
||||
recommendations.append(f"📊 {model_type.title()} performs best on: {best_dataset} (score: {best_score})")
|
||||
|
||||
if model_perf['worst_dataset']:
|
||||
worst_dataset, worst_score = model_perf['worst_dataset']
|
||||
recommendations.append(f"📉 {model_type.title()} needs improvement on: {worst_dataset} (score: {worst_score})")
|
||||
|
||||
return recommendations
|
||||
|
||||
def print_analysis(self):
|
||||
"""Print comprehensive analysis of validation results"""
|
||||
analysis = self.analyze_results()
|
||||
|
||||
print(f"\n📊 VALIDATION RESULTS ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
# Overall status
|
||||
status_emoji = {
|
||||
'excellent': '🌟',
|
||||
'good': '✅',
|
||||
'mixed': '⚠️',
|
||||
'poor': '❌',
|
||||
'unknown': '❓'
|
||||
}
|
||||
|
||||
print(f"{status_emoji.get(analysis['overall_status'], '❓')} Overall Status: {analysis['overall_status'].upper()}")
|
||||
|
||||
# Summary stats
|
||||
if 'summary_stats' in analysis and analysis['summary_stats']:
|
||||
stats = analysis['summary_stats']
|
||||
print(f"\n📈 Summary Statistics:")
|
||||
print(f" Total Improvements: {stats['total_improvements']}")
|
||||
print(f" Total Degradations: {stats['total_degradations']}")
|
||||
print(f" Net Improvement: {stats['net_improvement']:+d}")
|
||||
print(f" Improvement Ratio: {stats['improvement_ratio']:.2%}")
|
||||
|
||||
# Model performance
|
||||
if analysis['model_performance']:
|
||||
print(f"\n📊 Model Performance:")
|
||||
for model_type, perf in analysis['model_performance'].items():
|
||||
print(f" {model_type.title()}:")
|
||||
print(f" Improvements: {perf['total_improvements']}")
|
||||
print(f" Degradations: {perf['total_degradations']}")
|
||||
print(f" Average Score: {perf['avg_improvement']:.2f}")
|
||||
|
||||
if perf['best_dataset']:
|
||||
print(f" Best Dataset: {perf['best_dataset'][0]} (score: {perf['best_dataset'][1]})")
|
||||
|
||||
if perf['issues']:
|
||||
print(f" Issues: {', '.join(perf['issues'])}")
|
||||
|
||||
# Recommendations
|
||||
if analysis['recommendations']:
|
||||
print(f"\n💡 Recommendations:")
|
||||
for i, rec in enumerate(analysis['recommendations'], 1):
|
||||
print(f" {i}. {rec}")
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validation utilities for BGE models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("--discover-models", action="store_true",
|
||||
help="Discover trained models in workspace")
|
||||
parser.add_argument("--analyze-datasets", action="store_true",
|
||||
help="Analyze available test datasets")
|
||||
parser.add_argument("--create-sample-data", action="store_true",
|
||||
help="Create sample test datasets")
|
||||
parser.add_argument("--analyze-results", type=str, default=None,
|
||||
help="Analyze validation results in specified directory")
|
||||
parser.add_argument("--workspace-root", type=str, default=".",
|
||||
help="Root directory of workspace")
|
||||
parser.add_argument("--data-dir", type=str, default="data/datasets",
|
||||
help="Directory containing test datasets")
|
||||
parser.add_argument("--output-dir", type=str, default="data/test_samples",
|
||||
help="Output directory for sample data")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main utility function"""
|
||||
args = parse_args()
|
||||
|
||||
print("🔧 BGE Validation Utilities")
|
||||
print("=" * 50)
|
||||
|
||||
if args.discover_models:
|
||||
print("\n🔍 Discovering trained models...")
|
||||
discovery = ModelDiscovery(args.workspace_root)
|
||||
models = discovery.print_discovered_models()
|
||||
|
||||
if any(models.values()):
|
||||
print("\n💡 Use these model paths in your validation commands:")
|
||||
for model_type, model_dict in models.items():
|
||||
for name, info in model_dict.items():
|
||||
print(f" --{model_type[:-1]}_model {info['path']}")
|
||||
|
||||
if args.analyze_datasets:
|
||||
print("\n📁 Analyzing test datasets...")
|
||||
data_manager = TestDataManager(args.data_dir)
|
||||
datasets = data_manager.print_dataset_analysis()
|
||||
|
||||
if args.create_sample_data:
|
||||
print("\n📝 Creating sample test data...")
|
||||
data_manager = TestDataManager()
|
||||
sample_paths = data_manager.create_sample_test_data(args.output_dir)
|
||||
|
||||
print("\n💡 Use these sample datasets for testing:")
|
||||
print(f" --test_datasets {sample_paths['retriever_test']} {sample_paths['reranker_test']}")
|
||||
|
||||
if args.analyze_results:
|
||||
print(f"\n📊 Analyzing results from: {args.analyze_results}")
|
||||
analyzer = ValidationReportAnalyzer(args.analyze_results)
|
||||
analysis = analyzer.print_analysis()
|
||||
|
||||
if not any([args.discover_models, args.analyze_datasets, args.create_sample_data, args.analyze_results]):
|
||||
print("❓ No action specified. Use --help to see available options.")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user