bge_finetune/data/preprocessing.py
2025-07-22 16:55:25 +08:00

688 lines
28 KiB
Python

import json
import os
import random
from typing import List, Dict, Optional, Tuple, Union
import pandas as pd
import numpy as np
from tqdm import tqdm
import logging
from transformers import AutoTokenizer
import torch
from collections import defaultdict
import re
import xml.etree.ElementTree as ET
import csv
logger = logging.getLogger(__name__)
class DataPreprocessor:
"""
Comprehensive data preprocessor for various input formats
Supports JSON, JSONL, CSV, TSV, XML, and custom formats
Notes:
- All parameters (tokenizer, max_query_length, max_passage_length, min_passage_length, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
tokenizer: Optional[AutoTokenizer] = None,
max_query_length: int = 512,
max_passage_length: int = 8192,
# For production, set min_passage_length to 10
min_passage_length: int = 3
):
self.tokenizer = tokenizer
self.max_query_length = max_query_length
self.max_passage_length = max_passage_length
self.min_passage_length = min_passage_length
def preprocess_file(
self,
input_path: str,
output_path: str,
file_format: str = "auto",
validation_split: float = 0.1
) -> Tuple[str, Optional[str]]:
"""
Preprocess input file and convert to standard JSONL format.
Args:
input_path: Path to input file
output_path: Path to output JSONL file
file_format: Input file format (auto, json, jsonl, csv, tsv, xml, dureader, msmarco)
validation_split: Fraction of data to use for validation
Returns:
Tuple of (train_path, val_path)
"""
# Auto-detect format
if file_format == "auto":
file_format = self._detect_format(input_path)
logger.info(f"Processing {input_path} as {file_format} format")
# Load data based on format
if file_format == "jsonl":
data = self._load_jsonl(input_path)
elif file_format == "json":
data = self._load_json(input_path)
elif file_format in ["csv", "tsv"]:
data = self._load_csv(input_path, delimiter="," if file_format == "csv" else "\t")
elif file_format == "xml":
data = self._load_xml(input_path)
elif file_format == "dureader":
data = self._load_dureader(input_path)
elif file_format == "msmarco":
data = self._load_msmarco(input_path)
elif file_format == "beir":
data = self._load_beir(input_path)
else:
raise ValueError(f"Unsupported format: {file_format}")
# Validate and clean data
data = self._validate_and_clean(data)
# Split into train/val
if validation_split > 0:
train_data, val_data = self._split_data(data, validation_split)
# Save training data
train_path = output_path.replace(".jsonl", "_train.jsonl")
self._save_jsonl(train_data, train_path)
# Save validation data
val_path = output_path.replace(".jsonl", "_val.jsonl")
self._save_jsonl(val_data, val_path)
logger.info(f"Saved {len(train_data)} training and {len(val_data)} validation examples")
return train_path, val_path
else:
self._save_jsonl(data, output_path)
logger.info(f"Saved {len(data)} examples to {output_path}")
return output_path, None
def _detect_format(self, file_path: str) -> str:
"""Auto-detect file format based on extension and content"""
ext = os.path.splitext(file_path)[1].lower()
# First try by extension
if ext == ".jsonl":
return "jsonl"
elif ext == ".json":
return "json"
elif ext == ".csv":
return "csv"
elif ext == ".tsv":
return "tsv"
elif ext == ".xml":
return "xml"
# Try to detect by content
try:
with open(file_path, 'r', encoding='utf-8') as f:
first_lines = [f.readline().strip() for _ in range(5)]
first_content = '\n'.join(first_lines)
# Check for JSON Lines
if all(line.startswith("{") and line.endswith("}") for line in first_lines if line):
return "jsonl"
# Check for JSON
if first_content.strip().startswith("[") or first_content.strip().startswith("{"):
return "json"
# Check for XML
if first_content.strip().startswith("<?xml") or first_content.strip().startswith("<"):
return "xml"
# Check for TSV (tabs)
if any("\t" in line for line in first_lines):
return "tsv"
# Check for CSV (commas)
if any("," in line for line in first_lines):
return "csv"
except Exception as e:
logger.warning(f"Could not auto-detect format for {file_path}: {e}")
# Default fallback
return "jsonl"
def _load_jsonl(self, file_path: str) -> List[Dict]:
"""Load JSONL file with robust error handling and detailed logging."""
data = []
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading JSONL"), 1):
line = line.strip()
if line:
total_lines += 1
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e} | Content: {line[:100]}")
except Exception as e:
logger.warning(f"Skipping malformed line at {file_path}:{line_num}: {e} | Content: {line[:100]}")
except Exception as e:
logger.error(f"Error loading JSONL file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_json(self, file_path: str) -> List[Dict]:
"""Load JSON file"""
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Convert to list format if needed
if isinstance(data, dict):
if 'data' in data:
data = data['data']
else:
data = [data]
elif not isinstance(data, list):
data = [data]
total_lines = len(data)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in file {file_path}: {e}")
raise
except Exception as e:
logger.error(f"Error loading JSON file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_csv(self, file_path: str, delimiter: str = ",") -> List[Dict]:
"""Load CSV/TSV file with robust parsing"""
data = []
total_lines = 0
try:
# Try different encodings
encodings = ['utf-8', 'gb2312', 'gbk', 'big5', 'latin-1']
df = None
for encoding in encodings:
try:
df = pd.read_csv(file_path, delimiter=delimiter, encoding=encoding)
logger.info(f"Successfully loaded CSV with {encoding} encoding")
break
except UnicodeDecodeError:
continue
except Exception as e:
logger.warning(f"Error with encoding {encoding} in {file_path}: {e}")
continue
if df is None:
raise ValueError(f"Could not load CSV file {file_path} with any encoding")
# Process each row
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CSV/TSV"):
total_lines += 1
try:
item = {}
# Map common column names to standard format
row_dict = {str(k).lower().strip(): v for k, v in row.items() if pd.notna(v)}
# Extract query
query_fields = ['query', 'question', 'q', 'text', 'sentence']
for field in query_fields:
if field in row_dict:
item['query'] = str(row_dict[field]).strip()
break
# Extract positive passages
pos_fields = ['positive', 'pos', 'passage', 'document', 'answer', 'response']
positives = []
for field in pos_fields:
if field in row_dict:
pos_text = str(row_dict[field]).strip()
if '|||' in pos_text:
positives.extend([p.strip() for p in pos_text.split('|||') if p.strip()])
else:
positives.append(pos_text)
break
if 'passage' in row_dict and 'label' in row_dict:
try:
label = int(float(row_dict['label']))
passage = str(row_dict['passage']).strip()
if label == 1 and passage:
positives.append(passage)
except (ValueError, TypeError):
pass
if positives:
item['pos'] = positives
# Extract negative passages
neg_fields = ['negative', 'neg', 'hard_negative']
negatives = []
for field in neg_fields:
if field in row_dict:
neg_text = str(row_dict[field]).strip()
if '|||' in neg_text:
negatives.extend([n.strip() for n in neg_text.split('|||') if n.strip()])
else:
negatives.append(neg_text)
break
if 'passage' in row_dict and 'label' in row_dict:
try:
label = int(float(row_dict['label']))
passage = str(row_dict['passage']).strip()
if label == 0 and passage:
negatives.append(passage)
except (ValueError, TypeError):
pass
if negatives:
item['neg'] = negatives
# Extract scores if available
score_fields = ['score', 'relevance', 'rating']
for field in score_fields:
if field in row_dict:
try:
score = float(row_dict[field])
if 'pos' in item:
item['pos_scores'] = [score] * len(item['pos'])
except (ValueError, TypeError):
pass
# Only add if we have minimum required fields
if 'query' in item and ('pos' in item or 'neg' in item):
data.append(item)
except Exception as e:
logger.warning(f"Skipping malformed row in {file_path} at index {idx}: {e} | Row: {row}")
except Exception as e:
logger.error(f"Error loading CSV file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_tsv(self, file_path: str) -> List[Dict]:
"""Load TSV file (wrapper around _load_csv)."""
return self._load_csv(file_path, delimiter="\t")
def _load_xml(self, file_path: str) -> List[Dict]:
"""Load XML file (e.g., TREC format)"""
data = []
total_lines = 0
try:
tree = ET.parse(file_path)
root = tree.getroot()
# Handle different XML structures
if root.tag == 'questions' or root.tag == 'queries':
# TREC-style format
for question in tqdm(root.findall('.//question') or root.findall('.//query'), desc="Loading XML"):
item = {}
# Extract query
query_elem = question.find('text') or question.find('title') or question.find('query')
if query_elem is not None and query_elem.text is not None:
item['query'] = query_elem.text.strip()
# Extract passages/documents
passages = []
for doc in question.findall('.//document') or question.findall('.//passage'):
if doc.text:
passages.append(doc.text.strip())
if passages:
item['pos'] = passages
if 'query' in item:
data.append(item)
else:
# Generic XML - try to extract text content
for elem in tqdm(root.iter(), desc="Processing XML"):
if elem.text and len(elem.text.strip()) > 20: # Minimum length check
item = {
'query': elem.text.strip()[:100], # First part as query
'pos': [elem.text.strip()] # Full text as passage
}
data.append(item)
except Exception as e:
logger.error(f"Error loading XML file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_dureader(self, file_path: str) -> List[Dict]:
"""Load DuReader format data"""
data = []
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading DuReader"), 1):
if line.strip():
total_lines += 1
try:
item = json.loads(line)
# Convert DuReader format
processed = {
'query': item.get('question', '').strip(),
'pos': [],
'neg': []
}
# Extract positive passages from answers
if 'answers' in item:
for answer in item['answers']:
if isinstance(answer, dict) and 'text' in answer:
processed['pos'].append(answer['text'].strip())
elif isinstance(answer, str):
processed['pos'].append(answer.strip())
# Extract documents as potential negatives
if 'documents' in item:
for doc in item['documents']:
if isinstance(doc, dict):
if 'paragraphs' in doc:
for para in doc['paragraphs']:
para_text = para.strip() if isinstance(para, str) else str(para).strip()
if para_text and para_text not in processed['pos']:
processed['neg'].append(para_text)
elif 'content' in doc:
content = doc['content'].strip()
if content and content not in processed['pos']:
processed['neg'].append(content)
if processed['query'] and processed['pos']:
data.append(processed)
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON line in DuReader: {e}")
except Exception as e:
logger.error(f"Error loading DuReader file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_msmarco(self, file_path: str) -> List[Dict]:
"""Load MS MARCO format data"""
data = []
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading MS MARCO"), 1):
line = line.strip()
if line:
total_lines += 1
parts = line.split('\t')
if len(parts) >= 2:
item = {
'query': parts[0].strip(),
'pos': [parts[1].strip()],
'neg': [p.strip() for p in parts[2:] if p.strip()]
}
data.append(item)
except Exception as e:
logger.error(f"Error loading MS MARCO file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_beir(self, file_path: str) -> List[Dict]:
"""Load BEIR format data"""
data = []
total_lines = 0
try:
# BEIR format is typically JSONL with specific structure
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading BEIR"), 1):
if line.strip():
total_lines += 1
try:
item = json.loads(line)
# BEIR format usually has 'query', 'positive_passages', 'negative_passages'
processed = {}
if 'query' in item:
processed['query'] = item['query'].strip()
elif 'text' in item:
processed['query'] = item['text'].strip()
if 'positive_passages' in item:
processed['pos'] = [p.strip() for p in item['positive_passages'] if p.strip()]
elif 'pos' in item:
processed['pos'] = [p.strip() for p in item['pos'] if p.strip()]
if 'negative_passages' in item:
processed['neg'] = [n.strip() for n in item['negative_passages'] if n.strip()]
elif 'neg' in item:
processed['neg'] = [n.strip() for n in item['neg'] if n.strip()]
if 'query' in processed and 'pos' in processed:
data.append(processed)
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON line in BEIR: {e}")
except Exception as e:
logger.error(f"Error loading BEIR file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _validate_and_clean(self, data: List[Dict]) -> List[Dict]:
"""Validate and clean data"""
cleaned = []
for item in tqdm(data, desc="Validating data"):
try:
# Ensure required fields
if 'query' not in item or not item['query'].strip():
continue
# Clean query
if self.tokenizer:
encoding = self.tokenizer(item['query'], add_special_tokens=False) # type: ignore[operator]
input_ids = encoding['input_ids']
if len(input_ids) > self.max_query_length:
input_ids = input_ids[:self.max_query_length]
item['query'] = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
else:
if len(item['query']) > self.max_query_length:
item['query'] = item['query'][: self.max_query_length]
# Ensure at least one positive
if 'pos' not in item or not item['pos']:
continue
# Clean positives
item['pos'] = self._filter_passages_by_length(item['pos'])
if not item['pos']:
continue
# Clean negatives
if 'neg' in item:
item['neg'] = self._filter_passages_by_length(item['neg'])
else:
item['neg'] = []
# Filter by length if tokenizer available
if self.tokenizer:
# Check query length
# Check passage lengths
item['pos'] = self._filter_passages_by_length(item['pos'])
item['neg'] = self._filter_passages_by_length(item['neg'])
# Ensure we still have data after filtering
if item['pos']:
cleaned.append(item)
except Exception as e:
logger.warning(f"Error processing item: {e}")
continue
logger.info(f"Cleaned {len(cleaned)} examples from {len(data)} original")
return cleaned
def _filter_passages_by_length(self, passages: List[str]) -> List[str]:
"""Filter passages by token length"""
if not self.tokenizer:
return passages
filtered = []
for passage in passages:
try:
encoding = self.tokenizer(passage, add_special_tokens=False) # type: ignore[operator]
input_ids = encoding['input_ids']
if len(input_ids) < self.min_passage_length:
continue
if len(input_ids) > self.max_passage_length:
# Truncate to max length
input_ids = input_ids[:self.max_passage_length]
passage = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
filtered.append(passage)
except Exception as e:
logger.warning(f"Error processing passage: {e}")
continue
return filtered
def _split_data(
self,
data: List[Dict],
validation_split: float
) -> Tuple[List[Dict], List[Dict]]:
"""Split data into train/validation sets"""
# Shuffle data
shuffled_data = data.copy()
random.shuffle(shuffled_data)
# Calculate split point
split_idx = int(len(shuffled_data) * (1 - validation_split))
train_data = shuffled_data[:split_idx]
val_data = shuffled_data[split_idx:]
return train_data, val_data
def _save_jsonl(self, data: List[Dict], output_path: str):
"""Save data in JSONL format"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
with open(output_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
except Exception as e:
logger.error(f"Error saving JSONL file {output_path}: {e}")
raise
class DataAugmenter:
"""
Data augmentation strategies for retrieval training
"""
def __init__(self, tokenizer: Optional[AutoTokenizer] = None):
self.tokenizer = tokenizer
def augment_with_paraphrases(
self,
data: List[Dict],
num_paraphrases: int = 2
) -> List[Dict]:
"""
Augment queries with paraphrases using rule-based approach
"""
augmented = []
for item in tqdm(data, desc="Augmenting with paraphrases"):
# Add original
augmented.append(item)
# Generate paraphrases
for i in range(num_paraphrases):
para_item = item.copy()
para_query = self._generate_paraphrase(item['query'], i)
if para_query != item['query']:
para_item['query'] = para_query
para_item['is_augmented'] = True
para_item['augmentation_type'] = 'paraphrase'
augmented.append(para_item)
return augmented
def _generate_paraphrase(self, query: str, variant: int) -> str:
"""Generate paraphrase using rule-based patterns"""
if self._is_chinese(query):
# Chinese paraphrase patterns
patterns = [
("什么是", "请解释一下"),
("如何", "怎样才能"),
("为什么", "什么原因导致"),
("在哪里", "什么地方"),
("是什么", "指的是什么"),
("有什么", "存在哪些"),
("怎么办", "如何处理"),
("多少", "数量是多少")
]
# Apply patterns based on variant
if variant == 0:
for original, replacement in patterns:
if original in query:
return query.replace(original, replacement, 1)
else:
# Add question prefixes
prefixes = ["请问", "我想知道", "能否告诉我", "想了解"]
return f"{prefixes[variant % len(prefixes)]}{query}"
else:
# English paraphrase patterns
patterns = [
("what is", "what does"),
("how to", "how can I"),
("why does", "what causes"),
("where is", "what location"),
("when did", "at what time")
]
query_lower = query.lower()
for original, replacement in patterns:
if original in query_lower:
return query_lower.replace(original, replacement, 1)
# Add question prefixes
prefixes = ["Please explain", "Could you tell me", "I need to know", "Help me understand"]
return f"{prefixes[variant % len(prefixes)]} {query.lower()}"
return query
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters"""
return bool(re.search(r'[\u4e00-\u9fff]', text))