#!/usr/bin/env python3 """ Quick Training Test Script for BGE Models This script performs a quick training test on small data subsets to verify that fine-tuning is working before committing to full training. Usage: python scripts/quick_train_test.py \ --data_dir "data/datasets/δΈ‰ε›½ζΌ”δΉ‰/splits" \ --output_dir "./output/quick_test" \ --samples_per_model 1000 """ import os import sys import json import argparse import subprocess from pathlib import Path import logging # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def create_subset(input_file: str, output_file: str, num_samples: int): """Create a subset of the dataset for quick testing""" os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(input_file, 'r', encoding='utf-8') as f_in: with open(output_file, 'w', encoding='utf-8') as f_out: for i, line in enumerate(f_in): if i >= num_samples: break f_out.write(line) logger.info(f"Created subset: {output_file} with {num_samples} samples") def run_command(cmd: list, description: str): """Run a command and handle errors""" logger.info(f"πŸš€ {description}") logger.info(f"Command: {' '.join(cmd)}") try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info(f"βœ… {description} completed successfully") return True except subprocess.CalledProcessError as e: logger.error(f"❌ {description} failed:") logger.error(f"Error: {e.stderr}") return False def main(): parser = argparse.ArgumentParser(description="Quick training test for BGE models") parser.add_argument("--data_dir", type=str, required=True, help="Directory with split datasets") parser.add_argument("--output_dir", type=str, default="./output/quick_test", help="Output directory") parser.add_argument("--samples_per_model", type=int, default=1000, help="Number of samples for quick test") parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training") args = parser.parse_args() data_dir = Path(args.data_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info("πŸƒβ€β™‚οΈ Starting Quick Training Test") logger.info(f"πŸ“ Data directory: {data_dir}") logger.info(f"πŸ“ Output directory: {output_dir}") logger.info(f"πŸ”’ Samples per model: {args.samples_per_model}") # Create subset directories subset_dir = output_dir / "subsets" subset_dir.mkdir(parents=True, exist_ok=True) # Create BGE-M3 subsets m3_train_subset = subset_dir / "m3_train_subset.jsonl" m3_val_subset = subset_dir / "m3_val_subset.jsonl" if (data_dir / "m3_train.jsonl").exists(): create_subset(str(data_dir / "m3_train.jsonl"), str(m3_train_subset), args.samples_per_model) create_subset(str(data_dir / "m3_val.jsonl"), str(m3_val_subset), args.samples_per_model // 10) # Create Reranker subsets reranker_train_subset = subset_dir / "reranker_train_subset.jsonl" reranker_val_subset = subset_dir / "reranker_val_subset.jsonl" if (data_dir / "reranker_train.jsonl").exists(): create_subset(str(data_dir / "reranker_train.jsonl"), str(reranker_train_subset), args.samples_per_model) create_subset(str(data_dir / "reranker_val.jsonl"), str(reranker_val_subset), args.samples_per_model // 10) # Test BGE-M3 training if m3_train_subset.exists(): logger.info("\n" + "="*60) logger.info("πŸ” Testing BGE-M3 Training") logger.info("="*60) m3_output = output_dir / "bge_m3_test" m3_cmd = [ "python", "scripts/train_m3.py", "--train_data", str(m3_train_subset), "--eval_data", str(m3_val_subset), "--output_dir", str(m3_output), "--num_train_epochs", str(args.epochs), "--per_device_train_batch_size", str(args.batch_size), "--per_device_eval_batch_size", str(args.batch_size * 2), "--learning_rate", "1e-5", "--save_steps", "100", "--eval_steps", "100", "--logging_steps", "50", "--warmup_ratio", "0.1", "--gradient_accumulation_steps", "2" ] m3_success = run_command(m3_cmd, "BGE-M3 Quick Training") else: m3_success = False logger.warning("BGE-M3 training data not found") # Test Reranker training if reranker_train_subset.exists(): logger.info("\n" + "="*60) logger.info("πŸ† Testing Reranker Training") logger.info("="*60) reranker_output = output_dir / "reranker_test" reranker_cmd = [ "python", "scripts/train_reranker.py", "--reranker_data", str(reranker_train_subset), "--eval_data", str(reranker_val_subset), "--output_dir", str(reranker_output), "--num_train_epochs", str(args.epochs), "--per_device_train_batch_size", str(args.batch_size), "--per_device_eval_batch_size", str(args.batch_size * 2), "--learning_rate", "6e-5", "--save_steps", "100", "--eval_steps", "100", "--logging_steps", "50", "--warmup_ratio", "0.1" ] reranker_success = run_command(reranker_cmd, "Reranker Quick Training") else: reranker_success = False logger.warning("Reranker training data not found") # Quick validation logger.info("\n" + "="*60) logger.info("πŸ§ͺ Running Quick Validation") logger.info("="*60) validation_results = {} # Validate BGE-M3 if trained if m3_success and (m3_output / "final_model").exists(): m3_val_cmd = [ "python", "scripts/quick_validation.py", "--retriever_model", str(m3_output / "final_model") ] validation_results['m3_validation'] = run_command(m3_val_cmd, "BGE-M3 Validation") # Validate Reranker if trained if reranker_success and (reranker_output / "final_model").exists(): reranker_val_cmd = [ "python", "scripts/quick_validation.py", "--reranker_model", str(reranker_output / "final_model") ] validation_results['reranker_validation'] = run_command(reranker_val_cmd, "Reranker Validation") # Summary logger.info("\n" + "="*60) logger.info("πŸ“‹ Quick Training Test Summary") logger.info("="*60) summary = { "m3_training": m3_success, "reranker_training": reranker_success, "validation_results": validation_results, "recommendation": "" } if m3_success and reranker_success: summary["recommendation"] = "βœ… Both models trained successfully! Ready for full training." logger.info("βœ… Both models trained successfully!") logger.info("🎯 You can now proceed with full training using the complete datasets.") elif m3_success or reranker_success: summary["recommendation"] = "⚠️ Partial success. Check failed model configuration." logger.info("⚠️ Partial success. One model failed - check logs above.") else: summary["recommendation"] = "❌ Training failed. Check configuration and data." logger.info("❌ Training tests failed. Check your configuration and data format.") # Save summary with open(output_dir / "quick_test_summary.json", 'w', encoding='utf-8') as f: json.dump(summary, f, indent=2, ensure_ascii=False) logger.info(f"πŸ“‹ Summary saved to {output_dir / 'quick_test_summary.json'}") if __name__ == "__main__": main()