291 lines
10 KiB
Python
291 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BGE-M3 Dataset Format Converter
|
|
|
|
Converts the nested dataset format where training data is stored inside a "data" field
|
|
to the flat format expected by the BGE-M3 training pipeline.
|
|
|
|
Input format:
|
|
{
|
|
"query": "...",
|
|
"data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}"
|
|
}
|
|
|
|
Output format:
|
|
{
|
|
"query": "...",
|
|
"pos": [...],
|
|
"neg": [...],
|
|
"pos_scores": [...],
|
|
"neg_scores": [...],
|
|
"prompt": "...",
|
|
"type": "..."
|
|
}
|
|
"""
|
|
|
|
import json
|
|
import argparse
|
|
import logging
|
|
import os
|
|
from tqdm import tqdm
|
|
from typing import Dict, Any, Optional
|
|
import re
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Extract JSON data from the "data" field which contains a JSON string
|
|
wrapped in markdown code blocks.
|
|
|
|
Args:
|
|
data_field: String containing JSON data, possibly wrapped in markdown
|
|
|
|
Returns:
|
|
Parsed JSON dictionary or None if parsing fails
|
|
"""
|
|
try:
|
|
# Remove markdown code block markers if present
|
|
cleaned_data = data_field.strip()
|
|
|
|
# Remove ```json and ``` markers
|
|
if cleaned_data.startswith('```json'):
|
|
cleaned_data = cleaned_data[7:] # Remove ```json
|
|
elif cleaned_data.startswith('```'):
|
|
cleaned_data = cleaned_data[3:] # Remove ```
|
|
|
|
if cleaned_data.endswith('```'):
|
|
cleaned_data = cleaned_data[:-3] # Remove trailing ```
|
|
|
|
# Clean up any remaining whitespace
|
|
cleaned_data = cleaned_data.strip()
|
|
|
|
# Parse the JSON
|
|
return json.loads(cleaned_data)
|
|
|
|
except (json.JSONDecodeError, AttributeError) as e:
|
|
logger.warning(f"Failed to parse JSON from data field: {e}")
|
|
logger.warning(f"Data content (first 200 chars): {data_field[:200]}")
|
|
return None
|
|
|
|
|
|
def convert_dataset_format(input_file: str, output_file: str) -> None:
|
|
"""
|
|
Convert dataset from nested format to flat format.
|
|
|
|
Args:
|
|
input_file: Path to input JSONL file with nested format
|
|
output_file: Path to output JSONL file with flat format
|
|
"""
|
|
logger.info(f"Converting dataset from {input_file} to {output_file}")
|
|
|
|
# Count total lines for progress bar
|
|
total_lines = 0
|
|
with open(input_file, 'r', encoding='utf-8') as f:
|
|
for _ in f:
|
|
total_lines += 1
|
|
|
|
converted_count = 0
|
|
error_count = 0
|
|
|
|
with open(input_file, 'r', encoding='utf-8') as infile, \
|
|
open(output_file, 'w', encoding='utf-8') as outfile:
|
|
|
|
for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
# Parse the input JSON line
|
|
item = json.loads(line)
|
|
|
|
# Validate required fields
|
|
if 'query' not in item:
|
|
logger.warning(f"Line {line_num}: Missing 'query' field")
|
|
error_count += 1
|
|
continue
|
|
|
|
if 'data' not in item:
|
|
logger.warning(f"Line {line_num}: Missing 'data' field")
|
|
error_count += 1
|
|
continue
|
|
|
|
# Extract data from the nested field
|
|
data_content = extract_json_from_data_field(item['data'])
|
|
if data_content is None:
|
|
logger.warning(f"Line {line_num}: Failed to parse data field")
|
|
error_count += 1
|
|
continue
|
|
|
|
# Create the flattened format
|
|
flattened_item = {
|
|
'query': item['query']
|
|
}
|
|
|
|
# Add all fields from the data content
|
|
expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type']
|
|
for field in expected_fields:
|
|
if field in data_content:
|
|
flattened_item[field] = data_content[field]
|
|
else:
|
|
# Provide reasonable defaults
|
|
if field == 'pos':
|
|
flattened_item[field] = []
|
|
elif field == 'neg':
|
|
flattened_item[field] = []
|
|
elif field == 'pos_scores':
|
|
flattened_item[field] = [1.0] * len(data_content.get('pos', []))
|
|
elif field == 'neg_scores':
|
|
flattened_item[field] = [0.0] * len(data_content.get('neg', []))
|
|
elif field == 'prompt':
|
|
flattened_item[field] = '为此查询生成表示:'
|
|
elif field == 'type':
|
|
flattened_item[field] = 'normal'
|
|
|
|
# Validate the converted item
|
|
if not flattened_item['pos'] and not flattened_item['neg']:
|
|
logger.warning(f"Line {line_num}: No positive or negative passages found")
|
|
error_count += 1
|
|
continue
|
|
|
|
# Ensure scores arrays match passage arrays
|
|
if len(flattened_item['pos_scores']) != len(flattened_item['pos']):
|
|
flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos'])
|
|
|
|
if len(flattened_item['neg_scores']) != len(flattened_item['neg']):
|
|
flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg'])
|
|
|
|
# Write the flattened item
|
|
outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n')
|
|
converted_count += 1
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
|
|
error_count += 1
|
|
continue
|
|
except Exception as e:
|
|
logger.warning(f"Line {line_num}: Unexpected error - {e}")
|
|
error_count += 1
|
|
continue
|
|
|
|
logger.info(f"Conversion complete!")
|
|
logger.info(f"Successfully converted: {converted_count} items")
|
|
logger.info(f"Errors encountered: {error_count} items")
|
|
logger.info(f"Output saved to: {output_file}")
|
|
|
|
|
|
def validate_converted_file(file_path: str, sample_size: int = 5) -> None:
|
|
"""
|
|
Validate the converted file by showing a few sample entries.
|
|
|
|
Args:
|
|
file_path: Path to the converted JSONL file
|
|
sample_size: Number of sample entries to display
|
|
"""
|
|
logger.info(f"Validating converted file: {file_path}")
|
|
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
lines = f.readlines()
|
|
|
|
total_lines = len(lines)
|
|
logger.info(f"Total converted entries: {total_lines}")
|
|
|
|
if total_lines == 0:
|
|
logger.warning("No entries found in converted file!")
|
|
return
|
|
|
|
# Show sample entries
|
|
sample_indices = list(range(min(sample_size, total_lines)))
|
|
|
|
for i, idx in enumerate(sample_indices):
|
|
try:
|
|
item = json.loads(lines[idx])
|
|
logger.info(f"\nSample {i+1} (line {idx+1}):")
|
|
logger.info(f" Query: {item['query'][:100]}...")
|
|
logger.info(f" Positives: {len(item.get('pos', []))}")
|
|
logger.info(f" Negatives: {len(item.get('neg', []))}")
|
|
logger.info(f" Type: {item.get('type', 'N/A')}")
|
|
logger.info(f" Prompt: {item.get('prompt', 'N/A')}")
|
|
|
|
# Show first positive and negative if available
|
|
if item.get('pos'):
|
|
logger.info(f" First positive: {item['pos'][0][:150]}...")
|
|
if item.get('neg'):
|
|
logger.info(f" First negative: {item['neg'][0][:150]}...")
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Invalid JSON at line {idx+1}: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating file: {e}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert BGE-M3 dataset from nested to flat format",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl
|
|
python convert_m3_dataset.py input.jsonl output.jsonl --validate
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
'input_file',
|
|
help='Path to input JSONL file with nested format'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'output_file',
|
|
help='Path to output JSONL file with flat format'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--validate',
|
|
action='store_true',
|
|
help='Validate the converted file and show samples'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--sample-size',
|
|
type=int,
|
|
default=5,
|
|
help='Number of sample entries to show during validation (default: 5)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check if input file exists
|
|
if not os.path.exists(args.input_file):
|
|
logger.error(f"Input file not found: {args.input_file}")
|
|
return 1
|
|
|
|
# Create output directory if it doesn't exist
|
|
output_dir = os.path.dirname(args.output_file)
|
|
if output_dir and not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
logger.info(f"Created output directory: {output_dir}")
|
|
|
|
try:
|
|
# Convert the dataset
|
|
convert_dataset_format(args.input_file, args.output_file)
|
|
|
|
# Validate if requested
|
|
if args.validate:
|
|
validate_converted_file(args.output_file, args.sample_size)
|
|
|
|
logger.info("Dataset conversion completed successfully!")
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Dataset conversion failed: {e}")
|
|
return 1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
exit(main()) |