#!/usr/bin/env python3 """ BGE-M3 Dataset Format Converter Converts the nested dataset format where training data is stored inside a "data" field to the flat format expected by the BGE-M3 training pipeline. Input format: { "query": "...", "data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}" } Output format: { "query": "...", "pos": [...], "neg": [...], "pos_scores": [...], "neg_scores": [...], "prompt": "...", "type": "..." } """ import json import argparse import logging import os from tqdm import tqdm from typing import Dict, Any, Optional import re logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]: """ Extract JSON data from the "data" field which contains a JSON string wrapped in markdown code blocks. Args: data_field: String containing JSON data, possibly wrapped in markdown Returns: Parsed JSON dictionary or None if parsing fails """ try: # Remove markdown code block markers if present cleaned_data = data_field.strip() # Remove ```json and ``` markers if cleaned_data.startswith('```json'): cleaned_data = cleaned_data[7:] # Remove ```json elif cleaned_data.startswith('```'): cleaned_data = cleaned_data[3:] # Remove ``` if cleaned_data.endswith('```'): cleaned_data = cleaned_data[:-3] # Remove trailing ``` # Clean up any remaining whitespace cleaned_data = cleaned_data.strip() # Parse the JSON return json.loads(cleaned_data) except (json.JSONDecodeError, AttributeError) as e: logger.warning(f"Failed to parse JSON from data field: {e}") logger.warning(f"Data content (first 200 chars): {data_field[:200]}") return None def convert_dataset_format(input_file: str, output_file: str) -> None: """ Convert dataset from nested format to flat format. Args: input_file: Path to input JSONL file with nested format output_file: Path to output JSONL file with flat format """ logger.info(f"Converting dataset from {input_file} to {output_file}") # Count total lines for progress bar total_lines = 0 with open(input_file, 'r', encoding='utf-8') as f: for _ in f: total_lines += 1 converted_count = 0 error_count = 0 with open(input_file, 'r', encoding='utf-8') as infile, \ open(output_file, 'w', encoding='utf-8') as outfile: for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1): line = line.strip() if not line: continue try: # Parse the input JSON line item = json.loads(line) # Validate required fields if 'query' not in item: logger.warning(f"Line {line_num}: Missing 'query' field") error_count += 1 continue if 'data' not in item: logger.warning(f"Line {line_num}: Missing 'data' field") error_count += 1 continue # Extract data from the nested field data_content = extract_json_from_data_field(item['data']) if data_content is None: logger.warning(f"Line {line_num}: Failed to parse data field") error_count += 1 continue # Create the flattened format flattened_item = { 'query': item['query'] } # Add all fields from the data content expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type'] for field in expected_fields: if field in data_content: flattened_item[field] = data_content[field] else: # Provide reasonable defaults if field == 'pos': flattened_item[field] = [] elif field == 'neg': flattened_item[field] = [] elif field == 'pos_scores': flattened_item[field] = [1.0] * len(data_content.get('pos', [])) elif field == 'neg_scores': flattened_item[field] = [0.0] * len(data_content.get('neg', [])) elif field == 'prompt': flattened_item[field] = '为此查询生成表示:' elif field == 'type': flattened_item[field] = 'normal' # Validate the converted item if not flattened_item['pos'] and not flattened_item['neg']: logger.warning(f"Line {line_num}: No positive or negative passages found") error_count += 1 continue # Ensure scores arrays match passage arrays if len(flattened_item['pos_scores']) != len(flattened_item['pos']): flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos']) if len(flattened_item['neg_scores']) != len(flattened_item['neg']): flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg']) # Write the flattened item outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n') converted_count += 1 except json.JSONDecodeError as e: logger.warning(f"Line {line_num}: Invalid JSON - {e}") error_count += 1 continue except Exception as e: logger.warning(f"Line {line_num}: Unexpected error - {e}") error_count += 1 continue logger.info(f"Conversion complete!") logger.info(f"Successfully converted: {converted_count} items") logger.info(f"Errors encountered: {error_count} items") logger.info(f"Output saved to: {output_file}") def validate_converted_file(file_path: str, sample_size: int = 5) -> None: """ Validate the converted file by showing a few sample entries. Args: file_path: Path to the converted JSONL file sample_size: Number of sample entries to display """ logger.info(f"Validating converted file: {file_path}") try: with open(file_path, 'r', encoding='utf-8') as f: lines = f.readlines() total_lines = len(lines) logger.info(f"Total converted entries: {total_lines}") if total_lines == 0: logger.warning("No entries found in converted file!") return # Show sample entries sample_indices = list(range(min(sample_size, total_lines))) for i, idx in enumerate(sample_indices): try: item = json.loads(lines[idx]) logger.info(f"\nSample {i+1} (line {idx+1}):") logger.info(f" Query: {item['query'][:100]}...") logger.info(f" Positives: {len(item.get('pos', []))}") logger.info(f" Negatives: {len(item.get('neg', []))}") logger.info(f" Type: {item.get('type', 'N/A')}") logger.info(f" Prompt: {item.get('prompt', 'N/A')}") # Show first positive and negative if available if item.get('pos'): logger.info(f" First positive: {item['pos'][0][:150]}...") if item.get('neg'): logger.info(f" First negative: {item['neg'][0][:150]}...") except json.JSONDecodeError as e: logger.warning(f"Invalid JSON at line {idx+1}: {e}") except Exception as e: logger.error(f"Error validating file: {e}") def main(): parser = argparse.ArgumentParser( description="Convert BGE-M3 dataset from nested to flat format", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl python convert_m3_dataset.py input.jsonl output.jsonl --validate """ ) parser.add_argument( 'input_file', help='Path to input JSONL file with nested format' ) parser.add_argument( 'output_file', help='Path to output JSONL file with flat format' ) parser.add_argument( '--validate', action='store_true', help='Validate the converted file and show samples' ) parser.add_argument( '--sample-size', type=int, default=5, help='Number of sample entries to show during validation (default: 5)' ) args = parser.parse_args() # Check if input file exists if not os.path.exists(args.input_file): logger.error(f"Input file not found: {args.input_file}") return 1 # Create output directory if it doesn't exist output_dir = os.path.dirname(args.output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) logger.info(f"Created output directory: {output_dir}") try: # Convert the dataset convert_dataset_format(args.input_file, args.output_file) # Validate if requested if args.validate: validate_converted_file(args.output_file, args.sample_size) logger.info("Dataset conversion completed successfully!") return 0 except Exception as e: logger.error(f"Dataset conversion failed: {e}") return 1 if __name__ == '__main__': exit(main())