import random import torch import numpy as np from typing import List, Dict, Optional, Tuple, Union from tqdm import tqdm import logging from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import json import requests import time import re import os logger = logging.getLogger(__name__) def get_config(): """Always use centralized config loader unless running standalone (no utils.config_loader)""" try: from utils.config_loader import get_config as _get_config return _get_config() except ImportError: try: from bge_finetune.utils.config_loader import get_config as _get_config # type: ignore[import] return _get_config() except ImportError: try: from utils.config_loader import get_config as _get_config return _get_config() except ImportError: # Only fallback if truly standalone import toml class FallbackConfig: def __init__(self): self.config = {} def get(self, key, default=None): return default def get_section(self, section): return {} return FallbackConfig() class QueryAugmenter: """ Generate augmented queries using various techniques including API-based translation """ def __init__( self, augmentation_methods: Optional[List[str]] = None, device: Optional[str] = None, config_path: Optional[str] = None ): """ Initialize QueryAugmenter. Args: augmentation_methods: List of augmentation methods to use. device: Device to use for augmentation models. config_path: Optional path to config file. """ try: self.config = get_config() except Exception as e: logger.warning(f"Failed to load config: {e}, using minimal defaults") self.config = get_config() # Always use get_config() # Load configuration aug_config = self.config.get_section('augmentation') api_config = self.config.get_section('api') # Set augmentation methods with fallback if augmentation_methods is None: self.augmentation_methods = aug_config.get('augmentation_methods', ['paraphrase', 'back_translation']) else: self.augmentation_methods = augmentation_methods # Validate augmentation methods (extensible) valid_methods = set(['paraphrase', 'back_translation', 'synonym', 'template']) self.augmentation_methods = [m for m in self.augmentation_methods if m in valid_methods or callable(getattr(self, f'_augment_{m}', None))] if not self.augmentation_methods: logger.warning("No valid augmentation methods specified, using default") self.augmentation_methods = ['paraphrase'] # Set device with fallback if device is None: try: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' except Exception: self.device = 'cpu' else: self.device = device # Configuration parameters with safe defaults self.api_config = api_config self.augmentation_temperature = aug_config.get('augmentation_temperature', 0.8) self.num_augmentations_per_sample = aug_config.get('num_augmentations_per_sample', 2) self.enable_api = aug_config.get('enable_query_augmentation', False) self.api_first = aug_config.get('api_first', True) # API key check (env var override) api_key = os.environ.get('BGE_API_KEY') or self.api_config.get('api_key', None) if self.enable_api: if not api_key or api_key == 'your-api-key-here': raise ValueError("API-based augmentation is enabled, but a valid API key is missing in config.toml [api] section or BGE_API_KEY env var.") # Model cache self.models = {} self.paraphrase_model = None self.paraphrase_tokenizer = None # Load models for augmentation self._load_models() def _load_models(self): """Load required models for augmentation with comprehensive error handling""" logger.info(f"Loading augmentation models for methods: {self.augmentation_methods}") # Load paraphrase model if needed if 'paraphrase' in self.augmentation_methods: config = get_config() model_source = config.get('model_paths.source', 'huggingface') try: if model_source == 'huggingface': logger.info("Loading T5 model for paraphrasing from HuggingFace...") self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") self.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") elif model_source == 'modelscope': try: from modelscope.models import Model as MSModel # type: ignore[import] from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download # type: ignore[import] ms_snapshot_download(model_id="damo/nlp_t5_paraphrase_chinese-base", cache_dir=None) self.paraphrase_model = MSModel.from_pretrained("damo/nlp_t5_paraphrase_chinese-base") self.paraphrase_tokenizer = None # Add ModelScope tokenizer if needed logger.info("T5 model loaded from ModelScope.") except ImportError: logger.error("modelscope not available. Install it first: pip install modelscope") self.paraphrase_model = None self.paraphrase_tokenizer = None if self.enable_api is False: raise RuntimeError("ModelScope paraphrase model required but not available and API augmentation is not enabled.") return else: logger.error( f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.") self.paraphrase_model = None self.paraphrase_tokenizer = None if not self.enable_api: logger.warning("Paraphrase model unavailable and API augmentation disabled") return try: if self.paraphrase_model is not None: self.paraphrase_model.to(self.device) self.paraphrase_model.eval() logger.info(f"T5 model loaded successfully on {self.device}") except Exception as e: logger.warning(f"Failed to move T5 model to {self.device}: {e}, using CPU") self.device = 'cpu' if self.paraphrase_model is not None: self.paraphrase_model.to(self.device) self.paraphrase_model.eval() except Exception as e: logger.error(f"Failed to load T5 model: {e}.") self.paraphrase_model = None self.paraphrase_tokenizer = None if not self.enable_api: logger.warning("Paraphrase model failed to load and API augmentation disabled") # SOTA back-translation model loading (if available) if 'back_translation' in self.augmentation_methods: # Example: Use Helsinki-NLP/opus-mt-zh-en and en-zh for Chinese-English back-translation try: self.bt_zh2en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en") self.bt_zh2en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en") self.bt_en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh") self.bt_en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh") logger.info("Loaded SOTA back-translation models for zh<->en.") except Exception as e: logger.warning(f"Failed to load SOTA back-translation models: {e}") self.bt_zh2en_model = None self.bt_en2zh_model = None def augment_queries( self, data: List[Dict], num_augmentations: Optional[int] = None, keep_original: bool = True ) -> List[Dict]: """ Augment queries in the dataset with comprehensive error handling. Args: data: Original data num_augmentations: Number of augmentations per query keep_original: Whether to keep original queries Returns: Augmented dataset """ if not data: logger.warning("Empty data provided for augmentation") return [] if num_augmentations is None: num_augmentations = self.num_augmentations_per_sample if not isinstance(num_augmentations, int) or num_augmentations <= 0: logger.warning("Invalid num_augmentations, setting to 1") num_augmentations = 1 augmented_data = [] try: for item in tqdm(data, desc="Augmenting queries"): # Validate item if not isinstance(item, dict) or 'query' not in item: logger.warning(f"Invalid item format: {item}") continue # Keep original if keep_original: augmented_data.append(item) # Generate augmentations for i in range(num_augmentations): try: method = random.choice(self.augmentation_methods) aug_item = self._augment_single(item, method) if aug_item: augmented_data.append(aug_item) except Exception as e: logger.warning(f"Failed to augment item with {method}: {e}") continue except Exception as e: logger.error(f"Query augmentation failed: {e}") # Return original data on complete failure return data logger.info(f"Augmented {len(data)} items to {len(augmented_data)} items") return augmented_data def _augment_single(self, item: Dict, method: str) -> Optional[Dict]: """Augment a single item using specified method with error handling.""" if not item.get('query'): return None try: aug_item = item.copy() original_query = item['query'] if method == 'paraphrase': aug_query = self._paraphrase_query(original_query) elif method == 'back_translation': aug_query = self._back_translate_query(original_query) elif method == 'synonym': aug_query = self._synonym_replacement(original_query) elif method == 'template': aug_query = self._template_based_augmentation(original_query) else: logger.warning(f"Unknown augmentation method: {method}") return None if aug_query and aug_query != original_query: aug_item['query'] = aug_query aug_item['augmentation_method'] = method aug_item['is_augmented'] = True return aug_item except Exception as e: logger.warning(f"Error in {method} augmentation: {e}") return None def _get_prompt(self, task: str, query: str, lang: str) -> str: """Get prompt for LLM-based augmentation.""" if lang == 'zh': if task == 'paraphrase': return f"请将下列文本改写为意思相同的句子,只输出改写后的内容,不要解释:\n\n{query}" elif task == 'synonym': return f"请将下列文本中的词语替换为合适的同义词,只输出结果,不要解释:\n\n{query}" elif task == 'template': return f"请用不同的提问模板改写下列问题,只输出结果,不要解释:\n\n{query}" else: if task == 'paraphrase': return f"Paraphrase the following text. Only provide the paraphrase, no explanations:\n\n{query}" elif task == 'synonym': return f"Replace words in the following text with appropriate synonyms. Only provide the result, no explanations:\n\n{query}" elif task == 'template': return f"Rephrase the following text using a different question template. Only provide the result, no explanations:\n\n{query}" return query def _paraphrase_query(self, query: str) -> str: """Paraphrase a query using the loaded model or API.""" if not query or not query.strip(): return query lang = 'zh' if self._is_chinese(query) else 'en' # Try LLM API first if enabled and configured if self.enable_api and self.api_first: try: prompt = self._get_prompt('paraphrase', query, lang) result = self._call_llm_api(prompt) if result and result != query: return result except Exception as e: logger.warning(f"LLM API paraphrasing failed: {e}. Trying local model.") # Try T5-based paraphrasing next (only for English) if lang == 'en' and self.paraphrase_model is not None and self.paraphrase_tokenizer is not None: try: input_text = f"paraphrase: {query}" inputs = self.paraphrase_tokenizer( input_text, return_tensors="pt", max_length=512, truncation=True ).to(self.device) with torch.no_grad(): outputs = self.paraphrase_model.generate( **inputs, max_length=150, num_return_sequences=1, temperature=self.augmentation_temperature, do_sample=True, pad_token_id=self.paraphrase_tokenizer.eos_token_id ) paraphrase = self.paraphrase_tokenizer.decode(outputs[0], skip_special_tokens=True) if paraphrase and paraphrase != query: return paraphrase except Exception as e: logger.warning(f"T5 paraphrasing failed: {e}. Using rule-based approach.") # Fallback to rule-based paraphrasing return self._rule_based_paraphrase(query, lang) def _rule_based_paraphrase(self, query: str, lang: str = 'en') -> str: """Rule-based paraphrasing for fallback.""" if not query: return query try: if lang == 'zh': patterns = [ ("什么是", "请解释"), ("如何", "怎样"), ("为什么", "什么原因"), ("在哪里", "什么地方"), ("是什么", "指的是什么"), ("有什么", "存在哪些"), ("能否", "可以"), ("应该", "需要"), ("必须", "一定要"), ("关于", "有关") ] paraphrased = query for original, replacement in patterns: if original in query: paraphrased = query.replace(original, replacement, 1) break if paraphrased == query: prefixes = ["请问", "我想知道", "能否告诉我", "请教一下", "想了解"] paraphrased = f"{random.choice(prefixes)}{query}" return paraphrased else: templates = [ "What is meant by {}?", "Can you explain {}?", "Tell me about {}", "I want to know about {}", "Please describe {}", "Could you clarify {}?", "What does {} refer to?", "How would you define {}?" ] core = query.strip("?.,!").lower() question_starters = ["what", "how", "why", "when", "where", "who", "which"] for starter in question_starters: if core.startswith(starter): core = core[len(starter):].strip() for aux in ["is", "are", "was", "were", "does", "do", "did"]: if core.startswith(aux): core = core[len(aux):].strip() break break if core: return random.choice(templates).format(core) else: return query except Exception as e: logger.warning(f"Rule-based paraphrasing failed: {e}") return query def _back_translate_query(self, query: str) -> str: """Back-translate a query using SOTA models if available, else fallback to API or simulation. Always returns a str.""" if hasattr(self, 'bt_zh2en_model') and self.bt_zh2en_model is not None and self.bt_en2zh_model is not None: try: is_chinese = any('\u4e00' <= c <= '\u9fff' for c in query) if is_chinese: inputs = self.bt_zh2en_tokenizer(query, return_tensors="pt", truncation=True, max_length=64).to(self.device) en = self.bt_zh2en_model.generate(**inputs, max_length=64) en_text = self.bt_zh2en_tokenizer.decode(en[0], skip_special_tokens=True) inputs = self.bt_en2zh_tokenizer(en_text, return_tensors="pt", truncation=True, max_length=64).to(self.device) zh = self.bt_en2zh_model.generate(**inputs, max_length=64) zh_text = self.bt_en2zh_tokenizer.decode(zh[0], skip_special_tokens=True) return zh_text if zh_text else query else: inputs = self.bt_en2zh_tokenizer(query, return_tensors="pt", truncation=True, max_length=64).to(self.device) zh = self.bt_en2zh_model.generate(**inputs, max_length=64) zh_text = self.bt_en2zh_tokenizer.decode(zh[0], skip_special_tokens=True) inputs = self.bt_zh2en_tokenizer(zh_text, return_tensors="pt", truncation=True, max_length=64).to(self.device) en = self.bt_zh2en_model.generate(**inputs, max_length=64) en_text = self.bt_zh2en_tokenizer.decode(en[0], skip_special_tokens=True) return en_text if en_text else query except Exception as e: logger.warning(f"SOTA back-translation failed: {e}") if self.enable_api: try: is_chinese = any('\u4e00' <= c <= '\u9fff' for c in query) result = self._translate_with_api(query, 'zh' if is_chinese else 'en', 'en' if is_chinese else 'zh') if result: return result except Exception as e: logger.warning(f"API back-translation failed: {e}") fallback = self._simulate_back_translation(query) return fallback if fallback is not None else query def _translate_with_api(self, text: str, source_lang: str, target_lang: str) -> Optional[str]: """Translate text using API.""" if not text or not text.strip(): return None base_url = self.api_config.get('base_url', '') api_key = self.api_config.get('api_key', '') model = self.api_config.get('model', 'deepseek-chat') timeout = self.api_config.get('timeout', 30) max_retries = self.api_config.get('max_retries', 3) if not base_url or not api_key or api_key == 'your-api-key-here': return None headers = { 'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json' } # Create translation prompt lang_names = {'zh': 'Chinese', 'en': 'English', 'ja': 'Japanese', 'fr': 'French'} source_name = lang_names.get(source_lang, source_lang) target_name = lang_names.get(target_lang, target_lang) prompt = f"Translate the following {source_name} text to {target_name}. Only provide the translation, no explanations:\n\n{text}" data = { 'model': model, 'messages': [ {'role': 'user', 'content': prompt} ], 'temperature': self.augmentation_temperature, 'max_tokens': 500 } for attempt in range(max_retries): try: response = requests.post( f"{base_url}/chat/completions", headers=headers, json=data, timeout=timeout ) if response.status_code == 200: result = response.json() if 'choices' in result and result['choices']: translated = result['choices'][0]['message']['content'].strip() return translated else: logger.warning(f"Invalid response format: {result}") else: logger.warning(f"Translation API error: {response.status_code} - {response.text}") if attempt < max_retries - 1: time.sleep(2 ** attempt) # Exponential backoff except requests.exceptions.Timeout: logger.warning(f"Translation API timeout on attempt {attempt + 1}") if attempt < max_retries - 1: time.sleep(2 ** attempt) except requests.exceptions.RequestException as e: logger.warning(f"Translation API request failed: {e}") if attempt < max_retries - 1: time.sleep(2 ** attempt) except Exception as e: logger.error(f"Unexpected error in translation API: {e}") break return None def _simulate_back_translation(self, query: str) -> str: """Simulate back-translation as a fallback.""" if not query: return query try: # Add language-specific variations that simulate translation effects if self._is_chinese(query): # Simulate en->zh->en effects substitutions = [ ("什么", "哪些"), ("如何", "怎么样"), ("为什么", "什么原因"), ("非常", "很"), ("特别", "特殊"), ("重要", "关键"), ("问题", "事情"), ("方法", "方式"), ("内容", "信息") ] else: # Simulate zh->en->zh effects substitutions = [ ("what", "which"), ("how", "in what way"), ("very", "extremely"), ("important", "crucial"), ("method", "approach"), ("content", "information"), ("question", "issue"), ("problem", "matter"), ("find", "locate") ] result = query for original, replacement in substitutions: if original in result.lower(): # Case-preserving replacement if original.isupper(): replacement = replacement.upper() elif original.istitle(): replacement = replacement.title() result = result.replace(original, replacement, 1) break return result except Exception as e: logger.warning(f"Back-translation simulation failed: {e}") return query def _synonym_replacement(self, query: str) -> str: """Replace words in query with synonyms.""" if not query or not query.strip(): return query lang = 'zh' if self._is_chinese(query) else 'en' # Try LLM API first if enabled and configured if self.enable_api and self.api_first: try: prompt = self._get_prompt('synonym', query, lang) result = self._call_llm_api(prompt) if result and result != query: return result except Exception as e: logger.warning(f"LLM API synonym replacement failed: {e}. Using rule-based approach.") # Fallback to hard-coded synonym replacement try: if lang == 'zh': synonyms = { "重要": ["关键", "核心", "主要"], "方法": ["办法", "手段", "途径"], "问题": ["疑问", "难题", "困惑"], "内容": ["材料", "信息", "资料"], "增加": ["提升", "增长", "扩展"], "发现": ["查找", "寻找", "探索"] } words = list(query) replaced = False for i, word in enumerate(words): if word in synonyms: words[i] = random.choice(synonyms[word]) replaced = True if replaced: return ''.join(words) return query else: synonyms = { "important": ["crucial", "vital", "significant"], "method": ["approach", "technique", "procedure"], "find": ["locate", "discover", "identify"], "problem": ["issue", "challenge", "difficulty"], "increase": ["rise", "growth", "expansion"], "content": ["material", "substance", "information"] } words = query.split() replaced = False for i, word in enumerate(words): key = word.lower().strip('.,?!') if key in synonyms: words[i] = random.choice(synonyms[key]) replaced = True if replaced: return ' '.join(words) return query except Exception as e: logger.warning(f"Rule-based synonym replacement failed: {e}") return query def _template_based_augmentation(self, query: str) -> str: """Template-based augmentation for queries.""" if not query or not query.strip(): return query lang = 'zh' if self._is_chinese(query) else 'en' # Try LLM API first if enabled and configured if self.enable_api and self.api_first: try: prompt = self._get_prompt('template', query, lang) result = self._call_llm_api(prompt) if result and result != query: return result except Exception as e: logger.warning(f"LLM API template-based augmentation failed: {e}. Using rule-based approach.") # Fallback to hard-coded templates try: if lang == 'zh': templates = [ "请找到关于{}的段落", "搜索有关{}的信息", "寻找讨论{}的内容", "我需要了解{}的详细信息", "显示与{}相关的文本" ] core = query.strip("?.,!。:") for prefix in ["什么是", "如何", "为什么", "在哪里", "谁是"]: if core.startswith(prefix): core = core[len(prefix):].strip() break if core: return random.choice(templates).format(core) else: return query else: templates = [ "What is meant by {}?", "Can you explain {}?", "Tell me about {}", "I want to know about {}", "Please describe {}", "Could you clarify {}?", "What does {} refer to?", "How would you define {}?" ] core = query.strip("?.,!").lower() if core: return random.choice(templates).format(core) else: return query except Exception as e: logger.warning(f"Rule-based template augmentation failed: {e}") return query def _call_llm_api(self, prompt: str) -> Optional[str]: """Call LLM API for augmentation.""" base_url = self.api_config.get('base_url', '') api_key = self.api_config.get('api_key', '') model = self.api_config.get('model', 'deepseek-chat') timeout = self.api_config.get('timeout', 30) max_retries = self.api_config.get('max_retries', 3) temperature = self.api_config.get('temperature', self.augmentation_temperature) if not base_url or not api_key or api_key == 'your-api-key-here': return None headers = { 'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json' } data = { 'model': model, 'messages': [ {'role': 'user', 'content': prompt} ], 'temperature': temperature, 'max_tokens': 500 } for attempt in range(max_retries): try: response = requests.post( f"{base_url}/chat/completions", headers=headers, json=data, timeout=timeout ) if response.status_code == 200: result = response.json() if 'choices' in result and result['choices']: content = result['choices'][0]['message']['content'].strip() return content else: logger.warning(f"Invalid response format: {result}") else: logger.warning(f"LLM API error: {response.status_code} - {response.text}") if attempt < max_retries - 1: time.sleep(2 ** attempt) except requests.exceptions.Timeout: logger.warning(f"LLM API timeout on attempt {attempt + 1}") if attempt < max_retries - 1: time.sleep(2 ** attempt) except requests.exceptions.RequestException as e: logger.warning(f"LLM API request failed: {e}") if attempt < max_retries - 1: time.sleep(2 ** attempt) except Exception as e: logger.error(f"Unexpected error in LLM API: {e}") break return None def _is_chinese(self, text: str) -> bool: """Detect if text is Chinese.""" if not text: return False try: return bool(re.search(r'[\u4e00-\u9fff]', text)) except Exception: return False class PassageAugmenter: """ Augment passages to create training variations """ def __init__(self): try: self.config = get_config() except Exception as e: logger.warning(f"Failed to load config: {e}, using defaults") self.config = get_config() # Always use get_config() aug_config = self.config.get_section('augmentation') self.enable_passage_augmentation = aug_config.get('enable_passage_augmentation', False) self.augmentation_strategies = [ 'sentence_shuffle', 'span_masking', 'paraphrase_sentences', 'add_noise' ] def augment_passages( self, data: List[Dict], augmentation_prob: Optional[float] = None ) -> List[Dict]: """ Augment passages in the dataset with comprehensive error handling Args: data: Original data augmentation_prob: Probability of augmenting each item Returns: Augmented dataset """ if not self.enable_passage_augmentation: logger.info("Passage augmentation disabled in config") return data if not data: logger.warning("Empty data provided for passage augmentation") return [] if augmentation_prob is None: augmentation_prob = 0.3 if not (0 <= augmentation_prob <= 1): logger.warning("Invalid augmentation_prob, setting to 0.3") augmentation_prob = 0.3 augmented_data = [] try: for item in tqdm(data, desc="Augmenting passages"): # Validate item if not isinstance(item, dict): logger.warning(f"Invalid item format: {item}") continue augmented_data.append(item) if random.random() < augmentation_prob: # Augment positive passages if 'pos' in item and item['pos']: try: aug_item = item.copy() aug_item['pos'] = [ self._augment_passage(p, random.choice(self.augmentation_strategies)) for p in item['pos'] if p and p.strip() # Skip empty passages ] if aug_item['pos']: # Only add if we have valid augmented passages aug_item['is_augmented'] = True aug_item['augmentation_type'] = 'passage' augmented_data.append(aug_item) except Exception as e: logger.warning(f"Failed to augment passage for item: {e}") continue except Exception as e: logger.error(f"Passage augmentation failed: {e}") return data return augmented_data def _augment_passage(self, passage: str, strategy: str) -> str: """Augment a single passage with error handling""" if not passage or not passage.strip(): return passage try: if strategy == 'sentence_shuffle': return self._shuffle_sentences(passage) elif strategy == 'span_masking': return self._mask_spans(passage) elif strategy == 'paraphrase_sentences': return self._paraphrase_sentences(passage) elif strategy == 'add_noise': return self._add_noise(passage) else: logger.warning(f"Unknown augmentation strategy: {strategy}") return passage except Exception as e: logger.warning(f"Error in passage augmentation with {strategy}: {e}") return passage def _shuffle_sentences(self, passage: str) -> str: """Shuffle sentences in passage""" if not passage or len(passage.strip()) < 10: return passage try: # Split by Chinese or English sentence endings if re.search(r'[\u4e00-\u9fff]', passage): sentences = re.split(r'[。!?]', passage) delimiter = '。' else: sentences = re.split(r'[.!?]', passage) delimiter = '. ' sentences = [s.strip() for s in sentences if s.strip()] if len(sentences) > 2: # Keep first and last, shuffle middle middle = sentences[1:-1] random.shuffle(middle) sentences = [sentences[0]] + middle + [sentences[-1]] # Rejoin with appropriate punctuation if delimiter == '。': return delimiter.join(sentences) + '。' else: return delimiter.join(sentences) + '.' return passage except Exception as e: logger.warning(f"Sentence shuffling failed: {e}") return passage def _mask_spans(self, passage: str, mask_ratio: float = 0.15) -> str: """Mask random spans in passage""" if not passage or len(passage.strip()) < 20: return passage try: # Split into words/characters if re.search(r'[\u4e00-\u9fff]', passage): # Chinese: split by characters for finer granularity units = list(passage) join_char = '' else: # English: split by words units = passage.split() join_char = ' ' num_masks = max(1, int(len(units) * mask_ratio)) for _ in range(num_masks): if len(units) > 5: # Ensure minimum length start = random.randint(0, len(units) - 3) length = random.randint(1, min(3, len(units) - start)) for i in range(start, start + length): if i < len(units): units[i] = '[MASK]' # Rejoin return join_char.join(units) except Exception as e: logger.warning(f"Span masking failed: {e}") return passage def _paraphrase_sentences(self, passage: str) -> str: """Paraphrase some sentences""" if not passage or len(passage.strip()) < 20: return passage try: # Split sentences if re.search(r'[\u4e00-\u9fff]', passage): sentences = re.split(r'[。!?]', passage) delimiter = '。' else: sentences = re.split(r'[.!?]', passage) delimiter = '. ' sentences = [s.strip() for s in sentences if s.strip()] if sentences: # Randomly modify one sentence idx = random.randint(0, len(sentences) - 1) original = sentences[idx] # Simple paraphrasing by adding modifiers if re.search(r'[\u4e00-\u9fff]', original): modifiers = ["显然", "据了解", "实际上", "通常情况下", "根据研究"] sentences[idx] = f"{random.choice(modifiers)},{original}" else: modifiers = ["Generally", "According to studies", "Typically", "In fact", "Usually"] sentences[idx] = f"{random.choice(modifiers)}, {original.lower()}" result = delimiter.join(sentences) if delimiter == '。': result += '。' else: result += '.' return result except Exception as e: logger.warning(f"Sentence paraphrasing failed: {e}") return passage def _add_noise(self, passage: str, noise_level: float = 0.02) -> str: """Add character-level noise""" if not passage or len(passage.strip()) < 10: return passage try: chars = list(passage) num_noisy = max(1, int(len(chars) * noise_level)) if len(chars) > 10: # Only add noise to sufficiently long passages noise_positions = random.sample(range(len(chars)), min(num_noisy, len(chars) // 5)) for pos in noise_positions: if pos < len(chars): noise_type = random.choice(['duplicate', 'replace']) if noise_type == 'duplicate': chars[pos] = chars[pos] * 2 elif noise_type == 'replace': # Replace with similar character if re.search(r'[\u4e00-\u9fff]', chars[pos]): # Chinese characters similar_chars = ['的', '了', '在', '是', '我', '有', '他', '这', '中', '大'] chars[pos] = random.choice(similar_chars) else: # English characters if chars[pos].isalpha(): chars[pos] = random.choice('abcdefghijklmnopqrstuvwxyz') return ''.join(chars) except Exception as e: logger.warning(f"Noise addition failed: {e}") return passage class HardNegativeAugmenter: """ Create hard negatives through various strategies """ def __init__(self): try: self.config = get_config() except Exception as e: logger.warning(f"Failed to load config: {e}, using defaults") self.config = get_config() # Always use get_config() mining_config = self.config.get_section('hard_negative_mining') self.enable_mining = mining_config.get('enable_mining', True) self.num_hard_negatives = mining_config.get('num_hard_negatives', 15) self.strategies = [ 'entity_swap', 'number_change', 'negation_addition', 'similar_context' ] def create_hard_negatives( self, data: List[Dict], num_hard_negatives: Optional[int] = None ) -> List[Dict]: """ Create hard negatives for training data with comprehensive error handling Args: data: Original data num_hard_negatives: Number of hard negatives to create Returns: Data with hard negatives """ if not self.enable_mining: logger.info("Hard negative mining disabled in config") return data if not data: logger.warning("Empty data provided for hard negative creation") return [] if num_hard_negatives is None: num_hard_negatives = self.num_hard_negatives if not isinstance(num_hard_negatives, int) or num_hard_negatives <= 0: logger.warning("Invalid num_hard_negatives, setting to 5") num_hard_negatives = 5 augmented_data = [] try: for item in tqdm(data, desc="Creating hard negatives"): # Validate item if not isinstance(item, dict): logger.warning(f"Invalid item format: {item}") continue aug_item = item.copy() if 'pos' in item and item['pos']: hard_negs = [] try: for pos_passage in item['pos'][:2]: # Use first 2 positives if not pos_passage or not pos_passage.strip(): continue for i in range(min(num_hard_negatives, len(self.strategies) * 2)): strategy = self.strategies[i % len(self.strategies)] hard_neg = self._create_hard_negative(pos_passage, strategy) if hard_neg and hard_neg != pos_passage and hard_neg.strip(): hard_negs.append(hard_neg) if hard_negs: if 'hard_neg' not in aug_item: aug_item['hard_neg'] = [] # Remove duplicates and limit unique_negs = list(dict.fromkeys(hard_negs)) aug_item['hard_neg'].extend(unique_negs[:num_hard_negatives]) except Exception as e: logger.warning(f"Failed to create hard negatives for item: {e}") augmented_data.append(aug_item) except Exception as e: logger.error(f"Hard negative creation failed: {e}") return data return augmented_data def _create_hard_negative(self, passage: str, strategy: str) -> str: """Create a hard negative using specified strategy with error handling""" if not passage or not passage.strip(): return passage try: if strategy == 'entity_swap': return self._swap_entities(passage) elif strategy == 'number_change': return self._change_numbers(passage) elif strategy == 'negation_addition': return self._add_negation(passage) elif strategy == 'similar_context': return self._create_similar_context(passage) else: logger.warning(f"Unknown hard negative strategy: {strategy}") return passage except Exception as e: logger.warning(f"Error creating hard negative with {strategy}: {e}") return passage def _swap_entities(self, passage: str) -> str: """Swap entities to create factually incorrect passage""" if not passage: return passage try: if re.search(r'[\u4e00-\u9fff]', passage): entity_pairs = [ ("北京", "上海"), ("中国", "日本"), ("2023年", "2024年"), ("第一", "第二"), ("增加", "减少"), ("成功", "失败"), ("重要", "次要"), ("大型", "小型"), ("新的", "旧的") ] else: entity_pairs = [ ("Beijing", "Shanghai"), ("China", "Japan"), ("2023", "2024"), ("first", "second"), ("increase", "decrease"), ("success", "failure"), ("important", "minor"), ("large", "small"), ("new", "old") ] aug_passage = passage for ent1, ent2 in entity_pairs: if ent1 in passage and ent2 not in passage: aug_passage = aug_passage.replace(ent1, ent2, 1) break return aug_passage except Exception as e: logger.warning(f"Entity swapping failed: {e}") return passage def _change_numbers(self, passage: str) -> str: """Change numbers in passage""" if not passage: return passage try: def replace_number(match): try: num = int(match.group()) # Change by 20-80% change = random.uniform(0.2, 0.8) if random.random() > 0.5: new_num = int(num * (1 + change)) else: new_num = int(num * (1 - change)) return str(max(1, new_num)) # Ensure positive except: return match.group() # Find and replace numbers aug_passage = re.sub(r'\b\d+\b', replace_number, passage) return aug_passage except Exception as e: logger.warning(f"Number changing failed: {e}") return passage def _add_negation(self, passage: str) -> str: """Add negation to change meaning""" if not passage: return passage try: if re.search(r'[\u4e00-\u9fff]', passage): negation_patterns = [ ("是", "不是"), ("有", "没有"), ("会", "不会"), ("能", "不能"), ("可以", "不可以"), ("应该", "不应该"), ("成功", "失败"), ("正确", "错误"), ("增长", "下降") ] else: negation_patterns = [ ("is", "is not"), ("was", "was not"), ("can", "cannot"), ("will", "will not"), ("should", "should not"), ("has", "has no"), ("successful", "unsuccessful"), ("correct", "incorrect"), ("increase", "decrease") ] aug_passage = passage for positive, negative in negation_patterns: if positive in passage and negative not in passage: aug_passage = aug_passage.replace(positive, negative, 1) break return aug_passage except Exception as e: logger.warning(f"Negation addition failed: {e}") return passage def _create_similar_context(self, passage: str) -> str: """Create passage with similar context but different meaning""" if not passage: return passage try: # Split into sentences if re.search(r'[\u4e00-\u9fff]', passage): sentences = re.split(r'[。!?]', passage) delimiter = '。' generic_statements = [ "这一点在许多情况下都有所体现", "相关研究表明了不同的结果", "实际情况可能有所不同", "这个观点存在一定的争议", "具体细节仍在讨论之中" ] else: sentences = re.split(r'[.!?]', passage) delimiter = '. ' generic_statements = [ "This point is reflected in many cases", "Related research shows different results", "The actual situation may vary", "This viewpoint is somewhat controversial", "Specific details are still under discussion" ] sentences = [s.strip() for s in sentences if s.strip()] if len(sentences) >= 3: # Replace middle sentence with generic statement middle_idx = len(sentences) // 2 sentences[middle_idx] = random.choice(generic_statements) result = delimiter.join(sentences) if delimiter == '。': result += '。' else: result += '.' return result except Exception as e: logger.warning(f"Similar context creation failed: {e}") return passage class CrossLingualAugmenter: """ Create cross-lingual training data with enhanced error handling """ def __init__(self, source_lang: str = 'zh', target_langs: Optional[List[str]] = None): try: self.config = get_config() except Exception as e: logger.warning(f"Failed to load config: {e}, using defaults") self.config = get_config() # Always use get_config() self.source_lang = source_lang self.target_langs = target_langs or ['en'] # Get API configuration self.api_config = self.config.get_section('api') self.enable_api = bool(self.api_config.get('api_key') and self.api_config.get('api_key') != 'your-api-key-here') def create_cross_lingual_pairs( self, data: List[Dict], translation_prob: float = 0.3 ) -> List[Dict]: """ Create cross-lingual query-passage pairs with comprehensive error handling Args: data: Original data translation_prob: Probability of creating translation Returns: Augmented data with cross-lingual pairs """ if not data: logger.warning("Empty data provided for cross-lingual augmentation") return [] if not (0 <= translation_prob <= 1): logger.warning("Invalid translation_prob, setting to 0.3") translation_prob = 0.3 augmented_data = [] try: for item in tqdm(data, desc="Creating cross-lingual pairs"): # Validate item if not isinstance(item, dict) or 'query' not in item: logger.warning(f"Invalid item format: {item}") continue augmented_data.append(item) if random.random() < translation_prob: # Create translated query version for target_lang in self.target_langs: try: trans_item = item.copy() trans_query = self._translate_query( item['query'], self.source_lang, target_lang ) if trans_query and trans_query != item['query']: trans_item['query'] = trans_query trans_item['cross_lingual'] = True trans_item['query_lang'] = target_lang trans_item['passage_lang'] = self.source_lang augmented_data.append(trans_item) except Exception as e: logger.warning(f"Failed to create cross-lingual pair for {target_lang}: {e}") continue except Exception as e: logger.error(f"Cross-lingual augmentation failed: {e}") return data return augmented_data def _translate_query(self, query: str, source: str, target: str) -> str: """Translate query from source to target language with error handling""" if not query or not query.strip(): return query # Try API translation first if enabled if self.enable_api: try: # Use the existing API translation method augmenter = QueryAugmenter() result = augmenter._translate_with_api(query, source, target) if result: return result except Exception as e: logger.warning(f"API translation failed: {e}") # Fallback to rule-based simulation return self._simulate_translation(query, source, target) def _simulate_translation(self, query: str, source: str, target: str) -> str: """Simulate translation with pattern-based replacements""" if not query: return query try: if source == 'zh' and target == 'en': # Chinese to English patterns replacements = { "什么": "what", "如何": "how", "为什么": "why", "哪里": "where", "谁": "who", "书": "book", "章节": "chapter", "信息": "information", "内容": "content", "文本": "text", "段落": "passage" } translated = query for zh, en in replacements.items(): translated = translated.replace(zh, en) return f"[{source}→{target}] {translated}" elif source == 'en' and target == 'zh': # English to Chinese patterns replacements = { "what": "什么", "how": "如何", "why": "为什么", "where": "哪里", "who": "谁", "book": "书", "chapter": "章节", "information": "信息", "content": "内容", "text": "文本", "passage": "段落" } translated = query.lower() for en, zh in replacements.items(): translated = translated.replace(en, zh) return f"[{source}→{target}] {translated}" # For other language pairs, just add marker return f"[{source}→{target}] {query}" except Exception as e: logger.warning(f"Translation simulation failed: {e}") return query