init commit
This commit is contained in:
687
data/preprocessing.py
Normal file
687
data/preprocessing.py
Normal file
@@ -0,0 +1,687 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
import csv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataPreprocessor:
|
||||
"""
|
||||
Comprehensive data preprocessor for various input formats
|
||||
Supports JSON, JSONL, CSV, TSV, XML, and custom formats
|
||||
|
||||
Notes:
|
||||
- All parameters (tokenizer, max_query_length, max_passage_length, min_passage_length, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Optional[AutoTokenizer] = None,
|
||||
max_query_length: int = 512,
|
||||
max_passage_length: int = 8192,
|
||||
# For production, set min_passage_length to 10
|
||||
min_passage_length: int = 3
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_query_length = max_query_length
|
||||
self.max_passage_length = max_passage_length
|
||||
self.min_passage_length = min_passage_length
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
file_format: str = "auto",
|
||||
validation_split: float = 0.1
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
Preprocess input file and convert to standard JSONL format.
|
||||
Args:
|
||||
input_path: Path to input file
|
||||
output_path: Path to output JSONL file
|
||||
file_format: Input file format (auto, json, jsonl, csv, tsv, xml, dureader, msmarco)
|
||||
validation_split: Fraction of data to use for validation
|
||||
Returns:
|
||||
Tuple of (train_path, val_path)
|
||||
"""
|
||||
# Auto-detect format
|
||||
if file_format == "auto":
|
||||
file_format = self._detect_format(input_path)
|
||||
|
||||
logger.info(f"Processing {input_path} as {file_format} format")
|
||||
|
||||
# Load data based on format
|
||||
if file_format == "jsonl":
|
||||
data = self._load_jsonl(input_path)
|
||||
elif file_format == "json":
|
||||
data = self._load_json(input_path)
|
||||
elif file_format in ["csv", "tsv"]:
|
||||
data = self._load_csv(input_path, delimiter="," if file_format == "csv" else "\t")
|
||||
elif file_format == "xml":
|
||||
data = self._load_xml(input_path)
|
||||
elif file_format == "dureader":
|
||||
data = self._load_dureader(input_path)
|
||||
elif file_format == "msmarco":
|
||||
data = self._load_msmarco(input_path)
|
||||
elif file_format == "beir":
|
||||
data = self._load_beir(input_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {file_format}")
|
||||
|
||||
# Validate and clean data
|
||||
data = self._validate_and_clean(data)
|
||||
|
||||
# Split into train/val
|
||||
if validation_split > 0:
|
||||
train_data, val_data = self._split_data(data, validation_split)
|
||||
|
||||
# Save training data
|
||||
train_path = output_path.replace(".jsonl", "_train.jsonl")
|
||||
self._save_jsonl(train_data, train_path)
|
||||
|
||||
# Save validation data
|
||||
val_path = output_path.replace(".jsonl", "_val.jsonl")
|
||||
self._save_jsonl(val_data, val_path)
|
||||
|
||||
logger.info(f"Saved {len(train_data)} training and {len(val_data)} validation examples")
|
||||
return train_path, val_path
|
||||
else:
|
||||
self._save_jsonl(data, output_path)
|
||||
logger.info(f"Saved {len(data)} examples to {output_path}")
|
||||
return output_path, None
|
||||
|
||||
def _detect_format(self, file_path: str) -> str:
|
||||
"""Auto-detect file format based on extension and content"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
# First try by extension
|
||||
if ext == ".jsonl":
|
||||
return "jsonl"
|
||||
elif ext == ".json":
|
||||
return "json"
|
||||
elif ext == ".csv":
|
||||
return "csv"
|
||||
elif ext == ".tsv":
|
||||
return "tsv"
|
||||
elif ext == ".xml":
|
||||
return "xml"
|
||||
|
||||
# Try to detect by content
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
first_lines = [f.readline().strip() for _ in range(5)]
|
||||
first_content = '\n'.join(first_lines)
|
||||
|
||||
# Check for JSON Lines
|
||||
if all(line.startswith("{") and line.endswith("}") for line in first_lines if line):
|
||||
return "jsonl"
|
||||
|
||||
# Check for JSON
|
||||
if first_content.strip().startswith("[") or first_content.strip().startswith("{"):
|
||||
return "json"
|
||||
|
||||
# Check for XML
|
||||
if first_content.strip().startswith("<?xml") or first_content.strip().startswith("<"):
|
||||
return "xml"
|
||||
|
||||
# Check for TSV (tabs)
|
||||
if any("\t" in line for line in first_lines):
|
||||
return "tsv"
|
||||
|
||||
# Check for CSV (commas)
|
||||
if any("," in line for line in first_lines):
|
||||
return "csv"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not auto-detect format for {file_path}: {e}")
|
||||
|
||||
# Default fallback
|
||||
return "jsonl"
|
||||
|
||||
def _load_jsonl(self, file_path: str) -> List[Dict]:
|
||||
"""Load JSONL file with robust error handling and detailed logging."""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading JSONL"), 1):
|
||||
line = line.strip()
|
||||
if line:
|
||||
total_lines += 1
|
||||
try:
|
||||
data.append(json.loads(line))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e} | Content: {line[:100]}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping malformed line at {file_path}:{line_num}: {e} | Content: {line[:100]}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JSONL file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_json(self, file_path: str) -> List[Dict]:
|
||||
"""Load JSON file"""
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Convert to list format if needed
|
||||
if isinstance(data, dict):
|
||||
if 'data' in data:
|
||||
data = data['data']
|
||||
else:
|
||||
data = [data]
|
||||
elif not isinstance(data, list):
|
||||
data = [data]
|
||||
total_lines = len(data)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in file {file_path}: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JSON file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_csv(self, file_path: str, delimiter: str = ",") -> List[Dict]:
|
||||
"""Load CSV/TSV file with robust parsing"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
# Try different encodings
|
||||
encodings = ['utf-8', 'gb2312', 'gbk', 'big5', 'latin-1']
|
||||
df = None
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
df = pd.read_csv(file_path, delimiter=delimiter, encoding=encoding)
|
||||
logger.info(f"Successfully loaded CSV with {encoding} encoding")
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Error with encoding {encoding} in {file_path}: {e}")
|
||||
continue
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"Could not load CSV file {file_path} with any encoding")
|
||||
|
||||
# Process each row
|
||||
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CSV/TSV"):
|
||||
total_lines += 1
|
||||
try:
|
||||
item = {}
|
||||
# Map common column names to standard format
|
||||
row_dict = {str(k).lower().strip(): v for k, v in row.items() if pd.notna(v)}
|
||||
# Extract query
|
||||
query_fields = ['query', 'question', 'q', 'text', 'sentence']
|
||||
for field in query_fields:
|
||||
if field in row_dict:
|
||||
item['query'] = str(row_dict[field]).strip()
|
||||
break
|
||||
# Extract positive passages
|
||||
pos_fields = ['positive', 'pos', 'passage', 'document', 'answer', 'response']
|
||||
positives = []
|
||||
for field in pos_fields:
|
||||
if field in row_dict:
|
||||
pos_text = str(row_dict[field]).strip()
|
||||
if '|||' in pos_text:
|
||||
positives.extend([p.strip() for p in pos_text.split('|||') if p.strip()])
|
||||
else:
|
||||
positives.append(pos_text)
|
||||
break
|
||||
if 'passage' in row_dict and 'label' in row_dict:
|
||||
try:
|
||||
label = int(float(row_dict['label']))
|
||||
passage = str(row_dict['passage']).strip()
|
||||
if label == 1 and passage:
|
||||
positives.append(passage)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if positives:
|
||||
item['pos'] = positives
|
||||
# Extract negative passages
|
||||
neg_fields = ['negative', 'neg', 'hard_negative']
|
||||
negatives = []
|
||||
for field in neg_fields:
|
||||
if field in row_dict:
|
||||
neg_text = str(row_dict[field]).strip()
|
||||
if '|||' in neg_text:
|
||||
negatives.extend([n.strip() for n in neg_text.split('|||') if n.strip()])
|
||||
else:
|
||||
negatives.append(neg_text)
|
||||
break
|
||||
if 'passage' in row_dict and 'label' in row_dict:
|
||||
try:
|
||||
label = int(float(row_dict['label']))
|
||||
passage = str(row_dict['passage']).strip()
|
||||
if label == 0 and passage:
|
||||
negatives.append(passage)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if negatives:
|
||||
item['neg'] = negatives
|
||||
# Extract scores if available
|
||||
score_fields = ['score', 'relevance', 'rating']
|
||||
for field in score_fields:
|
||||
if field in row_dict:
|
||||
try:
|
||||
score = float(row_dict[field])
|
||||
if 'pos' in item:
|
||||
item['pos_scores'] = [score] * len(item['pos'])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# Only add if we have minimum required fields
|
||||
if 'query' in item and ('pos' in item or 'neg' in item):
|
||||
data.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping malformed row in {file_path} at index {idx}: {e} | Row: {row}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CSV file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_tsv(self, file_path: str) -> List[Dict]:
|
||||
"""Load TSV file (wrapper around _load_csv)."""
|
||||
return self._load_csv(file_path, delimiter="\t")
|
||||
|
||||
def _load_xml(self, file_path: str) -> List[Dict]:
|
||||
"""Load XML file (e.g., TREC format)"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
tree = ET.parse(file_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Handle different XML structures
|
||||
if root.tag == 'questions' or root.tag == 'queries':
|
||||
# TREC-style format
|
||||
for question in tqdm(root.findall('.//question') or root.findall('.//query'), desc="Loading XML"):
|
||||
item = {}
|
||||
|
||||
# Extract query
|
||||
query_elem = question.find('text') or question.find('title') or question.find('query')
|
||||
if query_elem is not None and query_elem.text is not None:
|
||||
item['query'] = query_elem.text.strip()
|
||||
|
||||
# Extract passages/documents
|
||||
passages = []
|
||||
for doc in question.findall('.//document') or question.findall('.//passage'):
|
||||
if doc.text:
|
||||
passages.append(doc.text.strip())
|
||||
|
||||
if passages:
|
||||
item['pos'] = passages
|
||||
|
||||
if 'query' in item:
|
||||
data.append(item)
|
||||
|
||||
else:
|
||||
# Generic XML - try to extract text content
|
||||
for elem in tqdm(root.iter(), desc="Processing XML"):
|
||||
if elem.text and len(elem.text.strip()) > 20: # Minimum length check
|
||||
item = {
|
||||
'query': elem.text.strip()[:100], # First part as query
|
||||
'pos': [elem.text.strip()] # Full text as passage
|
||||
}
|
||||
data.append(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading XML file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_dureader(self, file_path: str) -> List[Dict]:
|
||||
"""Load DuReader format data"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading DuReader"), 1):
|
||||
if line.strip():
|
||||
total_lines += 1
|
||||
try:
|
||||
item = json.loads(line)
|
||||
|
||||
# Convert DuReader format
|
||||
processed = {
|
||||
'query': item.get('question', '').strip(),
|
||||
'pos': [],
|
||||
'neg': []
|
||||
}
|
||||
|
||||
# Extract positive passages from answers
|
||||
if 'answers' in item:
|
||||
for answer in item['answers']:
|
||||
if isinstance(answer, dict) and 'text' in answer:
|
||||
processed['pos'].append(answer['text'].strip())
|
||||
elif isinstance(answer, str):
|
||||
processed['pos'].append(answer.strip())
|
||||
|
||||
# Extract documents as potential negatives
|
||||
if 'documents' in item:
|
||||
for doc in item['documents']:
|
||||
if isinstance(doc, dict):
|
||||
if 'paragraphs' in doc:
|
||||
for para in doc['paragraphs']:
|
||||
para_text = para.strip() if isinstance(para, str) else str(para).strip()
|
||||
if para_text and para_text not in processed['pos']:
|
||||
processed['neg'].append(para_text)
|
||||
elif 'content' in doc:
|
||||
content = doc['content'].strip()
|
||||
if content and content not in processed['pos']:
|
||||
processed['neg'].append(content)
|
||||
|
||||
if processed['query'] and processed['pos']:
|
||||
data.append(processed)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON line in DuReader: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading DuReader file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_msmarco(self, file_path: str) -> List[Dict]:
|
||||
"""Load MS MARCO format data"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading MS MARCO"), 1):
|
||||
line = line.strip()
|
||||
if line:
|
||||
total_lines += 1
|
||||
parts = line.split('\t')
|
||||
if len(parts) >= 2:
|
||||
item = {
|
||||
'query': parts[0].strip(),
|
||||
'pos': [parts[1].strip()],
|
||||
'neg': [p.strip() for p in parts[2:] if p.strip()]
|
||||
}
|
||||
data.append(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading MS MARCO file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_beir(self, file_path: str) -> List[Dict]:
|
||||
"""Load BEIR format data"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
# BEIR format is typically JSONL with specific structure
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading BEIR"), 1):
|
||||
if line.strip():
|
||||
total_lines += 1
|
||||
try:
|
||||
item = json.loads(line)
|
||||
|
||||
# BEIR format usually has 'query', 'positive_passages', 'negative_passages'
|
||||
processed = {}
|
||||
|
||||
if 'query' in item:
|
||||
processed['query'] = item['query'].strip()
|
||||
elif 'text' in item:
|
||||
processed['query'] = item['text'].strip()
|
||||
|
||||
if 'positive_passages' in item:
|
||||
processed['pos'] = [p.strip() for p in item['positive_passages'] if p.strip()]
|
||||
elif 'pos' in item:
|
||||
processed['pos'] = [p.strip() for p in item['pos'] if p.strip()]
|
||||
|
||||
if 'negative_passages' in item:
|
||||
processed['neg'] = [n.strip() for n in item['negative_passages'] if n.strip()]
|
||||
elif 'neg' in item:
|
||||
processed['neg'] = [n.strip() for n in item['neg'] if n.strip()]
|
||||
|
||||
if 'query' in processed and 'pos' in processed:
|
||||
data.append(processed)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON line in BEIR: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading BEIR file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _validate_and_clean(self, data: List[Dict]) -> List[Dict]:
|
||||
"""Validate and clean data"""
|
||||
cleaned = []
|
||||
|
||||
for item in tqdm(data, desc="Validating data"):
|
||||
try:
|
||||
# Ensure required fields
|
||||
if 'query' not in item or not item['query'].strip():
|
||||
continue
|
||||
|
||||
# Clean query
|
||||
if self.tokenizer:
|
||||
encoding = self.tokenizer(item['query'], add_special_tokens=False) # type: ignore[operator]
|
||||
input_ids = encoding['input_ids']
|
||||
if len(input_ids) > self.max_query_length:
|
||||
input_ids = input_ids[:self.max_query_length]
|
||||
item['query'] = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
|
||||
else:
|
||||
if len(item['query']) > self.max_query_length:
|
||||
item['query'] = item['query'][: self.max_query_length]
|
||||
|
||||
# Ensure at least one positive
|
||||
if 'pos' not in item or not item['pos']:
|
||||
continue
|
||||
|
||||
# Clean positives
|
||||
item['pos'] = self._filter_passages_by_length(item['pos'])
|
||||
if not item['pos']:
|
||||
continue
|
||||
|
||||
# Clean negatives
|
||||
if 'neg' in item:
|
||||
item['neg'] = self._filter_passages_by_length(item['neg'])
|
||||
else:
|
||||
item['neg'] = []
|
||||
|
||||
# Filter by length if tokenizer available
|
||||
if self.tokenizer:
|
||||
# Check query length
|
||||
# Check passage lengths
|
||||
item['pos'] = self._filter_passages_by_length(item['pos'])
|
||||
item['neg'] = self._filter_passages_by_length(item['neg'])
|
||||
|
||||
# Ensure we still have data after filtering
|
||||
if item['pos']:
|
||||
cleaned.append(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing item: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Cleaned {len(cleaned)} examples from {len(data)} original")
|
||||
return cleaned
|
||||
|
||||
def _filter_passages_by_length(self, passages: List[str]) -> List[str]:
|
||||
"""Filter passages by token length"""
|
||||
if not self.tokenizer:
|
||||
return passages
|
||||
|
||||
filtered = []
|
||||
|
||||
for passage in passages:
|
||||
try:
|
||||
encoding = self.tokenizer(passage, add_special_tokens=False) # type: ignore[operator]
|
||||
input_ids = encoding['input_ids']
|
||||
|
||||
if len(input_ids) < self.min_passage_length:
|
||||
continue
|
||||
|
||||
if len(input_ids) > self.max_passage_length:
|
||||
# Truncate to max length
|
||||
input_ids = input_ids[:self.max_passage_length]
|
||||
passage = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
|
||||
|
||||
filtered.append(passage)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing passage: {e}")
|
||||
continue
|
||||
|
||||
return filtered
|
||||
|
||||
def _split_data(
|
||||
self,
|
||||
data: List[Dict],
|
||||
validation_split: float
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Split data into train/validation sets"""
|
||||
# Shuffle data
|
||||
shuffled_data = data.copy()
|
||||
random.shuffle(shuffled_data)
|
||||
|
||||
# Calculate split point
|
||||
split_idx = int(len(shuffled_data) * (1 - validation_split))
|
||||
|
||||
train_data = shuffled_data[:split_idx]
|
||||
val_data = shuffled_data[split_idx:]
|
||||
|
||||
return train_data, val_data
|
||||
|
||||
def _save_jsonl(self, data: List[Dict], output_path: str):
|
||||
"""Save data in JSONL format"""
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving JSONL file {output_path}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class DataAugmenter:
|
||||
"""
|
||||
Data augmentation strategies for retrieval training
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: Optional[AutoTokenizer] = None):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def augment_with_paraphrases(
|
||||
self,
|
||||
data: List[Dict],
|
||||
num_paraphrases: int = 2
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Augment queries with paraphrases using rule-based approach
|
||||
"""
|
||||
augmented = []
|
||||
|
||||
for item in tqdm(data, desc="Augmenting with paraphrases"):
|
||||
# Add original
|
||||
augmented.append(item)
|
||||
|
||||
# Generate paraphrases
|
||||
for i in range(num_paraphrases):
|
||||
para_item = item.copy()
|
||||
para_query = self._generate_paraphrase(item['query'], i)
|
||||
if para_query != item['query']:
|
||||
para_item['query'] = para_query
|
||||
para_item['is_augmented'] = True
|
||||
para_item['augmentation_type'] = 'paraphrase'
|
||||
augmented.append(para_item)
|
||||
|
||||
return augmented
|
||||
|
||||
def _generate_paraphrase(self, query: str, variant: int) -> str:
|
||||
"""Generate paraphrase using rule-based patterns"""
|
||||
if self._is_chinese(query):
|
||||
# Chinese paraphrase patterns
|
||||
patterns = [
|
||||
("什么是", "请解释一下"),
|
||||
("如何", "怎样才能"),
|
||||
("为什么", "什么原因导致"),
|
||||
("在哪里", "什么地方"),
|
||||
("是什么", "指的是什么"),
|
||||
("有什么", "存在哪些"),
|
||||
("怎么办", "如何处理"),
|
||||
("多少", "数量是多少")
|
||||
]
|
||||
|
||||
# Apply patterns based on variant
|
||||
if variant == 0:
|
||||
for original, replacement in patterns:
|
||||
if original in query:
|
||||
return query.replace(original, replacement, 1)
|
||||
else:
|
||||
# Add question prefixes
|
||||
prefixes = ["请问", "我想知道", "能否告诉我", "想了解"]
|
||||
return f"{prefixes[variant % len(prefixes)]}{query}"
|
||||
else:
|
||||
# English paraphrase patterns
|
||||
patterns = [
|
||||
("what is", "what does"),
|
||||
("how to", "how can I"),
|
||||
("why does", "what causes"),
|
||||
("where is", "what location"),
|
||||
("when did", "at what time")
|
||||
]
|
||||
|
||||
query_lower = query.lower()
|
||||
for original, replacement in patterns:
|
||||
if original in query_lower:
|
||||
return query_lower.replace(original, replacement, 1)
|
||||
|
||||
# Add question prefixes
|
||||
prefixes = ["Please explain", "Could you tell me", "I need to know", "Help me understand"]
|
||||
return f"{prefixes[variant % len(prefixes)]} {query.lower()}"
|
||||
|
||||
return query
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters"""
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
Reference in New Issue
Block a user