688 lines
28 KiB
Python
688 lines
28 KiB
Python
import json
|
|
import os
|
|
import random
|
|
from typing import List, Dict, Optional, Tuple, Union
|
|
import pandas as pd
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import logging
|
|
from transformers import AutoTokenizer
|
|
import torch
|
|
from collections import defaultdict
|
|
import re
|
|
import xml.etree.ElementTree as ET
|
|
import csv
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DataPreprocessor:
|
|
"""
|
|
Comprehensive data preprocessor for various input formats
|
|
Supports JSON, JSONL, CSV, TSV, XML, and custom formats
|
|
|
|
Notes:
|
|
- All parameters (tokenizer, max_query_length, max_passage_length, min_passage_length, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer: Optional[AutoTokenizer] = None,
|
|
max_query_length: int = 512,
|
|
max_passage_length: int = 8192,
|
|
# For production, set min_passage_length to 10
|
|
min_passage_length: int = 3
|
|
):
|
|
self.tokenizer = tokenizer
|
|
self.max_query_length = max_query_length
|
|
self.max_passage_length = max_passage_length
|
|
self.min_passage_length = min_passage_length
|
|
|
|
def preprocess_file(
|
|
self,
|
|
input_path: str,
|
|
output_path: str,
|
|
file_format: str = "auto",
|
|
validation_split: float = 0.1
|
|
) -> Tuple[str, Optional[str]]:
|
|
"""
|
|
Preprocess input file and convert to standard JSONL format.
|
|
Args:
|
|
input_path: Path to input file
|
|
output_path: Path to output JSONL file
|
|
file_format: Input file format (auto, json, jsonl, csv, tsv, xml, dureader, msmarco)
|
|
validation_split: Fraction of data to use for validation
|
|
Returns:
|
|
Tuple of (train_path, val_path)
|
|
"""
|
|
# Auto-detect format
|
|
if file_format == "auto":
|
|
file_format = self._detect_format(input_path)
|
|
|
|
logger.info(f"Processing {input_path} as {file_format} format")
|
|
|
|
# Load data based on format
|
|
if file_format == "jsonl":
|
|
data = self._load_jsonl(input_path)
|
|
elif file_format == "json":
|
|
data = self._load_json(input_path)
|
|
elif file_format in ["csv", "tsv"]:
|
|
data = self._load_csv(input_path, delimiter="," if file_format == "csv" else "\t")
|
|
elif file_format == "xml":
|
|
data = self._load_xml(input_path)
|
|
elif file_format == "dureader":
|
|
data = self._load_dureader(input_path)
|
|
elif file_format == "msmarco":
|
|
data = self._load_msmarco(input_path)
|
|
elif file_format == "beir":
|
|
data = self._load_beir(input_path)
|
|
else:
|
|
raise ValueError(f"Unsupported format: {file_format}")
|
|
|
|
# Validate and clean data
|
|
data = self._validate_and_clean(data)
|
|
|
|
# Split into train/val
|
|
if validation_split > 0:
|
|
train_data, val_data = self._split_data(data, validation_split)
|
|
|
|
# Save training data
|
|
train_path = output_path.replace(".jsonl", "_train.jsonl")
|
|
self._save_jsonl(train_data, train_path)
|
|
|
|
# Save validation data
|
|
val_path = output_path.replace(".jsonl", "_val.jsonl")
|
|
self._save_jsonl(val_data, val_path)
|
|
|
|
logger.info(f"Saved {len(train_data)} training and {len(val_data)} validation examples")
|
|
return train_path, val_path
|
|
else:
|
|
self._save_jsonl(data, output_path)
|
|
logger.info(f"Saved {len(data)} examples to {output_path}")
|
|
return output_path, None
|
|
|
|
def _detect_format(self, file_path: str) -> str:
|
|
"""Auto-detect file format based on extension and content"""
|
|
ext = os.path.splitext(file_path)[1].lower()
|
|
|
|
# First try by extension
|
|
if ext == ".jsonl":
|
|
return "jsonl"
|
|
elif ext == ".json":
|
|
return "json"
|
|
elif ext == ".csv":
|
|
return "csv"
|
|
elif ext == ".tsv":
|
|
return "tsv"
|
|
elif ext == ".xml":
|
|
return "xml"
|
|
|
|
# Try to detect by content
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
first_lines = [f.readline().strip() for _ in range(5)]
|
|
first_content = '\n'.join(first_lines)
|
|
|
|
# Check for JSON Lines
|
|
if all(line.startswith("{") and line.endswith("}") for line in first_lines if line):
|
|
return "jsonl"
|
|
|
|
# Check for JSON
|
|
if first_content.strip().startswith("[") or first_content.strip().startswith("{"):
|
|
return "json"
|
|
|
|
# Check for XML
|
|
if first_content.strip().startswith("<?xml") or first_content.strip().startswith("<"):
|
|
return "xml"
|
|
|
|
# Check for TSV (tabs)
|
|
if any("\t" in line for line in first_lines):
|
|
return "tsv"
|
|
|
|
# Check for CSV (commas)
|
|
if any("," in line for line in first_lines):
|
|
return "csv"
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not auto-detect format for {file_path}: {e}")
|
|
|
|
# Default fallback
|
|
return "jsonl"
|
|
|
|
def _load_jsonl(self, file_path: str) -> List[Dict]:
|
|
"""Load JSONL file with robust error handling and detailed logging."""
|
|
data = []
|
|
total_lines = 0
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line_num, line in enumerate(tqdm(f, desc="Loading JSONL"), 1):
|
|
line = line.strip()
|
|
if line:
|
|
total_lines += 1
|
|
try:
|
|
data.append(json.loads(line))
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e} | Content: {line[:100]}")
|
|
except Exception as e:
|
|
logger.warning(f"Skipping malformed line at {file_path}:{line_num}: {e} | Content: {line[:100]}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading JSONL file {file_path}: {e}")
|
|
raise
|
|
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(
|
|
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
|
)
|
|
return data
|
|
|
|
def _load_json(self, file_path: str) -> List[Dict]:
|
|
"""Load JSON file"""
|
|
total_lines = 0
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
# Convert to list format if needed
|
|
if isinstance(data, dict):
|
|
if 'data' in data:
|
|
data = data['data']
|
|
else:
|
|
data = [data]
|
|
elif not isinstance(data, list):
|
|
data = [data]
|
|
total_lines = len(data)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Invalid JSON in file {file_path}: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error loading JSON file {file_path}: {e}")
|
|
raise
|
|
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(
|
|
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
|
)
|
|
return data
|
|
|
|
def _load_csv(self, file_path: str, delimiter: str = ",") -> List[Dict]:
|
|
"""Load CSV/TSV file with robust parsing"""
|
|
data = []
|
|
total_lines = 0
|
|
try:
|
|
# Try different encodings
|
|
encodings = ['utf-8', 'gb2312', 'gbk', 'big5', 'latin-1']
|
|
df = None
|
|
|
|
for encoding in encodings:
|
|
try:
|
|
df = pd.read_csv(file_path, delimiter=delimiter, encoding=encoding)
|
|
logger.info(f"Successfully loaded CSV with {encoding} encoding")
|
|
break
|
|
except UnicodeDecodeError:
|
|
continue
|
|
except Exception as e:
|
|
logger.warning(f"Error with encoding {encoding} in {file_path}: {e}")
|
|
continue
|
|
|
|
if df is None:
|
|
raise ValueError(f"Could not load CSV file {file_path} with any encoding")
|
|
|
|
# Process each row
|
|
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CSV/TSV"):
|
|
total_lines += 1
|
|
try:
|
|
item = {}
|
|
# Map common column names to standard format
|
|
row_dict = {str(k).lower().strip(): v for k, v in row.items() if pd.notna(v)}
|
|
# Extract query
|
|
query_fields = ['query', 'question', 'q', 'text', 'sentence']
|
|
for field in query_fields:
|
|
if field in row_dict:
|
|
item['query'] = str(row_dict[field]).strip()
|
|
break
|
|
# Extract positive passages
|
|
pos_fields = ['positive', 'pos', 'passage', 'document', 'answer', 'response']
|
|
positives = []
|
|
for field in pos_fields:
|
|
if field in row_dict:
|
|
pos_text = str(row_dict[field]).strip()
|
|
if '|||' in pos_text:
|
|
positives.extend([p.strip() for p in pos_text.split('|||') if p.strip()])
|
|
else:
|
|
positives.append(pos_text)
|
|
break
|
|
if 'passage' in row_dict and 'label' in row_dict:
|
|
try:
|
|
label = int(float(row_dict['label']))
|
|
passage = str(row_dict['passage']).strip()
|
|
if label == 1 and passage:
|
|
positives.append(passage)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if positives:
|
|
item['pos'] = positives
|
|
# Extract negative passages
|
|
neg_fields = ['negative', 'neg', 'hard_negative']
|
|
negatives = []
|
|
for field in neg_fields:
|
|
if field in row_dict:
|
|
neg_text = str(row_dict[field]).strip()
|
|
if '|||' in neg_text:
|
|
negatives.extend([n.strip() for n in neg_text.split('|||') if n.strip()])
|
|
else:
|
|
negatives.append(neg_text)
|
|
break
|
|
if 'passage' in row_dict and 'label' in row_dict:
|
|
try:
|
|
label = int(float(row_dict['label']))
|
|
passage = str(row_dict['passage']).strip()
|
|
if label == 0 and passage:
|
|
negatives.append(passage)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if negatives:
|
|
item['neg'] = negatives
|
|
# Extract scores if available
|
|
score_fields = ['score', 'relevance', 'rating']
|
|
for field in score_fields:
|
|
if field in row_dict:
|
|
try:
|
|
score = float(row_dict[field])
|
|
if 'pos' in item:
|
|
item['pos_scores'] = [score] * len(item['pos'])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
# Only add if we have minimum required fields
|
|
if 'query' in item and ('pos' in item or 'neg' in item):
|
|
data.append(item)
|
|
except Exception as e:
|
|
logger.warning(f"Skipping malformed row in {file_path} at index {idx}: {e} | Row: {row}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading CSV file {file_path}: {e}")
|
|
raise
|
|
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(
|
|
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
|
)
|
|
return data
|
|
|
|
def _load_tsv(self, file_path: str) -> List[Dict]:
|
|
"""Load TSV file (wrapper around _load_csv)."""
|
|
return self._load_csv(file_path, delimiter="\t")
|
|
|
|
def _load_xml(self, file_path: str) -> List[Dict]:
|
|
"""Load XML file (e.g., TREC format)"""
|
|
data = []
|
|
total_lines = 0
|
|
try:
|
|
tree = ET.parse(file_path)
|
|
root = tree.getroot()
|
|
|
|
# Handle different XML structures
|
|
if root.tag == 'questions' or root.tag == 'queries':
|
|
# TREC-style format
|
|
for question in tqdm(root.findall('.//question') or root.findall('.//query'), desc="Loading XML"):
|
|
item = {}
|
|
|
|
# Extract query
|
|
query_elem = question.find('text') or question.find('title') or question.find('query')
|
|
if query_elem is not None and query_elem.text is not None:
|
|
item['query'] = query_elem.text.strip()
|
|
|
|
# Extract passages/documents
|
|
passages = []
|
|
for doc in question.findall('.//document') or question.findall('.//passage'):
|
|
if doc.text:
|
|
passages.append(doc.text.strip())
|
|
|
|
if passages:
|
|
item['pos'] = passages
|
|
|
|
if 'query' in item:
|
|
data.append(item)
|
|
|
|
else:
|
|
# Generic XML - try to extract text content
|
|
for elem in tqdm(root.iter(), desc="Processing XML"):
|
|
if elem.text and len(elem.text.strip()) > 20: # Minimum length check
|
|
item = {
|
|
'query': elem.text.strip()[:100], # First part as query
|
|
'pos': [elem.text.strip()] # Full text as passage
|
|
}
|
|
data.append(item)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading XML file {file_path}: {e}")
|
|
raise
|
|
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(
|
|
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
|
)
|
|
return data
|
|
|
|
def _load_dureader(self, file_path: str) -> List[Dict]:
|
|
"""Load DuReader format data"""
|
|
data = []
|
|
total_lines = 0
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line_num, line in enumerate(tqdm(f, desc="Loading DuReader"), 1):
|
|
if line.strip():
|
|
total_lines += 1
|
|
try:
|
|
item = json.loads(line)
|
|
|
|
# Convert DuReader format
|
|
processed = {
|
|
'query': item.get('question', '').strip(),
|
|
'pos': [],
|
|
'neg': []
|
|
}
|
|
|
|
# Extract positive passages from answers
|
|
if 'answers' in item:
|
|
for answer in item['answers']:
|
|
if isinstance(answer, dict) and 'text' in answer:
|
|
processed['pos'].append(answer['text'].strip())
|
|
elif isinstance(answer, str):
|
|
processed['pos'].append(answer.strip())
|
|
|
|
# Extract documents as potential negatives
|
|
if 'documents' in item:
|
|
for doc in item['documents']:
|
|
if isinstance(doc, dict):
|
|
if 'paragraphs' in doc:
|
|
for para in doc['paragraphs']:
|
|
para_text = para.strip() if isinstance(para, str) else str(para).strip()
|
|
if para_text and para_text not in processed['pos']:
|
|
processed['neg'].append(para_text)
|
|
elif 'content' in doc:
|
|
content = doc['content'].strip()
|
|
if content and content not in processed['pos']:
|
|
processed['neg'].append(content)
|
|
|
|
if processed['query'] and processed['pos']:
|
|
data.append(processed)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Skipping invalid JSON line in DuReader: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading DuReader file {file_path}: {e}")
|
|
raise
|
|
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(
|
|
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
|
)
|
|
return data
|
|
|
|
def _load_msmarco(self, file_path: str) -> List[Dict]:
|
|
"""Load MS MARCO format data"""
|
|
data = []
|
|
total_lines = 0
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line_num, line in enumerate(tqdm(f, desc="Loading MS MARCO"), 1):
|
|
line = line.strip()
|
|
if line:
|
|
total_lines += 1
|
|
parts = line.split('\t')
|
|
if len(parts) >= 2:
|
|
item = {
|
|
'query': parts[0].strip(),
|
|
'pos': [parts[1].strip()],
|
|
'neg': [p.strip() for p in parts[2:] if p.strip()]
|
|
}
|
|
data.append(item)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading MS MARCO file {file_path}: {e}")
|
|
raise
|
|
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(
|
|
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
|
)
|
|
return data
|
|
|
|
def _load_beir(self, file_path: str) -> List[Dict]:
|
|
"""Load BEIR format data"""
|
|
data = []
|
|
total_lines = 0
|
|
try:
|
|
# BEIR format is typically JSONL with specific structure
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line_num, line in enumerate(tqdm(f, desc="Loading BEIR"), 1):
|
|
if line.strip():
|
|
total_lines += 1
|
|
try:
|
|
item = json.loads(line)
|
|
|
|
# BEIR format usually has 'query', 'positive_passages', 'negative_passages'
|
|
processed = {}
|
|
|
|
if 'query' in item:
|
|
processed['query'] = item['query'].strip()
|
|
elif 'text' in item:
|
|
processed['query'] = item['text'].strip()
|
|
|
|
if 'positive_passages' in item:
|
|
processed['pos'] = [p.strip() for p in item['positive_passages'] if p.strip()]
|
|
elif 'pos' in item:
|
|
processed['pos'] = [p.strip() for p in item['pos'] if p.strip()]
|
|
|
|
if 'negative_passages' in item:
|
|
processed['neg'] = [n.strip() for n in item['negative_passages'] if n.strip()]
|
|
elif 'neg' in item:
|
|
processed['neg'] = [n.strip() for n in item['neg'] if n.strip()]
|
|
|
|
if 'query' in processed and 'pos' in processed:
|
|
data.append(processed)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Skipping invalid JSON line in BEIR: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading BEIR file {file_path}: {e}")
|
|
raise
|
|
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
|
logger.warning(
|
|
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
|
)
|
|
return data
|
|
|
|
def _validate_and_clean(self, data: List[Dict]) -> List[Dict]:
|
|
"""Validate and clean data"""
|
|
cleaned = []
|
|
|
|
for item in tqdm(data, desc="Validating data"):
|
|
try:
|
|
# Ensure required fields
|
|
if 'query' not in item or not item['query'].strip():
|
|
continue
|
|
|
|
# Clean query
|
|
if self.tokenizer:
|
|
encoding = self.tokenizer(item['query'], add_special_tokens=False) # type: ignore[operator]
|
|
input_ids = encoding['input_ids']
|
|
if len(input_ids) > self.max_query_length:
|
|
input_ids = input_ids[:self.max_query_length]
|
|
item['query'] = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
|
|
else:
|
|
if len(item['query']) > self.max_query_length:
|
|
item['query'] = item['query'][: self.max_query_length]
|
|
|
|
# Ensure at least one positive
|
|
if 'pos' not in item or not item['pos']:
|
|
continue
|
|
|
|
# Clean positives
|
|
item['pos'] = self._filter_passages_by_length(item['pos'])
|
|
if not item['pos']:
|
|
continue
|
|
|
|
# Clean negatives
|
|
if 'neg' in item:
|
|
item['neg'] = self._filter_passages_by_length(item['neg'])
|
|
else:
|
|
item['neg'] = []
|
|
|
|
# Filter by length if tokenizer available
|
|
if self.tokenizer:
|
|
# Check query length
|
|
# Check passage lengths
|
|
item['pos'] = self._filter_passages_by_length(item['pos'])
|
|
item['neg'] = self._filter_passages_by_length(item['neg'])
|
|
|
|
# Ensure we still have data after filtering
|
|
if item['pos']:
|
|
cleaned.append(item)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error processing item: {e}")
|
|
continue
|
|
|
|
logger.info(f"Cleaned {len(cleaned)} examples from {len(data)} original")
|
|
return cleaned
|
|
|
|
def _filter_passages_by_length(self, passages: List[str]) -> List[str]:
|
|
"""Filter passages by token length"""
|
|
if not self.tokenizer:
|
|
return passages
|
|
|
|
filtered = []
|
|
|
|
for passage in passages:
|
|
try:
|
|
encoding = self.tokenizer(passage, add_special_tokens=False) # type: ignore[operator]
|
|
input_ids = encoding['input_ids']
|
|
|
|
if len(input_ids) < self.min_passage_length:
|
|
continue
|
|
|
|
if len(input_ids) > self.max_passage_length:
|
|
# Truncate to max length
|
|
input_ids = input_ids[:self.max_passage_length]
|
|
passage = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
|
|
|
|
filtered.append(passage)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error processing passage: {e}")
|
|
continue
|
|
|
|
return filtered
|
|
|
|
def _split_data(
|
|
self,
|
|
data: List[Dict],
|
|
validation_split: float
|
|
) -> Tuple[List[Dict], List[Dict]]:
|
|
"""Split data into train/validation sets"""
|
|
# Shuffle data
|
|
shuffled_data = data.copy()
|
|
random.shuffle(shuffled_data)
|
|
|
|
# Calculate split point
|
|
split_idx = int(len(shuffled_data) * (1 - validation_split))
|
|
|
|
train_data = shuffled_data[:split_idx]
|
|
val_data = shuffled_data[split_idx:]
|
|
|
|
return train_data, val_data
|
|
|
|
def _save_jsonl(self, data: List[Dict], output_path: str):
|
|
"""Save data in JSONL format"""
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
|
try:
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
for item in data:
|
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
except Exception as e:
|
|
logger.error(f"Error saving JSONL file {output_path}: {e}")
|
|
raise
|
|
|
|
|
|
class DataAugmenter:
|
|
"""
|
|
Data augmentation strategies for retrieval training
|
|
"""
|
|
|
|
def __init__(self, tokenizer: Optional[AutoTokenizer] = None):
|
|
self.tokenizer = tokenizer
|
|
|
|
def augment_with_paraphrases(
|
|
self,
|
|
data: List[Dict],
|
|
num_paraphrases: int = 2
|
|
) -> List[Dict]:
|
|
"""
|
|
Augment queries with paraphrases using rule-based approach
|
|
"""
|
|
augmented = []
|
|
|
|
for item in tqdm(data, desc="Augmenting with paraphrases"):
|
|
# Add original
|
|
augmented.append(item)
|
|
|
|
# Generate paraphrases
|
|
for i in range(num_paraphrases):
|
|
para_item = item.copy()
|
|
para_query = self._generate_paraphrase(item['query'], i)
|
|
if para_query != item['query']:
|
|
para_item['query'] = para_query
|
|
para_item['is_augmented'] = True
|
|
para_item['augmentation_type'] = 'paraphrase'
|
|
augmented.append(para_item)
|
|
|
|
return augmented
|
|
|
|
def _generate_paraphrase(self, query: str, variant: int) -> str:
|
|
"""Generate paraphrase using rule-based patterns"""
|
|
if self._is_chinese(query):
|
|
# Chinese paraphrase patterns
|
|
patterns = [
|
|
("什么是", "请解释一下"),
|
|
("如何", "怎样才能"),
|
|
("为什么", "什么原因导致"),
|
|
("在哪里", "什么地方"),
|
|
("是什么", "指的是什么"),
|
|
("有什么", "存在哪些"),
|
|
("怎么办", "如何处理"),
|
|
("多少", "数量是多少")
|
|
]
|
|
|
|
# Apply patterns based on variant
|
|
if variant == 0:
|
|
for original, replacement in patterns:
|
|
if original in query:
|
|
return query.replace(original, replacement, 1)
|
|
else:
|
|
# Add question prefixes
|
|
prefixes = ["请问", "我想知道", "能否告诉我", "想了解"]
|
|
return f"{prefixes[variant % len(prefixes)]}{query}"
|
|
else:
|
|
# English paraphrase patterns
|
|
patterns = [
|
|
("what is", "what does"),
|
|
("how to", "how can I"),
|
|
("why does", "what causes"),
|
|
("where is", "what location"),
|
|
("when did", "at what time")
|
|
]
|
|
|
|
query_lower = query.lower()
|
|
for original, replacement in patterns:
|
|
if original in query_lower:
|
|
return query_lower.replace(original, replacement, 1)
|
|
|
|
# Add question prefixes
|
|
prefixes = ["Please explain", "Could you tell me", "I need to know", "Help me understand"]
|
|
return f"{prefixes[variant % len(prefixes)]} {query.lower()}"
|
|
|
|
return query
|
|
|
|
def _is_chinese(self, text: str) -> bool:
|
|
"""Check if text contains Chinese characters"""
|
|
return bool(re.search(r'[\u4e00-\u9fff]', text))
|