init commit
This commit is contained in:
291
scripts/convert_m3_dataset.py
Normal file
291
scripts/convert_m3_dataset.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE-M3 Dataset Format Converter
|
||||
|
||||
Converts the nested dataset format where training data is stored inside a "data" field
|
||||
to the flat format expected by the BGE-M3 training pipeline.
|
||||
|
||||
Input format:
|
||||
{
|
||||
"query": "...",
|
||||
"data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}"
|
||||
}
|
||||
|
||||
Output format:
|
||||
{
|
||||
"query": "...",
|
||||
"pos": [...],
|
||||
"neg": [...],
|
||||
"pos_scores": [...],
|
||||
"neg_scores": [...],
|
||||
"prompt": "...",
|
||||
"type": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Any, Optional
|
||||
import re
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract JSON data from the "data" field which contains a JSON string
|
||||
wrapped in markdown code blocks.
|
||||
|
||||
Args:
|
||||
data_field: String containing JSON data, possibly wrapped in markdown
|
||||
|
||||
Returns:
|
||||
Parsed JSON dictionary or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Remove markdown code block markers if present
|
||||
cleaned_data = data_field.strip()
|
||||
|
||||
# Remove ```json and ``` markers
|
||||
if cleaned_data.startswith('```json'):
|
||||
cleaned_data = cleaned_data[7:] # Remove ```json
|
||||
elif cleaned_data.startswith('```'):
|
||||
cleaned_data = cleaned_data[3:] # Remove ```
|
||||
|
||||
if cleaned_data.endswith('```'):
|
||||
cleaned_data = cleaned_data[:-3] # Remove trailing ```
|
||||
|
||||
# Clean up any remaining whitespace
|
||||
cleaned_data = cleaned_data.strip()
|
||||
|
||||
# Parse the JSON
|
||||
return json.loads(cleaned_data)
|
||||
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
logger.warning(f"Failed to parse JSON from data field: {e}")
|
||||
logger.warning(f"Data content (first 200 chars): {data_field[:200]}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_dataset_format(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Convert dataset from nested format to flat format.
|
||||
|
||||
Args:
|
||||
input_file: Path to input JSONL file with nested format
|
||||
output_file: Path to output JSONL file with flat format
|
||||
"""
|
||||
logger.info(f"Converting dataset from {input_file} to {output_file}")
|
||||
|
||||
# Count total lines for progress bar
|
||||
total_lines = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
total_lines += 1
|
||||
|
||||
converted_count = 0
|
||||
error_count = 0
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as infile, \
|
||||
open(output_file, 'w', encoding='utf-8') as outfile:
|
||||
|
||||
for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse the input JSON line
|
||||
item = json.loads(line)
|
||||
|
||||
# Validate required fields
|
||||
if 'query' not in item:
|
||||
logger.warning(f"Line {line_num}: Missing 'query' field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
if 'data' not in item:
|
||||
logger.warning(f"Line {line_num}: Missing 'data' field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Extract data from the nested field
|
||||
data_content = extract_json_from_data_field(item['data'])
|
||||
if data_content is None:
|
||||
logger.warning(f"Line {line_num}: Failed to parse data field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Create the flattened format
|
||||
flattened_item = {
|
||||
'query': item['query']
|
||||
}
|
||||
|
||||
# Add all fields from the data content
|
||||
expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type']
|
||||
for field in expected_fields:
|
||||
if field in data_content:
|
||||
flattened_item[field] = data_content[field]
|
||||
else:
|
||||
# Provide reasonable defaults
|
||||
if field == 'pos':
|
||||
flattened_item[field] = []
|
||||
elif field == 'neg':
|
||||
flattened_item[field] = []
|
||||
elif field == 'pos_scores':
|
||||
flattened_item[field] = [1.0] * len(data_content.get('pos', []))
|
||||
elif field == 'neg_scores':
|
||||
flattened_item[field] = [0.0] * len(data_content.get('neg', []))
|
||||
elif field == 'prompt':
|
||||
flattened_item[field] = '为此查询生成表示:'
|
||||
elif field == 'type':
|
||||
flattened_item[field] = 'normal'
|
||||
|
||||
# Validate the converted item
|
||||
if not flattened_item['pos'] and not flattened_item['neg']:
|
||||
logger.warning(f"Line {line_num}: No positive or negative passages found")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Ensure scores arrays match passage arrays
|
||||
if len(flattened_item['pos_scores']) != len(flattened_item['pos']):
|
||||
flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos'])
|
||||
|
||||
if len(flattened_item['neg_scores']) != len(flattened_item['neg']):
|
||||
flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg'])
|
||||
|
||||
# Write the flattened item
|
||||
outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n')
|
||||
converted_count += 1
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
|
||||
error_count += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Line {line_num}: Unexpected error - {e}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
logger.info(f"Conversion complete!")
|
||||
logger.info(f"Successfully converted: {converted_count} items")
|
||||
logger.info(f"Errors encountered: {error_count} items")
|
||||
logger.info(f"Output saved to: {output_file}")
|
||||
|
||||
|
||||
def validate_converted_file(file_path: str, sample_size: int = 5) -> None:
|
||||
"""
|
||||
Validate the converted file by showing a few sample entries.
|
||||
|
||||
Args:
|
||||
file_path: Path to the converted JSONL file
|
||||
sample_size: Number of sample entries to display
|
||||
"""
|
||||
logger.info(f"Validating converted file: {file_path}")
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
total_lines = len(lines)
|
||||
logger.info(f"Total converted entries: {total_lines}")
|
||||
|
||||
if total_lines == 0:
|
||||
logger.warning("No entries found in converted file!")
|
||||
return
|
||||
|
||||
# Show sample entries
|
||||
sample_indices = list(range(min(sample_size, total_lines)))
|
||||
|
||||
for i, idx in enumerate(sample_indices):
|
||||
try:
|
||||
item = json.loads(lines[idx])
|
||||
logger.info(f"\nSample {i+1} (line {idx+1}):")
|
||||
logger.info(f" Query: {item['query'][:100]}...")
|
||||
logger.info(f" Positives: {len(item.get('pos', []))}")
|
||||
logger.info(f" Negatives: {len(item.get('neg', []))}")
|
||||
logger.info(f" Type: {item.get('type', 'N/A')}")
|
||||
logger.info(f" Prompt: {item.get('prompt', 'N/A')}")
|
||||
|
||||
# Show first positive and negative if available
|
||||
if item.get('pos'):
|
||||
logger.info(f" First positive: {item['pos'][0][:150]}...")
|
||||
if item.get('neg'):
|
||||
logger.info(f" First negative: {item['neg'][0][:150]}...")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON at line {idx+1}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating file: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert BGE-M3 dataset from nested to flat format",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl
|
||||
python convert_m3_dataset.py input.jsonl output.jsonl --validate
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'input_file',
|
||||
help='Path to input JSONL file with nested format'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'output_file',
|
||||
help='Path to output JSONL file with flat format'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--validate',
|
||||
action='store_true',
|
||||
help='Validate the converted file and show samples'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--sample-size',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Number of sample entries to show during validation (default: 5)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check if input file exists
|
||||
if not os.path.exists(args.input_file):
|
||||
logger.error(f"Input file not found: {args.input_file}")
|
||||
return 1
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = os.path.dirname(args.output_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
logger.info(f"Created output directory: {output_dir}")
|
||||
|
||||
try:
|
||||
# Convert the dataset
|
||||
convert_dataset_format(args.input_file, args.output_file)
|
||||
|
||||
# Validate if requested
|
||||
if args.validate:
|
||||
validate_converted_file(args.output_file, args.sample_size)
|
||||
|
||||
logger.info("Dataset conversion completed successfully!")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dataset conversion failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user