Revised for training
This commit is contained in:
177
scripts/split_datasets.py
Normal file
177
scripts/split_datasets.py
Normal file
@@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Dataset Splitting Script for BGE Fine-tuning
|
||||
|
||||
This script splits large datasets into train/validation/test sets while preserving
|
||||
data distribution and enabling proper evaluation.
|
||||
|
||||
Usage:
|
||||
python scripts/split_datasets.py \
|
||||
--input_dir "data/datasets/三国演义" \
|
||||
--output_dir "data/datasets/三国演义/splits" \
|
||||
--train_ratio 0.8 \
|
||||
--val_ratio 0.1 \
|
||||
--test_ratio 0.1
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_jsonl(file_path: str) -> List[Dict]:
|
||||
"""Load JSONL file"""
|
||||
data = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
try:
|
||||
item = json.loads(line.strip())
|
||||
data.append(item)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON at line {line_num}: {e}")
|
||||
continue
|
||||
return data
|
||||
|
||||
|
||||
def save_jsonl(data: List[Dict], file_path: str):
|
||||
"""Save data to JSONL file"""
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
def split_dataset(data: List[Dict], train_ratio: float, val_ratio: float, test_ratio: float, seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
||||
"""Split dataset into train/validation/test sets"""
|
||||
# Validate ratios
|
||||
total_ratio = train_ratio + val_ratio + test_ratio
|
||||
if abs(total_ratio - 1.0) > 1e-6:
|
||||
raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")
|
||||
|
||||
# Shuffle data
|
||||
random.seed(seed)
|
||||
data_shuffled = data.copy()
|
||||
random.shuffle(data_shuffled)
|
||||
|
||||
# Calculate split indices
|
||||
total_size = len(data_shuffled)
|
||||
train_size = int(total_size * train_ratio)
|
||||
val_size = int(total_size * val_ratio)
|
||||
|
||||
# Split data
|
||||
train_data = data_shuffled[:train_size]
|
||||
val_data = data_shuffled[train_size:train_size + val_size]
|
||||
test_data = data_shuffled[train_size + val_size:]
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Split BGE datasets for training")
|
||||
parser.add_argument("--input_dir", type=str, required=True, help="Directory containing m3.jsonl and reranker.jsonl")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for split datasets")
|
||||
parser.add_argument("--train_ratio", type=float, default=0.8, help="Training set ratio")
|
||||
parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio")
|
||||
parser.add_argument("--test_ratio", type=float, default=0.1, help="Test set ratio")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"🔄 Splitting datasets from {input_dir}")
|
||||
logger.info(f"📁 Output directory: {output_dir}")
|
||||
logger.info(f"📊 Split ratios - Train: {args.train_ratio}, Val: {args.val_ratio}, Test: {args.test_ratio}")
|
||||
|
||||
# Process BGE-M3 dataset
|
||||
m3_file = input_dir / "m3.jsonl"
|
||||
if m3_file.exists():
|
||||
logger.info("Processing BGE-M3 dataset...")
|
||||
m3_data = load_jsonl(str(m3_file))
|
||||
logger.info(f"Loaded {len(m3_data)} BGE-M3 samples")
|
||||
|
||||
train_m3, val_m3, test_m3 = split_dataset(
|
||||
m3_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
|
||||
)
|
||||
|
||||
# Save splits
|
||||
save_jsonl(train_m3, str(output_dir / "m3_train.jsonl"))
|
||||
save_jsonl(val_m3, str(output_dir / "m3_val.jsonl"))
|
||||
save_jsonl(test_m3, str(output_dir / "m3_test.jsonl"))
|
||||
|
||||
logger.info(f"✅ BGE-M3 splits saved: Train={len(train_m3)}, Val={len(val_m3)}, Test={len(test_m3)}")
|
||||
else:
|
||||
logger.warning(f"BGE-M3 file not found: {m3_file}")
|
||||
|
||||
# Process Reranker dataset
|
||||
reranker_file = input_dir / "reranker.jsonl"
|
||||
if reranker_file.exists():
|
||||
logger.info("Processing Reranker dataset...")
|
||||
reranker_data = load_jsonl(str(reranker_file))
|
||||
logger.info(f"Loaded {len(reranker_data)} reranker samples")
|
||||
|
||||
train_reranker, val_reranker, test_reranker = split_dataset(
|
||||
reranker_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
|
||||
)
|
||||
|
||||
# Save splits
|
||||
save_jsonl(train_reranker, str(output_dir / "reranker_train.jsonl"))
|
||||
save_jsonl(val_reranker, str(output_dir / "reranker_val.jsonl"))
|
||||
save_jsonl(test_reranker, str(output_dir / "reranker_test.jsonl"))
|
||||
|
||||
logger.info(f"✅ Reranker splits saved: Train={len(train_reranker)}, Val={len(val_reranker)}, Test={len(test_reranker)}")
|
||||
else:
|
||||
logger.warning(f"Reranker file not found: {reranker_file}")
|
||||
|
||||
# Create summary
|
||||
summary = {
|
||||
"input_dir": str(input_dir),
|
||||
"output_dir": str(output_dir),
|
||||
"split_ratios": {
|
||||
"train": args.train_ratio,
|
||||
"validation": args.val_ratio,
|
||||
"test": args.test_ratio
|
||||
},
|
||||
"seed": args.seed,
|
||||
"datasets": {}
|
||||
}
|
||||
|
||||
if m3_file.exists():
|
||||
summary["datasets"]["bge_m3"] = {
|
||||
"total_samples": len(m3_data),
|
||||
"train_samples": len(train_m3),
|
||||
"val_samples": len(val_m3),
|
||||
"test_samples": len(test_m3)
|
||||
}
|
||||
|
||||
if reranker_file.exists():
|
||||
summary["datasets"]["reranker"] = {
|
||||
"total_samples": len(reranker_data),
|
||||
"train_samples": len(train_reranker),
|
||||
"val_samples": len(val_reranker),
|
||||
"test_samples": len(test_reranker)
|
||||
}
|
||||
|
||||
# Save summary
|
||||
with open(output_dir / "split_summary.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"📋 Split summary saved to {output_dir / 'split_summary.json'}")
|
||||
logger.info("🎉 Dataset splitting completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user