bge_finetune/scripts/quick_train_test.py
2025-07-23 14:54:46 +08:00

207 lines
7.9 KiB
Python

#!/usr/bin/env python3
"""
Quick Training Test Script for BGE Models
This script performs a quick training test on small data subsets to verify
that fine-tuning is working before committing to full training.
Usage:
python scripts/quick_train_test.py \
--data_dir "data/datasets/三国演义/splits" \
--output_dir "./output/quick_test" \
--samples_per_model 1000
"""
import os
import sys
import json
import argparse
import subprocess
from pathlib import Path
import logging
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def create_subset(input_file: str, output_file: str, num_samples: int):
"""Create a subset of the dataset for quick testing"""
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(input_file, 'r', encoding='utf-8') as f_in:
with open(output_file, 'w', encoding='utf-8') as f_out:
for i, line in enumerate(f_in):
if i >= num_samples:
break
f_out.write(line)
logger.info(f"Created subset: {output_file} with {num_samples} samples")
def run_command(cmd: list, description: str):
"""Run a command and handle errors"""
logger.info(f"🚀 {description}")
logger.info(f"Command: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info(f"{description} completed successfully")
return True
except subprocess.CalledProcessError as e:
logger.error(f"{description} failed:")
logger.error(f"Error: {e.stderr}")
return False
def main():
parser = argparse.ArgumentParser(description="Quick training test for BGE models")
parser.add_argument("--data_dir", type=str, required=True, help="Directory with split datasets")
parser.add_argument("--output_dir", type=str, default="./output/quick_test", help="Output directory")
parser.add_argument("--samples_per_model", type=int, default=1000, help="Number of samples for quick test")
parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training")
args = parser.parse_args()
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("🏃‍♂️ Starting Quick Training Test")
logger.info(f"📁 Data directory: {data_dir}")
logger.info(f"📁 Output directory: {output_dir}")
logger.info(f"🔢 Samples per model: {args.samples_per_model}")
# Create subset directories
subset_dir = output_dir / "subsets"
subset_dir.mkdir(parents=True, exist_ok=True)
# Create BGE-M3 subsets
m3_train_subset = subset_dir / "m3_train_subset.jsonl"
m3_val_subset = subset_dir / "m3_val_subset.jsonl"
if (data_dir / "m3_train.jsonl").exists():
create_subset(str(data_dir / "m3_train.jsonl"), str(m3_train_subset), args.samples_per_model)
create_subset(str(data_dir / "m3_val.jsonl"), str(m3_val_subset), args.samples_per_model // 10)
# Create Reranker subsets
reranker_train_subset = subset_dir / "reranker_train_subset.jsonl"
reranker_val_subset = subset_dir / "reranker_val_subset.jsonl"
if (data_dir / "reranker_train.jsonl").exists():
create_subset(str(data_dir / "reranker_train.jsonl"), str(reranker_train_subset), args.samples_per_model)
create_subset(str(data_dir / "reranker_val.jsonl"), str(reranker_val_subset), args.samples_per_model // 10)
# Test BGE-M3 training
if m3_train_subset.exists():
logger.info("\n" + "="*60)
logger.info("🔍 Testing BGE-M3 Training")
logger.info("="*60)
m3_output = output_dir / "bge_m3_test"
m3_cmd = [
"python", "scripts/train_m3.py",
"--train_data", str(m3_train_subset),
"--eval_data", str(m3_val_subset),
"--output_dir", str(m3_output),
"--num_train_epochs", str(args.epochs),
"--per_device_train_batch_size", str(args.batch_size),
"--per_device_eval_batch_size", str(args.batch_size * 2),
"--learning_rate", "1e-5",
"--save_steps", "100",
"--eval_steps", "100",
"--logging_steps", "50",
"--warmup_ratio", "0.1",
"--gradient_accumulation_steps", "2"
]
m3_success = run_command(m3_cmd, "BGE-M3 Quick Training")
else:
m3_success = False
logger.warning("BGE-M3 training data not found")
# Test Reranker training
if reranker_train_subset.exists():
logger.info("\n" + "="*60)
logger.info("🏆 Testing Reranker Training")
logger.info("="*60)
reranker_output = output_dir / "reranker_test"
reranker_cmd = [
"python", "scripts/train_reranker.py",
"--reranker_data", str(reranker_train_subset),
"--eval_data", str(reranker_val_subset),
"--output_dir", str(reranker_output),
"--num_train_epochs", str(args.epochs),
"--per_device_train_batch_size", str(args.batch_size),
"--per_device_eval_batch_size", str(args.batch_size * 2),
"--learning_rate", "6e-5",
"--save_steps", "100",
"--eval_steps", "100",
"--logging_steps", "50",
"--warmup_ratio", "0.1"
]
reranker_success = run_command(reranker_cmd, "Reranker Quick Training")
else:
reranker_success = False
logger.warning("Reranker training data not found")
# Quick validation
logger.info("\n" + "="*60)
logger.info("🧪 Running Quick Validation")
logger.info("="*60)
validation_results = {}
# Validate BGE-M3 if trained
if m3_success and (m3_output / "final_model").exists():
m3_val_cmd = [
"python", "scripts/quick_validation.py",
"--retriever_model", str(m3_output / "final_model")
]
validation_results['m3_validation'] = run_command(m3_val_cmd, "BGE-M3 Validation")
# Validate Reranker if trained
if reranker_success and (reranker_output / "final_model").exists():
reranker_val_cmd = [
"python", "scripts/quick_validation.py",
"--reranker_model", str(reranker_output / "final_model")
]
validation_results['reranker_validation'] = run_command(reranker_val_cmd, "Reranker Validation")
# Summary
logger.info("\n" + "="*60)
logger.info("📋 Quick Training Test Summary")
logger.info("="*60)
summary = {
"m3_training": m3_success,
"reranker_training": reranker_success,
"validation_results": validation_results,
"recommendation": ""
}
if m3_success and reranker_success:
summary["recommendation"] = "✅ Both models trained successfully! Ready for full training."
logger.info("✅ Both models trained successfully!")
logger.info("🎯 You can now proceed with full training using the complete datasets.")
elif m3_success or reranker_success:
summary["recommendation"] = "⚠️ Partial success. Check failed model configuration."
logger.info("⚠️ Partial success. One model failed - check logs above.")
else:
summary["recommendation"] = "❌ Training failed. Check configuration and data."
logger.info("❌ Training tests failed. Check your configuration and data format.")
# Save summary
with open(output_dir / "quick_test_summary.json", 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
logger.info(f"📋 Summary saved to {output_dir / 'quick_test_summary.json'}")
if __name__ == "__main__":
main()