#!/usr/bin/env python3 """ Dataset Splitting Script for BGE Fine-tuning This script splits large datasets into train/validation/test sets while preserving data distribution and enabling proper evaluation. Usage: python scripts/split_datasets.py \ --input_dir "data/datasets/三国演义" \ --output_dir "data/datasets/三国演义/splits" \ --train_ratio 0.8 \ --val_ratio 0.1 \ --test_ratio 0.1 """ import os import sys import json import argparse import random from pathlib import Path from typing import List, Dict, Tuple import logging # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def load_jsonl(file_path: str) -> List[Dict]: """Load JSONL file""" data = [] with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): try: item = json.loads(line.strip()) data.append(item) except json.JSONDecodeError as e: logger.warning(f"Skipping invalid JSON at line {line_num}: {e}") continue return data def save_jsonl(data: List[Dict], file_path: str): """Save data to JSONL file""" os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'w', encoding='utf-8') as f: for item in data: f.write(json.dumps(item, ensure_ascii=False) + '\n') def split_dataset(data: List[Dict], train_ratio: float, val_ratio: float, test_ratio: float, seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]: """Split dataset into train/validation/test sets""" # Validate ratios total_ratio = train_ratio + val_ratio + test_ratio if abs(total_ratio - 1.0) > 1e-6: raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}") # Shuffle data random.seed(seed) data_shuffled = data.copy() random.shuffle(data_shuffled) # Calculate split indices total_size = len(data_shuffled) train_size = int(total_size * train_ratio) val_size = int(total_size * val_ratio) # Split data train_data = data_shuffled[:train_size] val_data = data_shuffled[train_size:train_size + val_size] test_data = data_shuffled[train_size + val_size:] return train_data, val_data, test_data def main(): parser = argparse.ArgumentParser(description="Split BGE datasets for training") parser.add_argument("--input_dir", type=str, required=True, help="Directory containing m3.jsonl and reranker.jsonl") parser.add_argument("--output_dir", type=str, required=True, help="Output directory for split datasets") parser.add_argument("--train_ratio", type=float, default=0.8, help="Training set ratio") parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio") parser.add_argument("--test_ratio", type=float, default=0.1, help="Test set ratio") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") args = parser.parse_args() input_dir = Path(args.input_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"🔄 Splitting datasets from {input_dir}") logger.info(f"📁 Output directory: {output_dir}") logger.info(f"📊 Split ratios - Train: {args.train_ratio}, Val: {args.val_ratio}, Test: {args.test_ratio}") # Process BGE-M3 dataset m3_file = input_dir / "m3.jsonl" if m3_file.exists(): logger.info("Processing BGE-M3 dataset...") m3_data = load_jsonl(str(m3_file)) logger.info(f"Loaded {len(m3_data)} BGE-M3 samples") train_m3, val_m3, test_m3 = split_dataset( m3_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed ) # Save splits save_jsonl(train_m3, str(output_dir / "m3_train.jsonl")) save_jsonl(val_m3, str(output_dir / "m3_val.jsonl")) save_jsonl(test_m3, str(output_dir / "m3_test.jsonl")) logger.info(f"✅ BGE-M3 splits saved: Train={len(train_m3)}, Val={len(val_m3)}, Test={len(test_m3)}") else: logger.warning(f"BGE-M3 file not found: {m3_file}") # Process Reranker dataset reranker_file = input_dir / "reranker.jsonl" if reranker_file.exists(): logger.info("Processing Reranker dataset...") reranker_data = load_jsonl(str(reranker_file)) logger.info(f"Loaded {len(reranker_data)} reranker samples") train_reranker, val_reranker, test_reranker = split_dataset( reranker_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed ) # Save splits save_jsonl(train_reranker, str(output_dir / "reranker_train.jsonl")) save_jsonl(val_reranker, str(output_dir / "reranker_val.jsonl")) save_jsonl(test_reranker, str(output_dir / "reranker_test.jsonl")) logger.info(f"✅ Reranker splits saved: Train={len(train_reranker)}, Val={len(val_reranker)}, Test={len(test_reranker)}") else: logger.warning(f"Reranker file not found: {reranker_file}") # Create summary summary = { "input_dir": str(input_dir), "output_dir": str(output_dir), "split_ratios": { "train": args.train_ratio, "validation": args.val_ratio, "test": args.test_ratio }, "seed": args.seed, "datasets": {} } if m3_file.exists(): summary["datasets"]["bge_m3"] = { "total_samples": len(m3_data), "train_samples": len(train_m3), "val_samples": len(val_m3), "test_samples": len(test_m3) } if reranker_file.exists(): summary["datasets"]["reranker"] = { "total_samples": len(reranker_data), "train_samples": len(train_reranker), "val_samples": len(val_reranker), "test_samples": len(test_reranker) } # Save summary with open(output_dir / "split_summary.json", 'w', encoding='utf-8') as f: json.dump(summary, f, indent=2, ensure_ascii=False) logger.info(f"📋 Split summary saved to {output_dir / 'split_summary.json'}") logger.info("🎉 Dataset splitting completed successfully!") if __name__ == "__main__": main()