bge_finetune/scripts/convert_m3_dataset.py
2025-07-22 16:55:25 +08:00

291 lines
10 KiB
Python

#!/usr/bin/env python3
"""
BGE-M3 Dataset Format Converter
Converts the nested dataset format where training data is stored inside a "data" field
to the flat format expected by the BGE-M3 training pipeline.
Input format:
{
"query": "...",
"data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}"
}
Output format:
{
"query": "...",
"pos": [...],
"neg": [...],
"pos_scores": [...],
"neg_scores": [...],
"prompt": "...",
"type": "..."
}
"""
import json
import argparse
import logging
import os
from tqdm import tqdm
from typing import Dict, Any, Optional
import re
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]:
"""
Extract JSON data from the "data" field which contains a JSON string
wrapped in markdown code blocks.
Args:
data_field: String containing JSON data, possibly wrapped in markdown
Returns:
Parsed JSON dictionary or None if parsing fails
"""
try:
# Remove markdown code block markers if present
cleaned_data = data_field.strip()
# Remove ```json and ``` markers
if cleaned_data.startswith('```json'):
cleaned_data = cleaned_data[7:] # Remove ```json
elif cleaned_data.startswith('```'):
cleaned_data = cleaned_data[3:] # Remove ```
if cleaned_data.endswith('```'):
cleaned_data = cleaned_data[:-3] # Remove trailing ```
# Clean up any remaining whitespace
cleaned_data = cleaned_data.strip()
# Parse the JSON
return json.loads(cleaned_data)
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(f"Failed to parse JSON from data field: {e}")
logger.warning(f"Data content (first 200 chars): {data_field[:200]}")
return None
def convert_dataset_format(input_file: str, output_file: str) -> None:
"""
Convert dataset from nested format to flat format.
Args:
input_file: Path to input JSONL file with nested format
output_file: Path to output JSONL file with flat format
"""
logger.info(f"Converting dataset from {input_file} to {output_file}")
# Count total lines for progress bar
total_lines = 0
with open(input_file, 'r', encoding='utf-8') as f:
for _ in f:
total_lines += 1
converted_count = 0
error_count = 0
with open(input_file, 'r', encoding='utf-8') as infile, \
open(output_file, 'w', encoding='utf-8') as outfile:
for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1):
line = line.strip()
if not line:
continue
try:
# Parse the input JSON line
item = json.loads(line)
# Validate required fields
if 'query' not in item:
logger.warning(f"Line {line_num}: Missing 'query' field")
error_count += 1
continue
if 'data' not in item:
logger.warning(f"Line {line_num}: Missing 'data' field")
error_count += 1
continue
# Extract data from the nested field
data_content = extract_json_from_data_field(item['data'])
if data_content is None:
logger.warning(f"Line {line_num}: Failed to parse data field")
error_count += 1
continue
# Create the flattened format
flattened_item = {
'query': item['query']
}
# Add all fields from the data content
expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type']
for field in expected_fields:
if field in data_content:
flattened_item[field] = data_content[field]
else:
# Provide reasonable defaults
if field == 'pos':
flattened_item[field] = []
elif field == 'neg':
flattened_item[field] = []
elif field == 'pos_scores':
flattened_item[field] = [1.0] * len(data_content.get('pos', []))
elif field == 'neg_scores':
flattened_item[field] = [0.0] * len(data_content.get('neg', []))
elif field == 'prompt':
flattened_item[field] = '为此查询生成表示:'
elif field == 'type':
flattened_item[field] = 'normal'
# Validate the converted item
if not flattened_item['pos'] and not flattened_item['neg']:
logger.warning(f"Line {line_num}: No positive or negative passages found")
error_count += 1
continue
# Ensure scores arrays match passage arrays
if len(flattened_item['pos_scores']) != len(flattened_item['pos']):
flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos'])
if len(flattened_item['neg_scores']) != len(flattened_item['neg']):
flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg'])
# Write the flattened item
outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n')
converted_count += 1
except json.JSONDecodeError as e:
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
error_count += 1
continue
except Exception as e:
logger.warning(f"Line {line_num}: Unexpected error - {e}")
error_count += 1
continue
logger.info(f"Conversion complete!")
logger.info(f"Successfully converted: {converted_count} items")
logger.info(f"Errors encountered: {error_count} items")
logger.info(f"Output saved to: {output_file}")
def validate_converted_file(file_path: str, sample_size: int = 5) -> None:
"""
Validate the converted file by showing a few sample entries.
Args:
file_path: Path to the converted JSONL file
sample_size: Number of sample entries to display
"""
logger.info(f"Validating converted file: {file_path}")
try:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
total_lines = len(lines)
logger.info(f"Total converted entries: {total_lines}")
if total_lines == 0:
logger.warning("No entries found in converted file!")
return
# Show sample entries
sample_indices = list(range(min(sample_size, total_lines)))
for i, idx in enumerate(sample_indices):
try:
item = json.loads(lines[idx])
logger.info(f"\nSample {i+1} (line {idx+1}):")
logger.info(f" Query: {item['query'][:100]}...")
logger.info(f" Positives: {len(item.get('pos', []))}")
logger.info(f" Negatives: {len(item.get('neg', []))}")
logger.info(f" Type: {item.get('type', 'N/A')}")
logger.info(f" Prompt: {item.get('prompt', 'N/A')}")
# Show first positive and negative if available
if item.get('pos'):
logger.info(f" First positive: {item['pos'][0][:150]}...")
if item.get('neg'):
logger.info(f" First negative: {item['neg'][0][:150]}...")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {idx+1}: {e}")
except Exception as e:
logger.error(f"Error validating file: {e}")
def main():
parser = argparse.ArgumentParser(
description="Convert BGE-M3 dataset from nested to flat format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl
python convert_m3_dataset.py input.jsonl output.jsonl --validate
"""
)
parser.add_argument(
'input_file',
help='Path to input JSONL file with nested format'
)
parser.add_argument(
'output_file',
help='Path to output JSONL file with flat format'
)
parser.add_argument(
'--validate',
action='store_true',
help='Validate the converted file and show samples'
)
parser.add_argument(
'--sample-size',
type=int,
default=5,
help='Number of sample entries to show during validation (default: 5)'
)
args = parser.parse_args()
# Check if input file exists
if not os.path.exists(args.input_file):
logger.error(f"Input file not found: {args.input_file}")
return 1
# Create output directory if it doesn't exist
output_dir = os.path.dirname(args.output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info(f"Created output directory: {output_dir}")
try:
# Convert the dataset
convert_dataset_format(args.input_file, args.output_file)
# Validate if requested
if args.validate:
validate_converted_file(args.output_file, args.sample_size)
logger.info("Dataset conversion completed successfully!")
return 0
except Exception as e:
logger.error(f"Dataset conversion failed: {e}")
return 1
if __name__ == '__main__':
exit(main())