bge_finetune/scripts/split_datasets.py
2025-07-23 14:54:46 +08:00

177 lines
6.4 KiB
Python

#!/usr/bin/env python3
"""
Dataset Splitting Script for BGE Fine-tuning
This script splits large datasets into train/validation/test sets while preserving
data distribution and enabling proper evaluation.
Usage:
python scripts/split_datasets.py \
--input_dir "data/datasets/三国演义" \
--output_dir "data/datasets/三国演义/splits" \
--train_ratio 0.8 \
--val_ratio 0.1 \
--test_ratio 0.1
"""
import os
import sys
import json
import argparse
import random
from pathlib import Path
from typing import List, Dict, Tuple
import logging
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_jsonl(file_path: str) -> List[Dict]:
"""Load JSONL file"""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
try:
item = json.loads(line.strip())
data.append(item)
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON at line {line_num}: {e}")
continue
return data
def save_jsonl(data: List[Dict], file_path: str):
"""Save data to JSONL file"""
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
def split_dataset(data: List[Dict], train_ratio: float, val_ratio: float, test_ratio: float, seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""Split dataset into train/validation/test sets"""
# Validate ratios
total_ratio = train_ratio + val_ratio + test_ratio
if abs(total_ratio - 1.0) > 1e-6:
raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")
# Shuffle data
random.seed(seed)
data_shuffled = data.copy()
random.shuffle(data_shuffled)
# Calculate split indices
total_size = len(data_shuffled)
train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)
# Split data
train_data = data_shuffled[:train_size]
val_data = data_shuffled[train_size:train_size + val_size]
test_data = data_shuffled[train_size + val_size:]
return train_data, val_data, test_data
def main():
parser = argparse.ArgumentParser(description="Split BGE datasets for training")
parser.add_argument("--input_dir", type=str, required=True, help="Directory containing m3.jsonl and reranker.jsonl")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for split datasets")
parser.add_argument("--train_ratio", type=float, default=0.8, help="Training set ratio")
parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio")
parser.add_argument("--test_ratio", type=float, default=0.1, help="Test set ratio")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"🔄 Splitting datasets from {input_dir}")
logger.info(f"📁 Output directory: {output_dir}")
logger.info(f"📊 Split ratios - Train: {args.train_ratio}, Val: {args.val_ratio}, Test: {args.test_ratio}")
# Process BGE-M3 dataset
m3_file = input_dir / "m3.jsonl"
if m3_file.exists():
logger.info("Processing BGE-M3 dataset...")
m3_data = load_jsonl(str(m3_file))
logger.info(f"Loaded {len(m3_data)} BGE-M3 samples")
train_m3, val_m3, test_m3 = split_dataset(
m3_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
)
# Save splits
save_jsonl(train_m3, str(output_dir / "m3_train.jsonl"))
save_jsonl(val_m3, str(output_dir / "m3_val.jsonl"))
save_jsonl(test_m3, str(output_dir / "m3_test.jsonl"))
logger.info(f"✅ BGE-M3 splits saved: Train={len(train_m3)}, Val={len(val_m3)}, Test={len(test_m3)}")
else:
logger.warning(f"BGE-M3 file not found: {m3_file}")
# Process Reranker dataset
reranker_file = input_dir / "reranker.jsonl"
if reranker_file.exists():
logger.info("Processing Reranker dataset...")
reranker_data = load_jsonl(str(reranker_file))
logger.info(f"Loaded {len(reranker_data)} reranker samples")
train_reranker, val_reranker, test_reranker = split_dataset(
reranker_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
)
# Save splits
save_jsonl(train_reranker, str(output_dir / "reranker_train.jsonl"))
save_jsonl(val_reranker, str(output_dir / "reranker_val.jsonl"))
save_jsonl(test_reranker, str(output_dir / "reranker_test.jsonl"))
logger.info(f"✅ Reranker splits saved: Train={len(train_reranker)}, Val={len(val_reranker)}, Test={len(test_reranker)}")
else:
logger.warning(f"Reranker file not found: {reranker_file}")
# Create summary
summary = {
"input_dir": str(input_dir),
"output_dir": str(output_dir),
"split_ratios": {
"train": args.train_ratio,
"validation": args.val_ratio,
"test": args.test_ratio
},
"seed": args.seed,
"datasets": {}
}
if m3_file.exists():
summary["datasets"]["bge_m3"] = {
"total_samples": len(m3_data),
"train_samples": len(train_m3),
"val_samples": len(val_m3),
"test_samples": len(test_m3)
}
if reranker_file.exists():
summary["datasets"]["reranker"] = {
"total_samples": len(reranker_data),
"train_samples": len(train_reranker),
"val_samples": len(val_reranker),
"test_samples": len(test_reranker)
}
# Save summary
with open(output_dir / "split_summary.json", 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
logger.info(f"📋 Split summary saved to {output_dir / 'split_summary.json'}")
logger.info("🎉 Dataset splitting completed successfully!")
if __name__ == "__main__":
main()