1308 lines
55 KiB
Python
1308 lines
55 KiB
Python
import random
|
||
import torch
|
||
import numpy as np
|
||
from typing import List, Dict, Optional, Tuple, Union
|
||
from tqdm import tqdm
|
||
import logging
|
||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||
import json
|
||
import requests
|
||
import time
|
||
import re
|
||
import os
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def get_config():
|
||
"""Always use centralized config loader unless running standalone (no utils.config_loader)"""
|
||
try:
|
||
from utils.config_loader import get_config as _get_config
|
||
return _get_config()
|
||
except ImportError:
|
||
try:
|
||
from bge_finetune.utils.config_loader import get_config as _get_config # type: ignore[import]
|
||
return _get_config()
|
||
except ImportError:
|
||
try:
|
||
from utils.config_loader import get_config as _get_config
|
||
return _get_config()
|
||
except ImportError:
|
||
# Only fallback if truly standalone
|
||
import toml
|
||
class FallbackConfig:
|
||
def __init__(self):
|
||
self.config = {}
|
||
def get(self, key, default=None):
|
||
return default
|
||
def get_section(self, section):
|
||
return {}
|
||
return FallbackConfig()
|
||
|
||
|
||
class QueryAugmenter:
|
||
"""
|
||
Generate augmented queries using various techniques including API-based translation
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
augmentation_methods: Optional[List[str]] = None,
|
||
device: Optional[str] = None,
|
||
config_path: Optional[str] = None
|
||
):
|
||
"""
|
||
Initialize QueryAugmenter.
|
||
Args:
|
||
augmentation_methods: List of augmentation methods to use.
|
||
device: Device to use for augmentation models.
|
||
config_path: Optional path to config file.
|
||
"""
|
||
try:
|
||
self.config = get_config()
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load config: {e}, using minimal defaults")
|
||
self.config = get_config() # Always use get_config()
|
||
# Load configuration
|
||
aug_config = self.config.get_section('augmentation')
|
||
api_config = self.config.get_section('api')
|
||
# Set augmentation methods with fallback
|
||
if augmentation_methods is None:
|
||
self.augmentation_methods = aug_config.get('augmentation_methods', ['paraphrase', 'back_translation'])
|
||
else:
|
||
self.augmentation_methods = augmentation_methods
|
||
# Validate augmentation methods (extensible)
|
||
valid_methods = set(['paraphrase', 'back_translation', 'synonym', 'template'])
|
||
self.augmentation_methods = [m for m in self.augmentation_methods if m in valid_methods or callable(getattr(self, f'_augment_{m}', None))]
|
||
if not self.augmentation_methods:
|
||
logger.warning("No valid augmentation methods specified, using default")
|
||
self.augmentation_methods = ['paraphrase']
|
||
# Set device with fallback
|
||
if device is None:
|
||
try:
|
||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
except Exception:
|
||
self.device = 'cpu'
|
||
else:
|
||
self.device = device
|
||
# Configuration parameters with safe defaults
|
||
self.api_config = api_config
|
||
self.augmentation_temperature = aug_config.get('augmentation_temperature', 0.8)
|
||
self.num_augmentations_per_sample = aug_config.get('num_augmentations_per_sample', 2)
|
||
self.enable_api = aug_config.get('enable_query_augmentation', False)
|
||
self.api_first = aug_config.get('api_first', True)
|
||
# API key check (env var override)
|
||
api_key = os.environ.get('BGE_API_KEY') or self.api_config.get('api_key', None)
|
||
if self.enable_api:
|
||
if not api_key or api_key == 'your-api-key-here':
|
||
raise ValueError("API-based augmentation is enabled, but a valid API key is missing in config.toml [api] section or BGE_API_KEY env var.")
|
||
# Model cache
|
||
self.models = {}
|
||
self.paraphrase_model = None
|
||
self.paraphrase_tokenizer = None
|
||
# Load models for augmentation
|
||
self._load_models()
|
||
|
||
def _load_models(self):
|
||
"""Load required models for augmentation with comprehensive error handling"""
|
||
logger.info(f"Loading augmentation models for methods: {self.augmentation_methods}")
|
||
# Load paraphrase model if needed
|
||
if 'paraphrase' in self.augmentation_methods:
|
||
config = get_config()
|
||
model_source = config.get('model_paths.source', 'huggingface')
|
||
try:
|
||
if model_source == 'huggingface':
|
||
logger.info("Loading T5 model for paraphrasing from HuggingFace...")
|
||
self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
||
self.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
|
||
elif model_source == 'modelscope':
|
||
try:
|
||
from modelscope.models import Model as MSModel # type: ignore[import]
|
||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download # type: ignore[import]
|
||
ms_snapshot_download(model_id="damo/nlp_t5_paraphrase_chinese-base", cache_dir=None)
|
||
self.paraphrase_model = MSModel.from_pretrained("damo/nlp_t5_paraphrase_chinese-base")
|
||
self.paraphrase_tokenizer = None # Add ModelScope tokenizer if needed
|
||
logger.info("T5 model loaded from ModelScope.")
|
||
except ImportError:
|
||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||
self.paraphrase_model = None
|
||
self.paraphrase_tokenizer = None
|
||
if self.enable_api is False:
|
||
raise RuntimeError("ModelScope paraphrase model required but not available and API augmentation is not enabled.")
|
||
return
|
||
else:
|
||
logger.error(
|
||
f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||
self.paraphrase_model = None
|
||
self.paraphrase_tokenizer = None
|
||
if not self.enable_api:
|
||
logger.warning("Paraphrase model unavailable and API augmentation disabled")
|
||
return
|
||
try:
|
||
if self.paraphrase_model is not None:
|
||
self.paraphrase_model.to(self.device)
|
||
self.paraphrase_model.eval()
|
||
logger.info(f"T5 model loaded successfully on {self.device}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to move T5 model to {self.device}: {e}, using CPU")
|
||
self.device = 'cpu'
|
||
if self.paraphrase_model is not None:
|
||
self.paraphrase_model.to(self.device)
|
||
self.paraphrase_model.eval()
|
||
except Exception as e:
|
||
logger.error(f"Failed to load T5 model: {e}.")
|
||
self.paraphrase_model = None
|
||
self.paraphrase_tokenizer = None
|
||
if not self.enable_api:
|
||
logger.warning("Paraphrase model failed to load and API augmentation disabled")
|
||
# SOTA back-translation model loading (if available)
|
||
if 'back_translation' in self.augmentation_methods:
|
||
# Example: Use Helsinki-NLP/opus-mt-zh-en and en-zh for Chinese-English back-translation
|
||
try:
|
||
self.bt_zh2en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
|
||
self.bt_zh2en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
|
||
self.bt_en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
|
||
self.bt_en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
|
||
logger.info("Loaded SOTA back-translation models for zh<->en.")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load SOTA back-translation models: {e}")
|
||
self.bt_zh2en_model = None
|
||
self.bt_en2zh_model = None
|
||
|
||
def augment_queries(
|
||
self,
|
||
data: List[Dict],
|
||
num_augmentations: Optional[int] = None,
|
||
keep_original: bool = True
|
||
) -> List[Dict]:
|
||
"""
|
||
Augment queries in the dataset with comprehensive error handling.
|
||
Args:
|
||
data: Original data
|
||
num_augmentations: Number of augmentations per query
|
||
keep_original: Whether to keep original queries
|
||
Returns:
|
||
Augmented dataset
|
||
"""
|
||
if not data:
|
||
logger.warning("Empty data provided for augmentation")
|
||
return []
|
||
|
||
if num_augmentations is None:
|
||
num_augmentations = self.num_augmentations_per_sample
|
||
|
||
if not isinstance(num_augmentations, int) or num_augmentations <= 0:
|
||
logger.warning("Invalid num_augmentations, setting to 1")
|
||
num_augmentations = 1
|
||
|
||
augmented_data = []
|
||
|
||
try:
|
||
for item in tqdm(data, desc="Augmenting queries"):
|
||
# Validate item
|
||
if not isinstance(item, dict) or 'query' not in item:
|
||
logger.warning(f"Invalid item format: {item}")
|
||
continue
|
||
|
||
# Keep original
|
||
if keep_original:
|
||
augmented_data.append(item)
|
||
|
||
# Generate augmentations
|
||
for i in range(num_augmentations):
|
||
try:
|
||
method = random.choice(self.augmentation_methods)
|
||
aug_item = self._augment_single(item, method)
|
||
if aug_item:
|
||
augmented_data.append(aug_item)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to augment item with {method}: {e}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"Query augmentation failed: {e}")
|
||
# Return original data on complete failure
|
||
return data
|
||
|
||
logger.info(f"Augmented {len(data)} items to {len(augmented_data)} items")
|
||
return augmented_data
|
||
|
||
def _augment_single(self, item: Dict, method: str) -> Optional[Dict]:
|
||
"""Augment a single item using specified method with error handling."""
|
||
if not item.get('query'):
|
||
return None
|
||
|
||
try:
|
||
aug_item = item.copy()
|
||
original_query = item['query']
|
||
|
||
if method == 'paraphrase':
|
||
aug_query = self._paraphrase_query(original_query)
|
||
elif method == 'back_translation':
|
||
aug_query = self._back_translate_query(original_query)
|
||
elif method == 'synonym':
|
||
aug_query = self._synonym_replacement(original_query)
|
||
elif method == 'template':
|
||
aug_query = self._template_based_augmentation(original_query)
|
||
else:
|
||
logger.warning(f"Unknown augmentation method: {method}")
|
||
return None
|
||
|
||
if aug_query and aug_query != original_query:
|
||
aug_item['query'] = aug_query
|
||
aug_item['augmentation_method'] = method
|
||
aug_item['is_augmented'] = True
|
||
return aug_item
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Error in {method} augmentation: {e}")
|
||
|
||
return None
|
||
|
||
def _get_prompt(self, task: str, query: str, lang: str) -> str:
|
||
"""Get prompt for LLM-based augmentation."""
|
||
if lang == 'zh':
|
||
if task == 'paraphrase':
|
||
return f"请将下列文本改写为意思相同的句子,只输出改写后的内容,不要解释:\n\n{query}"
|
||
elif task == 'synonym':
|
||
return f"请将下列文本中的词语替换为合适的同义词,只输出结果,不要解释:\n\n{query}"
|
||
elif task == 'template':
|
||
return f"请用不同的提问模板改写下列问题,只输出结果,不要解释:\n\n{query}"
|
||
else:
|
||
if task == 'paraphrase':
|
||
return f"Paraphrase the following text. Only provide the paraphrase, no explanations:\n\n{query}"
|
||
elif task == 'synonym':
|
||
return f"Replace words in the following text with appropriate synonyms. Only provide the result, no explanations:\n\n{query}"
|
||
elif task == 'template':
|
||
return f"Rephrase the following text using a different question template. Only provide the result, no explanations:\n\n{query}"
|
||
return query
|
||
|
||
def _paraphrase_query(self, query: str) -> str:
|
||
"""Paraphrase a query using the loaded model or API."""
|
||
if not query or not query.strip():
|
||
return query
|
||
lang = 'zh' if self._is_chinese(query) else 'en'
|
||
# Try LLM API first if enabled and configured
|
||
if self.enable_api and self.api_first:
|
||
try:
|
||
prompt = self._get_prompt('paraphrase', query, lang)
|
||
result = self._call_llm_api(prompt)
|
||
if result and result != query:
|
||
return result
|
||
except Exception as e:
|
||
logger.warning(f"LLM API paraphrasing failed: {e}. Trying local model.")
|
||
# Try T5-based paraphrasing next (only for English)
|
||
if lang == 'en' and self.paraphrase_model is not None and self.paraphrase_tokenizer is not None:
|
||
try:
|
||
input_text = f"paraphrase: {query}"
|
||
inputs = self.paraphrase_tokenizer(
|
||
input_text,
|
||
return_tensors="pt",
|
||
max_length=512,
|
||
truncation=True
|
||
).to(self.device)
|
||
with torch.no_grad():
|
||
outputs = self.paraphrase_model.generate(
|
||
**inputs,
|
||
max_length=150,
|
||
num_return_sequences=1,
|
||
temperature=self.augmentation_temperature,
|
||
do_sample=True,
|
||
pad_token_id=self.paraphrase_tokenizer.eos_token_id
|
||
)
|
||
paraphrase = self.paraphrase_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
if paraphrase and paraphrase != query:
|
||
return paraphrase
|
||
except Exception as e:
|
||
logger.warning(f"T5 paraphrasing failed: {e}. Using rule-based approach.")
|
||
# Fallback to rule-based paraphrasing
|
||
return self._rule_based_paraphrase(query, lang)
|
||
|
||
def _rule_based_paraphrase(self, query: str, lang: str = 'en') -> str:
|
||
"""Rule-based paraphrasing for fallback."""
|
||
if not query:
|
||
return query
|
||
try:
|
||
if lang == 'zh':
|
||
patterns = [
|
||
("什么是", "请解释"), ("如何", "怎样"), ("为什么", "什么原因"), ("在哪里", "什么地方"),
|
||
("是什么", "指的是什么"), ("有什么", "存在哪些"), ("能否", "可以"), ("应该", "需要"),
|
||
("必须", "一定要"), ("关于", "有关")
|
||
]
|
||
paraphrased = query
|
||
for original, replacement in patterns:
|
||
if original in query:
|
||
paraphrased = query.replace(original, replacement, 1)
|
||
break
|
||
if paraphrased == query:
|
||
prefixes = ["请问", "我想知道", "能否告诉我", "请教一下", "想了解"]
|
||
paraphrased = f"{random.choice(prefixes)}{query}"
|
||
return paraphrased
|
||
else:
|
||
templates = [
|
||
"What is meant by {}?", "Can you explain {}?", "Tell me about {}", "I want to know about {}",
|
||
"Please describe {}", "Could you clarify {}?", "What does {} refer to?", "How would you define {}?"
|
||
]
|
||
core = query.strip("?.,!").lower()
|
||
question_starters = ["what", "how", "why", "when", "where", "who", "which"]
|
||
for starter in question_starters:
|
||
if core.startswith(starter):
|
||
core = core[len(starter):].strip()
|
||
for aux in ["is", "are", "was", "were", "does", "do", "did"]:
|
||
if core.startswith(aux):
|
||
core = core[len(aux):].strip()
|
||
break
|
||
break
|
||
if core:
|
||
return random.choice(templates).format(core)
|
||
else:
|
||
return query
|
||
except Exception as e:
|
||
logger.warning(f"Rule-based paraphrasing failed: {e}")
|
||
return query
|
||
|
||
def _back_translate_query(self, query: str) -> str:
|
||
"""Back-translate a query using SOTA models if available, else fallback to API or simulation. Always returns a str."""
|
||
if hasattr(self, 'bt_zh2en_model') and self.bt_zh2en_model is not None and self.bt_en2zh_model is not None:
|
||
try:
|
||
is_chinese = any('\u4e00' <= c <= '\u9fff' for c in query)
|
||
if is_chinese:
|
||
inputs = self.bt_zh2en_tokenizer(query, return_tensors="pt", truncation=True, max_length=64).to(self.device)
|
||
en = self.bt_zh2en_model.generate(**inputs, max_length=64)
|
||
en_text = self.bt_zh2en_tokenizer.decode(en[0], skip_special_tokens=True)
|
||
inputs = self.bt_en2zh_tokenizer(en_text, return_tensors="pt", truncation=True, max_length=64).to(self.device)
|
||
zh = self.bt_en2zh_model.generate(**inputs, max_length=64)
|
||
zh_text = self.bt_en2zh_tokenizer.decode(zh[0], skip_special_tokens=True)
|
||
return zh_text if zh_text else query
|
||
else:
|
||
inputs = self.bt_en2zh_tokenizer(query, return_tensors="pt", truncation=True, max_length=64).to(self.device)
|
||
zh = self.bt_en2zh_model.generate(**inputs, max_length=64)
|
||
zh_text = self.bt_en2zh_tokenizer.decode(zh[0], skip_special_tokens=True)
|
||
inputs = self.bt_zh2en_tokenizer(zh_text, return_tensors="pt", truncation=True, max_length=64).to(self.device)
|
||
en = self.bt_zh2en_model.generate(**inputs, max_length=64)
|
||
en_text = self.bt_zh2en_tokenizer.decode(en[0], skip_special_tokens=True)
|
||
return en_text if en_text else query
|
||
except Exception as e:
|
||
logger.warning(f"SOTA back-translation failed: {e}")
|
||
if self.enable_api:
|
||
try:
|
||
is_chinese = any('\u4e00' <= c <= '\u9fff' for c in query)
|
||
result = self._translate_with_api(query, 'zh' if is_chinese else 'en', 'en' if is_chinese else 'zh')
|
||
if result:
|
||
return result
|
||
except Exception as e:
|
||
logger.warning(f"API back-translation failed: {e}")
|
||
fallback = self._simulate_back_translation(query)
|
||
return fallback if fallback is not None else query
|
||
|
||
def _translate_with_api(self, text: str, source_lang: str, target_lang: str) -> Optional[str]:
|
||
"""Translate text using API."""
|
||
if not text or not text.strip():
|
||
return None
|
||
|
||
base_url = self.api_config.get('base_url', '')
|
||
api_key = self.api_config.get('api_key', '')
|
||
model = self.api_config.get('model', 'deepseek-chat')
|
||
timeout = self.api_config.get('timeout', 30)
|
||
max_retries = self.api_config.get('max_retries', 3)
|
||
|
||
if not base_url or not api_key or api_key == 'your-api-key-here':
|
||
return None
|
||
|
||
headers = {
|
||
'Authorization': f'Bearer {api_key}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
# Create translation prompt
|
||
lang_names = {'zh': 'Chinese', 'en': 'English', 'ja': 'Japanese', 'fr': 'French'}
|
||
source_name = lang_names.get(source_lang, source_lang)
|
||
target_name = lang_names.get(target_lang, target_lang)
|
||
|
||
prompt = f"Translate the following {source_name} text to {target_name}. Only provide the translation, no explanations:\n\n{text}"
|
||
|
||
data = {
|
||
'model': model,
|
||
'messages': [
|
||
{'role': 'user', 'content': prompt}
|
||
],
|
||
'temperature': self.augmentation_temperature,
|
||
'max_tokens': 500
|
||
}
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
response = requests.post(
|
||
f"{base_url}/chat/completions",
|
||
headers=headers,
|
||
json=data,
|
||
timeout=timeout
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
if 'choices' in result and result['choices']:
|
||
translated = result['choices'][0]['message']['content'].strip()
|
||
return translated
|
||
else:
|
||
logger.warning(f"Invalid response format: {result}")
|
||
else:
|
||
logger.warning(f"Translation API error: {response.status_code} - {response.text}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(2 ** attempt) # Exponential backoff
|
||
|
||
except requests.exceptions.Timeout:
|
||
logger.warning(f"Translation API timeout on attempt {attempt + 1}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(2 ** attempt)
|
||
except requests.exceptions.RequestException as e:
|
||
logger.warning(f"Translation API request failed: {e}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(2 ** attempt)
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error in translation API: {e}")
|
||
break
|
||
|
||
return None
|
||
|
||
def _simulate_back_translation(self, query: str) -> str:
|
||
"""Simulate back-translation as a fallback."""
|
||
if not query:
|
||
return query
|
||
|
||
try:
|
||
# Add language-specific variations that simulate translation effects
|
||
if self._is_chinese(query):
|
||
# Simulate en->zh->en effects
|
||
substitutions = [
|
||
("什么", "哪些"), ("如何", "怎么样"), ("为什么", "什么原因"),
|
||
("非常", "很"), ("特别", "特殊"), ("重要", "关键"),
|
||
("问题", "事情"), ("方法", "方式"), ("内容", "信息")
|
||
]
|
||
else:
|
||
# Simulate zh->en->zh effects
|
||
substitutions = [
|
||
("what", "which"), ("how", "in what way"), ("very", "extremely"),
|
||
("important", "crucial"), ("method", "approach"), ("content", "information"),
|
||
("question", "issue"), ("problem", "matter"), ("find", "locate")
|
||
]
|
||
|
||
result = query
|
||
for original, replacement in substitutions:
|
||
if original in result.lower():
|
||
# Case-preserving replacement
|
||
if original.isupper():
|
||
replacement = replacement.upper()
|
||
elif original.istitle():
|
||
replacement = replacement.title()
|
||
|
||
result = result.replace(original, replacement, 1)
|
||
break
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Back-translation simulation failed: {e}")
|
||
return query
|
||
|
||
def _synonym_replacement(self, query: str) -> str:
|
||
"""Replace words in query with synonyms."""
|
||
if not query or not query.strip():
|
||
return query
|
||
lang = 'zh' if self._is_chinese(query) else 'en'
|
||
# Try LLM API first if enabled and configured
|
||
if self.enable_api and self.api_first:
|
||
try:
|
||
prompt = self._get_prompt('synonym', query, lang)
|
||
result = self._call_llm_api(prompt)
|
||
if result and result != query:
|
||
return result
|
||
except Exception as e:
|
||
logger.warning(f"LLM API synonym replacement failed: {e}. Using rule-based approach.")
|
||
# Fallback to hard-coded synonym replacement
|
||
try:
|
||
if lang == 'zh':
|
||
synonyms = {
|
||
"重要": ["关键", "核心", "主要"], "方法": ["办法", "手段", "途径"], "问题": ["疑问", "难题", "困惑"],
|
||
"内容": ["材料", "信息", "资料"], "增加": ["提升", "增长", "扩展"], "发现": ["查找", "寻找", "探索"]
|
||
}
|
||
words = list(query)
|
||
replaced = False
|
||
for i, word in enumerate(words):
|
||
if word in synonyms:
|
||
words[i] = random.choice(synonyms[word])
|
||
replaced = True
|
||
if replaced:
|
||
return ''.join(words)
|
||
return query
|
||
else:
|
||
synonyms = {
|
||
"important": ["crucial", "vital", "significant"], "method": ["approach", "technique", "procedure"],
|
||
"find": ["locate", "discover", "identify"], "problem": ["issue", "challenge", "difficulty"],
|
||
"increase": ["rise", "growth", "expansion"], "content": ["material", "substance", "information"]
|
||
}
|
||
words = query.split()
|
||
replaced = False
|
||
for i, word in enumerate(words):
|
||
key = word.lower().strip('.,?!')
|
||
if key in synonyms:
|
||
words[i] = random.choice(synonyms[key])
|
||
replaced = True
|
||
if replaced:
|
||
return ' '.join(words)
|
||
return query
|
||
except Exception as e:
|
||
logger.warning(f"Rule-based synonym replacement failed: {e}")
|
||
return query
|
||
|
||
def _template_based_augmentation(self, query: str) -> str:
|
||
"""Template-based augmentation for queries."""
|
||
if not query or not query.strip():
|
||
return query
|
||
lang = 'zh' if self._is_chinese(query) else 'en'
|
||
# Try LLM API first if enabled and configured
|
||
if self.enable_api and self.api_first:
|
||
try:
|
||
prompt = self._get_prompt('template', query, lang)
|
||
result = self._call_llm_api(prompt)
|
||
if result and result != query:
|
||
return result
|
||
except Exception as e:
|
||
logger.warning(f"LLM API template-based augmentation failed: {e}. Using rule-based approach.")
|
||
# Fallback to hard-coded templates
|
||
try:
|
||
if lang == 'zh':
|
||
templates = [
|
||
"请找到关于{}的段落", "搜索有关{}的信息", "寻找讨论{}的内容", "我需要了解{}的详细信息", "显示与{}相关的文本"
|
||
]
|
||
core = query.strip("?.,!。:")
|
||
for prefix in ["什么是", "如何", "为什么", "在哪里", "谁是"]:
|
||
if core.startswith(prefix):
|
||
core = core[len(prefix):].strip()
|
||
break
|
||
if core:
|
||
return random.choice(templates).format(core)
|
||
else:
|
||
return query
|
||
else:
|
||
templates = [
|
||
"What is meant by {}?", "Can you explain {}?", "Tell me about {}", "I want to know about {}",
|
||
"Please describe {}", "Could you clarify {}?", "What does {} refer to?", "How would you define {}?"
|
||
]
|
||
core = query.strip("?.,!").lower()
|
||
if core:
|
||
return random.choice(templates).format(core)
|
||
else:
|
||
return query
|
||
except Exception as e:
|
||
logger.warning(f"Rule-based template augmentation failed: {e}")
|
||
return query
|
||
|
||
def _call_llm_api(self, prompt: str) -> Optional[str]:
|
||
"""Call LLM API for augmentation."""
|
||
base_url = self.api_config.get('base_url', '')
|
||
api_key = self.api_config.get('api_key', '')
|
||
model = self.api_config.get('model', 'deepseek-chat')
|
||
timeout = self.api_config.get('timeout', 30)
|
||
max_retries = self.api_config.get('max_retries', 3)
|
||
temperature = self.api_config.get('temperature', self.augmentation_temperature)
|
||
if not base_url or not api_key or api_key == 'your-api-key-here':
|
||
return None
|
||
headers = {
|
||
'Authorization': f'Bearer {api_key}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
data = {
|
||
'model': model,
|
||
'messages': [
|
||
{'role': 'user', 'content': prompt}
|
||
],
|
||
'temperature': temperature,
|
||
'max_tokens': 500
|
||
}
|
||
for attempt in range(max_retries):
|
||
try:
|
||
response = requests.post(
|
||
f"{base_url}/chat/completions",
|
||
headers=headers,
|
||
json=data,
|
||
timeout=timeout
|
||
)
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
if 'choices' in result and result['choices']:
|
||
content = result['choices'][0]['message']['content'].strip()
|
||
return content
|
||
else:
|
||
logger.warning(f"Invalid response format: {result}")
|
||
else:
|
||
logger.warning(f"LLM API error: {response.status_code} - {response.text}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(2 ** attempt)
|
||
except requests.exceptions.Timeout:
|
||
logger.warning(f"LLM API timeout on attempt {attempt + 1}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(2 ** attempt)
|
||
except requests.exceptions.RequestException as e:
|
||
logger.warning(f"LLM API request failed: {e}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(2 ** attempt)
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error in LLM API: {e}")
|
||
break
|
||
return None
|
||
|
||
def _is_chinese(self, text: str) -> bool:
|
||
"""Detect if text is Chinese."""
|
||
if not text:
|
||
return False
|
||
try:
|
||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
class PassageAugmenter:
|
||
"""
|
||
Augment passages to create training variations
|
||
"""
|
||
|
||
def __init__(self):
|
||
try:
|
||
self.config = get_config()
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load config: {e}, using defaults")
|
||
self.config = get_config() # Always use get_config()
|
||
|
||
aug_config = self.config.get_section('augmentation')
|
||
self.enable_passage_augmentation = aug_config.get('enable_passage_augmentation', False)
|
||
|
||
self.augmentation_strategies = [
|
||
'sentence_shuffle',
|
||
'span_masking',
|
||
'paraphrase_sentences',
|
||
'add_noise'
|
||
]
|
||
|
||
def augment_passages(
|
||
self,
|
||
data: List[Dict],
|
||
augmentation_prob: Optional[float] = None
|
||
) -> List[Dict]:
|
||
"""
|
||
Augment passages in the dataset with comprehensive error handling
|
||
|
||
Args:
|
||
data: Original data
|
||
augmentation_prob: Probability of augmenting each item
|
||
|
||
Returns:
|
||
Augmented dataset
|
||
"""
|
||
if not self.enable_passage_augmentation:
|
||
logger.info("Passage augmentation disabled in config")
|
||
return data
|
||
|
||
if not data:
|
||
logger.warning("Empty data provided for passage augmentation")
|
||
return []
|
||
|
||
if augmentation_prob is None:
|
||
augmentation_prob = 0.3
|
||
|
||
if not (0 <= augmentation_prob <= 1):
|
||
logger.warning("Invalid augmentation_prob, setting to 0.3")
|
||
augmentation_prob = 0.3
|
||
|
||
augmented_data = []
|
||
|
||
try:
|
||
for item in tqdm(data, desc="Augmenting passages"):
|
||
# Validate item
|
||
if not isinstance(item, dict):
|
||
logger.warning(f"Invalid item format: {item}")
|
||
continue
|
||
|
||
augmented_data.append(item)
|
||
|
||
if random.random() < augmentation_prob:
|
||
# Augment positive passages
|
||
if 'pos' in item and item['pos']:
|
||
try:
|
||
aug_item = item.copy()
|
||
aug_item['pos'] = [
|
||
self._augment_passage(p, random.choice(self.augmentation_strategies))
|
||
for p in item['pos']
|
||
if p and p.strip() # Skip empty passages
|
||
]
|
||
if aug_item['pos']: # Only add if we have valid augmented passages
|
||
aug_item['is_augmented'] = True
|
||
aug_item['augmentation_type'] = 'passage'
|
||
augmented_data.append(aug_item)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to augment passage for item: {e}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"Passage augmentation failed: {e}")
|
||
return data
|
||
|
||
return augmented_data
|
||
|
||
def _augment_passage(self, passage: str, strategy: str) -> str:
|
||
"""Augment a single passage with error handling"""
|
||
if not passage or not passage.strip():
|
||
return passage
|
||
|
||
try:
|
||
if strategy == 'sentence_shuffle':
|
||
return self._shuffle_sentences(passage)
|
||
elif strategy == 'span_masking':
|
||
return self._mask_spans(passage)
|
||
elif strategy == 'paraphrase_sentences':
|
||
return self._paraphrase_sentences(passage)
|
||
elif strategy == 'add_noise':
|
||
return self._add_noise(passage)
|
||
else:
|
||
logger.warning(f"Unknown augmentation strategy: {strategy}")
|
||
return passage
|
||
except Exception as e:
|
||
logger.warning(f"Error in passage augmentation with {strategy}: {e}")
|
||
return passage
|
||
|
||
def _shuffle_sentences(self, passage: str) -> str:
|
||
"""Shuffle sentences in passage"""
|
||
if not passage or len(passage.strip()) < 10:
|
||
return passage
|
||
|
||
try:
|
||
# Split by Chinese or English sentence endings
|
||
if re.search(r'[\u4e00-\u9fff]', passage):
|
||
sentences = re.split(r'[。!?]', passage)
|
||
delimiter = '。'
|
||
else:
|
||
sentences = re.split(r'[.!?]', passage)
|
||
delimiter = '. '
|
||
|
||
sentences = [s.strip() for s in sentences if s.strip()]
|
||
|
||
if len(sentences) > 2:
|
||
# Keep first and last, shuffle middle
|
||
middle = sentences[1:-1]
|
||
random.shuffle(middle)
|
||
sentences = [sentences[0]] + middle + [sentences[-1]]
|
||
|
||
# Rejoin with appropriate punctuation
|
||
if delimiter == '。':
|
||
return delimiter.join(sentences) + '。'
|
||
else:
|
||
return delimiter.join(sentences) + '.'
|
||
|
||
return passage
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Sentence shuffling failed: {e}")
|
||
return passage
|
||
|
||
def _mask_spans(self, passage: str, mask_ratio: float = 0.15) -> str:
|
||
"""Mask random spans in passage"""
|
||
if not passage or len(passage.strip()) < 20:
|
||
return passage
|
||
|
||
try:
|
||
# Split into words/characters
|
||
if re.search(r'[\u4e00-\u9fff]', passage):
|
||
# Chinese: split by characters for finer granularity
|
||
units = list(passage)
|
||
join_char = ''
|
||
else:
|
||
# English: split by words
|
||
units = passage.split()
|
||
join_char = ' '
|
||
|
||
num_masks = max(1, int(len(units) * mask_ratio))
|
||
|
||
for _ in range(num_masks):
|
||
if len(units) > 5: # Ensure minimum length
|
||
start = random.randint(0, len(units) - 3)
|
||
length = random.randint(1, min(3, len(units) - start))
|
||
for i in range(start, start + length):
|
||
if i < len(units):
|
||
units[i] = '[MASK]'
|
||
|
||
# Rejoin
|
||
return join_char.join(units)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Span masking failed: {e}")
|
||
return passage
|
||
|
||
def _paraphrase_sentences(self, passage: str) -> str:
|
||
"""Paraphrase some sentences"""
|
||
if not passage or len(passage.strip()) < 20:
|
||
return passage
|
||
|
||
try:
|
||
# Split sentences
|
||
if re.search(r'[\u4e00-\u9fff]', passage):
|
||
sentences = re.split(r'[。!?]', passage)
|
||
delimiter = '。'
|
||
else:
|
||
sentences = re.split(r'[.!?]', passage)
|
||
delimiter = '. '
|
||
|
||
sentences = [s.strip() for s in sentences if s.strip()]
|
||
|
||
if sentences:
|
||
# Randomly modify one sentence
|
||
idx = random.randint(0, len(sentences) - 1)
|
||
original = sentences[idx]
|
||
|
||
# Simple paraphrasing by adding modifiers
|
||
if re.search(r'[\u4e00-\u9fff]', original):
|
||
modifiers = ["显然", "据了解", "实际上", "通常情况下", "根据研究"]
|
||
sentences[idx] = f"{random.choice(modifiers)},{original}"
|
||
else:
|
||
modifiers = ["Generally", "According to studies", "Typically", "In fact", "Usually"]
|
||
sentences[idx] = f"{random.choice(modifiers)}, {original.lower()}"
|
||
|
||
result = delimiter.join(sentences)
|
||
if delimiter == '。':
|
||
result += '。'
|
||
else:
|
||
result += '.'
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Sentence paraphrasing failed: {e}")
|
||
return passage
|
||
|
||
def _add_noise(self, passage: str, noise_level: float = 0.02) -> str:
|
||
"""Add character-level noise"""
|
||
if not passage or len(passage.strip()) < 10:
|
||
return passage
|
||
|
||
try:
|
||
chars = list(passage)
|
||
num_noisy = max(1, int(len(chars) * noise_level))
|
||
|
||
if len(chars) > 10: # Only add noise to sufficiently long passages
|
||
noise_positions = random.sample(range(len(chars)), min(num_noisy, len(chars) // 5))
|
||
|
||
for pos in noise_positions:
|
||
if pos < len(chars):
|
||
noise_type = random.choice(['duplicate', 'replace'])
|
||
|
||
if noise_type == 'duplicate':
|
||
chars[pos] = chars[pos] * 2
|
||
elif noise_type == 'replace':
|
||
# Replace with similar character
|
||
if re.search(r'[\u4e00-\u9fff]', chars[pos]):
|
||
# Chinese characters
|
||
similar_chars = ['的', '了', '在', '是', '我', '有', '他', '这', '中', '大']
|
||
chars[pos] = random.choice(similar_chars)
|
||
else:
|
||
# English characters
|
||
if chars[pos].isalpha():
|
||
chars[pos] = random.choice('abcdefghijklmnopqrstuvwxyz')
|
||
|
||
return ''.join(chars)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Noise addition failed: {e}")
|
||
return passage
|
||
|
||
|
||
class HardNegativeAugmenter:
|
||
"""
|
||
Create hard negatives through various strategies
|
||
"""
|
||
|
||
def __init__(self):
|
||
try:
|
||
self.config = get_config()
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load config: {e}, using defaults")
|
||
self.config = get_config() # Always use get_config()
|
||
|
||
mining_config = self.config.get_section('hard_negative_mining')
|
||
self.enable_mining = mining_config.get('enable_mining', True)
|
||
self.num_hard_negatives = mining_config.get('num_hard_negatives', 15)
|
||
|
||
self.strategies = [
|
||
'entity_swap',
|
||
'number_change',
|
||
'negation_addition',
|
||
'similar_context'
|
||
]
|
||
|
||
def create_hard_negatives(
|
||
self,
|
||
data: List[Dict],
|
||
num_hard_negatives: Optional[int] = None
|
||
) -> List[Dict]:
|
||
"""
|
||
Create hard negatives for training data with comprehensive error handling
|
||
|
||
Args:
|
||
data: Original data
|
||
num_hard_negatives: Number of hard negatives to create
|
||
|
||
Returns:
|
||
Data with hard negatives
|
||
"""
|
||
if not self.enable_mining:
|
||
logger.info("Hard negative mining disabled in config")
|
||
return data
|
||
|
||
if not data:
|
||
logger.warning("Empty data provided for hard negative creation")
|
||
return []
|
||
|
||
if num_hard_negatives is None:
|
||
num_hard_negatives = self.num_hard_negatives
|
||
|
||
if not isinstance(num_hard_negatives, int) or num_hard_negatives <= 0:
|
||
logger.warning("Invalid num_hard_negatives, setting to 5")
|
||
num_hard_negatives = 5
|
||
|
||
augmented_data = []
|
||
|
||
try:
|
||
for item in tqdm(data, desc="Creating hard negatives"):
|
||
# Validate item
|
||
if not isinstance(item, dict):
|
||
logger.warning(f"Invalid item format: {item}")
|
||
continue
|
||
|
||
aug_item = item.copy()
|
||
|
||
if 'pos' in item and item['pos']:
|
||
hard_negs = []
|
||
|
||
try:
|
||
for pos_passage in item['pos'][:2]: # Use first 2 positives
|
||
if not pos_passage or not pos_passage.strip():
|
||
continue
|
||
|
||
for i in range(min(num_hard_negatives, len(self.strategies) * 2)):
|
||
strategy = self.strategies[i % len(self.strategies)]
|
||
hard_neg = self._create_hard_negative(pos_passage, strategy)
|
||
if hard_neg and hard_neg != pos_passage and hard_neg.strip():
|
||
hard_negs.append(hard_neg)
|
||
|
||
if hard_negs:
|
||
if 'hard_neg' not in aug_item:
|
||
aug_item['hard_neg'] = []
|
||
# Remove duplicates and limit
|
||
unique_negs = list(dict.fromkeys(hard_negs))
|
||
aug_item['hard_neg'].extend(unique_negs[:num_hard_negatives])
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to create hard negatives for item: {e}")
|
||
|
||
augmented_data.append(aug_item)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Hard negative creation failed: {e}")
|
||
return data
|
||
|
||
return augmented_data
|
||
|
||
def _create_hard_negative(self, passage: str, strategy: str) -> str:
|
||
"""Create a hard negative using specified strategy with error handling"""
|
||
if not passage or not passage.strip():
|
||
return passage
|
||
|
||
try:
|
||
if strategy == 'entity_swap':
|
||
return self._swap_entities(passage)
|
||
elif strategy == 'number_change':
|
||
return self._change_numbers(passage)
|
||
elif strategy == 'negation_addition':
|
||
return self._add_negation(passage)
|
||
elif strategy == 'similar_context':
|
||
return self._create_similar_context(passage)
|
||
else:
|
||
logger.warning(f"Unknown hard negative strategy: {strategy}")
|
||
return passage
|
||
except Exception as e:
|
||
logger.warning(f"Error creating hard negative with {strategy}: {e}")
|
||
return passage
|
||
|
||
def _swap_entities(self, passage: str) -> str:
|
||
"""Swap entities to create factually incorrect passage"""
|
||
if not passage:
|
||
return passage
|
||
|
||
try:
|
||
if re.search(r'[\u4e00-\u9fff]', passage):
|
||
entity_pairs = [
|
||
("北京", "上海"), ("中国", "日本"), ("2023年", "2024年"),
|
||
("第一", "第二"), ("增加", "减少"), ("成功", "失败"),
|
||
("重要", "次要"), ("大型", "小型"), ("新的", "旧的")
|
||
]
|
||
else:
|
||
entity_pairs = [
|
||
("Beijing", "Shanghai"), ("China", "Japan"), ("2023", "2024"),
|
||
("first", "second"), ("increase", "decrease"), ("success", "failure"),
|
||
("important", "minor"), ("large", "small"), ("new", "old")
|
||
]
|
||
|
||
aug_passage = passage
|
||
for ent1, ent2 in entity_pairs:
|
||
if ent1 in passage and ent2 not in passage:
|
||
aug_passage = aug_passage.replace(ent1, ent2, 1)
|
||
break
|
||
|
||
return aug_passage
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Entity swapping failed: {e}")
|
||
return passage
|
||
|
||
def _change_numbers(self, passage: str) -> str:
|
||
"""Change numbers in passage"""
|
||
if not passage:
|
||
return passage
|
||
|
||
try:
|
||
def replace_number(match):
|
||
try:
|
||
num = int(match.group())
|
||
# Change by 20-80%
|
||
change = random.uniform(0.2, 0.8)
|
||
if random.random() > 0.5:
|
||
new_num = int(num * (1 + change))
|
||
else:
|
||
new_num = int(num * (1 - change))
|
||
return str(max(1, new_num)) # Ensure positive
|
||
except:
|
||
return match.group()
|
||
|
||
# Find and replace numbers
|
||
aug_passage = re.sub(r'\b\d+\b', replace_number, passage)
|
||
return aug_passage
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Number changing failed: {e}")
|
||
return passage
|
||
|
||
def _add_negation(self, passage: str) -> str:
|
||
"""Add negation to change meaning"""
|
||
if not passage:
|
||
return passage
|
||
|
||
try:
|
||
if re.search(r'[\u4e00-\u9fff]', passage):
|
||
negation_patterns = [
|
||
("是", "不是"), ("有", "没有"), ("会", "不会"),
|
||
("能", "不能"), ("可以", "不可以"), ("应该", "不应该"),
|
||
("成功", "失败"), ("正确", "错误"), ("增长", "下降")
|
||
]
|
||
else:
|
||
negation_patterns = [
|
||
("is", "is not"), ("was", "was not"), ("can", "cannot"),
|
||
("will", "will not"), ("should", "should not"), ("has", "has no"),
|
||
("successful", "unsuccessful"), ("correct", "incorrect"), ("increase", "decrease")
|
||
]
|
||
|
||
aug_passage = passage
|
||
for positive, negative in negation_patterns:
|
||
if positive in passage and negative not in passage:
|
||
aug_passage = aug_passage.replace(positive, negative, 1)
|
||
break
|
||
|
||
return aug_passage
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Negation addition failed: {e}")
|
||
return passage
|
||
|
||
def _create_similar_context(self, passage: str) -> str:
|
||
"""Create passage with similar context but different meaning"""
|
||
if not passage:
|
||
return passage
|
||
|
||
try:
|
||
# Split into sentences
|
||
if re.search(r'[\u4e00-\u9fff]', passage):
|
||
sentences = re.split(r'[。!?]', passage)
|
||
delimiter = '。'
|
||
generic_statements = [
|
||
"这一点在许多情况下都有所体现",
|
||
"相关研究表明了不同的结果",
|
||
"实际情况可能有所不同",
|
||
"这个观点存在一定的争议",
|
||
"具体细节仍在讨论之中"
|
||
]
|
||
else:
|
||
sentences = re.split(r'[.!?]', passage)
|
||
delimiter = '. '
|
||
generic_statements = [
|
||
"This point is reflected in many cases",
|
||
"Related research shows different results",
|
||
"The actual situation may vary",
|
||
"This viewpoint is somewhat controversial",
|
||
"Specific details are still under discussion"
|
||
]
|
||
|
||
sentences = [s.strip() for s in sentences if s.strip()]
|
||
|
||
if len(sentences) >= 3:
|
||
# Replace middle sentence with generic statement
|
||
middle_idx = len(sentences) // 2
|
||
sentences[middle_idx] = random.choice(generic_statements)
|
||
|
||
result = delimiter.join(sentences)
|
||
if delimiter == '。':
|
||
result += '。'
|
||
else:
|
||
result += '.'
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Similar context creation failed: {e}")
|
||
return passage
|
||
|
||
|
||
class CrossLingualAugmenter:
|
||
"""
|
||
Create cross-lingual training data with enhanced error handling
|
||
"""
|
||
|
||
def __init__(self,
|
||
source_lang: str = 'zh',
|
||
target_langs: Optional[List[str]] = None):
|
||
try:
|
||
self.config = get_config()
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load config: {e}, using defaults")
|
||
self.config = get_config() # Always use get_config()
|
||
|
||
self.source_lang = source_lang
|
||
self.target_langs = target_langs or ['en']
|
||
|
||
# Get API configuration
|
||
self.api_config = self.config.get_section('api')
|
||
self.enable_api = bool(self.api_config.get('api_key') and
|
||
self.api_config.get('api_key') != 'your-api-key-here')
|
||
|
||
def create_cross_lingual_pairs(
|
||
self,
|
||
data: List[Dict],
|
||
translation_prob: float = 0.3
|
||
) -> List[Dict]:
|
||
"""
|
||
Create cross-lingual query-passage pairs with comprehensive error handling
|
||
|
||
Args:
|
||
data: Original data
|
||
translation_prob: Probability of creating translation
|
||
|
||
Returns:
|
||
Augmented data with cross-lingual pairs
|
||
"""
|
||
if not data:
|
||
logger.warning("Empty data provided for cross-lingual augmentation")
|
||
return []
|
||
|
||
if not (0 <= translation_prob <= 1):
|
||
logger.warning("Invalid translation_prob, setting to 0.3")
|
||
translation_prob = 0.3
|
||
|
||
augmented_data = []
|
||
|
||
try:
|
||
for item in tqdm(data, desc="Creating cross-lingual pairs"):
|
||
# Validate item
|
||
if not isinstance(item, dict) or 'query' not in item:
|
||
logger.warning(f"Invalid item format: {item}")
|
||
continue
|
||
|
||
augmented_data.append(item)
|
||
|
||
if random.random() < translation_prob:
|
||
# Create translated query version
|
||
for target_lang in self.target_langs:
|
||
try:
|
||
trans_item = item.copy()
|
||
trans_query = self._translate_query(
|
||
item['query'],
|
||
self.source_lang,
|
||
target_lang
|
||
)
|
||
if trans_query and trans_query != item['query']:
|
||
trans_item['query'] = trans_query
|
||
trans_item['cross_lingual'] = True
|
||
trans_item['query_lang'] = target_lang
|
||
trans_item['passage_lang'] = self.source_lang
|
||
augmented_data.append(trans_item)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to create cross-lingual pair for {target_lang}: {e}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"Cross-lingual augmentation failed: {e}")
|
||
return data
|
||
|
||
return augmented_data
|
||
|
||
def _translate_query(self, query: str, source: str, target: str) -> str:
|
||
"""Translate query from source to target language with error handling"""
|
||
if not query or not query.strip():
|
||
return query
|
||
|
||
# Try API translation first if enabled
|
||
if self.enable_api:
|
||
try:
|
||
# Use the existing API translation method
|
||
augmenter = QueryAugmenter()
|
||
result = augmenter._translate_with_api(query, source, target)
|
||
if result:
|
||
return result
|
||
except Exception as e:
|
||
logger.warning(f"API translation failed: {e}")
|
||
|
||
# Fallback to rule-based simulation
|
||
return self._simulate_translation(query, source, target)
|
||
|
||
def _simulate_translation(self, query: str, source: str, target: str) -> str:
|
||
"""Simulate translation with pattern-based replacements"""
|
||
if not query:
|
||
return query
|
||
|
||
try:
|
||
if source == 'zh' and target == 'en':
|
||
# Chinese to English patterns
|
||
replacements = {
|
||
"什么": "what", "如何": "how", "为什么": "why", "哪里": "where",
|
||
"谁": "who", "书": "book", "章节": "chapter", "信息": "information",
|
||
"内容": "content", "文本": "text", "段落": "passage"
|
||
}
|
||
translated = query
|
||
for zh, en in replacements.items():
|
||
translated = translated.replace(zh, en)
|
||
return f"[{source}→{target}] {translated}"
|
||
|
||
elif source == 'en' and target == 'zh':
|
||
# English to Chinese patterns
|
||
replacements = {
|
||
"what": "什么", "how": "如何", "why": "为什么", "where": "哪里",
|
||
"who": "谁", "book": "书", "chapter": "章节", "information": "信息",
|
||
"content": "内容", "text": "文本", "passage": "段落"
|
||
}
|
||
translated = query.lower()
|
||
for en, zh in replacements.items():
|
||
translated = translated.replace(en, zh)
|
||
return f"[{source}→{target}] {translated}"
|
||
|
||
# For other language pairs, just add marker
|
||
return f"[{source}→{target}] {query}"
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Translation simulation failed: {e}")
|
||
return query
|