import json import os import random from typing import List, Dict, Optional, Tuple, Union import pandas as pd import numpy as np from tqdm import tqdm import logging from transformers import AutoTokenizer import torch from collections import defaultdict import re import xml.etree.ElementTree as ET import csv logger = logging.getLogger(__name__) class DataPreprocessor: """ Comprehensive data preprocessor for various input formats Supports JSON, JSONL, CSV, TSV, XML, and custom formats Notes: - All parameters (tokenizer, max_query_length, max_passage_length, min_passage_length, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, tokenizer: Optional[AutoTokenizer] = None, max_query_length: int = 512, max_passage_length: int = 8192, # For production, set min_passage_length to 10 min_passage_length: int = 3 ): self.tokenizer = tokenizer self.max_query_length = max_query_length self.max_passage_length = max_passage_length self.min_passage_length = min_passage_length def preprocess_file( self, input_path: str, output_path: str, file_format: str = "auto", validation_split: float = 0.1 ) -> Tuple[str, Optional[str]]: """ Preprocess input file and convert to standard JSONL format. Args: input_path: Path to input file output_path: Path to output JSONL file file_format: Input file format (auto, json, jsonl, csv, tsv, xml, dureader, msmarco) validation_split: Fraction of data to use for validation Returns: Tuple of (train_path, val_path) """ # Auto-detect format if file_format == "auto": file_format = self._detect_format(input_path) logger.info(f"Processing {input_path} as {file_format} format") # Load data based on format if file_format == "jsonl": data = self._load_jsonl(input_path) elif file_format == "json": data = self._load_json(input_path) elif file_format in ["csv", "tsv"]: data = self._load_csv(input_path, delimiter="," if file_format == "csv" else "\t") elif file_format == "xml": data = self._load_xml(input_path) elif file_format == "dureader": data = self._load_dureader(input_path) elif file_format == "msmarco": data = self._load_msmarco(input_path) elif file_format == "beir": data = self._load_beir(input_path) else: raise ValueError(f"Unsupported format: {file_format}") # Validate and clean data data = self._validate_and_clean(data) # Split into train/val if validation_split > 0: train_data, val_data = self._split_data(data, validation_split) # Save training data train_path = output_path.replace(".jsonl", "_train.jsonl") self._save_jsonl(train_data, train_path) # Save validation data val_path = output_path.replace(".jsonl", "_val.jsonl") self._save_jsonl(val_data, val_path) logger.info(f"Saved {len(train_data)} training and {len(val_data)} validation examples") return train_path, val_path else: self._save_jsonl(data, output_path) logger.info(f"Saved {len(data)} examples to {output_path}") return output_path, None def _detect_format(self, file_path: str) -> str: """Auto-detect file format based on extension and content""" ext = os.path.splitext(file_path)[1].lower() # First try by extension if ext == ".jsonl": return "jsonl" elif ext == ".json": return "json" elif ext == ".csv": return "csv" elif ext == ".tsv": return "tsv" elif ext == ".xml": return "xml" # Try to detect by content try: with open(file_path, 'r', encoding='utf-8') as f: first_lines = [f.readline().strip() for _ in range(5)] first_content = '\n'.join(first_lines) # Check for JSON Lines if all(line.startswith("{") and line.endswith("}") for line in first_lines if line): return "jsonl" # Check for JSON if first_content.strip().startswith("[") or first_content.strip().startswith("{"): return "json" # Check for XML if first_content.strip().startswith(" List[Dict]: """Load JSONL file with robust error handling and detailed logging.""" data = [] total_lines = 0 try: with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(tqdm(f, desc="Loading JSONL"), 1): line = line.strip() if line: total_lines += 1 try: data.append(json.loads(line)) except json.JSONDecodeError as e: logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e} | Content: {line[:100]}") except Exception as e: logger.warning(f"Skipping malformed line at {file_path}:{line_num}: {e} | Content: {line[:100]}") except Exception as e: logger.error(f"Error loading JSONL file {file_path}: {e}") raise if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning( f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}." ) return data def _load_json(self, file_path: str) -> List[Dict]: """Load JSON file""" total_lines = 0 try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) # Convert to list format if needed if isinstance(data, dict): if 'data' in data: data = data['data'] else: data = [data] elif not isinstance(data, list): data = [data] total_lines = len(data) except json.JSONDecodeError as e: logger.error(f"Invalid JSON in file {file_path}: {e}") raise except Exception as e: logger.error(f"Error loading JSON file {file_path}: {e}") raise if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning( f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}." ) return data def _load_csv(self, file_path: str, delimiter: str = ",") -> List[Dict]: """Load CSV/TSV file with robust parsing""" data = [] total_lines = 0 try: # Try different encodings encodings = ['utf-8', 'gb2312', 'gbk', 'big5', 'latin-1'] df = None for encoding in encodings: try: df = pd.read_csv(file_path, delimiter=delimiter, encoding=encoding) logger.info(f"Successfully loaded CSV with {encoding} encoding") break except UnicodeDecodeError: continue except Exception as e: logger.warning(f"Error with encoding {encoding} in {file_path}: {e}") continue if df is None: raise ValueError(f"Could not load CSV file {file_path} with any encoding") # Process each row for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CSV/TSV"): total_lines += 1 try: item = {} # Map common column names to standard format row_dict = {str(k).lower().strip(): v for k, v in row.items() if pd.notna(v)} # Extract query query_fields = ['query', 'question', 'q', 'text', 'sentence'] for field in query_fields: if field in row_dict: item['query'] = str(row_dict[field]).strip() break # Extract positive passages pos_fields = ['positive', 'pos', 'passage', 'document', 'answer', 'response'] positives = [] for field in pos_fields: if field in row_dict: pos_text = str(row_dict[field]).strip() if '|||' in pos_text: positives.extend([p.strip() for p in pos_text.split('|||') if p.strip()]) else: positives.append(pos_text) break if 'passage' in row_dict and 'label' in row_dict: try: label = int(float(row_dict['label'])) passage = str(row_dict['passage']).strip() if label == 1 and passage: positives.append(passage) except (ValueError, TypeError): pass if positives: item['pos'] = positives # Extract negative passages neg_fields = ['negative', 'neg', 'hard_negative'] negatives = [] for field in neg_fields: if field in row_dict: neg_text = str(row_dict[field]).strip() if '|||' in neg_text: negatives.extend([n.strip() for n in neg_text.split('|||') if n.strip()]) else: negatives.append(neg_text) break if 'passage' in row_dict and 'label' in row_dict: try: label = int(float(row_dict['label'])) passage = str(row_dict['passage']).strip() if label == 0 and passage: negatives.append(passage) except (ValueError, TypeError): pass if negatives: item['neg'] = negatives # Extract scores if available score_fields = ['score', 'relevance', 'rating'] for field in score_fields: if field in row_dict: try: score = float(row_dict[field]) if 'pos' in item: item['pos_scores'] = [score] * len(item['pos']) except (ValueError, TypeError): pass # Only add if we have minimum required fields if 'query' in item and ('pos' in item or 'neg' in item): data.append(item) except Exception as e: logger.warning(f"Skipping malformed row in {file_path} at index {idx}: {e} | Row: {row}") except Exception as e: logger.error(f"Error loading CSV file {file_path}: {e}") raise if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning( f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}." ) return data def _load_tsv(self, file_path: str) -> List[Dict]: """Load TSV file (wrapper around _load_csv).""" return self._load_csv(file_path, delimiter="\t") def _load_xml(self, file_path: str) -> List[Dict]: """Load XML file (e.g., TREC format)""" data = [] total_lines = 0 try: tree = ET.parse(file_path) root = tree.getroot() # Handle different XML structures if root.tag == 'questions' or root.tag == 'queries': # TREC-style format for question in tqdm(root.findall('.//question') or root.findall('.//query'), desc="Loading XML"): item = {} # Extract query query_elem = question.find('text') or question.find('title') or question.find('query') if query_elem is not None and query_elem.text is not None: item['query'] = query_elem.text.strip() # Extract passages/documents passages = [] for doc in question.findall('.//document') or question.findall('.//passage'): if doc.text: passages.append(doc.text.strip()) if passages: item['pos'] = passages if 'query' in item: data.append(item) else: # Generic XML - try to extract text content for elem in tqdm(root.iter(), desc="Processing XML"): if elem.text and len(elem.text.strip()) > 20: # Minimum length check item = { 'query': elem.text.strip()[:100], # First part as query 'pos': [elem.text.strip()] # Full text as passage } data.append(item) except Exception as e: logger.error(f"Error loading XML file {file_path}: {e}") raise if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning( f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}." ) return data def _load_dureader(self, file_path: str) -> List[Dict]: """Load DuReader format data""" data = [] total_lines = 0 try: with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(tqdm(f, desc="Loading DuReader"), 1): if line.strip(): total_lines += 1 try: item = json.loads(line) # Convert DuReader format processed = { 'query': item.get('question', '').strip(), 'pos': [], 'neg': [] } # Extract positive passages from answers if 'answers' in item: for answer in item['answers']: if isinstance(answer, dict) and 'text' in answer: processed['pos'].append(answer['text'].strip()) elif isinstance(answer, str): processed['pos'].append(answer.strip()) # Extract documents as potential negatives if 'documents' in item: for doc in item['documents']: if isinstance(doc, dict): if 'paragraphs' in doc: for para in doc['paragraphs']: para_text = para.strip() if isinstance(para, str) else str(para).strip() if para_text and para_text not in processed['pos']: processed['neg'].append(para_text) elif 'content' in doc: content = doc['content'].strip() if content and content not in processed['pos']: processed['neg'].append(content) if processed['query'] and processed['pos']: data.append(processed) except json.JSONDecodeError as e: logger.warning(f"Skipping invalid JSON line in DuReader: {e}") except Exception as e: logger.error(f"Error loading DuReader file {file_path}: {e}") raise if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning( f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}." ) return data def _load_msmarco(self, file_path: str) -> List[Dict]: """Load MS MARCO format data""" data = [] total_lines = 0 try: with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(tqdm(f, desc="Loading MS MARCO"), 1): line = line.strip() if line: total_lines += 1 parts = line.split('\t') if len(parts) >= 2: item = { 'query': parts[0].strip(), 'pos': [parts[1].strip()], 'neg': [p.strip() for p in parts[2:] if p.strip()] } data.append(item) except Exception as e: logger.error(f"Error loading MS MARCO file {file_path}: {e}") raise if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning( f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}." ) return data def _load_beir(self, file_path: str) -> List[Dict]: """Load BEIR format data""" data = [] total_lines = 0 try: # BEIR format is typically JSONL with specific structure with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(tqdm(f, desc="Loading BEIR"), 1): if line.strip(): total_lines += 1 try: item = json.loads(line) # BEIR format usually has 'query', 'positive_passages', 'negative_passages' processed = {} if 'query' in item: processed['query'] = item['query'].strip() elif 'text' in item: processed['query'] = item['text'].strip() if 'positive_passages' in item: processed['pos'] = [p.strip() for p in item['positive_passages'] if p.strip()] elif 'pos' in item: processed['pos'] = [p.strip() for p in item['pos'] if p.strip()] if 'negative_passages' in item: processed['neg'] = [n.strip() for n in item['negative_passages'] if n.strip()] elif 'neg' in item: processed['neg'] = [n.strip() for n in item['neg'] if n.strip()] if 'query' in processed and 'pos' in processed: data.append(processed) except json.JSONDecodeError as e: logger.warning(f"Skipping invalid JSON line in BEIR: {e}") except Exception as e: logger.error(f"Error loading BEIR file {file_path}: {e}") raise if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines): logger.warning( f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}." ) return data def _validate_and_clean(self, data: List[Dict]) -> List[Dict]: """Validate and clean data""" cleaned = [] for item in tqdm(data, desc="Validating data"): try: # Ensure required fields if 'query' not in item or not item['query'].strip(): continue # Clean query if self.tokenizer: encoding = self.tokenizer(item['query'], add_special_tokens=False) # type: ignore[operator] input_ids = encoding['input_ids'] if len(input_ids) > self.max_query_length: input_ids = input_ids[:self.max_query_length] item['query'] = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined] else: if len(item['query']) > self.max_query_length: item['query'] = item['query'][: self.max_query_length] # Ensure at least one positive if 'pos' not in item or not item['pos']: continue # Clean positives item['pos'] = self._filter_passages_by_length(item['pos']) if not item['pos']: continue # Clean negatives if 'neg' in item: item['neg'] = self._filter_passages_by_length(item['neg']) else: item['neg'] = [] # Filter by length if tokenizer available if self.tokenizer: # Check query length # Check passage lengths item['pos'] = self._filter_passages_by_length(item['pos']) item['neg'] = self._filter_passages_by_length(item['neg']) # Ensure we still have data after filtering if item['pos']: cleaned.append(item) except Exception as e: logger.warning(f"Error processing item: {e}") continue logger.info(f"Cleaned {len(cleaned)} examples from {len(data)} original") return cleaned def _filter_passages_by_length(self, passages: List[str]) -> List[str]: """Filter passages by token length""" if not self.tokenizer: return passages filtered = [] for passage in passages: try: encoding = self.tokenizer(passage, add_special_tokens=False) # type: ignore[operator] input_ids = encoding['input_ids'] if len(input_ids) < self.min_passage_length: continue if len(input_ids) > self.max_passage_length: # Truncate to max length input_ids = input_ids[:self.max_passage_length] passage = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined] filtered.append(passage) except Exception as e: logger.warning(f"Error processing passage: {e}") continue return filtered def _split_data( self, data: List[Dict], validation_split: float ) -> Tuple[List[Dict], List[Dict]]: """Split data into train/validation sets""" # Shuffle data shuffled_data = data.copy() random.shuffle(shuffled_data) # Calculate split point split_idx = int(len(shuffled_data) * (1 - validation_split)) train_data = shuffled_data[:split_idx] val_data = shuffled_data[split_idx:] return train_data, val_data def _save_jsonl(self, data: List[Dict], output_path: str): """Save data in JSONL format""" os.makedirs(os.path.dirname(output_path), exist_ok=True) try: with open(output_path, 'w', encoding='utf-8') as f: for item in data: f.write(json.dumps(item, ensure_ascii=False) + '\n') except Exception as e: logger.error(f"Error saving JSONL file {output_path}: {e}") raise class DataAugmenter: """ Data augmentation strategies for retrieval training """ def __init__(self, tokenizer: Optional[AutoTokenizer] = None): self.tokenizer = tokenizer def augment_with_paraphrases( self, data: List[Dict], num_paraphrases: int = 2 ) -> List[Dict]: """ Augment queries with paraphrases using rule-based approach """ augmented = [] for item in tqdm(data, desc="Augmenting with paraphrases"): # Add original augmented.append(item) # Generate paraphrases for i in range(num_paraphrases): para_item = item.copy() para_query = self._generate_paraphrase(item['query'], i) if para_query != item['query']: para_item['query'] = para_query para_item['is_augmented'] = True para_item['augmentation_type'] = 'paraphrase' augmented.append(para_item) return augmented def _generate_paraphrase(self, query: str, variant: int) -> str: """Generate paraphrase using rule-based patterns""" if self._is_chinese(query): # Chinese paraphrase patterns patterns = [ ("什么是", "请解释一下"), ("如何", "怎样才能"), ("为什么", "什么原因导致"), ("在哪里", "什么地方"), ("是什么", "指的是什么"), ("有什么", "存在哪些"), ("怎么办", "如何处理"), ("多少", "数量是多少") ] # Apply patterns based on variant if variant == 0: for original, replacement in patterns: if original in query: return query.replace(original, replacement, 1) else: # Add question prefixes prefixes = ["请问", "我想知道", "能否告诉我", "想了解"] return f"{prefixes[variant % len(prefixes)]}{query}" else: # English paraphrase patterns patterns = [ ("what is", "what does"), ("how to", "how can I"), ("why does", "what causes"), ("where is", "what location"), ("when did", "at what time") ] query_lower = query.lower() for original, replacement in patterns: if original in query_lower: return query_lower.replace(original, replacement, 1) # Add question prefixes prefixes = ["Please explain", "Could you tell me", "I need to know", "Help me understand"] return f"{prefixes[variant % len(prefixes)]} {query.lower()}" return query def _is_chinese(self, text: str) -> bool: """Check if text contains Chinese characters""" return bool(re.search(r'[\u4e00-\u9fff]', text))