177 lines
6.4 KiB
Python
177 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Dataset Splitting Script for BGE Fine-tuning
|
|
|
|
This script splits large datasets into train/validation/test sets while preserving
|
|
data distribution and enabling proper evaluation.
|
|
|
|
Usage:
|
|
python scripts/split_datasets.py \
|
|
--input_dir "data/datasets/三国演义" \
|
|
--output_dir "data/datasets/三国演义/splits" \
|
|
--train_ratio 0.8 \
|
|
--val_ratio 0.1 \
|
|
--test_ratio 0.1
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import argparse
|
|
import random
|
|
from pathlib import Path
|
|
from typing import List, Dict, Tuple
|
|
import logging
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_jsonl(file_path: str) -> List[Dict]:
|
|
"""Load JSONL file"""
|
|
data = []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line_num, line in enumerate(f, 1):
|
|
try:
|
|
item = json.loads(line.strip())
|
|
data.append(item)
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Skipping invalid JSON at line {line_num}: {e}")
|
|
continue
|
|
return data
|
|
|
|
|
|
def save_jsonl(data: List[Dict], file_path: str):
|
|
"""Save data to JSONL file"""
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|
for item in data:
|
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
|
|
|
|
def split_dataset(data: List[Dict], train_ratio: float, val_ratio: float, test_ratio: float, seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
|
"""Split dataset into train/validation/test sets"""
|
|
# Validate ratios
|
|
total_ratio = train_ratio + val_ratio + test_ratio
|
|
if abs(total_ratio - 1.0) > 1e-6:
|
|
raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")
|
|
|
|
# Shuffle data
|
|
random.seed(seed)
|
|
data_shuffled = data.copy()
|
|
random.shuffle(data_shuffled)
|
|
|
|
# Calculate split indices
|
|
total_size = len(data_shuffled)
|
|
train_size = int(total_size * train_ratio)
|
|
val_size = int(total_size * val_ratio)
|
|
|
|
# Split data
|
|
train_data = data_shuffled[:train_size]
|
|
val_data = data_shuffled[train_size:train_size + val_size]
|
|
test_data = data_shuffled[train_size + val_size:]
|
|
|
|
return train_data, val_data, test_data
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Split BGE datasets for training")
|
|
parser.add_argument("--input_dir", type=str, required=True, help="Directory containing m3.jsonl and reranker.jsonl")
|
|
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for split datasets")
|
|
parser.add_argument("--train_ratio", type=float, default=0.8, help="Training set ratio")
|
|
parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio")
|
|
parser.add_argument("--test_ratio", type=float, default=0.1, help="Test set ratio")
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
|
|
|
args = parser.parse_args()
|
|
|
|
input_dir = Path(args.input_dir)
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"🔄 Splitting datasets from {input_dir}")
|
|
logger.info(f"📁 Output directory: {output_dir}")
|
|
logger.info(f"📊 Split ratios - Train: {args.train_ratio}, Val: {args.val_ratio}, Test: {args.test_ratio}")
|
|
|
|
# Process BGE-M3 dataset
|
|
m3_file = input_dir / "m3.jsonl"
|
|
if m3_file.exists():
|
|
logger.info("Processing BGE-M3 dataset...")
|
|
m3_data = load_jsonl(str(m3_file))
|
|
logger.info(f"Loaded {len(m3_data)} BGE-M3 samples")
|
|
|
|
train_m3, val_m3, test_m3 = split_dataset(
|
|
m3_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
|
|
)
|
|
|
|
# Save splits
|
|
save_jsonl(train_m3, str(output_dir / "m3_train.jsonl"))
|
|
save_jsonl(val_m3, str(output_dir / "m3_val.jsonl"))
|
|
save_jsonl(test_m3, str(output_dir / "m3_test.jsonl"))
|
|
|
|
logger.info(f"✅ BGE-M3 splits saved: Train={len(train_m3)}, Val={len(val_m3)}, Test={len(test_m3)}")
|
|
else:
|
|
logger.warning(f"BGE-M3 file not found: {m3_file}")
|
|
|
|
# Process Reranker dataset
|
|
reranker_file = input_dir / "reranker.jsonl"
|
|
if reranker_file.exists():
|
|
logger.info("Processing Reranker dataset...")
|
|
reranker_data = load_jsonl(str(reranker_file))
|
|
logger.info(f"Loaded {len(reranker_data)} reranker samples")
|
|
|
|
train_reranker, val_reranker, test_reranker = split_dataset(
|
|
reranker_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
|
|
)
|
|
|
|
# Save splits
|
|
save_jsonl(train_reranker, str(output_dir / "reranker_train.jsonl"))
|
|
save_jsonl(val_reranker, str(output_dir / "reranker_val.jsonl"))
|
|
save_jsonl(test_reranker, str(output_dir / "reranker_test.jsonl"))
|
|
|
|
logger.info(f"✅ Reranker splits saved: Train={len(train_reranker)}, Val={len(val_reranker)}, Test={len(test_reranker)}")
|
|
else:
|
|
logger.warning(f"Reranker file not found: {reranker_file}")
|
|
|
|
# Create summary
|
|
summary = {
|
|
"input_dir": str(input_dir),
|
|
"output_dir": str(output_dir),
|
|
"split_ratios": {
|
|
"train": args.train_ratio,
|
|
"validation": args.val_ratio,
|
|
"test": args.test_ratio
|
|
},
|
|
"seed": args.seed,
|
|
"datasets": {}
|
|
}
|
|
|
|
if m3_file.exists():
|
|
summary["datasets"]["bge_m3"] = {
|
|
"total_samples": len(m3_data),
|
|
"train_samples": len(train_m3),
|
|
"val_samples": len(val_m3),
|
|
"test_samples": len(test_m3)
|
|
}
|
|
|
|
if reranker_file.exists():
|
|
summary["datasets"]["reranker"] = {
|
|
"total_samples": len(reranker_data),
|
|
"train_samples": len(train_reranker),
|
|
"val_samples": len(val_reranker),
|
|
"test_samples": len(test_reranker)
|
|
}
|
|
|
|
# Save summary
|
|
with open(output_dir / "split_summary.json", 'w', encoding='utf-8') as f:
|
|
json.dump(summary, f, indent=2, ensure_ascii=False)
|
|
|
|
logger.info(f"📋 Split summary saved to {output_dir / 'split_summary.json'}")
|
|
logger.info("🎉 Dataset splitting completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |