207 lines
7.9 KiB
Python
207 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick Training Test Script for BGE Models
|
|
|
|
This script performs a quick training test on small data subsets to verify
|
|
that fine-tuning is working before committing to full training.
|
|
|
|
Usage:
|
|
python scripts/quick_train_test.py \
|
|
--data_dir "data/datasets/三国演义/splits" \
|
|
--output_dir "./output/quick_test" \
|
|
--samples_per_model 1000
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import argparse
|
|
import subprocess
|
|
from pathlib import Path
|
|
import logging
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def create_subset(input_file: str, output_file: str, num_samples: int):
|
|
"""Create a subset of the dataset for quick testing"""
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
|
|
with open(input_file, 'r', encoding='utf-8') as f_in:
|
|
with open(output_file, 'w', encoding='utf-8') as f_out:
|
|
for i, line in enumerate(f_in):
|
|
if i >= num_samples:
|
|
break
|
|
f_out.write(line)
|
|
|
|
logger.info(f"Created subset: {output_file} with {num_samples} samples")
|
|
|
|
|
|
def run_command(cmd: list, description: str):
|
|
"""Run a command and handle errors"""
|
|
logger.info(f"🚀 {description}")
|
|
logger.info(f"Command: {' '.join(cmd)}")
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
logger.info(f"✅ {description} completed successfully")
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"❌ {description} failed:")
|
|
logger.error(f"Error: {e.stderr}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Quick training test for BGE models")
|
|
parser.add_argument("--data_dir", type=str, required=True, help="Directory with split datasets")
|
|
parser.add_argument("--output_dir", type=str, default="./output/quick_test", help="Output directory")
|
|
parser.add_argument("--samples_per_model", type=int, default=1000, help="Number of samples for quick test")
|
|
parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
|
|
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training")
|
|
|
|
args = parser.parse_args()
|
|
|
|
data_dir = Path(args.data_dir)
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info("🏃♂️ Starting Quick Training Test")
|
|
logger.info(f"📁 Data directory: {data_dir}")
|
|
logger.info(f"📁 Output directory: {output_dir}")
|
|
logger.info(f"🔢 Samples per model: {args.samples_per_model}")
|
|
|
|
# Create subset directories
|
|
subset_dir = output_dir / "subsets"
|
|
subset_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create BGE-M3 subsets
|
|
m3_train_subset = subset_dir / "m3_train_subset.jsonl"
|
|
m3_val_subset = subset_dir / "m3_val_subset.jsonl"
|
|
|
|
if (data_dir / "m3_train.jsonl").exists():
|
|
create_subset(str(data_dir / "m3_train.jsonl"), str(m3_train_subset), args.samples_per_model)
|
|
create_subset(str(data_dir / "m3_val.jsonl"), str(m3_val_subset), args.samples_per_model // 10)
|
|
|
|
# Create Reranker subsets
|
|
reranker_train_subset = subset_dir / "reranker_train_subset.jsonl"
|
|
reranker_val_subset = subset_dir / "reranker_val_subset.jsonl"
|
|
|
|
if (data_dir / "reranker_train.jsonl").exists():
|
|
create_subset(str(data_dir / "reranker_train.jsonl"), str(reranker_train_subset), args.samples_per_model)
|
|
create_subset(str(data_dir / "reranker_val.jsonl"), str(reranker_val_subset), args.samples_per_model // 10)
|
|
|
|
# Test BGE-M3 training
|
|
if m3_train_subset.exists():
|
|
logger.info("\n" + "="*60)
|
|
logger.info("🔍 Testing BGE-M3 Training")
|
|
logger.info("="*60)
|
|
|
|
m3_output = output_dir / "bge_m3_test"
|
|
m3_cmd = [
|
|
"python", "scripts/train_m3.py",
|
|
"--train_data", str(m3_train_subset),
|
|
"--eval_data", str(m3_val_subset),
|
|
"--output_dir", str(m3_output),
|
|
"--num_train_epochs", str(args.epochs),
|
|
"--per_device_train_batch_size", str(args.batch_size),
|
|
"--per_device_eval_batch_size", str(args.batch_size * 2),
|
|
"--learning_rate", "1e-5",
|
|
"--save_steps", "100",
|
|
"--eval_steps", "100",
|
|
"--logging_steps", "50",
|
|
"--warmup_ratio", "0.1",
|
|
"--gradient_accumulation_steps", "2"
|
|
]
|
|
|
|
m3_success = run_command(m3_cmd, "BGE-M3 Quick Training")
|
|
else:
|
|
m3_success = False
|
|
logger.warning("BGE-M3 training data not found")
|
|
|
|
# Test Reranker training
|
|
if reranker_train_subset.exists():
|
|
logger.info("\n" + "="*60)
|
|
logger.info("🏆 Testing Reranker Training")
|
|
logger.info("="*60)
|
|
|
|
reranker_output = output_dir / "reranker_test"
|
|
reranker_cmd = [
|
|
"python", "scripts/train_reranker.py",
|
|
"--reranker_data", str(reranker_train_subset),
|
|
"--eval_data", str(reranker_val_subset),
|
|
"--output_dir", str(reranker_output),
|
|
"--num_train_epochs", str(args.epochs),
|
|
"--per_device_train_batch_size", str(args.batch_size),
|
|
"--per_device_eval_batch_size", str(args.batch_size * 2),
|
|
"--learning_rate", "6e-5",
|
|
"--save_steps", "100",
|
|
"--eval_steps", "100",
|
|
"--logging_steps", "50",
|
|
"--warmup_ratio", "0.1"
|
|
]
|
|
|
|
reranker_success = run_command(reranker_cmd, "Reranker Quick Training")
|
|
else:
|
|
reranker_success = False
|
|
logger.warning("Reranker training data not found")
|
|
|
|
# Quick validation
|
|
logger.info("\n" + "="*60)
|
|
logger.info("🧪 Running Quick Validation")
|
|
logger.info("="*60)
|
|
|
|
validation_results = {}
|
|
|
|
# Validate BGE-M3 if trained
|
|
if m3_success and (m3_output / "final_model").exists():
|
|
m3_val_cmd = [
|
|
"python", "scripts/quick_validation.py",
|
|
"--retriever_model", str(m3_output / "final_model")
|
|
]
|
|
validation_results['m3_validation'] = run_command(m3_val_cmd, "BGE-M3 Validation")
|
|
|
|
# Validate Reranker if trained
|
|
if reranker_success and (reranker_output / "final_model").exists():
|
|
reranker_val_cmd = [
|
|
"python", "scripts/quick_validation.py",
|
|
"--reranker_model", str(reranker_output / "final_model")
|
|
]
|
|
validation_results['reranker_validation'] = run_command(reranker_val_cmd, "Reranker Validation")
|
|
|
|
# Summary
|
|
logger.info("\n" + "="*60)
|
|
logger.info("📋 Quick Training Test Summary")
|
|
logger.info("="*60)
|
|
|
|
summary = {
|
|
"m3_training": m3_success,
|
|
"reranker_training": reranker_success,
|
|
"validation_results": validation_results,
|
|
"recommendation": ""
|
|
}
|
|
|
|
if m3_success and reranker_success:
|
|
summary["recommendation"] = "✅ Both models trained successfully! Ready for full training."
|
|
logger.info("✅ Both models trained successfully!")
|
|
logger.info("🎯 You can now proceed with full training using the complete datasets.")
|
|
elif m3_success or reranker_success:
|
|
summary["recommendation"] = "⚠️ Partial success. Check failed model configuration."
|
|
logger.info("⚠️ Partial success. One model failed - check logs above.")
|
|
else:
|
|
summary["recommendation"] = "❌ Training failed. Check configuration and data."
|
|
logger.info("❌ Training tests failed. Check your configuration and data format.")
|
|
|
|
# Save summary
|
|
with open(output_dir / "quick_test_summary.json", 'w', encoding='utf-8') as f:
|
|
json.dump(summary, f, indent=2, ensure_ascii=False)
|
|
|
|
logger.info(f"📋 Summary saved to {output_dir / 'quick_test_summary.json'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |