init commit

This commit is contained in:
ldy 2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

166
.gitignore vendored Normal file
View File

@ -0,0 +1,166 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
.idea/
# Cache
cache/
.cursor/
# Models
models/bge-m3/
models/bge-reranker-base/

105
__init__.py Normal file
View File

@ -0,0 +1,105 @@
"""
BGE Fine-tuning Package
A comprehensive implementation for fine-tuning BAAI BGE models
"""
__version__ = "1.0.0"
__author__ = "ldy"
import logging
import os
import sys
import torch
logger = logging.getLogger(__name__)
# Mandatory dependency check
def _check_dependencies():
"""Ensure core dependencies are available."""
try:
_ = torch.__version__
except Exception as exc: # pragma: no cover - should not happen in tests
raise ImportError("PyTorch is required for bge_finetune") from exc
try:
import transformers # noqa: F401
except Exception as exc: # pragma: no cover - should not happen in tests
raise ImportError("transformers is required for bge_finetune") from exc
# Import guard for optional dependencies
def _check_optional_dependencies():
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
# Check for torch_npu if running on Ascend
try:
npu_available = hasattr(torch, 'npu') and torch.npu.is_available()
except Exception:
npu_available = False
if npu_available or os.environ.get('DEVICE', '').lower() == 'npu':
try:
import torch_npu
except ImportError:
logger.warning("Optional dependency 'torch_npu' not found. Ascend NPU support will be unavailable.")
# Check for wandb if Weights & Biases logging is enabled
try:
from utils.config_loader import get_config
config = get_config()
use_wandb = False
try:
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
except Exception:
pass
if use_wandb:
try:
import wandb
except ImportError:
logger.warning("Optional dependency 'wandb' not found. Weights & Biases logging will be unavailable.")
except Exception:
pass
# Check for tensorboard if tensorboard logging is likely to be used
tensorboard_dir = None
try:
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
except Exception:
pass
if tensorboard_dir:
try:
import tensorboard
except ImportError:
logger.warning("Optional dependency 'tensorboard' not found. TensorBoard logging will be unavailable.")
# Check for pytest if running tests
if any('pytest' in arg for arg in sys.argv):
try:
import pytest
except ImportError:
logger.warning("Optional dependency 'pytest' not found. Some tests may not run.")
# Check dependencies on import
_check_dependencies()
_check_optional_dependencies()
# Make key components available at package level
try:
from .config import BaseConfig, BGEM3Config, BGERerankerConfig
from .models import BGEM3Model, BGERerankerModel
from .utils import setup_logging, get_logger
__all__ = [
'BaseConfig',
'BGEM3Config',
'BGERerankerConfig',
'BGEM3Model',
'BGERerankerModel',
'setup_logging',
'get_logger'
]
except ImportError as e:
# Allow partial imports during development
import warnings
warnings.warn(f"Some components not available: {e}")
__all__ = []

292
config.toml Normal file
View File

@ -0,0 +1,292 @@
[api]
# API settings for data augmentation (used in data/augmentation.py)
# NOTE: You can override api_key by setting the environment variable BGE_API_KEY.
base_url = "https://api.deepseek.com/v1"
api_key = "sk-1bf2a963751_deletethis_b4e5e96d319810f8926e5" # Replace with your actual API key or set BGE_API_KEY env var
model = "deepseek-chat"
max_retries = 3 # Number of times to retry API calls on failure
timeout = 30 # Timeout for API requests (seconds)
temperature = 0.7 # Sampling temperature for augmentation
[training]
# Default training parameters (overridden by [hardware] or model-specific sections if present)
# If a parameter is present in both [training] and [hardware], [hardware] takes precedence.
default_batch_size = 4 # Default batch size per device for training
default_learning_rate = 1e-5 # Default learning rate
default_num_epochs = 3 # Default number of epochs
default_warmup_ratio = 0.1 # Warmup steps as a ratio of total steps
default_weight_decay = 0.01 # Weight decay for optimizer
gradient_accumulation_steps = 2 # Number of steps to accumulate gradients - reduced for small datasets
[model_paths]
# Model source: 'huggingface' or 'modelscope'. Determines which hub to use for downloading/loading models.
source = "huggingface" # Options: 'huggingface', 'modelscope'
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
# Cache directory for downloaded models (tokenizers, config files, etc.)
cache_dir = "./cache/models"
# To use ModelScope, set source = "modelscope" and provide the correct model id/path.
[data]
# Enhanced data processing settings with unified schema support
max_query_length = 64 # Max length for query tokens
max_passage_length = 512 # Max length for passage tokens
max_seq_length = 512 # Max total input sequence length
# For testing only, in production, set to 0.3
validation_split = 0.0 # Fraction of data for validation
random_seed = 42 # Random seed for reproducibility
# Cache directory for processed/optimized datasets (.pkl files)
cache_dir = "./cache/data"
# Default output directory for optimization script
optimization_output_dir = "./cache/data"
# New: Enhanced dataset format settings
auto_detect_format = true # Enable automatic format detection
legacy_support = true # Support legacy nested reranker format
migrate_legacy_on_load = false # Automatically migrate legacy data during loading
embedding_schema_version = "v3" # Use enhanced embedding schema
reranker_schema_version = "v3" # Use enhanced reranker schema
# New: Enhanced sampling settings
use_score_weighted_sampling = true # Enable score-weighted positive sampling
multiple_positives_per_batch = true # Use multiple positives for richer contrastive signals
hard_negative_ratio = 0.7 # Ratio of hard to random negatives
difficulty_threshold = 0.1 # Minimum difficulty threshold for hard negatives
[augmentation]
# Data augmentation settings (used in data/augmentation.py)
enable_query_augmentation = false # Enable query augmentation
enable_passage_augmentation = false # Enable passage augmentation
augmentation_methods = ["paraphrase", "back_translation"] # List of augmentation methods
api_first = true # Prefer LLM API for all augmentation if enabled; fallback to rule-based if API unavailable
# Augmentation is now language-aware: prompts and templates are selected based on query language (Chinese/English)
num_augmentations_per_sample = 2 # Number of augmentations per sample
augmentation_temperature = 0.8 # Sampling temperature for augmentation
[hard_negative_mining]
# ANCE hard negative mining settings (used in data/hard_negative_mining.py)
enable_mining = true # Enable hard negative mining
num_hard_negatives = 15 # Number of hard negatives to mine
negative_range_min = 10 # Min index for negative mining
negative_range_max = 100 # Max index for negative mining
margin = 0.1 # Margin for ANCE mining
index_refresh_interval = 1000 # Steps between index refreshes
index_type = "IVF" # FAISS index type
nlist = 100 # Number of clusters for IVF
[evaluation]
# Evaluation settings (used in evaluation/evaluator.py)
eval_batch_size = 32 # Batch size for evaluation
k_values = [1, 3, 5, 10, 20, 50, 100] # k values for metrics
metrics = ["recall", "precision", "map", "mrr", "ndcg"] # Metrics to compute
save_predictions = true # Save predictions to file
predictions_dir = "./output/predictions" # Directory for predictions
[logging]
# Logging configuration (used in utils/logging.py)
log_level = "INFO" # Logging level (DEBUG, INFO, WARNING, ERROR)
log_to_file = true # Log to file in addition to stdout
log_dir = "./logs" # Directory for log files
tensorboard_dir = "./logs/tensorboard" # Directory for TensorBoard logs
wandb_project = "bge-finetuning" # Weights & Biases project name
wandb_entity = "" # Your wandb entity
use_wandb = false # Enable Weights & Biases logging
[distributed]
# Distributed training settings (used in utils/distributed.py)
# backend: 'nccl' for CUDA GPUs, 'hccl' for Huawei Ascend NPUs, 'gloo' for CPU
backend = "nccl" # Backend for distributed training (nccl for CUDA, hccl for Ascend, gloo for CPU)
init_method = "env://" # Initialization method for distributed training
find_unused_parameters = true # Find unused parameters in DDP
[ascend]
# Ascend NPU-specific settings (used for Huawei hardware)
# Set device_type to 'npu' to enable Ascend support
# backend: 'hccl' is required for distributed training on Ascend
# mixed_precision: true enables float16/bfloat16 training if supported
# Example usage:
# device_type = "npu"
# backend = "hccl"
# mixed_precision = true
# visible_devices = "0,1,2,3" # Comma-separated list of NPU device IDs
# dataloader_num_workers = 4
# per_device_train_batch_size = 4
# per_device_eval_batch_size = 8
# Uncomment and set as needed:
# device_type = "npu"
# backend = "hccl"
# mixed_precision = true
# visible_devices = "0,1,2,3"
# dataloader_num_workers = 4
# per_device_train_batch_size = 4
# per_device_eval_batch_size = 8
[optimization]
# Advanced optimization settings (used in models/losses.py, training/*_trainer.py)
gradient_checkpointing = true # Enable gradient checkpointing for memory savings
fp16 = true # Use mixed precision (float16) training
bf16 = false # Set to true for A100/H100 GPUs (bfloat16)
fp16_opt_level = "O1" # Apex AMP optimization level
max_grad_norm = 1.0 # Max gradient norm for clipping
adam_epsilon = 1e-8 # Epsilon for Adam optimizer
adam_beta1 = 0.9 # Beta1 for Adam optimizer
adam_beta2 = 0.999 # Beta2 for Adam optimizer
[hardware]
# Hardware-specific settings (overrides [training] for batch sizes, etc.)
cuda_visible_devices = "0,1,2,3" # Comma-separated list of visible CUDA devices
dataloader_num_workers = 4 # Number of workers for DataLoader
dataloader_pin_memory = true # Use pinned memory in DataLoader
per_device_train_batch_size = 4 # Batch size per device for training (overrides [training])
per_device_eval_batch_size = 8 # Batch size per device for evaluation
[checkpoint]
# Checkpointing settings (used in training/*_trainer.py)
save_strategy = "steps" # Save checkpoint by steps or epoch
save_steps = 1000 # Steps between checkpoints
save_total_limit = 3 # Max number of checkpoints to keep
load_best_model_at_end = true # Load best model at end of training
metric_for_best_model = "eval_recall@10" # Metric to select best model
greater_is_better = true # Whether higher metric is better
resume_from_checkpoint = "" # Path to checkpoint to resume from
[export]
# Model export settings (used in scripts/export.py)
export_format = ["pytorch", "onnx"] # Export formats
optimize_for_inference = true # Optimize model for inference
quantize = false # Enable quantization
quantization_bits = 8 # Number of bits for quantization
[m3]
# Enhanced BGE-M3 model-specific parameters with new features
# If a parameter is present in both [m3] and [training]/[hardware], [m3] takes precedence for BGE-M3.
use_dense = true # Enable dense embedding head
use_sparse = true # Enable sparse (SPLADE-style) embedding head
use_colbert = true # Enable ColBERT multi-vector head
unified_finetuning = true # Unified fine-tuning for all heads
sentence_pooling_method = "cls" # Pooling method for sentence embeddings
normalize_embeddings = true # Normalize output embeddings
colbert_dim = -1 # Output dimension for ColBERT head (-1 uses model hidden size)
sparse_top_k = 100 # Top-k for sparse representations
# Training
train_group_size = 8 # Number of passages per query in training
temperature = 0.1 # InfoNCE loss temperature - increased from 0.02 for better stability
negatives_cross_device = true # Share negatives across devices/GPUs
# Enhanced hard negative mining
use_hard_negatives = true # Enable hard negative mining
num_hard_negatives = 15 # Number of hard negatives to mine
hard_negative_mining_steps = 1000 # Steps between hard negative mining refreshes
ance_negative_range = [10, 100] # Range for ANCE negative mining
ance_margin = 0.1 # Margin for ANCE mining
hard_negative_ratio = 0.7 # Ratio of hard to random negatives
difficulty_threshold = 0.1 # Minimum difficulty for hard negatives
max_triplets_per_query = 8 # Maximum triplets per query for training
# Enhanced self-distillation
use_self_distill = false # Enable self-knowledge distillation - changed default to false
self_distill_temperature = 4.0 # Temperature for self-distillation
self_distill_alpha = 0.5 # Alpha (weight) for self-distillation loss
# Enhanced query instruction
use_query_instruction = false # Add instruction to queries
query_instruction_for_retrieval = "Represent this sentence for searching relevant passages:"
query_instruction_for_reranking = ""
# Enhanced data augmentation
augment_queries = false # Enable query augmentation
augmentation_methods = ["paraphrase", "back_translation"] # List of augmentation methods
create_hard_negatives = false # Create hard negatives via augmentation
# Enhanced sampling features
use_score_weighted_sampling = true # Enable score-weighted positive sampling
multiple_positives = true # Use multiple positives per batch for richer contrastive signals
enhanced_contrastive_learning = true # Enable enhanced contrastive learning features
# Loss weights
# For multi-task learning, these weights control the contribution of each head/loss
dense_weight = 1.0
sparse_weight = 0.3
colbert_weight = 0.5
distillation_weight = 1.0
# Evaluation
metric_for_best_model = "eval_recall@10"
greater_is_better = true
# Memory optimization
use_gradient_checkpointing = true # Enable gradient checkpointing
max_passages_per_query = 32 # Max passages per query in memory
# Multi-granularity
max_query_length = 64 # Max query length
max_passage_length = 512 # Max passage length
max_length = 8192 # Max total input length (multi-granularity)
# New: Enhanced dataset integration
use_new_dataset_api = true # Use integrated dataset API with format detection
auto_migrate_legacy = false # Automatically migrate legacy formats during training
[reranker]
# Enhanced BGE-Reranker model-specific parameters with flat pairs support
# If a parameter is present in both [reranker] and [training]/[hardware], [reranker] takes precedence for BGE-Reranker.
add_pooling_layer = false # Not used for cross-encoder
use_fp16 = true # Use mixed precision (float16) training
train_group_size = 16 # Number of passages per query for listwise training
loss_type = "listwise_ce" # Loss function type ('listwise_ce', 'pairwise_margin', 'combined', 'cross_entropy', 'binary_cross_entropy')
margin = 1.0 # Margin for pairwise loss
temperature = 1.0 # Temperature for listwise CE loss
# Enhanced knowledge distillation
use_knowledge_distillation = false # Enable knowledge distillation from teacher model
teacher_model_name_or_path = "" # Path to teacher model for distillation
distillation_temperature = 3.0 # Temperature for distillation
distillation_alpha = 0.7 # Alpha (weight) for distillation loss
# Enhanced hard negatives
use_hard_negatives = true # Enable hard negative mining
hard_negative_ratio = 0.8 # Ratio of hard to random negatives
# Enhanced denoising
use_denoising = false # Enable denoising strategy
noise_probability = 0.1 # Probability of noise for denoising
# Performance optimizations
max_pairs_per_device = 128 # Max query-passage pairs per device
accumulate_gradients_from_all_devices = true # Accumulate gradients from all devices
max_seq_length = 512 # Max combined query+passage length
query_max_length = 64 # Max query length (if separate encoding)
passage_max_length = 512 # Max passage length (if separate encoding)
# Enhanced training features
shuffle_passages = true # Shuffle passage order during training
position_bias_cutoff = 10 # Top-k passages to consider for position bias
use_gradient_caching = true # Enable gradient caching for efficiency
gradient_cache_batch_size = 32 # Batch size for gradient caching
eval_group_size = 100 # Number of candidates during evaluation
metric_for_best_model = "eval_mrr@10" # Metric to select best model
eval_normalize_scores = true # Normalize scores during evaluation
# Loss weights for combined training
listwise_weight = 0.7 # Loss weight for listwise loss
pairwise_weight = 0.3 # Loss weight for pairwise loss
# Data quality controls
filter_empty_passages = true # Filter out empty passages
min_passage_length = 10 # Minimum passage length to keep
max_passage_length_ratio = 50.0 # Max ratio between query and passage length
# Advanced features
gradient_checkpointing = true # Enable gradient checkpointing
recompute_scores = false # Recompute scores instead of storing all
use_dynamic_teacher = false # Enable dynamic teacher for RocketQAv2-style training
teacher_update_steps = 1000 # Steps between teacher updates
teacher_momentum = 0.999 # Momentum for teacher updates
# New: Enhanced format support
support_flat_pairs = true # Support flat pairs format (query, passage, label)
support_nested_format = true # Support legacy nested format for backward compatibility
auto_convert_nested_to_flat = false # Automatically convert nested to flat during training
use_new_dataset_api = true # Use integrated dataset API with format detection

13
config/__init__.py Normal file
View File

@ -0,0 +1,13 @@
"""
Configuration module for BGE fine-tuning
"""
from .base_config import BaseConfig
from .m3_config import BGEM3Config
from .reranker_config import BGERerankerConfig
__all__ = [
'BaseConfig',
'BGEM3Config',
'BGERerankerConfig'
]

560
config/base_config.py Normal file
View File

@ -0,0 +1,560 @@
import os
import torch
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
import logging
logger = logging.getLogger(__name__)
@dataclass
class BaseConfig:
"""
Base configuration for all trainers and models.
Configuration Precedence Order:
1. Command-line arguments (highest priority)
2. Model-specific config sections ([m3], [reranker])
3. [hardware] section (for hardware-specific overrides)
4. [training] section (general training defaults)
5. Dataclass field defaults (lowest priority)
Example:
If batch_size is defined in multiple places:
- CLI arg --batch_size 32 (wins)
- [m3] batch_size = 16
- [hardware] per_device_train_batch_size = 8
- [training] default_batch_size = 4
- Dataclass default = 2
Note: Model-specific configs ([m3], [reranker]) take precedence over
[hardware] for their respective models to allow fine-grained control.
"""
temperature: float = 1.0 # Softmax temperature for loss functions
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
negatives_cross_device: bool = False # Use negatives across devices (DDP)
label_smoothing: float = 0.0 # Label smoothing for classification losses
alpha: float = 1.0 # Weight for auxiliary losses
gradient_accumulation_steps: int = 1 # Steps to accumulate gradients
num_train_epochs: int = 3 # Number of training epochs
eval_steps: int = 1 # Evaluate every N epochs
save_steps: int = 1 # Save checkpoint every N epochs
warmup_steps: int = 0 # Number of warmup steps for scheduler
warmup_ratio: float = 0.0 # Warmup ratio for scheduler
metric_for_best_model: str = "accuracy" # Metric to select best model
checkpoint_dir: str = "./checkpoints" # Directory for saving checkpoints
resume_from_checkpoint: str = "" # Path to resume checkpoint
max_grad_norm: float = 1.0 # Max gradient norm for clipping
use_amp: bool = False # Use automatic mixed precision
# Model configuration
model_name_or_path: str = "BAAI/bge-m3"
cache_dir: Optional[str] = None
# Data configuration
train_data_path: str = "./data/train.jsonl"
eval_data_path: Optional[str] = "./data/eval.jsonl"
max_seq_length: int = 512
query_max_length: int = 64
passage_max_length: int = 512
# Training configuration
output_dir: str = "./output"
learning_rate: float = 1e-5
weight_decay: float = 0.01
adam_epsilon: float = 1e-8
adam_beta1: float = 0.9
adam_beta2: float = 0.999
# Advanced training configuration
fp16: bool = True
bf16: bool = False
gradient_checkpointing: bool = True
deepspeed: Optional[str] = None
local_rank: int = -1
dataloader_num_workers: int = 4
dataloader_pin_memory: bool = True
# Batch sizes
per_device_train_batch_size: int = 4
per_device_eval_batch_size: int = 8
# Logging and saving
logging_dir: str = "./logs"
# For testing only, in production, set to 100
logging_steps: int = 100
save_total_limit: int = 3
load_best_model_at_end: bool = True
greater_is_better: bool = False
evaluation_strategy: str = "steps"
# Checkpoint and resume
overwrite_output_dir: bool = False
# Seed
seed: int = 42
# Platform compatibility
device: str = "auto"
n_gpu: int = 0
def __post_init__(self):
"""Post-initialization with robust config loading and device setup
Notes:
- [hardware] section overrides [training] for overlapping parameters.
- Model-specific configs ([m3], [reranker]) override [training]/[hardware] for their models.
"""
# Load configuration from centralized config system
config = self._get_config_safely()
# Apply configuration values with fallbacks
self._apply_config_values(config)
# Setup device detection
self._setup_device()
def _get_config_safely(self):
"""Safely get configuration with multiple fallback strategies"""
try:
# Try absolute import
from utils.config_loader import get_config
return get_config()
except ImportError:
try:
# Try absolute import (for scripts)
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
return get_config()
except ImportError:
try:
# Try direct import (when run as script from root)
from utils.config_loader import get_config
return get_config()
except ImportError:
logger.error("Could not import config_loader, using fallback configuration. Please check your environment and config loader path.")
return self._create_fallback_config()
def _create_fallback_config(self):
"""Create a minimal fallback configuration object"""
class FallbackConfig:
def __init__(self):
self._config = {
'training': {
'default_batch_size': 4,
'default_learning_rate': 1e-5,
'default_num_epochs': 3,
'default_warmup_ratio': 0.1,
'default_weight_decay': 0.01,
'gradient_accumulation_steps': 4
},
'model_paths': {
'cache_dir': './cache/data'
},
'data': {
'max_seq_length': 512,
'max_query_length': 64,
'max_passage_length': 512,
'random_seed': 42
},
'hardware': {
'dataloader_num_workers': 4,
'dataloader_pin_memory': True,
'per_device_train_batch_size': 4,
'per_device_eval_batch_size': 8
},
'optimization': {
'fp16': True,
'bf16': False,
'gradient_checkpointing': True,
'max_grad_norm': 1.0,
'adam_epsilon': 1e-8,
'adam_beta1': 0.9,
'adam_beta2': 0.999
},
'checkpoint': {
'save_steps': 1000,
'save_total_limit': 3,
'load_best_model_at_end': True,
'metric_for_best_model': 'eval_loss',
'greater_is_better': False
},
'logging': {
'log_dir': './logs'
}
}
def get(self, key, default=None):
keys = key.split('.')
value = self._config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
def get_section(self, section):
return self._config.get(section, {})
return FallbackConfig()
def _apply_config_values(self, config):
"""
Apply configuration values with proper precedence order.
Precedence (lowest to highest):
1. Dataclass field defaults (already set)
2. [training] section values
3. [hardware] section values (override training)
4. Model-specific sections (handled in subclasses)
5. Command-line arguments (handled by training scripts)
This method handles steps 2-3. Model-specific configs are applied
in subclass __post_init__ methods AFTER this base method runs.
"""
try:
# === STEP 2: Apply [training] section (general defaults) ===
training_config = config.get_section('training')
if training_config:
# Only update if value exists in config
if 'default_num_epochs' in training_config:
old_value = self.num_train_epochs
self.num_train_epochs = training_config['default_num_epochs']
if old_value != self.num_train_epochs:
logger.debug(f"[training] num_train_epochs: {old_value} -> {self.num_train_epochs}")
if 'default_batch_size' in training_config:
self.per_device_train_batch_size = training_config['default_batch_size']
logger.debug(f"[training] batch_size set to {self.per_device_train_batch_size}")
if 'default_learning_rate' in training_config:
self.learning_rate = training_config['default_learning_rate']
if 'default_warmup_ratio' in training_config:
self.warmup_ratio = training_config['default_warmup_ratio']
if 'gradient_accumulation_steps' in training_config:
self.gradient_accumulation_steps = training_config['gradient_accumulation_steps']
if 'default_weight_decay' in training_config:
self.weight_decay = training_config['default_weight_decay']
# === STEP 3: Apply [hardware] section (overrides training) ===
hardware_config = config.get_section('hardware')
if hardware_config:
# Hardware settings take precedence over training defaults
if 'dataloader_num_workers' in hardware_config:
self.dataloader_num_workers = hardware_config['dataloader_num_workers']
if 'dataloader_pin_memory' in hardware_config:
self.dataloader_pin_memory = hardware_config['dataloader_pin_memory']
# Batch sizes from hardware override training defaults
if 'per_device_train_batch_size' in hardware_config:
old_value = self.per_device_train_batch_size
self.per_device_train_batch_size = hardware_config['per_device_train_batch_size']
if old_value != self.per_device_train_batch_size:
logger.info(f"[hardware] overrides batch_size: {old_value} -> {self.per_device_train_batch_size}")
if 'per_device_eval_batch_size' in hardware_config:
self.per_device_eval_batch_size = hardware_config['per_device_eval_batch_size']
# === Apply other configuration sections ===
# Model paths
if self.cache_dir is None:
cache_dir = config.get('model_paths.cache_dir')
if cache_dir is not None:
self.cache_dir = cache_dir
logger.debug(f"cache_dir set from config: {self.cache_dir}")
# Data configuration
data_config = config.get_section('data')
if data_config:
if 'max_seq_length' in data_config:
self.max_seq_length = data_config['max_seq_length']
if 'max_query_length' in data_config:
self.query_max_length = data_config['max_query_length']
if 'max_passage_length' in data_config:
self.passage_max_length = data_config['max_passage_length']
if 'random_seed' in data_config:
self.seed = data_config['random_seed']
# Optimization configuration
opt_config = config.get_section('optimization')
if opt_config:
if 'fp16' in opt_config:
self.fp16 = opt_config['fp16']
if 'bf16' in opt_config:
self.bf16 = opt_config['bf16']
if 'gradient_checkpointing' in opt_config:
self.gradient_checkpointing = opt_config['gradient_checkpointing']
if 'max_grad_norm' in opt_config:
self.max_grad_norm = opt_config['max_grad_norm']
if 'adam_epsilon' in opt_config:
self.adam_epsilon = opt_config['adam_epsilon']
if 'adam_beta1' in opt_config:
self.adam_beta1 = opt_config['adam_beta1']
if 'adam_beta2' in opt_config:
self.adam_beta2 = opt_config['adam_beta2']
# Checkpoint configuration
checkpoint_config = config.get_section('checkpoint')
if checkpoint_config:
if 'save_steps' in checkpoint_config:
self.save_steps = checkpoint_config['save_steps']
if 'save_total_limit' in checkpoint_config:
self.save_total_limit = checkpoint_config['save_total_limit']
if 'load_best_model_at_end' in checkpoint_config:
self.load_best_model_at_end = checkpoint_config['load_best_model_at_end']
if 'metric_for_best_model' in checkpoint_config:
self.metric_for_best_model = checkpoint_config['metric_for_best_model']
if 'greater_is_better' in checkpoint_config:
self.greater_is_better = checkpoint_config['greater_is_better']
# Logging configuration
logging_config = config.get_section('logging')
if logging_config:
if 'log_dir' in logging_config:
self.logging_dir = logging_config['log_dir']
except Exception as e:
logger.warning(f"Error applying configuration values: {e}")
def _setup_device(self):
"""Enhanced device detection with Ascend NPU support"""
if self.device == "auto":
device_info = self._detect_available_devices()
self.device = device_info['device']
self.n_gpu = device_info['count']
logger.info(f"Auto-detected device: {self.device} (count: {self.n_gpu})")
else:
# Manual device specification
if self.device == "cuda" and torch.cuda.is_available():
self.n_gpu = torch.cuda.device_count()
elif self.device == "npu":
try:
import torch_npu # type: ignore[import]
npu = getattr(torch, 'npu', None)
if npu is not None and npu.is_available():
self.n_gpu = npu.device_count()
else:
logger.warning("NPU requested but not available, falling back to CPU")
self.device = "cpu"
self.n_gpu = 0
except ImportError:
logger.warning("NPU requested but torch_npu not installed, falling back to CPU")
self.device = "cpu"
self.n_gpu = 0
else:
self.n_gpu = 0
def _detect_available_devices(self) -> Dict[str, Any]:
"""Comprehensive device detection with priority ordering"""
device_priority = ['npu', 'cuda', 'mps', 'cpu']
available_devices = []
# Check Ascend NPU
try:
import torch_npu # type: ignore[import]
npu = getattr(torch, 'npu', None)
if npu is not None and npu.is_available():
count = npu.device_count()
available_devices.append(('npu', count))
logger.debug(f"Ascend NPU available: {count} devices")
except ImportError:
logger.debug("Ascend NPU library not available")
# Check CUDA
if torch.cuda.is_available():
count = torch.cuda.device_count()
available_devices.append(('cuda', count))
logger.debug(f"CUDA available: {count} devices")
# Check MPS (Apple Silicon)
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and hasattr(torch.backends.mps, 'is_available') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
available_devices.append(('mps', 1))
logger.debug("Apple MPS available")
# CPU always available
available_devices.append(('cpu', 1))
# Select best available device based on priority
for device_type in device_priority:
for available_type, count in available_devices:
if available_type == device_type:
return {'device': device_type, 'count': count}
return {'device': 'cpu', 'count': 1} # Fallback
def validate(self):
"""Comprehensive configuration validation"""
errors = []
# Check required paths
if self.train_data_path and not os.path.exists(self.train_data_path):
errors.append(f"Training data not found: {self.train_data_path}")
if self.eval_data_path and not os.path.exists(self.eval_data_path):
errors.append(f"Evaluation data not found: {self.eval_data_path}")
# Validate numeric parameters
if self.learning_rate is None or self.learning_rate <= 0:
errors.append("Learning rate must be positive and not None")
if self.num_train_epochs is None or self.num_train_epochs <= 0:
errors.append("Number of epochs must be positive and not None")
if self.per_device_train_batch_size is None or self.per_device_train_batch_size <= 0:
errors.append("Training batch size must be positive and not None")
if self.gradient_accumulation_steps is None or self.gradient_accumulation_steps <= 0:
errors.append("Gradient accumulation steps must be positive and not None")
if self.warmup_ratio is None or self.warmup_ratio < 0 or self.warmup_ratio > 1:
errors.append("Warmup ratio must be between 0 and 1 and not None")
# Validate device
if self.device not in ['auto', 'cpu', 'cuda', 'npu', 'mps']:
errors.append(f"Invalid device: {self.device}")
# Check sequence lengths
if self.max_seq_length is None or self.max_seq_length <= 0:
errors.append("Maximum sequence length must be positive and not None")
if self.query_max_length is None or self.query_max_length <= 0:
errors.append("Maximum query length must be positive and not None")
if self.passage_max_length is None or self.passage_max_length <= 0:
errors.append("Maximum passage length must be positive and not None")
if errors:
raise ValueError(f"Configuration validation failed:\n" + "\n".join(f" - {error}" for error in errors))
logger.info("Configuration validation passed")
return True
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary with type handling"""
config_dict = {}
for key, value in self.__dict__.items():
if isinstance(value, (str, int, float, bool, type(None))):
config_dict[key] = value
elif isinstance(value, (list, tuple)):
config_dict[key] = list(value)
elif isinstance(value, dict):
config_dict[key] = dict(value)
else:
config_dict[key] = str(value)
return config_dict
def save(self, save_path: str):
"""Save configuration to file with error handling"""
try:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
import json
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save configuration: {e}")
raise
@classmethod
def load(cls, load_path: str):
"""Load configuration from file with validation"""
try:
import json
with open(load_path, 'r', encoding='utf-8') as f:
config_dict = json.load(f)
# Filter only valid field names
import inspect
sig = inspect.signature(cls)
valid_keys = set(sig.parameters.keys())
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
config = cls(**filtered_dict)
config.validate() # Validate after loading
logger.info(f"Configuration loaded from {load_path}")
return config
except Exception as e:
logger.error(f"Failed to load configuration from {load_path}: {e}")
raise
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]):
"""Create config from dictionary with validation"""
try:
# Filter only valid field names
import inspect
sig = inspect.signature(cls)
valid_keys = set(sig.parameters.keys())
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
config = cls(**filtered_dict)
config.validate() # Validate after creation
return config
except Exception as e:
logger.error(f"Failed to create configuration from dictionary: {e}")
raise
def update_from_args(self, args):
"""Update config from command line arguments with validation"""
try:
updated_fields = []
for key, value in vars(args).items():
if hasattr(self, key) and value is not None:
old_value = getattr(self, key)
setattr(self, key, value)
updated_fields.append(f"{key}: {old_value} -> {value}")
if updated_fields:
logger.info("Updated configuration from command line arguments:")
for field in updated_fields:
logger.info(f" {field}")
# Validate after updates
self.validate()
except Exception as e:
logger.error(f"Failed to update configuration from arguments: {e}")
raise
def get_effective_batch_size(self) -> int:
"""Calculate effective batch size considering gradient accumulation"""
return self.per_device_train_batch_size * self.gradient_accumulation_steps * max(1, self.n_gpu)
def get_training_steps(self, dataset_size: int) -> int:
"""Calculate total training steps"""
steps_per_epoch = dataset_size // self.get_effective_batch_size()
return steps_per_epoch * self.num_train_epochs
def print_effective_config(self):
"""Print the effective value and its source for each parameter."""
print("Parameter | Value | Source")
print("----------|-------|-------")
for field in self.__dataclass_fields__:
value = getattr(self, field)
# For now, just print the value; source tracking can be added if needed
print(f"{field} | {value} | (see docstring for precedence)")
def __str__(self) -> str:
"""String representation of configuration"""
return f"BaseConfig(device={self.device}, batch_size={self.per_device_train_batch_size}, lr={self.learning_rate})"
def __repr__(self) -> str:
return self.__str__()
def get_device(self):
"""Return the device string for training (cuda, npu, mps, cpu)"""
if torch.cuda.is_available():
return "cuda"
npu = getattr(torch, "npu", None)
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
return "npu"
backends = getattr(torch, "backends", None)
if backends is not None and hasattr(backends, "mps") and backends.mps.is_available():
return "mps"
return "cpu"
def is_npu(self):
"""Return True if NPU is available and selected as device."""
npu = getattr(torch, "npu", None)
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
device = self.get_device() if hasattr(self, 'get_device') else "cpu"
return device == "npu"
return False

254
config/m3_config.py Normal file
View File

@ -0,0 +1,254 @@
"""
BGE-M3 specific configuration
"""
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from .base_config import BaseConfig
import logging
logger = logging.getLogger(__name__)
@dataclass
class BGEM3Config(BaseConfig):
"""
Configuration for BGE-M3 model and trainer.
"""
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
teacher_temperature: float = 1.0 # Temperature for teacher distillation
kd_loss_weight: float = 1.0 # Weight for KD loss
kd_loss_type: str = "mse" # Type of KD loss (e.g., mse, kl)
hard_negative_margin: float = 0.0 # Margin for hard negative mining
"""Configuration for BGE-M3 embedding model fine-tuning
Attributes:
model_name_or_path: Path or identifier for the pretrained BGE-M3 model.
use_dense: Enable dense embedding head.
use_sparse: Enable sparse (SPLADE-style) embedding head.
use_colbert: Enable ColBERT multi-vector head.
unified_finetuning: Unified fine-tuning for all functionalities.
sentence_pooling_method: Pooling method for sentence embeddings (e.g., 'cls').
normalize_embeddings: Whether to normalize output embeddings.
colbert_dim: Output dimension for ColBERT head (-1 uses model hidden size).
sparse_top_k: Top-k for sparse representations.
temperature: InfoNCE loss temperature.
train_group_size: Number of passages per query in training.
negatives_cross_device: Share negatives across devices/GPUs.
use_hard_negatives: Enable hard negative mining.
num_hard_negatives: Number of hard negatives to mine.
hard_negative_mining_steps: Steps between hard negative mining refreshes.
ance_negative_range: Range for ANCE negative mining.
ance_margin: Margin for ANCE mining.
use_self_distill: Enable self-knowledge distillation.
self_distill_temperature: Temperature for self-distillation.
self_distill_alpha: Alpha (weight) for self-distillation loss.
use_query_instruction: Add instruction to queries.
query_instruction_for_retrieval: Instruction for retrieval queries.
query_instruction_for_reranking: Instruction for reranking queries.
augment_queries: Enable query augmentation.
augmentation_methods: List of augmentation methods to use.
create_hard_negatives: Create hard negatives via augmentation.
dense_weight: Loss weight for dense head.
sparse_weight: Loss weight for sparse head.
colbert_weight: Loss weight for ColBERT head.
distillation_weight: Loss weight for distillation.
eval_strategy: Evaluation strategy ('epoch', 'steps').
metric_for_best_model: Metric to select best model.
greater_is_better: Whether higher metric is better.
use_gradient_checkpointing: Enable gradient checkpointing.
max_passages_per_query: Max passages per query in memory.
max_query_length: Max query length.
max_passage_length: Max passage length.
max_length: Max total input length (multi-granularity).
"""
# Model specific
model_name_or_path: str = "BAAI/bge-m3"
# Multi-functionality settings
use_dense: bool = True
use_sparse: bool = True
use_colbert: bool = True
unified_finetuning: bool = True
# Architecture settings
sentence_pooling_method: str = "cls" # BGE-M3 uses CLS pooling
normalize_embeddings: bool = True
colbert_dim: int = -1 # -1 means use model hidden size
sparse_top_k: int = 100 # Top-k for sparse representations
# Training specific
temperature: float = 0.02 # InfoNCE temperature (from BGE-M3 paper)
train_group_size: int = 8 # Number of passages per query
negatives_cross_device: bool = True # Share negatives across GPUs
# Hard negative mining
use_hard_negatives: bool = True
num_hard_negatives: int = 15
hard_negative_mining_steps: int = 1000
ance_negative_range: tuple = (10, 100)
ance_margin: float = 0.1
# Self-distillation
use_self_distill: bool = False # Changed default to False - make KD opt-in
self_distill_temperature: float = 4.0
self_distill_alpha: float = 0.5
# Query instruction
use_query_instruction: bool = False
query_instruction_for_retrieval: str = "Represent this sentence for searching relevant passages:"
query_instruction_for_reranking: str = ""
# Data augmentation
augment_queries: bool = False
augmentation_methods: List[str] = field(default_factory=lambda: ["paraphrase", "back_translation"])
create_hard_negatives: bool = False
# Loss weights for multi-task learning
dense_weight: float = 1.0
sparse_weight: float = 0.3
colbert_weight: float = 0.5
distillation_weight: float = 1.0
# Evaluation
eval_strategy: str = "epoch"
metric_for_best_model: str = "eval_recall@10"
greater_is_better: bool = True
# Memory optimization
use_gradient_checkpointing: bool = True
max_passages_per_query: int = 32
# Specific to M3's multi-granularity
max_query_length: int = 64
max_passage_length: int = 512
max_length: int = 8192 # BGE-M3 supports up to 8192 tokens
def __post_init__(self):
"""Load model-specific parameters from [m3] section in config.toml"""
super().__post_init__()
try:
from utils.config_loader import get_config
except ImportError:
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
config = get_config()
self._apply_m3_config(config)
def _apply_m3_config(self, config):
"""
Apply BGE-M3 specific configuration from [m3] section in config.toml.
"""
m3_config = config.get_section('m3')
if not m3_config:
logger.warning("No [m3] section found in config.toml")
return
# Helper to convert tuple from config to list if needed
def get_val(key, default):
if key not in m3_config:
return default
v = m3_config[key]
if isinstance(default, list) and isinstance(v, tuple):
return list(v)
return v
# Only update if key exists in config
if 'use_dense' in m3_config:
self.use_dense = m3_config['use_dense']
if 'use_sparse' in m3_config:
self.use_sparse = m3_config['use_sparse']
if 'use_colbert' in m3_config:
self.use_colbert = m3_config['use_colbert']
if 'unified_finetuning' in m3_config:
self.unified_finetuning = m3_config['unified_finetuning']
if 'sentence_pooling_method' in m3_config:
self.sentence_pooling_method = m3_config['sentence_pooling_method']
if 'normalize_embeddings' in m3_config:
self.normalize_embeddings = m3_config['normalize_embeddings']
if 'colbert_dim' in m3_config:
self.colbert_dim = m3_config['colbert_dim']
if 'sparse_top_k' in m3_config:
self.sparse_top_k = m3_config['sparse_top_k']
if 'temperature' in m3_config:
self.temperature = m3_config['temperature']
if 'train_group_size' in m3_config:
self.train_group_size = m3_config['train_group_size']
if 'negatives_cross_device' in m3_config:
self.negatives_cross_device = m3_config['negatives_cross_device']
if 'use_hard_negatives' in m3_config:
self.use_hard_negatives = m3_config['use_hard_negatives']
if 'num_hard_negatives' in m3_config:
self.num_hard_negatives = m3_config['num_hard_negatives']
if 'hard_negative_mining_steps' in m3_config:
self.hard_negative_mining_steps = m3_config['hard_negative_mining_steps']
if 'ance_negative_range' in m3_config:
self.ance_negative_range = tuple(get_val('ance_negative_range', self.ance_negative_range))
if 'augmentation_methods' in m3_config:
self.augmentation_methods = list(get_val('augmentation_methods', self.augmentation_methods))
if 'ance_margin' in m3_config:
self.ance_margin = m3_config['ance_margin']
if 'use_self_distill' in m3_config:
self.use_self_distill = m3_config['use_self_distill']
if 'self_distill_temperature' in m3_config:
self.self_distill_temperature = m3_config['self_distill_temperature']
if 'self_distill_alpha' in m3_config:
self.self_distill_alpha = m3_config['self_distill_alpha']
if 'use_query_instruction' in m3_config:
self.use_query_instruction = m3_config['use_query_instruction']
if 'query_instruction_for_retrieval' in m3_config:
self.query_instruction_for_retrieval = m3_config['query_instruction_for_retrieval']
if 'query_instruction_for_reranking' in m3_config:
self.query_instruction_for_reranking = m3_config['query_instruction_for_reranking']
if 'augment_queries' in m3_config:
self.augment_queries = m3_config['augment_queries']
if 'create_hard_negatives' in m3_config:
self.create_hard_negatives = m3_config['create_hard_negatives']
if 'dense_weight' in m3_config:
self.dense_weight = m3_config['dense_weight']
if 'sparse_weight' in m3_config:
self.sparse_weight = m3_config['sparse_weight']
if 'colbert_weight' in m3_config:
self.colbert_weight = m3_config['colbert_weight']
if 'distillation_weight' in m3_config:
self.distillation_weight = m3_config['distillation_weight']
if 'metric_for_best_model' in m3_config:
self.metric_for_best_model = m3_config['metric_for_best_model']
if 'greater_is_better' in m3_config:
self.greater_is_better = m3_config['greater_is_better']
if 'use_gradient_checkpointing' in m3_config:
self.use_gradient_checkpointing = m3_config['use_gradient_checkpointing']
if 'max_passages_per_query' in m3_config:
self.max_passages_per_query = m3_config['max_passages_per_query']
if 'max_query_length' in m3_config:
self.max_query_length = m3_config['max_query_length']
if 'max_passage_length' in m3_config:
self.max_passage_length = m3_config['max_passage_length']
if 'max_length' in m3_config:
self.max_length = m3_config['max_length']
# Log parameter overrides
logger.info("Loaded BGE-M3 config from [m3] section in config.toml")
# Check for required parameters
if self.temperature is None or self.temperature <= 0:
raise ValueError("Temperature must be positive and not None")
if self.train_group_size is None or self.train_group_size < 2:
raise ValueError("train_group_size must be at least 2 and not None")
def validate(self):
"""Validate configuration settings"""
if not any([self.use_dense, self.use_sparse, self.use_colbert]):
raise ValueError("At least one of use_dense, use_sparse, or use_colbert must be True")
if self.unified_finetuning and not all([self.use_dense, self.use_sparse, self.use_colbert]):
logger.warning("unified_finetuning is True but not all functionalities are enabled")
if self.temperature <= 0:
raise ValueError("Temperature must be positive")
if self.train_group_size < 2:
raise ValueError("train_group_size must be at least 2")
if self.max_length > 8192:
logger.warning("max_length > 8192 may not be supported by BGE-M3")
return True

264
config/reranker_config.py Normal file
View File

@ -0,0 +1,264 @@
"""
BGE-Reranker specific configuration
"""
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from .base_config import BaseConfig
import logging
logger = logging.getLogger(__name__)
@dataclass
class BGERerankerConfig(BaseConfig):
"""
Configuration for BGE-Reranker model and trainer.
"""
num_labels: int = 1 # Number of output labels
listwise_temperature: float = 1.0 # Temperature for listwise loss
kd_temperature: float = 1.0 # Temperature for KD loss
kd_loss_weight: float = 1.0 # Weight for KD loss
denoise_confidence_threshold: float = 0.0 # Threshold for denoising
"""Configuration for BGE-Reranker cross-encoder fine-tuning
Attributes:
model_name_or_path: Path or identifier for the pretrained BGE-Reranker model.
add_pooling_layer: Whether to add a pooling layer (not used for cross-encoder).
use_fp16: Use mixed precision (float16) training.
train_group_size: Number of passages per query for listwise training.
loss_type: Loss function type ('listwise_ce', 'pairwise_margin', 'combined').
margin: Margin for pairwise loss.
temperature: Temperature for listwise CE loss.
use_knowledge_distillation: Enable knowledge distillation from teacher model.
teacher_model_name_or_path: Path to teacher model for distillation.
distillation_temperature: Temperature for distillation.
distillation_alpha: Alpha (weight) for distillation loss.
use_hard_negatives: Enable hard negative mining.
hard_negative_ratio: Ratio of hard to random negatives.
use_denoising: Enable denoising strategy.
noise_probability: Probability of noise for denoising.
max_pairs_per_device: Max query-passage pairs per device.
accumulate_gradients_from_all_devices: Accumulate gradients from all devices.
max_seq_length: Max combined query+passage length.
query_max_length: Max query length (if separate encoding).
passage_max_length: Max passage length (if separate encoding).
shuffle_passages: Shuffle passage order during training.
position_bias_cutoff: Top-k passages to consider for position bias.
use_gradient_caching: Enable gradient caching for efficiency.
gradient_cache_batch_size: Batch size for gradient caching.
eval_group_size: Number of candidates during evaluation.
metric_for_best_model: Metric to select best model.
eval_normalize_scores: Whether to normalize scores during evaluation.
listwise_weight: Loss weight for listwise loss.
pairwise_weight: Loss weight for pairwise loss.
filter_empty_passages: Filter out empty passages.
min_passage_length: Minimum passage length to keep.
max_passage_length_ratio: Max ratio between query and passage length.
gradient_checkpointing: Enable gradient checkpointing.
recompute_scores: Recompute scores instead of storing all.
use_dynamic_teacher: Enable dynamic teacher for RocketQAv2-style training.
teacher_update_steps: Steps between teacher updates.
teacher_momentum: Momentum for teacher updates.
"""
# Model specific
model_name_or_path: str = "BAAI/bge-reranker-base"
# Architecture settings
add_pooling_layer: bool = False # Cross-encoder doesn't use pooling
use_fp16: bool = True # Use mixed precision by default
# Training specific
train_group_size: int = 16 # Number of passages per query for listwise training
loss_type: str = "listwise_ce" # "listwise_ce", "pairwise_margin", "combined"
margin: float = 1.0 # For pairwise margin loss
temperature: float = 1.0 # For listwise CE loss
# Knowledge distillation
use_knowledge_distillation: bool = False
teacher_model_name_or_path: Optional[str] = None
distillation_temperature: float = 3.0
distillation_alpha: float = 0.7
# Hard negative strategies
use_hard_negatives: bool = True
hard_negative_ratio: float = 0.8 # Ratio of hard to random negatives
use_denoising: bool = False # Denoising strategy from paper
noise_probability: float = 0.1
# Listwise training specific
max_pairs_per_device: int = 128 # Maximum query-passage pairs per device
accumulate_gradients_from_all_devices: bool = True
# Cross-encoder specific settings
max_seq_length: int = 512 # Maximum combined query+passage length
query_max_length: int = 64 # For separate encoding (if needed)
passage_max_length: int = 512 # For separate encoding (if needed)
# Position bias mitigation
shuffle_passages: bool = True # Shuffle passage order during training
position_bias_cutoff: int = 10 # Consider position bias for top-k passages
# Advanced training strategies
use_gradient_caching: bool = True # Cache gradients for efficient training
gradient_cache_batch_size: int = 32
# Evaluation specific
eval_group_size: int = 100 # Number of candidates during evaluation
metric_for_best_model: str = "eval_mrr@10"
eval_normalize_scores: bool = True
# Loss combination weights (for combined loss)
listwise_weight: float = 0.7
pairwise_weight: float = 0.3
# Data filtering
filter_empty_passages: bool = True
min_passage_length: int = 10
max_passage_length_ratio: float = 50.0 # Max ratio between query and passage length
# Memory optimization for large group sizes
gradient_checkpointing: bool = True
recompute_scores: bool = False # Recompute instead of storing all scores
# Dynamic teacher (for RocketQAv2-style training)
use_dynamic_teacher: bool = False
teacher_update_steps: int = 1000
teacher_momentum: float = 0.999
def __post_init__(self):
"""Load model-specific parameters from [reranker] section in config.toml"""
super().__post_init__()
try:
from utils.config_loader import get_config
except ImportError:
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
config = get_config()
self._apply_reranker_config(config)
def _apply_reranker_config(self, config):
"""
Apply BGE-Reranker specific configuration from [reranker] section in config.toml.
"""
reranker_config = config.get_section('reranker')
if not reranker_config:
logger.warning("No [reranker] section found in config.toml")
return
# Helper to convert tuple from config to list if needed
def get_val(key, default):
if key not in reranker_config:
return default
v = reranker_config[key]
if isinstance(default, tuple) and isinstance(v, list):
return tuple(v)
if isinstance(default, list) and isinstance(v, tuple):
return list(v)
return v
# Only update if key exists in config
if 'add_pooling_layer' in reranker_config:
self.add_pooling_layer = reranker_config['add_pooling_layer']
if 'use_fp16' in reranker_config:
self.use_fp16 = reranker_config['use_fp16']
if 'train_group_size' in reranker_config:
self.train_group_size = reranker_config['train_group_size']
if 'loss_type' in reranker_config:
self.loss_type = reranker_config['loss_type']
if 'margin' in reranker_config:
self.margin = reranker_config['margin']
if 'temperature' in reranker_config:
self.temperature = reranker_config['temperature']
if 'use_knowledge_distillation' in reranker_config:
self.use_knowledge_distillation = reranker_config['use_knowledge_distillation']
if 'teacher_model_name_or_path' in reranker_config:
self.teacher_model_name_or_path = reranker_config['teacher_model_name_or_path']
if 'distillation_temperature' in reranker_config:
self.distillation_temperature = reranker_config['distillation_temperature']
if 'distillation_alpha' in reranker_config:
self.distillation_alpha = reranker_config['distillation_alpha']
if 'use_hard_negatives' in reranker_config:
self.use_hard_negatives = reranker_config['use_hard_negatives']
if 'hard_negative_ratio' in reranker_config:
self.hard_negative_ratio = reranker_config['hard_negative_ratio']
if 'use_denoising' in reranker_config:
self.use_denoising = reranker_config['use_denoising']
if 'noise_probability' in reranker_config:
self.noise_probability = reranker_config['noise_probability']
if 'max_pairs_per_device' in reranker_config:
self.max_pairs_per_device = reranker_config['max_pairs_per_device']
if 'accumulate_gradients_from_all_devices' in reranker_config:
self.accumulate_gradients_from_all_devices = reranker_config['accumulate_gradients_from_all_devices']
if 'max_seq_length' in reranker_config:
self.max_seq_length = reranker_config['max_seq_length']
if 'query_max_length' in reranker_config:
self.query_max_length = reranker_config['query_max_length']
if 'passage_max_length' in reranker_config:
self.passage_max_length = reranker_config['passage_max_length']
if 'shuffle_passages' in reranker_config:
self.shuffle_passages = reranker_config['shuffle_passages']
if 'position_bias_cutoff' in reranker_config:
self.position_bias_cutoff = reranker_config['position_bias_cutoff']
if 'use_gradient_caching' in reranker_config:
self.use_gradient_caching = reranker_config['use_gradient_caching']
if 'gradient_cache_batch_size' in reranker_config:
self.gradient_cache_batch_size = reranker_config['gradient_cache_batch_size']
if 'eval_group_size' in reranker_config:
self.eval_group_size = reranker_config['eval_group_size']
if 'metric_for_best_model' in reranker_config:
self.metric_for_best_model = reranker_config['metric_for_best_model']
if 'eval_normalize_scores' in reranker_config:
self.eval_normalize_scores = reranker_config['eval_normalize_scores']
if 'listwise_weight' in reranker_config:
self.listwise_weight = reranker_config['listwise_weight']
if 'pairwise_weight' in reranker_config:
self.pairwise_weight = reranker_config['pairwise_weight']
if 'filter_empty_passages' in reranker_config:
self.filter_empty_passages = reranker_config['filter_empty_passages']
if 'min_passage_length' in reranker_config:
self.min_passage_length = reranker_config['min_passage_length']
if 'max_passage_length_ratio' in reranker_config:
self.max_passage_length_ratio = reranker_config['max_passage_length_ratio']
if 'gradient_checkpointing' in reranker_config:
self.gradient_checkpointing = reranker_config['gradient_checkpointing']
if 'recompute_scores' in reranker_config:
self.recompute_scores = reranker_config['recompute_scores']
if 'use_dynamic_teacher' in reranker_config:
self.use_dynamic_teacher = reranker_config['use_dynamic_teacher']
if 'teacher_update_steps' in reranker_config:
self.teacher_update_steps = reranker_config['teacher_update_steps']
if 'teacher_momentum' in reranker_config:
self.teacher_momentum = reranker_config['teacher_momentum']
# Log parameter overrides
logger.info("Loaded BGE-Reranker config from [reranker] section in config.toml")
# Check for required parameters
if self.loss_type not in ["listwise_ce", "pairwise_margin", "combined", "cross_entropy", "binary_cross_entropy"]:
raise ValueError("loss_type must be one of ['listwise_ce', 'pairwise_margin', 'combined', 'cross_entropy', 'binary_cross_entropy']")
if self.train_group_size is None or self.train_group_size < 2:
raise ValueError("train_group_size must be at least 2 and not None for reranking")
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
def validate(self):
"""Validate configuration settings"""
valid_loss_types = ["listwise_ce", "pairwise_margin", "combined"]
if self.loss_type not in valid_loss_types:
raise ValueError(f"loss_type must be one of {valid_loss_types}")
if self.train_group_size < 2:
raise ValueError("train_group_size must be at least 2 for reranking")
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
if self.hard_negative_ratio < 0 or self.hard_negative_ratio > 1:
raise ValueError("hard_negative_ratio must be between 0 and 1")
if self.max_seq_length > 512 and "base" in self.model_name_or_path:
logger.warning("max_seq_length > 512 may not be optimal for base models")
if self.gradient_cache_batch_size > self.per_device_train_batch_size:
logger.warning("gradient_cache_batch_size > per_device_train_batch_size may not improve efficiency")
return True

85
data/__init__.py Normal file
View File

@ -0,0 +1,85 @@
"""
Integrated data processing and augmentation module for BGE fine-tuning
Now with unified dataset pipeline supporting both embedding and reranker formats
"""
def __getattr__(name):
"""Lazy loading to avoid import errors for missing dependencies"""
# Core dataset classes - from integrated dataset.py
if name == 'BGEDataset':
from .dataset import BGEDataset
return BGEDataset
elif name == 'BGEM3Dataset':
from .dataset import BGEM3Dataset
return BGEM3Dataset
elif name == 'BGERerankerDataset':
from .dataset import BGERerankerDataset
return BGERerankerDataset
elif name == 'JointDataset':
from .dataset import JointDataset
return JointDataset
# Factory functions for dataset creation
elif name == 'create_embedding_dataset':
from .dataset import create_embedding_dataset
return create_embedding_dataset
elif name == 'create_reranker_dataset':
from .dataset import create_reranker_dataset
return create_reranker_dataset
# Data migration utilities
elif name == 'migrate_nested_to_flat':
from .dataset import migrate_nested_to_flat
return migrate_nested_to_flat
# Collate functions
elif name == 'collate_embedding_batch':
from .dataset import collate_embedding_batch
return collate_embedding_batch
elif name == 'collate_reranker_batch':
from .dataset import collate_reranker_batch
return collate_reranker_batch
# Data preprocessing and augmentation
elif name == 'DataPreprocessor':
from .preprocessing import DataPreprocessor
return DataPreprocessor
elif name == 'DataAugmenter':
from .preprocessing import DataAugmenter
return DataAugmenter
elif name == 'QueryAugmenter':
from .augmentation import QueryAugmenter
return QueryAugmenter
elif name == 'PassageAugmenter':
from .augmentation import PassageAugmenter
return PassageAugmenter
elif name == 'HardNegativeAugmenter':
from .augmentation import HardNegativeAugmenter
return HardNegativeAugmenter
elif name == 'CrossLingualAugmenter':
from .augmentation import CrossLingualAugmenter
return CrossLingualAugmenter
# Hard negative mining
elif name == 'ANCEHardNegativeMiner':
from .hard_negative_mining import ANCEHardNegativeMiner
return ANCEHardNegativeMiner
elif name == 'HardNegativeDataAugmenter':
from .hard_negative_mining import HardNegativeDataAugmenter
return HardNegativeDataAugmenter
# Backward compatibility - remove after migration
elif name == 'M3Dataset':
from .dataset import BGEM3Dataset
return BGEM3Dataset
elif name == 'RerankerDataset':
from .dataset import BGERerankerDataset
return BGERerankerDataset
elif name == 'NestedDataset':
# Deprecated - use BGEDataset with legacy_support=True
from .dataset import BGEDataset
return BGEDataset
else:
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

1307
data/augmentation.py Normal file

File diff suppressed because it is too large Load Diff

857
data/dataset.py Normal file
View File

@ -0,0 +1,857 @@
"""
Unified BGE dataset pipeline with comprehensive format support
- Triplets: Query-triplets format (query, pos, neg) for embedding training
- Pairs: Flat pairs format (query, passage, label) for reranker training
- Candidates: Unified format (query, candidates) with threshold-based label assignment
- Legacy: Supports nested formats with automatic conversion to flat pairs
- Enhanced: Score-weighted sampling, hard negative mining, multiple positives
"""
import json
import random
from typing import Dict, List, Optional, Union, Tuple
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
import logging
from tqdm import tqdm
import numpy as np
logger = logging.getLogger(__name__)
class BGEDataset(Dataset):
"""
Unified BGE dataset with automatic format detection and enhanced features
Supports: triplets, pairs, candidates, and legacy nested formats
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 512,
is_train: bool = True,
cache_dir: Optional[str] = None,
legacy_support: bool = True,
format_type: str = "auto", # "triplets", "pairs", "candidates", "legacy_nested", "auto"
candidates_threshold: float = 0.5, # Threshold for candidates format label assignment
cache_in_memory: bool = False, # Cache tokenized data in memory for small datasets
max_cache_size: int = 100000 # Maximum number of samples to cache in memory
):
self.tokenizer = tokenizer
self.query_max_length = query_max_length
self.passage_max_length = passage_max_length
self.is_train = is_train
self.legacy_support = legacy_support
self.format_type = format_type
self.candidates_threshold = candidates_threshold
self.cache_in_memory = cache_in_memory
self.max_cache_size = max_cache_size
self._tokenized_cache = {}
logger.info(f"Loading data from {data_path}")
self.data, self.detected_format = self._load_and_detect_format(data_path)
logger.info(f"Loaded {len(self.data)} examples in {self.detected_format} format")
# Pre-tokenize data if caching is enabled and dataset is small enough
if self.cache_in_memory and len(self.data) <= self.max_cache_size:
logger.info(f"Pre-tokenizing {len(self.data)} samples for in-memory cache...")
self._pretokenize_all()
logger.info("Pre-tokenization complete")
def _load_and_detect_format(self, data_path: str) -> Tuple[List[Dict], str]:
"""Load data and automatically detect format with improved detection logic"""
data = []
format_type = "unknown"
total_lines = 0
with open(data_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading and detecting format"), 1):
if line.strip():
total_lines += 1
try:
item = json.loads(line)
# Detect format on first valid item
if format_type == "unknown":
format_type = self._detect_schema_type(item)
logger.info(f"Detected schema: {format_type}")
# Validate and process item
if self._validate_item(item, format_type):
processed_items = self._process_item(item, format_type)
# Process_item may return multiple items (for candidates format)
if isinstance(processed_items, list):
data.extend(processed_items)
else:
data.append(processed_items)
else:
logger.warning(f"Invalid item at line {line_num}: {item}")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
continue
if not data:
raise ValueError(f"No valid data found in {data_path}")
# Validation check
if total_lines > 10 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(f"Few valid data lines loaded from {data_path}: {len(data)} out of {total_lines}")
return data, format_type
def _detect_schema_type(self, item: Dict) -> str:
"""
Enhanced format detection based on key presence and structure:
- triplets: has 'pos' and 'neg' keys (for embedding training)
- pairs: has 'passage' and 'label' keys (for reranker training)
- candidates: has 'candidates' key with list of candidate objects
- legacy_nested: has 'pos' but no 'label' (backward compatibility)
"""
try:
# Use explicitly set format if provided
if self.format_type != "auto":
return self.format_type
# Ensure item is a dictionary
if not isinstance(item, dict):
raise ValueError(f"Expected dict but got {type(item)}")
# Extract available keys for debugging
available_keys = set(item.keys())
# Check for candidates format first (most specific)
if 'candidates' in item:
# Validate candidates structure
candidates = item.get('candidates', [])
if isinstance(candidates, list) and len(candidates) > 0:
# Check if first candidate has expected structure
first_candidate = candidates[0]
if isinstance(first_candidate, dict) and 'text' in first_candidate:
return "candidates"
# Also support simple string list as candidates
elif isinstance(first_candidate, str):
logger.info("Detected candidates format with string list - will convert to dict format")
return "candidates"
# Check for pairs format (reranker)
if 'passage' in item and 'label' in item:
# Additional check for query field
if 'query' in item:
return "pairs"
else:
logger.warning(f"Found 'passage' and 'label' but missing 'query'. Available keys: {available_keys}")
# Check for triplets format (embedding)
if 'pos' in item:
# Standard triplets format has both pos and neg
if 'neg' in item:
return "triplets"
# Might be triplets with only positives
elif 'query' in item and not ('label' in item or 'passage' in item):
logger.info("Detected triplets format with only positive samples")
return "triplets"
# Legacy nested format
elif self.legacy_support and 'query' in item:
logger.warning("Detected legacy nested format. Converting to flat pairs for reranker training.")
return "legacy_nested"
# Try to provide helpful error message
common_keys = available_keys & {'query', 'pos', 'neg', 'passage', 'label', 'candidates'}
error_msg = f"Unknown data format. Available keys: {available_keys}"
if common_keys:
error_msg += f", Common keys found: {common_keys}"
else:
error_msg += ". No common keys found. Expected 'query' with 'pos'/'neg' (triplets), 'passage'/'label' (pairs), or 'candidates' (candidates format)"
raise ValueError(error_msg)
except Exception as e:
logger.error(f"Format detection failed: {e}")
raise
def _validate_item(self, item: Dict, format_type: str) -> bool:
"""
Validate item based on detected format with comprehensive checks
Returns:
bool: True if item is valid, False otherwise
"""
try:
# Basic type check
if not isinstance(item, dict):
return False
# Format-specific validation
if format_type == "triplets":
# Check required fields
if not all(field in item for field in ['query', 'pos']):
return False
# Validate query
if not isinstance(item['query'], str) or not item['query'].strip():
return False
# Validate positives
if not isinstance(item.get('pos', []), list) or len(item['pos']) == 0:
return False
# Check each positive is a non-empty string
for pos in item['pos']:
if not isinstance(pos, str) or not pos.strip():
return False
# Validate negatives if present
if 'neg' in item:
if not isinstance(item['neg'], list):
return False
for neg in item['neg']:
if not isinstance(neg, str) or not neg.strip():
return False
return True
elif format_type == "pairs":
# Check required fields
if not all(field in item for field in ['query', 'passage', 'label']):
return False
# Validate query and passage
if not isinstance(item['query'], str) or not item['query'].strip():
return False
if not isinstance(item['passage'], str) or not item['passage'].strip():
return False
# Validate label (should be numeric 0 or 1, or float between 0 and 1)
label = item['label']
if isinstance(label, (int, float)):
if not (0 <= label <= 1):
return False
else:
return False
return True
elif format_type == "candidates":
# Check required fields
if not ('query' in item and 'candidates' in item):
return False
# Validate query
if not isinstance(item['query'], str) or not item['query'].strip():
return False
# Validate candidates
if not isinstance(item['candidates'], list) or len(item['candidates']) == 0:
return False
# Check each candidate
for i, candidate in enumerate(item['candidates']):
if not isinstance(candidate, dict):
return False
if 'text' not in candidate:
return False
if not isinstance(candidate['text'], str) or not candidate['text'].strip():
return False
# Validate score if present
if 'score' in candidate:
score = candidate['score']
if not isinstance(score, (int, float)):
return False
return True
elif format_type == "legacy_nested":
# Basic validation for legacy format
if not ('query' in item and 'pos' in item):
return False
if not isinstance(item['query'], str) or not item['query'].strip():
return False
return True
else:
return False
except Exception as e:
logger.debug(f"Validation error: {e}")
return False
def _process_item(self, item: Dict, format_type: str) -> Union[Dict, List[Dict]]:
"""Process item based on format type - may return multiple items for candidates"""
if format_type == "triplets":
return self._process_triplets_item(item)
elif format_type == "pairs":
return self._process_pairs_item(item)
elif format_type == "candidates":
return self._process_candidates_item(item)
elif format_type == "legacy_nested":
return self._process_legacy_nested_item(item)
else:
raise ValueError(f"Unknown format type: {format_type}")
def _process_triplets_item(self, item: Dict) -> Dict:
"""Process triplets format item for embedding training"""
processed = {
'query': item['query'],
'pos': item.get('pos', []),
'neg': item.get('neg', []),
'pos_scores': item.get('pos_scores', [1.0] * len(item.get('pos', []))),
'neg_scores': item.get('neg_scores', [0.0] * len(item.get('neg', []))),
'prompt': item.get('prompt', ''),
'type': item.get('type', 'normal'),
'format': 'triplets'
}
return processed
def _process_pairs_item(self, item: Dict) -> Dict:
"""Process pairs format item for reranker training"""
processed = {
'query': item['query'],
'passage': item['passage'],
'label': item['label'],
'score': item.get('score', 1.0 if item['label'] > 0 else 0.0),
'qid': item.get('qid', ''),
'pid': item.get('pid', ''),
'format': 'pairs'
}
return processed
def _process_candidates_item(self, item: Dict) -> List[Dict]:
"""
Process candidates format item - explode into multiple pairs
Each candidate becomes a separate (query, passage, label) pair
Label is assigned based on score threshold
"""
query = item['query']
candidates = item['candidates']
qid = item.get('qid', '')
pairs = []
for i, candidate in enumerate(candidates):
if isinstance(candidate, dict):
text = candidate.get('text', '')
score = candidate.get('score', 0.0)
pid = candidate.get('pid', f'cand_{i}')
else:
# Handle simple string candidates
text = str(candidate)
score = 1.0 # Default to positive if no score
pid = f'cand_{i}'
# Assign label based on threshold
label = 1 if score >= self.candidates_threshold else 0
pair = {
'query': query,
'passage': text,
'label': label,
'score': score,
'qid': qid,
'pid': pid,
'format': 'pairs', # Convert to pairs format
'source': 'candidates' # Track original source
}
pairs.append(pair)
return pairs
def _process_legacy_nested_item(self, item: Dict) -> List[Dict]:
"""
Process legacy nested format - convert to flat pairs for reranker training
Each pos[i] becomes (query, passage, label=1)
Each neg[j] becomes (query, passage, label=0)
"""
query = item['query']
positives = item.get('pos', [])
negatives = item.get('neg', [])
pos_scores = item.get('pos_scores', [1.0] * len(positives))
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
qid = item.get('qid', '')
pairs = []
# Convert positives
for i, passage in enumerate(positives):
score = pos_scores[i] if i < len(pos_scores) else 1.0
pair = {
'query': query,
'passage': passage,
'label': 1,
'score': score,
'qid': qid,
'pid': f'pos_{i}',
'format': 'pairs', # Convert to pairs format
'source': 'legacy_nested' # Track original source
}
pairs.append(pair)
# Convert negatives
for i, passage in enumerate(negatives):
score = neg_scores[i] if i < len(neg_scores) else 0.0
pair = {
'query': query,
'passage': passage,
'label': 0,
'score': score,
'qid': qid,
'pid': f'neg_{i}',
'format': 'pairs', # Convert to pairs format
'source': 'legacy_nested' # Track original source
}
pairs.append(pair)
return pairs
def _pretokenize_all(self):
"""Pre-tokenize all samples for in-memory caching"""
from tqdm import tqdm
for idx in tqdm(range(len(self.data)), desc="Pre-tokenizing"):
# Tokenize based on format
item = self.data[idx]
if item['format'] == 'triplets':
# Pre-tokenize query and passages for triplets
query = item['query']
query_encoding = self.tokenizer(
query,
max_length=self.query_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
self._tokenized_cache[f"{idx}_query"] = {
'input_ids': query_encoding.input_ids.squeeze(0),
'attention_mask': query_encoding.attention_mask.squeeze(0)
}
elif item['format'] == 'pairs':
# Pre-tokenize query-passage pair
query = item['query']
passage = item['passage']
pair_text = f"{query} [SEP] {passage}"
encoding = self.tokenizer(
pair_text,
max_length=self.query_max_length + self.passage_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
self._tokenized_cache[idx] = {
'input_ids': encoding.input_ids.squeeze(0),
'attention_mask': encoding.attention_mask.squeeze(0)
}
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict:
"""Get item based on detected format"""
item = self.data[idx]
if item['format'] == 'triplets':
return self._get_triplets_item(item, idx)
elif item['format'] == 'pairs':
return self._get_pairs_item(item, idx)
else:
raise ValueError(f"Unknown format: {item['format']}")
def _get_triplets_item(self, item: Dict, idx: int) -> Dict:
"""Get triplets training example for embedding model (contrastive learning)"""
query = item['query']
positives = item['pos']
negatives = item['neg']
pos_scores = item['pos_scores']
neg_scores = item['neg_scores']
prompt = item['prompt']
# Apply prompt if provided
if prompt and self.is_train:
query = f"{prompt}{query}"
# Check if we have cached tokenized query
if f"{idx}_query" in self._tokenized_cache:
query_encodings = self._tokenized_cache[f"{idx}_query"]
else:
# Tokenize query
query_encodings = self.tokenizer(
query,
max_length=self.query_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
query_encodings = {
'input_ids': query_encodings.input_ids.squeeze(0),
'attention_mask': query_encodings.attention_mask.squeeze(0)
}
# Enhanced sampling for contrastive learning
if self.is_train:
# Sample multiple positives for richer contrastive signals
num_positives = min(2, len(positives))
sampled_positives = random.sample(positives, num_positives) if positives else []
# Hard negative mining with score-based sampling
num_negatives = min(6, len(negatives))
if negatives and neg_scores:
# Weight by scores for hard negative selection
neg_weights = np.array(neg_scores)
neg_weights = neg_weights / neg_weights.sum() if neg_weights.sum() > 0 else np.ones_like(neg_weights)
sampled_neg_indices = np.random.choice(
len(negatives), size=min(num_negatives, len(negatives)),
replace=False, p=neg_weights
)
sampled_negatives = [negatives[i] for i in sampled_neg_indices]
sampled_neg_scores = [neg_scores[i] for i in sampled_neg_indices]
else:
sampled_negatives = random.sample(negatives, num_negatives) if negatives else []
sampled_neg_scores = [0.0] * len(sampled_negatives)
passages = sampled_positives + sampled_negatives
labels = [1] * len(sampled_positives) + [0] * len(sampled_negatives)
scores = pos_scores[:len(sampled_positives)] + sampled_neg_scores
else:
# Use all for evaluation
passages = positives + negatives
labels = [1] * len(positives) + [0] * len(negatives)
scores = pos_scores + neg_scores
# Tokenize passages
passage_encodings = []
for passage in passages:
encoding = self.tokenizer(
passage,
max_length=self.passage_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
passage_encodings.append({
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0)
})
# Pad to consistent size (max 8 passages for triplets)
max_passages = 8
while len(passage_encodings) < max_passages:
passage_encodings.append({
'input_ids': torch.zeros(self.passage_max_length, dtype=torch.long),
'attention_mask': torch.zeros(self.passage_max_length, dtype=torch.long)
})
labels.append(-1)
scores.append(0.0)
# Stack tensors
passage_input_ids = torch.stack([p['input_ids'] for p in passage_encodings])
passage_attention_mask = torch.stack([p['attention_mask'] for p in passage_encodings])
return {
'query_input_ids': query_encodings['input_ids'],
'query_attention_mask': query_encodings['attention_mask'],
'passage_input_ids': passage_input_ids,
'passage_attention_mask': passage_attention_mask,
'labels': torch.tensor(labels, dtype=torch.long),
'scores': torch.tensor(scores, dtype=torch.float),
'query_text': query,
'passage_texts': passages + [''] * (max_passages - len(passages)),
'item_idx': idx,
'format': 'triplets',
# Backward compatibility
'positives': positives,
'negatives': negatives
}
def _get_pairs_item(self, item: Dict, idx: int) -> Dict:
"""Get pairs training example for reranker (cross-encoder training)"""
query = item['query']
passage = item['passage']
label = item['label']
score = item['score']
# Check if we have cached tokenized data
if idx in self._tokenized_cache:
encoding = self._tokenized_cache[idx]
return {
'input_ids': encoding['input_ids'],
'attention_mask': encoding['attention_mask'],
'labels': torch.tensor(label, dtype=torch.long),
'scores': torch.tensor(score, dtype=torch.float),
'query_text': query,
'passage_text': passage,
'qid': item.get('qid', ''),
'pid': item.get('pid', ''),
'item_idx': idx,
'format': 'pairs',
'source': item.get('source', 'original')
}
# Tokenize query-passage pair if not cached
encoding = self.tokenizer(
query,
passage,
max_length=self.passage_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0),
'labels': torch.tensor(label, dtype=torch.long),
'scores': torch.tensor(score, dtype=torch.float),
'query_text': query,
'passage_text': passage,
'qid': item['qid'],
'pid': item['pid'],
'item_idx': idx,
'format': 'pairs',
'source': item.get('source', 'original')
}
# Specialized dataset classes for backward compatibility
class BGEM3Dataset(BGEDataset):
"""BGE-M3 specialized dataset (embedding model) with enhanced contrastive learning"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 512,
is_train: bool = True,
train_group_size: int = 8,
cache_dir: Optional[str] = None
):
# Force embedding format
super().__init__(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
cache_dir=cache_dir,
format_type="triplets"
)
self.train_group_size = train_group_size
# Warn if wrong format detected
if self.detected_format not in ['triplets']:
logger.warning(f"Expected triplets format for BGEM3Dataset, got {self.detected_format}")
class BGERerankerDataset(BGEDataset):
"""BGE-Reranker specialized dataset (cross-encoder model) with flat pairs support"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 448,
is_train: bool = True,
train_group_size: int = 16,
cache_dir: Optional[str] = None,
legacy_support: bool = True
):
# Force reranker format (or allow legacy)
format_type = "auto" if legacy_support else "pairs"
super().__init__(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
cache_dir=cache_dir,
legacy_support=legacy_support,
format_type=format_type
)
self.train_group_size = train_group_size
# Warn if wrong format detected
if self.detected_format not in ['pairs', 'legacy_nested']:
logger.warning(f"Expected pairs format for BGERerankerDataset, got {self.detected_format}")
class JointDataset(Dataset):
"""Dataset for joint training of retriever and reranker"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
retriever_config: Dict,
reranker_config: Dict,
is_train: bool = True,
cache_dir: Optional[str] = None
):
self.tokenizer = tokenizer
self.is_train = is_train
# Create both datasets
self.retriever_dataset = BGEM3Dataset(
data_path=data_path,
tokenizer=tokenizer,
is_train=is_train,
cache_dir=cache_dir,
**retriever_config
)
self.reranker_dataset = BGERerankerDataset(
data_path=data_path,
tokenizer=tokenizer,
is_train=is_train,
cache_dir=cache_dir,
**reranker_config
)
def __len__(self) -> int:
return len(self.retriever_dataset)
def __getitem__(self, idx: int) -> Dict:
"""Get example for both retriever and reranker"""
retriever_item = self.retriever_dataset[idx]
reranker_item = self.reranker_dataset[idx]
return {
'retriever': retriever_item,
'reranker': reranker_item,
'item_idx': idx
}
# Factory functions for easy dataset creation
def create_embedding_dataset(
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 512,
is_train: bool = True
) -> BGEDataset:
"""Create dataset specifically for embedding model training"""
dataset = BGEDataset(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
format_type="triplets"
)
if dataset.detected_format not in ['triplets']:
logger.warning(f"Expected triplets format, got {dataset.detected_format}")
return dataset
def create_reranker_dataset(
data_path: str,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 64,
passage_max_length: int = 448,
is_train: bool = True,
legacy_support: bool = True
) -> BGEDataset:
"""Create dataset specifically for reranker model training"""
dataset = BGEDataset(
data_path=data_path,
tokenizer=tokenizer,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
is_train=is_train,
legacy_support=legacy_support,
format_type="auto" if legacy_support else "pairs"
)
if dataset.detected_format not in ['pairs', 'legacy_nested']:
logger.warning(f"Expected pairs format, got {dataset.detected_format}")
return dataset
def migrate_nested_to_flat(input_file: str, output_file: str) -> str:
"""
Migrate legacy nested reranker JSONL to flat pairs format
Explodes each (pos[i], label=1) and (neg[j], label=0) into separate lines
"""
logger.info(f"Migrating nested reranker format: {input_file} -> {output_file}")
flat_pairs = []
with open(input_file, 'r', encoding='utf-8') as f:
for line_no, line in enumerate(tqdm(f, desc="Converting to flat pairs"), 1):
if line.strip():
try:
item = json.loads(line)
query = item['query']
positives = item.get('pos', [])
negatives = item.get('neg', [])
pos_scores = item.get('pos_scores', [1.0] * len(positives))
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
# Create flat pairs for positives
for i, passage in enumerate(positives):
flat_pairs.append({
'query': query,
'passage': passage,
'label': 1,
'score': pos_scores[i] if i < len(pos_scores) else 1.0,
'qid': f"Q{line_no}",
'pid': f"P{line_no}_{i}"
})
# Create flat pairs for negatives
for i, passage in enumerate(negatives):
flat_pairs.append({
'query': query,
'passage': passage,
'label': 0,
'score': neg_scores[i] if i < len(neg_scores) else 0.0,
'qid': f"Q{line_no}",
'pid': f"N{line_no}_{i}"
})
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON on line {line_no}: {e}")
continue
# Save flat pairs
with open(output_file, 'w', encoding='utf-8') as f:
for pair in flat_pairs:
f.write(json.dumps(pair, ensure_ascii=False) + '\n')
logger.info(f"✅ Migrated {len(flat_pairs)} flat pairs to {output_file}")
return output_file
# Collate functions for different formats
def collate_embedding_batch(batch: List[Dict]) -> Dict:
"""Collate function for embedding model batches"""
return {
'query_input_ids': torch.stack([item['query_input_ids'] for item in batch]),
'query_attention_mask': torch.stack([item['query_attention_mask'] for item in batch]),
'passage_input_ids': torch.stack([item['passage_input_ids'] for item in batch]),
'passage_attention_mask': torch.stack([item['passage_attention_mask'] for item in batch]),
'labels': torch.stack([item['labels'] for item in batch]),
'scores': torch.stack([item['scores'] for item in batch]),
'query_text': [item['query_text'] for item in batch],
'passage_texts': [item['passage_texts'] for item in batch],
'item_idx': [item['item_idx'] for item in batch]
}
def collate_reranker_batch(batch: List[Dict]) -> Dict:
"""Collate function for reranker model batches"""
if batch[0]['format'] == 'pairs':
# Flat pairs format
return {
'input_ids': torch.stack([item['input_ids'] for item in batch]),
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
'labels': torch.stack([item['labels'] for item in batch]),
'scores': torch.stack([item['scores'] for item in batch]),
'query_text': [item['query_text'] for item in batch],
'passage_text': [item['passage_text'] for item in batch],
'qid': [item['qid'] for item in batch],
'pid': [item['pid'] for item in batch],
'item_idx': [item['item_idx'] for item in batch]
}
else:
# Legacy nested format
return {
'input_ids': torch.stack([item['input_ids'] for item in batch]),
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
'labels': torch.stack([item['labels'] for item in batch]),
'query_text': [item['query_text'] for item in batch],
'passage_texts': [item['passage_texts'] for item in batch],
'item_idx': [item['item_idx'] for item in batch]
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,400 @@
{"query": "花呗账单金额和还款金额不一致", "pos": ["我的花呗账单是***,还款怎么是***", "我的花呗,月结出来说让我还***元,我自己算了一下详细名单我应该还***元"], "neg": ["蚂蚁借呗还款时,为什么会自动扣款", "为什么我现在花呗额度为负", "花呗什么情况不能使用", "申请借呗界面审核多久", "借呗分期利息怎样算", "蚂蚁借呗不按期还款什么处罚"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.12, 0.135, 0.076, 0.053, 0.062], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "支付宝找不到花呗入口", "pos": ["支付宝系统点我的里面没有花呗这一项", "我下载支付宝怎么没有花呗的"], "neg": ["蚂蚁借呗必须使用本人名下储蓄卡才能借款吗", "花呗,已经还款,怎么还有显示没还款", "我的花呗刷脸为什么不成功", "几个小时借呗能到账", "今天用了花呗是下个月才还款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.071, 0.061, 0.163, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗支持的商品", "pos": ["花呗标示的商品", "花呗识别的商品"], "neg": ["蚂蚁借呗这么没有了", "借呗还款单笔支付宝限额", "花呗自动扣款是在哪里扣", "花呗授权怎么设置", "有蚂蚁借呗怎么没额度", "蚂蚁借呗只能选择***个月的期限"], "pos_scores": [1.0, 0.95], "neg_scores": [0.103, 0.087, 0.198, 0.102, 0.2, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗超额消费或负额度的后果", "pos": ["花呗消费超过额度有什么影响吗", "花呗额度成负数有啥影响吗"], "neg": ["为什么我的蚂蚁花呗无法为游戏充值了", "在么花呗分期后可以一次性还么", "花呗可以通过哪些途径消费", "花呗分期还款利息如何算", "使用花呗分期后退货", "花呗和借呗有什么事区别", "花呗是否有手续费和利息", "蚂蚁借呗可以还***期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.112, 0.058, 0.09, 0.165, 0.06, 0.096, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还清后仍显示待还款", "pos": ["还款还清了,为什么花呗账单显示还要还款", "花呗全额还清怎么显示没有还款"], "neg": ["我的借呗关闭了是什么情况", "花呗交易记录被删怎么退款", "小蓝单车的押金用的花呗", "花呗 还好了", "花呗还款逾期***天要利息不", "花呗退款后还款额未变"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.079, 0.189, 0.087, 0.097, 0.148], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗评估周期", "pos": ["借呗一般是多久评估", "大概多久评估一次借呗条件"], "neg": ["如何取消借呗自动还款", "花呗可以在银行卡扣款不", "芝麻信用是不是借呗信用额度", "花呗如果逾期几天可以吗", "之前的支付宝花呗 我已经关闭了", "你能帮查下我的花呗吗?点不进去的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.076, 0.134, 0.188, 0.095, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗无法使用原因", "pos": ["么,这个是不是我现在不能不能用花呗", "我的花呗不能用了是吧"], "neg": ["借呗会影响证信吗", "借呗分期还款可以在第一个月还清吗", "花呗预留号码", "花呗被冻结是否能帮忙解开", "怎么改变花呗扣款日期", "我的花呗信用卡收钱码跟我支付宝收钱码一样吗", "我用花呗东西的话,忘记还了,他会自动从余额宝里面扣吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.099, 0.17, 0.133, 0.123, 0.081, 0.138, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗客服电话", "pos": ["把你们花呗的电话给我", "给花呗打电话"], "neg": ["花呗是使用多少还多少还是还额度", "花呗的账单我已还为什么上边显示还款的金额没有变", "借呗提前还款可以提升信用分吗", "花呗分期退款为什么少钱", "余额宝有钱为什么蚂蚁借呗到还款期不扣钱", "花呗算是借钱,这个月借下个月还吗", "花呗是商家开通还是卖家", "商家花呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.184, 0.08, 0.19, 0.071, 0.14, 0.095, 0.124, 0.188], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "不使用花呗是否需要还款", "pos": ["要是不用花呗 还需要还吗", "不用花呗里面的钱 是不是就可以不还款了"], "neg": ["小号不开花呗的情况下,客人能不能用花呗支付", "花呗额度*** 怎么在淘宝里使用", "我没取消授权,我的借贝怎么不见了", "花呗已还款,退款又打回花呗", "花呗什么时候邀请开通", "网商银行还借呗", "每次使用花呗后有一个抽奖的页面在哪里"], "pos_scores": [1.0, 0.95], "neg_scores": [0.183, 0.113, 0.084, 0.139, 0.089, 0.171, 0.131], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "交易不支持花呗付款原因", "pos": ["当前交易不支持花呗付款怎么回事", "该付款不支持花呗是什么意思"], "neg": ["花呗怎么删除账单", "花呗有逾期记录了,还能使用吗", "花呗***号还款逾期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.143, 0.143], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期开始时间", "pos": ["花呗分期是从什么时候开始", "花呗分期的钱什么时候扣除"], "neg": ["开通花呗不用有什么费用", "花呗上月只还最低,本月提前还款,利息怎么算", "***w额度免费提现是跟蚂蚁借呗一样吗", "花呗收款,别人怎么付不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.063, 0.085, 0.183, 0.192], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗开通额度条件", "pos": ["借呗信用到多少可以借款", "达到多少额度可以申请借呗"], "neg": ["我怎么能暂时关闭花呗", "花呗分期付款期间还可以涨额度吗", "怎么关 花呗", "花呗超出部分怎么还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.111, 0.093, 0.168], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗借款未到账", "pos": ["昨天晚上借呗借了***,怎么还没到账", "申请的蚂蚁借呗怎么还没到账了"], "neg": ["我现在不知道怎么用上花呗", "为何开通了花呗,支付宝无显示", "花呗分期提前还完后还可分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.082, 0.104, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "网商贷转借呗方法", "pos": ["网商贷换回借呗贷……晕", "网商贷要怎么换借呗"], "neg": ["本款衣服有标识花呗支付,为什么不能", "花呗欠款逾期,我现在花呗被关闭了,要怎么还款", "花呗***分***期每月给多少钱", "第一次认证花呗没有成功 要多久又才可以认证", "为什么说我 花呗操作失误", "我想知道花呗临时额度多少", "借呗必须还了债才能再借吗", "我的花呗已经全部还款了,什么时候可以恢复"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.173, 0.06, 0.163, 0.056, 0.141, 0.094, 0.163], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款后无法使用", "pos": ["账单都还清楚了!怎么花呗不能用了", "为什么我的花呗还款后还不能用"], "neg": ["注意主动申请蚂蚁借呗", "我打算在淘宝上买,能用花呗吗", "需求人工客服", "花呗能不能转账到另一个账号的", "花呗充值到支付宝账号", "花呗不支持充游戏吗", "花呗上的临时额度是什么", "为什么花呗里面提醒我打开允许读取联系人我打开了允许读取联系人以后还不行"], "pos_scores": [1.0, 0.95], "neg_scores": [0.093, 0.176, 0.063, 0.107, 0.067, 0.071, 0.125, 0.176], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗消费记录查询", "pos": ["我能查看用花呗买什么东西吗", "花呗都能购买那些东西"], "neg": ["我的有花呗额度挺多", "蚂蚁借呗怎么关闭借款功能了", "开通了花呗信用卡怎么还不能收款", "借呗日利率怎么制定的", "我这个花呗怎么退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.059, 0.105, 0.096, 0.084, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期提前还款规则", "pos": ["花呗分期提前还完不可以吗", "花呗分期能提前全部还清"], "neg": ["借呗还款日期可以根改吗", "子账户花呗怎么取消", "如何能够快速的使用借呗", "我花呗可以给我开通了吗", "为什么二维码收款用花呗的才几百元", "为什么我的花呗还了还有账单", "蚂蚁借呗当前操作可在风险"], "pos_scores": [1.0, 0.95], "neg_scores": [0.074, 0.059, 0.196, 0.143, 0.176, 0.059, 0.17], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗天猫券领取位置", "pos": ["花呗抽天猫券放哪了", "花呗抽了***元天猫抵用券到哪去了"], "neg": ["花呗的消费在哪里能查得到", "花呗额度怎么涨的", "邮政储蓄卡不能使用借呗", "花呗分期付款有利息吗", "找不到蚂蚁借呗还款记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.085, 0.115, 0.077, 0.188], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗自动扣款失败", "pos": ["借呗借呗,今天还没扣款", "借呗还款金额未扣"], "neg": ["蚂蚁花呗额度查询", "我的花呗不能用是怎么回事", "为什么我的蚂蚁借呗分期还款,现在换不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.087, 0.099, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗支付产生蚂蚁森林能量", "pos": ["用蚂蚁花呗缴费产生能量球吗", "使用花呗能收集蚂蚁深林能量不"], "neg": ["我这花呗为什么冻结了", "主申请蚂蚁借呗", "如果退花呗的压金", "我蚂蚁借呗什么时候还钱", "不是淘宝绑定的支付宝可以用花呗么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.145, 0.185, 0.122, 0.15], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗绑定手机号不一致", "pos": ["我点花呗,怎么显示之前绑定手机号", "花呗手机号和绑定手机号不一致"], "neg": ["商家收款,花呗有手续费吗", "花呗提前结清了,但是还有未还进去的,怎么把他还进去", "我想知道怎么还花呗的款", "借呗还款怎么还没扣", "花呗额度为负值还能继续使用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.156, 0.14, 0.158, 0.183, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗充值话费入口", "pos": ["花呗充话费在哪个界面", "在那里用花呗充值话费"], "neg": ["关闭花呗多久能在开通", "我的押金是花呗教的,退出来之后为什么还要还花呗的***", "买东西要用花呗分期付款,为什么提示额度不够用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.131, 0.095], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "加油卡能否用花呗支付", "pos": ["油卡充值可以使用花呗吗", "加油卡花呗可以不"], "neg": ["如何看我的蚂蚁借呗账单", "我想查花呗怎么查", "用蚂蚁借呗借了钱怎么还", "蚂蚁借呗我想推迟几天还款", "我的花呗现在自己用不了了", "借呗还款自动扣还是", "花呗可以微信还款不可以"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.106, 0.165, 0.057, 0.178, 0.132, 0.064], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "无法开通花呗解决方法", "pos": ["说暂时无法开通花呗", "我开通不了花呗"], "neg": ["蚂蚁花呗体验卡卷在哪", "该商户不支持花呗支付 怎么办", "花呗最多分期几次", "花呗账单钱对不上", "和花呗付款后蚂蚁森林怎么不产生能量球"], "pos_scores": [1.0, 0.95], "neg_scores": [0.145, 0.137, 0.161, 0.171, 0.158], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭花呗收款服务", "pos": ["花呗收钱钱码关闭", "如何关闭信用卡和花呗收款服务"], "neg": ["我的花呗什么时候才能再开通", "我是用花呗交的话费,是不是要等自己还了才能退", "我要花呗退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.095, 0.199], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗使用教程", "pos": ["花呗是怎么一回事吗", "花呗怎么玩吗"], "neg": ["蚂蚁借呗可以借支吗", "花呗额度支付不够", "我那花呗还了,什么又显示还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.2, 0.159, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗是否产生蚂蚁森林能量", "pos": ["用花呗 蚂蚁森林能产生能量吗", "花呗会产生能量吗"], "neg": ["不知道花呗有什么好处", "我的蚂蚁借呗怎么没有了", "开通了花呗不使用要扣钱吗", "过了***点怎么借呗还款", "花呗总共还有多少负债"], "pos_scores": [1.0, 0.95], "neg_scores": [0.144, 0.095, 0.185, 0.068, 0.085], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗自动还款说明", "pos": ["自动还款 花呗是什么", "这花呗自动还款怎么回事"], "neg": ["为什么暂时开通不了花呗", "两个账号只有一个能使用花呗", "借呗可以就是后面一点还吗拖点时间", "蚂蚁借呗提额要不要输密码"], "pos_scores": [1.0, 0.95], "neg_scores": [0.18, 0.086, 0.074, 0.122], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通余额宝代扣花呗", "pos": ["开通余额宝代扣功能还款蚂蚁花呗", "蚂蚁花呗余额宝代扣协议怎么开通"], "neg": ["为什么我的借呗不能提高额度", "芝麻分要多少才可以用花呗", "花呗的账单周期是每月几号到几号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.119, 0.055, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗误发欠款通知", "pos": ["我的花呗怎么会欠款泥", "我都没欠花呗钱,怎么老是发信息来"], "neg": ["一个营业执照可以开通几个支付宝信用卡花呗收款", "我花呗人脸识别一直失败", "花呗几号出帐单的", "我上次申请借呗的时间", "已经还款的花呗账单明细怎么查"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.126, 0.09, 0.094, 0.153], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通花呗信用分要求", "pos": ["开通蚂蚁花呗需要多少分", "花呗要多少信用额度才可以开通"], "neg": ["花呗退回来的钱去哪里了", "我想提升借呗都可以吗", "花呗立减十元", "为什么开通花呗要常"], "pos_scores": [1.0, 0.95], "neg_scores": [0.068, 0.177, 0.164, 0.124], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收货时间与还款关系", "pos": ["我用蚂蚁花呗确认收货的时间是不一样的,还款期限也是一样的嘛", "确认收货时间与花呗还款时间有关吗"], "neg": ["为什么花呗不能使用共享单车了", "为什么借呗提示网商贷", "我昨天晚上给151******11号交了***元花费怎么还没到账", "我的花呗双***的时候临时额度是***现在为什么就没有了", "我的收款码怎么用不了花呗", "花呗能从银行卡直接扣除还款吗", "我的花呗额度用完了 怎么还能用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.13, 0.159, 0.052, 0.168, 0.075, 0.056], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "提前还清花呗分期", "pos": ["怎么样提前还花呗的分期账单", "怎么样提前还清花呗分期"], "neg": ["借呗交易可以删除吗", "的是蚂蚁借呗为什么关了,我之前还是开着的", "花呗还完款了,怎么还让还款", "之前的支付宝不用了 花呗怎么办", "花呗没按期还", "在那里提升花呗临时额度", "用花呗买的东西。现在退货了。怎么退款的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.193, 0.132, 0.122, 0.1, 0.145, 0.196, 0.148], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗最低还款影响信用吗", "pos": ["花呗最低还款会不会扣信誉", "花呗最低还款影响信用吗"], "neg": ["滴滴打车能不能使支付宝花呗了", "我的花呗是否还请清了", "蚂蚁借呗可以提现到支付宝吗,怎样提现", "花呗提前还款手续费怎么算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.165, 0.098, 0.126, 0.13], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "提升借呗额度方法", "pos": ["如何增加?蚂蚁借呗额度", "蚂蚁借呗要怎么样才能提升额度"], "neg": ["为什么花呗会忽然空白看不到额度", "我要在本账户开通花呗 但是另一个账户已经开通了 怎么办", "如何解绑花呗", "花呗脸部识别", "花呗还了一部分 下月可以分期还吗", "蚂蚁借呗分期还一次之后,还可以继续用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.18, 0.154, 0.062, 0.112, 0.093], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查询花呗总还款额", "pos": ["花呗怎么的看借款总金额", "怎样看花呗总需还款额"], "neg": ["蚂蚁借呗自功还款系统升级我怎样还款", "刚刚我用花呗付款了,要怎么还款", "怎么在手机上使用花呗收款", "开通花呗收款后二维码怎么用", "我不想用花呗,如何退掉", "还花呗。为啥。额度满了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.144, 0.139, 0.144, 0.091, 0.14, 0.197], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "冻结花呗额度", "pos": ["可以帮我冻结花呗吗", "里冻结花呗额度"], "neg": ["蚂蚁借呗什么时候通过审核", "我花呗逾期了,现在已经还清了,什么时候可以用", "借呗分期逾期以后 下次还款需要一次性还清吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.17, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "个人收款码不能用花呗", "pos": ["扫个人收钱码,为什么不能用花呗", "收钱码为什么不能使用花呗"], "neg": ["我的花呗不小心分了***期", "花呗要怎昵解冻结", "怎么永久关闭借呗", "刚刚还蚂蚁花呗,银行卡扣了费,怎么花呗没有清楚", "如何把借呗删了", "为啥没有借呗信用额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.109, 0.158, 0.053, 0.059, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款是否收费", "pos": ["花呗退款还要收费吗", "花呗退的款有利息吗"], "neg": ["如何更改借呗里预留的手机号", "我注销这个账户以后,我另外一个账户就可以申请花呗了吗", "我想借蚂蚁借呗为什么收不到手机验证码", "我上个月把花呗不知道是冻结还是调了。现在怎么恢复之前额度", "为什么我滴奖励金不能翻倍?用了一个月了,就出现过两次翻倍机会,其它时候没有翻倍?我一直设置的付款方式是花呗,余额宝"], "pos_scores": [1.0, 0.95], "neg_scores": [0.102, 0.141, 0.179, 0.157, 0.123], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通花呗条件", "pos": ["我怎么还没花呗", "我的花呗还没有开户"], "neg": ["花呗不能支付共享单车押金我", "我的借呗还款日几号", "怎么没有借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.155, 0.09], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款日当天还款", "pos": ["蚂蚁花呗还钱,,***号之前还款,那九号当天还可以吗", "花呗还款是***号前,***号当天还可以吧"], "neg": ["借呗的开通和信用分数有关吗", "这两三个月可能借呗还不起款", "花呗额度会随着还款而变吗", "我从来没有延迟还过钱,为什么把我的蚂蚁借呗给停用了", "借呗还款了还可以借出来么", "花呗这个账号有,另一个账号没有"], "pos_scores": [1.0, 0.95], "neg_scores": [0.077, 0.179, 0.12, 0.077, 0.053, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款日逾期判定", "pos": ["我的花呗最后还款日为***号,我十号还算不算逾期", "花呗***号还款 十号还中不"], "neg": ["花呗不够用怎么申请", "花呗可以每天还款不", "花呗如何退款", "关闭了借呗怎么在开通", "花呗不能付单车押金", "怎么提前还款一个月的借呗", "花呗充话费!充到空号了!花呗会不会自动退回"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.079, 0.135, 0.145, 0.194, 0.132, 0.146], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查询花呗消费明细", "pos": ["我的花呗账单有***我不记得是买什么", "我花呗***块是买什么"], "neg": ["我的蚂蚁借呗额度申请提升了怎么", "为什么我的账户没有条件开通花呗收款", "已经开通花呗 商品也支持 为什么还不能付款", "我怎样可以使用蚂蚁花呗", "怎么授权芝麻花呗", "我已经用银行卡还了花呗的钱,可是支付宝显示没还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.109, 0.081, 0.053, 0.082, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "电脑端开通花呗", "pos": ["用电脑客户端可以开通花呗么", "电脑能开通花呗吗"], "neg": ["花呗付款,后来还款了,可是退款又退款到花呗,钱去哪了", "我的花呗额度变为***是怎么回事", "我问的是我怎么不能用花呗和借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.191, 0.067, 0.133], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗借款失败原因", "pos": ["为什么借呗都用不了", "我在借呗怎么借不了钱"], "neg": ["我的借呗了", "蚂蚁花呗还款了银行卡扣钱了为啥还显示没有还款", "哪些航空公司可以使用花呗", "花呗还款之后,之前用花呗买的火车票退款了,最后没有收到退款", "花呗的期限"], "pos_scores": [1.0, 0.95], "neg_scores": [0.199, 0.174, 0.177, 0.133, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗支付退款流程", "pos": ["怎么要求花呗退款", "花呗支付怎么退款"], "neg": ["支付宝付款摩拜单车用花呗支付", "我是为什么花呗降额", "余额能还借呗么", "之前,蚂蚁借呗不是可以分为***期借贷吗,为什么现在只能分为***期", "我的借呗开了之后怎么又关闭了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.109, 0.114, 0.16, 0.052], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款后金额未更新", "pos": ["我是还没发退款的,退款到了花呗还是金额没变动", "花呗用出去,但是退款了,还款还是木有变的"], "neg": ["花呗,我最多能用多少", "花呗分期后能不能提前部分结清", "我花呗都取消自动首选了", "查询蚂蚁借呗***年***月账单", "花呗分期后一次性付清会有手续费吗", "还不了解花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.185, 0.189, 0.087, 0.12, 0.116, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗利率差异原因", "pos": ["蚂蚁借呗怎么利息不同", "为什么借呗的利息每个人的都不一样"], "neg": ["我是用银行卡还花呗的,扣钱的短信收到了花呗哪里没有少钱", "怎么查花呗明细", "身份信息完善了,为什么还不可以借呗", "怎么查询蚂蚁借呗还款记录利息", "花呗已自动还款,为什么提示还款失败", "为什么借呗还了就不能用了", "如何知道花呗是否还清", "蚂蚁借呗上的提前还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.125, 0.138, 0.081, 0.078, 0.177, 0.138, 0.113], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "密付是否支持花呗", "pos": ["密付可以当使用花呗吗", "密付。能用花呗吗"], "neg": ["注册了网商银行还有没有借呗", "退款是在花呗里再转去银行卡的吗", "为什么不支持花呗付款了", "借唄可以分期嗎", "花呗逾期 以后要还能用吗", "我换手机了我的花呗登不上怎么办", "花呗逾期 借呗能不能继续使用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.08, 0.13, 0.135, 0.114, 0.186, 0.147], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期恢复使用", "pos": ["花呗逾期被停用怎么才能恢复使用", "蚂蚁花呗逾期被停止使用 如何再次使用"], "neg": ["花呗分期不能还全款吗", "从我银行卡提现到花呗一百,说提升额度,提现一百额度没动,那怎么算提升。还是我自己的钱", "花呗要多久才能充值话费", "花呗a帐户关闭如何开通b帐户", "借呗利息多少钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.142, 0.11, 0.19, 0.114], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗利率不同原因", "pos": ["借呗利率为什么和别人的不一样", "为什么我的借呗收万分四"], "neg": ["等多久才可以开通花呗", "花呗邀请好友代还", "什么情况下平台会关闭个人借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.091, 0.183], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗入口消失", "pos": ["我的蚂蚁借呗怎么关闭了", "我的借呗已经有一万六的额度,今天登录支付宝一看怎么没有入口了"], "neg": ["信用卡逾期也影响花呗的使用吗", "花呗大额付款有限制嘛", "花呗分期不够额度怎么办", "花呗开通不使用会收费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.164, 0.185, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通信用卡花呗收款", "pos": ["怎样设置信用卡花呗收款", "怎么开通,信用卡和花呗收款"], "neg": ["花呗红包,专享红包都没法用", "花呗里的钱和余额宝的钱能同时用吗", "刚刚都可以开通借呗", "花呗,可以让朋友带还吗", "花呗***线下支付不了", "借呗是怎么样的还款方式", "花呗收款为什么会显示不满足升级条件"], "pos_scores": [1.0, 0.95], "neg_scores": [0.088, 0.1, 0.172, 0.145, 0.101, 0.186, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗最高额度限制", "pos": ["我想蚂蚁借呗最高金额有上线么", "借呗有最高额度限制吗"], "neg": ["有个有花呗,怎么有一个没有", "为什么我这个没有花呗的", "为什么支付宝花呗还不了款", "为什么我的花呗用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.145, 0.075, 0.111, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "余额宝自动还花呗", "pos": ["如果花呗,忘记还款,自己在余额宝里扣除", "花呗还钱,可以自动从余额宝扣除吗"], "neg": ["蚂蚁花呗不够不能分期", "我的支付宝添加的是我交通银行信用卡,怎么会用了花呗里的钱", "为什么我找人代付成功了花呗却要我还钱", "我用花呗买的东西 退货后 怎么退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.124, 0.061, 0.126, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "网商贷转借呗", "pos": ["网商贷怎么样变借呗", "能把网商贷换成借呗吗"], "neg": ["花呗还需还账总额", "为什么借呗要上传身份证件", "花呗最多支付***是什么原因", "上午从借呗上借款到现在没有发款到卡上是怎么回事", "为什么我的花呗不能提前还完", "花呗收款码额度提升"], "pos_scores": [1.0, 0.95], "neg_scores": [0.058, 0.098, 0.074, 0.063, 0.096, 0.087], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "降低借呗额度", "pos": ["怎么能降低蚂蚁借呗", "借呗可以降额么"], "neg": ["我的花呗怎么一直不长", "借呗还款有个先息后本和等额还款 有什么区别", "借呗首月还款", "关团花呗,别一个帐户什么时候能开通", "花呗的流量提取", "花呗如何解除子账户", "a账号关闭花呗b账号多久可以开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.194, 0.056, 0.123, 0.08, 0.089, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗放款时间", "pos": ["蚂蚁 借呗借款多长时间放款", "蚂蚁借呗,什么时候放款"], "neg": ["***不是临时额度吗", "花呗金额为什么这么底", "花呗可以用来订机票么", "花呗无法使用、是怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.111, 0.144, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款去向查询", "pos": ["我的退款到哪去了?我是提前还款花呗,但是收到货产生退货退款,卖家退回花呗了,那我的钱到哪去了", "我之前的花呗已经还清了,但是我在淘宝上买的东西退货退钱了又退到了花呗上面,那钱去哪了"], "neg": ["花呗点餐能点几次", "花呗说我额度涨了,可是却没有变,这是怎么回事", "我用银行卡多还了花呗怎么办", "暂停花呗额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.148, 0.113, 0.111, 0.097], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查询上月花呗账单", "pos": ["想查上个月的花呗账单", "把我花呗这个月的账单发给我一下"], "neg": ["花呗提前结清多久额度能提", "用花呗分期买的怎么还款", "蚂蚁借呗只能分***个月或者***个月还么", "花呗上的支付宝账户和我需要添加的银行卡不一致怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.131, 0.189, 0.067, 0.197], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗和借呗区别", "pos": ["蚂蚁借呗和蚂蚁花呗一样", "花呗和蚂蚁借呗是一回事吗"], "neg": ["什么时候提的借呗额度", "我的***元花呗还款咋回事", "蚂蚁借呗多久重新可以评估额度", "花呗分期影响", "逾期还清后再使用借呗为什么还显示有逾期", "中专生如何完善花呗的学历资料"], "pos_scores": [1.0, 0.95], "neg_scores": [0.118, 0.131, 0.129, 0.125, 0.174, 0.157], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗提前还款规则", "pos": ["蚂蚁借呗提前还款一个月这个月还需要再次还款吗", "提前把蚂蚁借呗分期的钱还了,账单日还扣款吗"], "neg": ["花呗的手续费是什么", "我的额度是多少花呗", "花呗,怎么老是要我绑定银行卡,,我都绑定银行卡了都不行", "我的花呗为什么只能在淘宝买东西,其他功能没有", "能帮我开开花呗", "花呗 能转给他人支付宝吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.191, 0.063, 0.082, 0.195, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查看借呗账单", "pos": ["怎么看借呗还款的哪个账单", "借呗账单怎么看"], "neg": ["为什么我的花呗,不是我的支付宝账号", "花呗也关了", "用花呗支付手续费怎么扣", "购物津贴花呗可以用吗", "第一次开通花呗有多少额度", "需要具备什么样的条件才能向蚂蚁借呗借钱", "花呗什么安全验证不通过是怎么回事", "借呗能修改还款方式吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.108, 0.103, 0.083, 0.068, 0.137, 0.156, 0.149, 0.076], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗被关闭原因", "pos": ["借呗为和又又把我关了", "我的蚂蚁借呗怎么还关了"], "neg": ["借呗日息如何再降低", "好友没有花呗怎么邀请", "因为个人原因,我花呗逾期了好久,我现在还清了,还可以使用吗", "而且网商贷没有额度", "我的花呗所有分期账单如何提前还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.181, 0.081, 0.105, 0.139, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "线下无法用花呗", "pos": ["为什么商家支持花呗,而我用不了", "我的花呗为什么好多线下支付不了"], "neg": ["花呗怎么读取手机联系人名单", "我的花呗额度是够的 店主也支持花呗 为什么付款的时候不能显示花呗支付", "借呗把钱还了会有吗", "花呗的还款日是几号", "余额里的钱还蚂蚁借呗的钱有限额吗", "花呗临时额度入口", "借呗没有自动还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.099, 0.072, 0.115, 0.163, 0.151, 0.158], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗开通时间", "pos": ["我什么时候可以开通支付宝花呗", "花呗何时能开通"], "neg": ["能自动提高借呗额度", "支付宝显示花呗没开通,天猫却用了花呗,怎么还款", "花呗支付没享受到优惠"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.141, 0.121], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭借呗功能", "pos": ["我的意思是,要像刚申请支付宝那种状态,没有借呗", "能不能把我借呗关闭成没开通状态"], "neg": ["借呗还款为什么没有自动扣款", "付尾款怎么使用花呗", "我的借呗怎么没有提升额度的", "花呗可以自动扣除绑定银行卡吗", "花呗临时额度到期后还能不能使用", "我的花呗为什么扫码不能付款呀"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.198, 0.064, 0.16, 0.149, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款方式", "pos": ["花呗 还款怎么还", "不知道怎样还花呗"], "neg": ["开通花呗分期条件", "花呗体验额度多少", "花呗 绑定的卡 怎么更换"], "pos_scores": [1.0, 0.95], "neg_scores": [0.153, 0.138, 0.157], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "已还花呗的退款处理", "pos": ["花呗钱己经还了,后面退了款的钱退到那里", "之前的花呗已还,但后来有笔花呗退款"], "neg": ["怎么不支持花呗", "咋要我还款,我没用花呗", "我能继续使用支付宝花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.193, 0.13, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "删除借呗记录", "pos": ["你可以帮我删掉借呗里的账单记录吗", "蚂蚁借呗还款记录可以清除吗"], "neg": ["花呗分期还款支持邮政储蓄银行绿卡通支付吗", "花呗款已还,怎么还显示下个月还可还十几元,说未出帐的,什么意思", "借呗能逾期多久不会影响信用", "花呗充值话费退款是什么原因", "花呗能让别人先还吗", "为什么没收到花呗的优惠券", "花呗绑定了一个支付宝另一个还能开通吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.074, 0.102, 0.125, 0.077, 0.164, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账单月划分规则", "pos": ["***号当天花呗消费是算当月的,还是次月的账单", "我一号过后用的花呗支付,是当月还还是下月还"], "neg": ["蚂蚁借呗分期的一次性还完", "蚂蚁借呗授权书生效是什么意思", "钱可以退到花呗里面吗", "把以前的帐号注销,现在这个能开通花呗吗", "我的花呗还欠有多少钱", "网商货和借呗可以同时有吗", "蚂蚁借呗借款超时没到账", "达到什么条件才能使用花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.155, 0.098, 0.128, 0.18, 0.068, 0.086, 0.19, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期后能否使用", "pos": ["我的花呗逾期两个月了现在还款还能用吗", "花呗逾期***填还掉后还能不能用"], "neg": ["我现在花呗不可以用吗", "可以有花呗充话费吗", "花呗黄金会员", "花呗支付尾款可以用购物津贴吗", "积分可以兑换花呗免费分期吗", "蚂蚁借呗可以每月还款吗", "蚂蚁花呗额度怎么提升"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.154, 0.17, 0.086, 0.062, 0.098, 0.192], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗到账时间", "pos": ["借呗什么时候会到账", "借呗蚂蚁几天到账"], "neg": ["借呗利润怎么降低", "我的花呗钱被扣了***我该怎么找回", "我想还了花呗能关闭功能不", "花呗有没有每个月必须花多少吗", "为什么我开不了花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.16, 0.072, 0.139, 0.081, 0.068], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期提前结清", "pos": ["花呗已经分期的可以一次还款吗", "花呗欠款***元,然后我分期付款,到下一个月***日前可以一次还清***元吗"], "neg": ["为什么我的收钱码开通了花呗支付别人付款却不能用花呗", "借呗按日借款怎么没了,只能按月了", "花呗分期提前还款收手续费么", "我花呗里的钱应该多出***", "花呗对应的保险"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.179, 0.18, 0.095, 0.184], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗系统繁忙", "pos": ["花呗点进去总是提示系统繁忙", "打开花呗显示系统繁忙,请稍后再试"], "neg": ["我的蚂蚁借呗还款了嘛", "我怎么没法还花呗", "蚂蚁借呗随时还随时借可以吗", "我都退了,怎么还在花呗里", "信用卡花呗是哪个收钱码"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.13, 0.132, 0.174, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗随时还款", "pos": ["花呗是不是可以随时还", "花呗可不可以随时还款"], "neg": ["是的 花呗账单分期", "在我的里面找不到花呗", "花呗分期提前还要利息吗", "为什么我的花呗上没看到,没有关闭", "借呗到账时间变慢了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.138, 0.16, 0.127, 0.199, 0.107], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期后可用额度", "pos": ["花呗分期后 还能用吗", "分期付款这个月还能使用花呗"], "neg": ["没有找到花呗的设置", "取消花呗、信用卡收款服务", "我用花呗支付了,为什么对方还没有收到钱", "淘宝花呗信息截图在哪里", "我的蚂蚁借呗可以借***千的额度,如果是提前还款的话,利息是多少钱", "花呗退款如何拿到这比退款", "为什么我开通了花呗,就是用不了", "我在商场买东西不能使用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.111, 0.158, 0.194, 0.108, 0.051, 0.096, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗不显示原因", "pos": ["我的蚂蚁借呗功能怎么没有", "为啥我的蚂蚁借呗不显示"], "neg": ["分期花呗提前还,有什么危害吗", "现在分期花呗,到时候想一次性还,可以一次性还吗", "蚂蚁借呗额度没用完,还可以继续借嘛", "本月花呗已还款,银行卡钱也扣了 ,为什么显示没还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.103, 0.143, 0.092, 0.111], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗停用原因", "pos": ["为什么花呗不能使用了", "花呗很久用不了了"], "neg": ["登陆花呗提示你的支付宝账号存在安全风险", "花呗已经还款,还显示有余额", "花呗退款后,还款额不变"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.198, 0.118], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期影响信用", "pos": ["花呗分期会对个人信用不好吗", "花呗分期对分数有影响吗"], "neg": ["为什么我花呗只能淘宝付款", "用了花呗没还款会怎样", "口碑商家开通花呗收款", "查询申请借呗进度", "冻结花呗什么时候能解开", "花呗还款怎么没到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.081, 0.176, 0.127, 0.183, 0.122, 0.066], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗只能分6期", "pos": ["为什么蚂蚁借呗里借多久的选项中只有六个月", "我得借呗为什么只能分六个月"], "neg": ["花呗不用咯。 他会不会扣钱的", "借呗额度能关闭不", "我的花呗额度被冻结了怎么办", "花呗请朋友代付,没有链接", "我有一个花呗账单被我删除了", "我是卖家,但是我没有开通蚂蚁花呗,现在有个买家买东西的时候说他不能使用蚂蚁花呗付款", "为什么我的花呗额度这么底", "为什么支付宝手机号跟花呗手机号不一样"], "pos_scores": [1.0, 0.95], "neg_scores": [0.103, 0.131, 0.115, 0.147, 0.144, 0.18, 0.193, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗好友代付", "pos": ["朋友代付朋友可以用花呗帮我付款吗", "可以用花呗帮好友代付么"], "neg": ["蚂蚁借呗开通最少能借多少", "花呗开通后,不用会不会扣取费用", "支付宝上面的花呗可以每个月分期吗", "要是光开通花呗,不用有费用吗", "彻底关闭借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.129, 0.066, 0.144, 0.133], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗二次分期", "pos": ["借呗借款分期后到期还可以再分期还款吗", "蚂蚁借呗能分期后能再次分期吗"], "neg": ["蚂蚁借呗提前还款有副", "附近哪些门店可以用蚂蚁花呗", "借呗***天免息券"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.14, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗临时额度申请", "pos": ["花呗怎么才能使用临时额度", "我需要花呗临时额度"], "neg": ["我后期有网商贷,为什么借呗不见了", "蚂蚁借呗还款有短信通知吗", "花呗如何点提前结清", "花呗冻结好几个月了,以后还能用么", "花呗能支付微信零钱吗", "短信提示借呗额度提升成功为什么我额度没变", "花呗额度有***为啥不能分期付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.14, 0.159, 0.188, 0.167, 0.163, 0.061, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗重复扣款", "pos": ["银行卡支付了什么花呗再重复扣款", "我用花呗付款,银行扣了一次款,还花呗银行又扣了一次款,重复扣款了!怎么回事"], "neg": ["花呗分期买手机在哪查看物流", "我开通不了花呗也开通不了体验额度", "我刚申请关闭借呗,是否借呗关闭了,花呗也不能用了", "借呗怎么不给借"], "pos_scores": [1.0, 0.95], "neg_scores": [0.116, 0.098, 0.139, 0.145], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗临时额度分期", "pos": ["花呗临时额度用了能分期还吗", "花呗临时额度用掉可以分期吗"], "neg": ["花呗临时额度也可以分期还款么", "花呗可以他人代还么", "借呗手机号码换了收不到验证码怎么办", "蚂蚁借呗怎么每天那么多人申请", "花呗临时额度可以改几次", "我刚刚花呗点错了分期,可以取消吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.157, 0.128, 0.193, 0.118, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "清除借呗记录", "pos": ["能不能把借呗记录给删除了", "蚂蚁借呗还完后记录怎样删除"], "neg": ["我的借呗开通过,无不良记录,为何停了", "蚂蚁借呗可推迟还款吗", "我的借呗为什么没有按天还款的选项", "花呗邦定的手机", "蚂蚁借呗和网商银行是一家吗", "我想要使用余额宝自动还款花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.067, 0.088, 0.164, 0.124, 0.131, 0.058], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查询花呗额度", "pos": ["花呗额度怎么看", "花呗额度查不了"], "neg": ["花呗只能在网上消费吗", "没有在花呗", "花呗支付笔笔抽奖,抽到奖品在哪查看", "为什么分期额度那么低", "我的花呗账单想先还一半可以吗", "怎么提升花呗额度", "蚂蚁花呗怎么设置余额宝不自动还款", "实体店买手机能用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.06, 0.068, 0.1, 0.119, 0.088, 0.089, 0.141], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "芝麻信用与借呗关系", "pos": ["芝麻信用跟借呗有关吗", "芝麻信用跟借呗开通有关嘛"], "neg": ["花呗开通了为什么还不能使用", "在淘宝可以使用花呗吗", "怎么还借呗", "有手续费吗借呗", "花呗为出账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.089, 0.196, 0.147, 0.156, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "网商贷和借呗利息对比", "pos": ["网商贷的利息跟借呗利息差多少", "网上贷和借呗哪个利息高"], "neg": ["蚂蚁借呗手续费多钱", "怎么样才可以用借呗", "花呗逾期还款之后什么时候可以恢复使用", "借呗怎么不能分***期还咯", "能不能提高蚂蚁借呗的额度", "我好端端的花呗 不能使用了", "蚂蚁借呗一般是几号提额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.099, 0.101, 0.157, 0.1, 0.061, 0.131, 0.092], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款期限调整", "pos": ["怎么现在借呗只能***个月还款", "蚂蚁借呗怎么不能借***个月了"], "neg": ["为什么我花呗不能使用", "蚂蚁借呗客服", "花呗的钱我不花,是不是不用还钱", "我的蚂蚁借呗是不是已经升级成网商贷了", "刚用花呗支付的可以退款吗", "借呗还款期最多只有六个月"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.118, 0.176, 0.185, 0.166, 0.108], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账号转移", "pos": ["如何把花呗功能转到主账号", "花呗怎么能从一个账号转到另一个账号"], "neg": ["为什么我的蚂蚁借呗还是不能使用", "借呗还款不显示还款", "这个花呗我都不知道怎么弄", "我已经开通了花呗收款和信用卡 可是客户扫码付款确是不支持花呗付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.157, 0.183, 0.059, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期后恢复", "pos": ["我逾期还款已还还能用花呗吗", "花呗逾期后还款了,用额已经冻结,以后是否可以正常使用花呗"], "neg": ["蚂蚁借呗还款日期还没有到,就扣钱", "花呗从哪里可以手机充值", "上传身份证后多久才能给我开通借呗", "账户有花呗", "花呗可以 透支吗", "刚刚还的花呗怎么没到账", "为什么不见了花呗的二维码了", "花呗还完了多久可以用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.156, 0.183, 0.119, 0.143, 0.086, 0.175, 0.123], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗分期还款", "pos": ["借呗可以分期还", "蚂蚁借呗分期还"], "neg": ["就帮我恢复借呗吧", "我开通花呗信用卡功能,是不是有二维码给我", "这个月的***号用花呗买 多久还", "我的借呗今天是还款日,刚才***点还上的款,算逾期吗", "怎样查看花呗总还额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.146, 0.195, 0.194, 0.06], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗重复还款提醒", "pos": ["我这个月还完花呗了,怎么还要还", "花呗钱我还了,怎么又出来了让我还"], "neg": ["蚂蚁借呗不用额度是变吗", "花呗多少分可以免押金", "信用积分***多可惜申请蚂蚁借呗吗", "我的花呗里,为什么少了***元钱", "开通花呗需要多少额度", "愈期还过要多长时间才能使用花呗", "借呗每月的还款日可以改么", "我的蚂蚁花呗在哪儿"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.128, 0.171, 0.194, 0.055, 0.174, 0.132, 0.177], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗结清证明", "pos": ["蚂蚁借呗还请证明", "个人借呗 结清证明"], "neg": ["已经还款的花呗账单明细怎么查", "花呗标识图案是什么样", "借呗当月还不上咋办", "借呗能邀请朋友开通吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.079, 0.091, 0.175], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收款码查看", "pos": ["查看花呗收款码", "我的花呗收钱码"], "neg": ["蚂蚁借呗什么能评估完", "可以预期还花呗么", "支付宝花呗只能有一个么", "支付宝借呗可以推后一天还款吗", "已申请退款退货成功了,退货后退的款是还退到花呗吗", "我要关闭蚂蚁庄园,不是关闭蚂蚁花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.07, 0.07, 0.18, 0.173, 0.056], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗审核未通过", "pos": ["蚂蚁借呗审核没通过,好何重新提交", "借呗申诉审核未通过"], "neg": ["你花呗是什么意思", "怎么来取消花呗", "银行卡还借呗有限额吗", "为什么借呗申请失败"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.11, 0.056, 0.051], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期提前还款手续费", "pos": ["花呗账单分期提起还款还有手续费么", "花呗分期以后,如果提前还款需要手续费吗"], "neg": ["花呗能提前还下一期吗", "花呗还款可以推迟麽", "用花呗怎么开发票", "花呗分期想到能提前还款么", "蚂蚁借呗支持邮政储蓄卡吗", "为什么我的码蚁借呗不能用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.168, 0.196, 0.097, 0.177, 0.055, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "没有花呗原因", "pos": ["为什么,我的没有花呗", "我为什么没有花呗呀"], "neg": ["些条件满足借呗", "我的借呗额度有点不够用,可以提额吗", "蚂蚁借呗可以用网商银行帐户还款", "花呗额度是高多久", "为啥不开通借呗", "花呗支付的时候会有手机短信通知吗", "才申请的口碑码怎么不支持花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.159, 0.072, 0.187, 0.153, 0.123, 0.094, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗逾期几天不影响芝麻信用分", "pos": ["借呗逾期***天,会影响多少信用分", "借呗逾期几天不影响芝麻信用分"], "neg": ["什么额度能用借呗", "花呗的单笔消费最多是多少", "打算花呗分期买个手机 拜托快些提升额度", "花呗上提示可以用,蚂蚁借呗可用***为什么不能借", "如何删除借呗借还款记录", "为什么我不能使用花呗付电费", "我想知道花呗为什么用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.18, 0.108, 0.136, 0.168, 0.125, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款去向 提前还款后", "pos": ["我用花呗买东西、买错了、退款回花呗我查看花呗没钱、因为我一用花呗买东西了就立马还花呗的钱了、那退回来的款在花呗也没有、那退去哪了", "我是用的花呗,然后在淘宝上面买了东西,我就用银行卡提前还款了,之后我要把淘宝的那样东西退款了,但是退款的钱去了哪里,没有收到银行卡,上面也花呗,上面也没有看到"], "neg": ["我订了份外卖用花呗付的款还用给卖家钱或者还钱吗", "开通花呗,芝麻分要多少", "我已经还款两个月为什么没有恢复借呗额度", "我这借呗都有一阵显示异常了 啥意思", "蚂蚁借呗还了,啥时候能把逾期关拉"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.092, 0.072, 0.16, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗负分影响", "pos": ["我的花呗是负的有影响吗", "花呗负分有影响吗"], "neg": ["我没有任何不良记录怎么花呗开通不了", "网商贷切回借呗操作", "花呗以还款为什么还有账单", "我满足借呗开通的要求嘛", "花呗临时额度可以分期嘛"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.129, 0.153, 0.081, 0.186], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "口碑不能用花呗原因", "pos": ["口碑支付是不是不能用花呗", "我的口碑为什么不能用花呗"], "neg": ["我现在开通花呗什么时候还款", "花呗最低额度设置", "钱怎么退到花呗了什么时候到卡上", "花呗换起安全核身验证失败", "蚂蚁借呗电脑怎么使用", "花呗怎么删除账单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.057, 0.178, 0.194, 0.149, 0.087], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "无银行卡如何还花呗", "pos": ["没有银行卡可以还花呗嘛", "支付宝没有绑定银行卡,能不能还花呗"], "neg": ["怎么不显示蚂蚁借呗", "花呗还款是直接从绑定银行卡上扣吗", "商家花呗收款有手续费么", "借呗取消授权", "花呗付款是不是都是货到付款", "跟借呗有什么不同", "我花呗是***我想先还***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.08, 0.198, 0.191, 0.177, 0.11, 0.098, 0.148], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款去向 已还款", "pos": ["我已经还完款了,可是这是淘宝退款自动退到花呗里的,可是在他退之前我已经先还完了这个款,他这个钱退去哪里了", "蚂蚁花呗退款的钱去哪里了"], "neg": ["花呗分期还款费用怎么计算", "买家用花呗分期 在我店里买了***元的", "蚂蚁借呗分期还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.095, 0.183], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗无法使用原因", "pos": ["我蚂蚁花呗怎么不能使用", "我的花呗怎么花不了了"], "neg": ["双***花呗***期免息", "借呗可以拖一天还款吗", "借呗可以***期还款吗", "花呗额度成负数有啥影响吗", "借呗分期最长只有***个月吗", "蚂蚁借呗***期还款怎么不能选了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.118, 0.064, 0.109, 0.179, 0.075, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期还款设置", "pos": ["如果将花呗用完是否可以分期还款", "花呗分期还款如何设置"], "neg": ["手机丢了花呗还能用吗", "这个已经退款了,为什么还没把钱返回花呗", "使用借呗对以后房贷有影响吗", "借呗长的太慢了", "为什么我***号借呗还款后,一直没有恢复额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.184, 0.177, 0.182, 0.163, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无法借款原因", "pos": ["怎么在借呗借不了钱了", "借呗里面还有钱\\怎么借不了"], "neg": ["蚂蚁借呗的分期手续费", "蚂蚁借呗输错三次密码", "借呗分期,提前还了是否会影响再次借款", "按时还款为什么借呗额度不恢复了", "蚂蚁借呗可以延长分期", "花呗分期后提前还款,要收取手续费吗", "花呗总额度是不是可以透支的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.052, 0.094, 0.068, 0.13, 0.075, 0.053, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "支付方式无花呗选项", "pos": ["付款的地方没有花呗支付", "每次支付方式里直接没有了花呗这一个"], "neg": ["我通过支付宝购物时每次都从银行卡扣过款项了,但是花呗还让还款", "花呗分期付款手机 利息是怎么算的", "我用花呗交摩拜单车押金,退押金后没有到账", "我怎么删除用花呗", "借呗还款日,没提醒吗", "使用蚂蚁花呗支付时会有短信通知吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.153, 0.161, 0.14, 0.115, 0.183, 0.053], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗余额宝自动还款设置", "pos": ["花呗自动还款余额宝代扣功能怎么开", "怎么设置花呗从余额宝中自动还款"], "neg": ["取消花呗支付", "花呗可以换最低吗", "花呗额度是多少那"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.152, 0.129], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款期修改", "pos": ["蚂蚁借呗换还款时间", "蚂蚁借呗分期还款,还款期可以改吗"], "neg": ["花呗关闭后,怎么另外一个还是不能开通", "花呗,如何为好友额度加分", "借呗还款日是", "关了借呗能再开通吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.055, 0.191, 0.093], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗评估时间", "pos": ["系统评估借呗需要几天", "系统多久评估借呗可以使用"], "neg": ["蚂蚁借呗可以延长时间还款吗", "原先蚂蚁借呗都没有手续费", "蚂蚁借呗提前还款先还一部分可以吗", "花呗分期可以多个么", "我的花呗***分,为什么没有借呗入口"], "pos_scores": [1.0, 0.95], "neg_scores": [0.132, 0.137, 0.158, 0.121, 0.143], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗额度突然降低原因", "pos": ["为什么我的蚂蚁借呗额度突然降了那么多", "为什么我的借呗降低额度了"], "neg": ["这个月经济紧张推迟***天还花呗行不行", "怎么提升花呗的额度", "花呗还的时候有利息么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.103, 0.057], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "支付宝人工客服联系", "pos": ["找客服人员", "请联系人工客服为我解答"], "neg": ["用花呗分期买的东西,结果不合适退款了,手续费怎么没退", "蚂蚁借呗如何邀请好友开通", "花呗更换手机号了怎么换新的手机号", "我花呗是以前的号码,我想换现在的,怎么换", "花呗可以在店里用吗", "我的借呗为什么只能选择六个月和十二个月"], "pos_scores": [1.0, 0.95], "neg_scores": [0.163, 0.07, 0.124, 0.108, 0.123, 0.117], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗开通方法", "pos": ["花呗怎么样开通", "我的账户如何开通花呗"], "neg": ["我的花呗怎么才能开通", "蚂蚁借呗属于贷款吗,对于以后买房首贷会有影响吗", "怎样才是蚂蚁借呗关闭", "能不能申请借呗最低还款", "怎么看是不是使用花呗支付", "我说的借呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.119, 0.134, 0.19, 0.11, 0.071, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗生活圈加入方式", "pos": ["蚂蚁借呗生活圈怎么关注", "如何加入借呗生活圈"], "neg": ["我想蚂蚁借呗提升额度怎么办", "我用花呗已经付款了。怎么我在淘宝上看了。怎么显示没有付款", "网商贷额度和借呗额度哪个更高"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.193, 0.09], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗提前还款后无额度", "pos": ["蚂蚁借呗为什么提前还款了之后没有额度", "为什么提前还款借呗,没有额度了"], "neg": ["我的花呗有什么微贷商品逾期,能不能查出来", "借呗现在还可以提现吗", "花呗额度为何会下降", "为什么我的借呗额度没有了", "花呗还了最低额度了会有影响吗", "花呗临时额度完了会回到原有额度吗", "花呗分期最多是多少钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.148, 0.185, 0.141, 0.112, 0.056, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "商家支持花呗支付吗", "pos": ["你们家支持花呗支付么", "花呗支付可以吗"], "neg": ["要用多久支付宝才能用花呗", "蚂蚁借呗会影响信誉吗", "还花呗为什么发验证码还发以前的号码", "我怎么在电脑上进行花呗的还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.165, 0.102, 0.121, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭花呗方法", "pos": ["花呗如关闭", "花呗额度不长、想关闭"], "neg": ["关于花呗最低还款的", "花呗自动还款红功能", "为什么我的借呗借款方式跟别人不一样", "花呗一次性还款怎么操作", "花呗分期中,剩下的额度还能用吗", "信用卡与花呗有什么不同"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.156, 0.151, 0.166, 0.085, 0.071], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗长期不用会降额吗", "pos": ["蚂蚁借呗不用额度是变吗", "蚂蚁借呗长期不使用是否会降低额度"], "neg": ["为什么我美团外卖不能花呗", "借呗也没有", "***号之前花呗会自动扣款吗", "代付可以用花呗吗", "换个手机借呗怎么没有了", "借呗如何能提升额度", "花呗怎么不能在美团上买东西了", "电费怎么不能用花呗交"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.136, 0.11, 0.123, 0.14, 0.055, 0.2, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收款功能异常", "pos": ["为什么我现在花呗收款用不了", "花呗收不了钱怎么了"], "neg": ["花呗的优惠券有吗", "借呗怎么解除等额还款", "还什么花呗还款", "蚂蚁借呗不可以分期还款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.152, 0.063, 0.179, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗延期还款规则", "pos": ["花呗可以迟一个星期还吗", "花呗可以延期几天还吗"], "neg": ["花呗临时额度也能分期账单吗", "花呗分期后,额度是多少", "买手机用花呗分期付款", "我为什么没有蚂蚁借呗信用额度", "我用花呗买的火车票,然后我还上了,改签火车票退回的钱又给我退到额度一部分"], "pos_scores": [1.0, 0.95], "neg_scores": [0.146, 0.195, 0.056, 0.075, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗开通条件不满足原因", "pos": ["我哪里不满足借呗条件了", "蚂蚁借呗不满足条件的原因"], "neg": ["帮我看一下,我这个花呗是什么时候用的吗", "怎么样限额借呗", "购物津贴可以用来花呗分期吗", "系统一般好久评估一下借呗", "花呗可以让别人帮你还吗", "花呗分期后能提前还清么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.086, 0.112, 0.097, 0.1, 0.14], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款去向 已还款", "pos": ["我的这笔款是已经还了的 为什么不退现金而是退给花呗", "我上次买袜子的钱自己还款了 但是退钱的钱为什么又到花呗了"], "neg": ["交易退款至花呗为什么需要还款的金额没变", "花呗开通不能在淘宝使用", "借呗还款日", "我已经花呗分期了,还款日是***月***号,那我这个月***号提前还,还要手续费吗", "如何采用花呗二维码收款", "网商跟花呗有什么关系", "我这个账号不是有欠借呗吗还是贷款", "我花呗在淘宝买的衣服,然后我提前还上了花呗,淘宝退回的款到哪里"], "pos_scores": [1.0, 0.95], "neg_scores": [0.148, 0.097, 0.12, 0.07, 0.193, 0.066, 0.188, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗提前还款方法", "pos": ["花呗钱可以提前还么", "花呗付款了,可以提前还款吗"], "neg": ["在淘宝上买东西可以不用花呗吗", "花呗瑜期怎么办", "商品可以用花呗支付", "我的借呗为什么突然没有额度", "如果退出花呗", "我得支付宝,花呗是什么意思"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.15, 0.121, 0.147, 0.111, 0.076], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还清后显示欠款", "pos": ["花呗的钱我全部还完了滴按,怎么显示还欠款", "我的花呗明明已经还清了。为什么显示有未还完的"], "neg": ["花呗款怎么办", "口碑怎么用花呗支付不了了", "我买东西用花呗付款,怎么还钱", "我的信用很高 为什么不能使用借呗", "花呗临时额度会到期", "蚂蚁借呗分期得能在分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.182, 0.137, 0.123, 0.077, 0.17], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗上传身份证要求", "pos": ["?借呗上传身份证", "开通蚂蚁借呗可以不用上传身份证吗"], "neg": ["花呗还款实际到账不符", "有花呗怎么才能尽快有借呗", "退款回复花呗额度怎么做", "花呗分期会不会影响信用问题", "我支付宝就是现在这个支付宝账号登入但是花呗显示不是我登入支付宝的手机号怎么办", "花呗最低还款后剩下的金额可以分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.185, 0.149, 0.147, 0.058, 0.135], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通花呗不用会收费吗", "pos": ["花呗如果我开通了不用会有费用吗", "开能花呗不用要收费的吗"], "neg": ["我的蚂蚁花呗无法正常使用,为什么是怎么回事", "花呗欠款可以用微信还吗", "怎样设置付款方式为花呗", "蚂蚁花呗逾期被停止使用 如何再次使用", "花呗可以在外面店里购物吗", "取消的订单花呗还要还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.166, 0.073, 0.175, 0.182, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度无法使用原因", "pos": ["我的花呗额度为什么不可以用了", "为啥我现在用不了蚂蚁花呗"], "neg": ["借呗是不是临时额度的", "查询我的花呗注册手机号是", "系统多久评估一次借呗权限", "花呗分期后还能不能继续分期", "分期提前还款花呗", "蚂蚁借呗降了***额度", "可以把花呗权限给自己的另一个支付宝账号嘛"], "pos_scores": [1.0, 0.95], "neg_scores": [0.069, 0.161, 0.129, 0.119, 0.144, 0.116, 0.068], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通花呗收费吗", "pos": ["开通花呗需不需要费用", "开通花呗会收取费用吗"], "neg": ["企业支付宝如何开通信用卡与花呗支付功能", "支付宝的号码换过来了,花呗的号码还是没有换过来", "最高借呗多少", "蚂蚁花呗怎么还了两次钱", "借呗钱能转给别人吗", "如何寻找带有花呗标识的商品", "在花呗里没有看到退回来的钱", "为什么无缘无故就开通花呗,我要取消怎样操作"], "pos_scores": [1.0, 0.95], "neg_scores": [0.072, 0.165, 0.186, 0.199, 0.117, 0.16, 0.097, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期能用购物津贴吗", "pos": ["购物津贴能用在花呗分期上吗", "使用花呗分期,能否享受购物津贴"], "neg": ["我的花呗什么时候才能用了", "借呗没有还款功能", "我开通花呗收款现在怎么不能用了", "我的花呗写的是***号可以***号还款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.087, 0.113, 0.088, 0.052], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款状态查询", "pos": ["花呗的钱是否还清", "我的花呗钱还完了吗"], "neg": ["我消费的花呗没有那么多钱,为什么我的还款还要多于这个钱", "可以用支付宝吗", "订单取消后花呗什么时候恢复", "花呗对芝麻信用有多大影响", "借呗只有***个月和***个月的期限", "支付宝花呗营业照主体和本账户名称不一致"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.091, 0.056, 0.086, 0.089, 0.149], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "淘宝支付宝手机号不一致能用花呗吗", "pos": ["淘宝手机号和支付宝手机号码不一致,不能用花呗吗", "支付宝支付宝绑定的手机号和淘宝绑定的手机号不一样可以用花呗吗"], "neg": ["怎么查用花呗买的物品", "我明明付钱了,为什么确认收货后花呗上还要多次付款", "共享单车退钱了可是支付宝花呗没有收到", "在花呗怎么绑定银行卡"], "pos_scores": [1.0, 0.95], "neg_scores": [0.109, 0.122, 0.179, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗利息计算", "pos": ["***元花呗利息多少", "我上个月扣了我多少花呗利息"], "neg": ["朋友帮还花呗有什么好处", "蚂蚁借呗最后还款日是什么时候", "花呗是分期的吗", "我要怎样才能用借呗", "蚂蚁借呗用钱什么时候到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.134, 0.168, 0.072, 0.148, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "尽情花呗含义", "pos": ["显示尽情花呗什么意思", "尽情花呗怎么回事"], "neg": ["蚂蚁借呗还款日设置", "花呗逾期不还会影响黑户吗", "能不能注销支付宝", "花呗返利活动", "为什么这期花呗不能分期", "密付不能用花呗么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.063, 0.066, 0.144, 0.145, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "密付支持花呗吗", "pos": ["买东西密付可以用花呗吗", "密付能使用花呗么"], "neg": ["花呗能在淘宝里用吗", "我的花呗还款日期能延长到***号吗", "花呗分期第一个月需要付钱吗", "我的借呗额度什么时候才能恢复", "借呗怎么到期了不扣款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.081, 0.181, 0.054, 0.115, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗利息差异原因", "pos": ["蚂蚁借呗利息怎么别人低我的高", "怎么蚂蚁借呗我的利息和别人的不同,万分之六,别人的万分之***"], "neg": ["比如我用花呗分期买了东西,然后我应该怎么还款", "花呗能不能提前还清", "我是在还花呗的时候扣了***次款", "我的花呗为什么还有账单?我从未贷过款", "加好友可以发花呗吗", "我不是网商", "支付宝里找不到花呗――我的账单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.11, 0.057, 0.12, 0.19, 0.124, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗提前还款利息", "pos": ["提前还款的话利息会少么", "花呗 提前还款要利息吗"], "neg": ["怎么看是否支持花呗", "怎么完善借呗的信息", "借呗信用额度不够了", "花呗借呗开通不了是怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.159, 0.178, 0.052, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "网商贷和借呗区别", "pos": ["网商贷好借呗有什么区别", "我想知道蚂蚁借呗和网商贷的区别"], "neg": ["花呗短信怎样用", "怎么提前还已分期花呗", "我的借呗未还成功", "蚂蚁借呗没有三个月的借款期限", "花呗可以用余额还吗", "为什么借呗无信用额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.17, 0.055, 0.175, 0.15, 0.171], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款周期", "pos": ["蚂蚁借呗多久还一次", "借呗最长多久要还款"], "neg": ["蚂蚁借呗还款账单", "花呗分期期间提升可用额度吗", "信用额度怎样才能蚂蚁借呗的条件"], "pos_scores": [1.0, 0.95], "neg_scores": [0.125, 0.128, 0.055], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗开通时间", "pos": ["我的什么时候才能用花呗", "我的蚂蚁花呗什么时候能给开通,我现在一直信誉良好"], "neg": ["怎么样查询自己有没有开通花呗收款功能", "上个月花呗买的东西已经扣款为什么这个月又扣一次", "花呗可以换红包不", "如何设置花呗绑定银行卡", "我花呗付款,退货的话钱是不是也退到花呗了", "我的花呗为什么不能用?说有逾期的,但是找不到"], "pos_scores": [1.0, 0.95], "neg_scores": [0.135, 0.16, 0.084, 0.164, 0.064, 0.152], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "支付宝无花呗选项", "pos": ["为什么我没有花呗可以用", "支付宝里怎么没有花呗"], "neg": ["为什么这个月借呗不能借", "借呗一次性还款后能不能再借钱", "我花呗分三期还款的", "花呗充话费怎么充不了啦"], "pos_scores": [1.0, 0.95], "neg_scores": [0.167, 0.154, 0.084, 0.169], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无额度原因", "pos": ["借吧怎么没有额度啦", "借呗怎么没有咯"], "neg": ["为什么我用花呗一直提示满***用花呗分期", "花呗冻结啥时候能开通", "花呗的最大额度是多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.128, 0.119, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗专享红包使用条件", "pos": ["开通了花呗,专享红包能用吗", "开通花呗是不是就可以申请专享红包门店"], "neg": ["总是显示我的另一个账户,怎么能把花呗转到我这个账户上", "花呗我为什么绑定不了银行卡", "为什么在淘宝购物不能用花呗", "我看了我的花呗没有这个", "中专生如何完善花呗的学历资料", "借呗有最低还款额,没有", "花呗贷款一百还多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.074, 0.12, 0.182, 0.182, 0.159, 0.104, 0.08], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗未还完能否再借", "pos": ["借呗是要全部还清才可以再次借款吗", "借呗分期未还完可以再借吗"], "neg": ["我啥时候能开通蚂蚁花呗", "借呗额度全部还清可以恢复吗", "在共享单车里的押金可以用花呗吗", "我想请你们把我花呗开通", "刚开通花呗有多少透支额", "我的花呗收款吗,不能收款了怎么办", "为什么我的借呗收回去了", "花呗还款 组合付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.069, 0.084, 0.113, 0.062, 0.158, 0.136, 0.099, 0.176], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗通讯录权限设置", "pos": ["花呗允许读取通讯录怎么搞", "花呗怎么设置通讯录"], "neg": ["花呗限制了虚拟支付是什么意思", "可以使用花呗里的余额吗", "商品价格大于花呗额度但能分期", "为什么退款了之后花呗还款还是没有减"], "pos_scores": [1.0, 0.95], "neg_scores": [0.175, 0.195, 0.094, 0.166], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗购买火车票限制", "pos": ["花呗,能买火车票吗", "支付宝花呗不能买火车票对吗"], "neg": ["买花呗保险", "借呗十二点以后怎么还款", "花呗临时额度啥时还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.171, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无法借款原因", "pos": ["为什么借呗不能借钱了", "为什么我的蚂蚁借呗不能用"], "neg": ["开通了花呗就成了蚂蚁会员吗", "蚂蚁花呗可以提前还款有手续费吗", "商家支持花呗付,但是付不了", "什么时间可以使用花呗", "***月***号用的花呗,是否可以到下个月***月***号还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.077, 0.063, 0.193, 0.087], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还清后无法使用", "pos": ["花呗我已经还了,为什么现在无法使用", "我的花呗还完了怎么不能用"], "neg": ["花呗完善学历资料怎么完成", "我蚂蚁借呗今天还款没有", "花呗免单活动怎么报名", "怎么填写申请蚂蚁借呗", "从淘宝买东西用花呗只用几十块钱可以吗", "蚂蚁借呗进不去", "花呗账单分期免息券在哪儿领"], "pos_scores": [1.0, 0.95], "neg_scores": [0.058, 0.124, 0.189, 0.131, 0.098, 0.053, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗不支持电费缴费", "pos": ["交电费不能扣花呗", "支付宝里的花呗怎不支持交电费了"], "neg": ["花呗在还款日的几点自动还款", "为什么我改了手机号码还是不可以使用花呗", "刚才没有输入付款码花呗就付款了", "可不可以更改花呗还款日", "淘宝花呗支付宝花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.106, 0.071, 0.187, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退票未到账", "pos": ["用花呗买的汽车票退票为什么钱还没到账", "退款钱到花呗,查询没到账"], "neg": ["花呗有额度已选了花呗付款,为什么还是在我银行卡里扣", "借呗分期还可以分期吗", "花呗可以定酒店麽", "花呗开通了,不使用会产生费用么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.091, 0.151, 0.121], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账单查询", "pos": ["在哪看用了几次花呗", "查看花呗所有账单在哪"], "neg": ["收钱码在哪儿?我开通了花呗收款,但是找不到收钱吗", "我前面忘了花呗还款有几条 我的花呗被冻结了", "蚂蚁借呗可以借***年", "我九月这期蚂蚁借呗是否还款成功", "大概多久评估一次借呗条件", "借呗没有了成了网商贷", "我刚才在淘宝买东西了,是用花呗付的钱怎么办", "***元已经退了,花呗怎么没有减掉"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.058, 0.115, 0.171, 0.174, 0.082, 0.187, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期一天影响", "pos": ["花呗只逾期一天还款后可以用嘛", "我用的花呗我晚***天还可以吗"], "neg": ["什么交易能用花呗", "借呗和花呗不能同时开通吗", "花呗的还快款金额不对", "花呗的钱还不进"], "pos_scores": [1.0, 0.95], "neg_scores": [0.135, 0.067, 0.148, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗更换手机号", "pos": ["花呗能不能换手机号码呀", "我花呗如何更改手机号"], "neg": ["我没有逾期怎么花呗就不能用了", "另一个号关闭花呗了,这个号可以开了吗", "借呗怎么再用", "同一个身份证花呗在另一个账户上面,转到这一个上面怎么转", "花呗在哪里查询"], "pos_scores": [1.0, 0.95], "neg_scores": [0.099, 0.192, 0.111, 0.154, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗有额度借不出", "pos": ["我借呗还有额度,为什么借不出来了", "借呗上有额度,但是借不出钱"], "neg": ["花呗就能使用一个淘宝账号", "主动申请花呗审核时间", "花呗手机直接还款", "可以用蚂蚁花呗购买商品吗", "花呗还最低,影响个人信用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.089, 0.065, 0.085, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款记录查询", "pos": ["蚂蚁借呗今天的还款记录", "我是说借呗还款记录"], "neg": ["开通了花呗,不花钱是不是也不算", "为什么我的没有蚂蚁借呗功能", "花呗逾期未还有利息吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.152, 0.192], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款处理", "pos": ["退到花呗里的钱怎么办", "钱款退到我的花呗里"], "neg": ["花呗还款后才能恢复可用额度", "花呗扣钱是在哪里扣", "每个月都可以继续花呗分期吗", "为什么我的账号不能用花呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.186, 0.056, 0.155, 0.179], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗实体店使用范围", "pos": ["花呗可以在实体店使用么", "花呗可以用于超市吗"], "neg": ["花呗刷脸为什么老是失败", "为什么花呗每个月减***元", "为什么我在借呗哪里借的钱一天都没到帐", "信用卡花呗收款怎么开通不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.052, 0.079, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度查询", "pos": ["我有多少的花呗额度", "开通花呗我的额度是多少"], "neg": ["***.***显示两笔退款,但我的花呗账单***月和***月只有支出金额,没有这两笔收益", "还有一个开通了网商贷,开通不了借呗", "蚂蚁花呗被冻结了怎么解冻", "花呗能让别人先还吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.162, 0.173, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗线下支付限额", "pos": ["花呗线下支付能用多少额度", "线下商家花呗数额"], "neg": ["用花呗买东西不收取利息和手续费吗", "花呗***号自动还款扣利息吗", "蚂蚁借呗提前还款后可以在提现吗", "信贷产品要确定永久关闭关闭关闭", "我花呗有***额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.146, 0.14, 0.083, 0.095, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "注销花呗方法", "pos": ["怎么消花呗卡", "我要怎么注销蚂蚁花呗"], "neg": ["花呗临时额度消费了可以分期吗", "为什么我的花呗要利息", "借呗还款可以自动在支付宝余额扣钱吗", "花呗提升额度需要学历我没有学历怎么办", "花呗额度负数还能买吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.163, 0.192, 0.125, 0.173, 0.067], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通花呗不使用影响", "pos": ["开通花呗不使用行吗", "开通花呗都不用会怎么样"], "neg": ["为什么借呗只能最高借六个月了", "为什么我花呗刷脸刷不过去", "花呗还上次分期的怎么还", "花呗分期付款的苹果手机"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.057, 0.096, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗关联账号更换", "pos": ["我能把我另一个关联账号花呗转到现在用的账号吗", "换花呗关联账号"], "neg": ["花呗支付加油卡充值", "身份证过期了咋办借呗", "已经形成的借呗账单可以分期吗", "借呗放款了没到银行卡", "花呗点进去总是提示系统繁忙", "我想退出花呗,不想再使用", "花呗短信提醒免费吗", "花呗可以使用的途径"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.083, 0.13, 0.054, 0.139, 0.196, 0.071, 0.184], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗缴费功能异常", "pos": ["以前能用花呗缴费着嘛", "前几天还能用花呗充值"], "neg": ["怎么查看花呗分期还的开销", "差了多少钱,我给她还了,把她的蚂蚁花呗给关了吧", "花呗是不是全部还好就能使用了", "花呗分期的费率是多少", "网商贷和蚂蚁借呗额度共享吗", "二维码申请花呗怎么申请"], "pos_scores": [1.0, 0.95], "neg_scores": [0.068, 0.16, 0.09, 0.162, 0.058, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗突然无法使用", "pos": ["我现在为什么用不了蚂蚁花呗", "怎么不能用花呗了"], "neg": ["借呗还款日最晚当天什么时间", "这个花呗里面借的钱怎么还呀", "也就是花呗账单多于实际消费,而且我是有明细的", "借呗提额为什么要输入密码", "借呗为什么不借了", "花呗分期后提前还款,要收取手续费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.148, 0.166, 0.112, 0.149, 0.156, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还清后能否再借", "pos": ["借呗一次还完本金下次还能借吗", "借呗还完后还能继续借款不"], "neg": ["花呗手续费怎么算", "花呗能在闲鱼上使用ok", "我都不用花呗买东西。怎么会有费用叫还款", "为什么我的花呗账户冻结了", "花呗钱能不还吗", "为什么我已经退款了 还款还是多一笔花呗? 退款在那里去了", "花呗冻结了。怎么开通", "我这个花呗里面的钱怎么扫不出去是怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.116, 0.094, 0.123, 0.159, 0.161, 0.117, 0.174, 0.074], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗最新活动", "pos": ["花呗近期活动", "花呗有新活动吗"], "neg": ["淘宝退款还到花呗我什么没看到钱", "怎么不能用花呗充话费了", "花呗借款,还款日期是固定的么", "花呗分期付款的没有积分", "我借呗是每月等额的,现在可以改成先息后款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.085, 0.149, 0.141, 0.104, 0.15], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "收款码不支持花呗", "pos": ["扫手机上的收钱码怎么不支持花呗收钱", "为什么别人扫我的收款码不支持花呗付款"], "neg": ["花呗pass授权", "花呗显示上期账单已还清,但还有短信提示欠费,怎么回事", "我要还花呗什么还", "借呗最高额度能提到多少钱", "借呗分期还款能先还利息么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.15, 0.163, 0.067, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗提现到银行卡", "pos": ["我要从花呗里提到卡里面", "我要花呗滴钱转到银行卡"], "neg": ["我现在应该怎么还款,我花呗是花着,我号欠款的,我都不用了,怎么还款", "借呗说已经到银行卡了,可是银行卡没收到", "可以用密付支付花呗账单吗", "借呗每个月几号还", "怎样修改花呗密码", "花呗自己充的钱怎么提出", "支付宝里如何切换成花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.141, 0.153, 0.086, 0.152, 0.095, 0.061], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗重新开通时间", "pos": ["我的花呗什么时候可以再开通回来", "我什么时候可以重新开启花呗"], "neg": ["蚂蚁借呗的借款记录怎么删", "系统平的花呗一万我可以定五千吗", "我的怎么找不到花呗在哪", "花呗逾期不能超过今天", "为什么我点花呗,进去显示系统繁忙"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.083, 0.139, 0.192, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款未到账", "pos": ["说是退回了花呗但没看到钱进了哪儿了", "退到花呗,没看到进帐"], "neg": ["花呗拖一天还款", "淘宝花呗如何使用", "借呗还了怎么还有负面记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.125, 0.195], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期还款后无法使用", "pos": ["花呗逾期还款了为啥还不能用", "花呗逾期还款以后怎么还不能用"], "neg": ["蚂蚁借呗额度提高", "花呗可不可以用美团", "要开通花呗线下收款怎样开通", "花呗最多能透支多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.19, 0.107, 0.196, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还清后关闭方法", "pos": ["花呗全部还了 能不能关闭", "还款以后能不能关闭花呗"], "neg": ["花呗什么时候能分期", "花呗那有条信息说,当前剩余款***应还,是什么意思", "什么时候花呗能提前还款", "花呗用了就要还吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.077, 0.161, 0.165, 0.056], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗可用额度查询", "pos": ["花呗我能接多少钱", "我花呗现在多少钱"], "neg": ["没开通,怎么用了花呗的钱", "为什么我的花呗页面没有账单分期", "花呗还款可以不分期么", "商家蚂蚁花呗怎么收手续费", "用花呗按时还款的话要不要利息"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.172, 0.099, 0.157, 0.191], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗强制开通方法", "pos": ["可以帮我开启花呗没", "帮我打开花呗好嘛"], "neg": ["杂样申通蚂蚁借呗", "支付宝已经绑定银行卡,为什么花呗还要绑定银行卡", "花呗还款日到了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.121, 0.111, 0.066], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "微信还花呗", "pos": ["花呗还款可以在微信还吗", "我想知道微信可以还支付宝上的花呗欠款吗"], "neg": ["这是退了的订单,花呗还有帐单要还", "花呗逾期,利息怎么计算", "花呗分期付款的利息多少", "花呗能买飞机票", "我不想用花呗了,想还款", "借呗银行卡体现放款多久到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.183, 0.183, 0.188, 0.164, 0.134], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款无短信验证码", "pos": ["蚂蚁借呗还钱要手机验证码,我手机号码没有再用,所以收不到信息,应该怎么还", "借呗自助还款收不到短信验证码怎么办"], "neg": ["美团花呗用不了", "这个月,花呗还能不能用", "花呗为什么不自动缴电费了", "为什么查不到蚂蚁借呗的还款记录", "花呗能否预订酒店", "问,我的借呗当期还款怎么提前还", "集分宝可以还花呗借款吗", "花呗里的钱只能买东西吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.105, 0.174, 0.13, 0.074, 0.2, 0.145, 0.093, 0.108], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收钱开通", "pos": ["花呗,我给我开通,我现在急需要收钱", "我开通了花呗收钱了"], "neg": ["原来的邮寄的二维码可以用花呗吗现在", "为什么我现在还是无法开通花呗", "我如果全部还款完成了还可以用花呗分期购物吗", "花呗申请分期后,一次还清也需要还全部的手续费", "帮我花呗打开", "为什么在淘宝付款用了花呗,而在支付宝上显示未开通花呗", "芝麻信用够了,身份证绑定了,银行卡绑定了,为什么还不能申请蚂蚁借呗", "借呗会自动扣除余额宝欠款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.169, 0.165, 0.056, 0.154, 0.098, 0.092, 0.17, 0.151], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期影响芝麻分", "pos": ["花呗分期影响,芝麻分吗", "花呗分期付款 影响信用额度那"], "neg": ["我的借呗超过了***笔就不能结款了吗", "花呗可以支付滴滴打车的费用吗", "打开花呗总显示登录超时怎么回事", "花呗付款后退款会收手续费吗", "退款回到花呗的钱还需要还吗", "在淘宝和天猫上什么东西都可以用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.175, 0.194, 0.106, 0.083, 0.169, 0.123], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗手机号不一致", "pos": ["花呗和现在是一个手机号,不符,怎么办", "花呗和绑定支付宝的手机号不一致"], "neg": ["花呗付款了,可以提前还款吗", "花呗还款提醒怎么取消", "花呗还款是优先还最低的吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.119, 0.086], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗需旧手机号登录", "pos": ["为何花呗还需使用旧号码登陆", "为什么我登录的是新手机号,花呗怎么让我登录旧手机号码使用花呗"], "neg": ["前几天花***花呗积分兑换的优酷会员在哪里可以领取", "退回花呗我还会不会有影响", "付的钱是花呗的,该怎么退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.157, 0.137], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗绑定抽奖", "pos": ["蚂蚁花呗抽奖机会", "绑定花呗,抽奖"], "neg": ["我的花呗额度是", "借呗借第二次用不用验证码", "没有人工客服吗", "双十二可以摇花呗额度吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.139, 0.104, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "密付花呗支付", "pos": ["密付能用花呗支付嘛", "密付用花呗"], "neg": ["蚂蚁借呗有保证金吗", "花呗可以借钱", "可以给个借呗免息卷", "参加花呗的活动什么时候到账", "花呗降额度是什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.138, 0.082, 0.098, 0.102], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗放款未到账", "pos": ["借呗提示已放款,但是银行卡没有收到钱", "借呗提到银行卡,也什么一天了都还没到"], "neg": ["更改开通花呗的支付宝", "借呗能提前还清吗", "花呗分期免息卷怎么积分兑换", "为什么淘宝买的东西付了款,支付宝花呗还要扣除一次"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.091, 0.144, 0.195], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收款码贴纸更换", "pos": ["我以前有申请过的贴纸。现在申请了花呗收款需要重新再申请贴纸吗", "收到申请的收款码贴纸后再申请花呗收款用更换贴纸吗"], "neg": ["花呗在我另一个支付宝账号上我可以改在我这个账号", "花呗付款三次 ,显示二个订单。怎么回事", "如何撤销花呗这个软件", "为什么我的收钱码不能用花呗了", "我的借呗怎么开通了,现在又没有了,怎么开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.066, 0.175, 0.084, 0.175, 0.141], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗无法买车票", "pos": ["为什么花呗无法买车票", "花呗能买火车票ma"], "neg": ["花呗可以提红包吗", "花呗额度转账号", "花呗最低还款以后每日利息多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.124, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款后未恢复额度", "pos": ["蚂蚁借呗钱已经还上了怎么还没有恢复额度", "我的借呗怎么还款了还没有恢复额度"], "neg": ["我的蚂蚁借呗几号还", "支付宝与淘宝上的花呗额不一样", "怎么提高花呗的信用额度", "更改花呗最低消费金额"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.149, 0.063, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款日修改", "pos": ["借呗还款日能换日期么", "借呗还款日期可以改不"], "neg": ["花呗自动抵扣红包", "为什么我的花呗额度不能提高", "我的花呗收钱怎解收不了钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.158, 0.116, 0.17], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗审核放款时间", "pos": ["申请借呗,多久能审核额度", "借呗多久审核完,能放款"], "neg": ["借呗必须还款日还款吗", "花呗信息补全", "借呗手机号码换了收不到验证码怎么办", "花呗分期取消后重新分期", "怎么花呗借款不能分***期了", "怎么样才能继续使用借呗", "我之前用了花呗分期,今天提前还款了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.186, 0.18, 0.193, 0.099, 0.053, 0.171, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗开通失败原因", "pos": ["借呗开通不了", "为什么我开通不了借呗"], "neg": ["我用花呗付款成功,商家未收到?是什么情况", "另一个账户开通了花呗,怎么转到这个账户", "花呗的钱还完了。那退款怎么办", "怎么没蚂蚁借呗", "为什么给我关了借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.064, 0.055, 0.077, 0.119, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "朋友代还花呗", "pos": ["可以让朋友替还花呗吗", "花呗可不可以让朋友帮忙还款"], "neg": ["刚一笔还款,我银行卡有信息扣取成功,为何花呗没还成功", "更改花呗还款日期", "蚂蚁借呗可以恢复那", "问什么关闭我的借呗", "我的蚂蚁借呗刚才领了一个什么礼包。额度就涨到***了", "花呗开通后一定要使用吗", "花呗分期,如果提前还完,会扣后边的手续费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.112, 0.093, 0.137, 0.084, 0.149, 0.107], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗降额原因", "pos": ["花呗额度为什么会降下", "为什么我的花呗降额了"], "neg": ["花呗怎么是旧手机号", "蚂蚁花呗要什么分期付款不能一下还完", "我上个月花呗的账单是什么", "借呗逾期三天已经还了,为什么还显示有逾期", "怎么关闭首选花呗支付方式", "花呗用了最低还款算逾期吗", "你的,为什么我的花呗不能用,求解", "花呗最高分多少期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.094, 0.169, 0.14, 0.152, 0.068, 0.127, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款短信通知", "pos": ["借呗还款有短信通知吗", "借呗借钱短信通知"], "neg": ["花呗年货购物津贴有什么用", "蚂蚁借呗怎么解呀", "花呗没绑定银行卡能用吗", "借呗想分期还款", "花呗分期提前还款扣手续费吗", "我的支付宝账户被提示司法冻结!我的借呗还款怎么还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.101, 0.191, 0.081, 0.107, 0.131], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无法确认借款", "pos": ["借呗输入借款金额无法点击确认", "借呗怎么无法确认借款了"], "neg": ["花呗逾期怎么办,会影响信用吗", "蚂蚁借呗怎么绑定网商银行", "手机号被停了 然后支付宝号登不上了 花呗逾期未还", "为什么我用了花呗分期后就一堆贷款的打电话加微信", "为什么每次点开花呗都叫我绑定银行卡", "蚂蚁借呗借***万的手续费是多少", "蚂蚁借呗服务密码是什么", "花呗额度还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.086, 0.158, 0.102, 0.175, 0.15, 0.161, 0.057, 0.061], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗额度恢复时间", "pos": ["我的借呗额度什么时候恢复,我希望快点", "我问借呗提前完款,什么时候恢复额度"], "neg": ["我的蚂蚁花呗怎么可用余额变少了", "把网商贷变成借呗", "蚂蚁借呗能不能***个月后再还款", "花呗消费了在哪里抽奖", "借呗总额度多少", "我想启用花呗,但是我的银行卡绑定的不是我本人的怎么办", "为什么花呗申请退款还没到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.066, 0.069, 0.147, 0.124, 0.149, 0.069, 0.051], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗无法付款", "pos": ["为什么我有花呗没办法付款", "为什么我这个花呗不能用"], "neg": ["为什么我的好花呗重复扣款", "怎么开通不了花呗了", "为什么蚂蚁借呗身份信息无法完善", "花呗返现是怎么", "花呗用多久能用借呗", "花呗一次分期后下个月还能分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.156, 0.191, 0.099, 0.196, 0.083, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗开通不用收费", "pos": ["开通花呗不使用需要缴费吗", "如果我开通花呗,确一直没有使用,会扣费用吗"], "neg": ["我的这个月花呗可以分期吗", "我花呗***元,十一月一号已经还了怎么还显示没还款", "蚂蚁借呗怎么解呀", "我不想用花呗,更不想体验"], "pos_scores": [1.0, 0.95], "neg_scores": [0.163, 0.072, 0.053, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗停用原因", "pos": ["怎么花呗使用不了", "花呗为什么停用"], "neg": ["我花呗怎么还款了,不开通", "为什么为花呗主页看不到我的账单", "信用卡没有限额,而花呗限额***", "有过逾期记录会影响借呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.117, 0.072, 0.173, 0.172], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭花呗服务", "pos": ["我不想用花呗,终止合同怎么操作", "花呗我不想使用"], "neg": ["为什么y每个人的花呗可用额度不一样", "问用了花呗付款,下个月是自动扣款吗", "花呗额度掉了怎么回事", "怎样弄,借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.066, 0.114, 0.077], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗需要支付利息吗", "pos": ["花呗要利息口马", "花呗他需要付利息吗"], "neg": ["我不想用了怎样卸载花呗", "花呗交了话费怎么还没到账", "蚂蚁花呗分期如果提前还还需要手续费吗", "花呗能转到另一个支付宝吗", "这个账号花呗已经关闭,为什么另一个账号不能开通花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.125, 0.165, 0.132, 0.113, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗可以申请临时额度吗", "pos": ["花呗可以提临额吗", "我的花呗额度可以申请调高一点吗"], "neg": ["要开通信用卡与花呗收款功能,但提示说我的账户不符合开通条件", "花呗额度怎么评估", "花呗帮还对以后提额有影响吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.054, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款后为什么没有恢复额度", "pos": ["蚂蚁借呗按时还款了。但没有余额", "为什么借呗还款了没有可用余额"], "neg": ["企业支付宝可以使用花呗吗", "我的借呗什么时候降的额度", "现在这个账号的花呗怎么开通", "借呗到帐了,怎么钱还不到了", "蚂蚁借呗如果借***一个月的利息多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.125, 0.095, 0.183, 0.071, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗和支付宝账号是否通用", "pos": ["花呗号和支付宝账号不一样", "支付宝花呗和淘宝花呗不是一个账号"], "neg": ["开通了另一个支付宝怎么我的花呗开通不了", "再次解冻花呗", "我想取消花呗怎样才能取消", "我的花呗之前绑定了一个邮箱密码和账号都忘记了怎么办", "花呗每月最低还款额度怎么看", "花呗的金额可以冻结吗", "顾客花呗付款卖家收不到钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.142, 0.165, 0.097, 0.118, 0.101, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款日期可以修改吗", "pos": ["蚂蚁花呗是可以自己设置还款时间吗", "蚂蚁花呗的还款日期可以调整么"], "neg": ["之前用完花呗有一次抽奖的机会现在还有吗", "借呗提前全额还款了,怎么没有恢复额度", "花呗每天只能用一次吗", "交电费是不是欠费了才可以用花呗", "可以提前还完借呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.117, 0.172, 0.071, 0.058, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期还款会影响信用吗", "pos": ["花呗晚两天还款,会对信用有影响吗", "花呗还了最低后 剩下的不及时还 会影响个人信用吗"], "neg": ["卖家说花呗分期付款,是按每月付款吗", "为何我开不了蚂蚁借呗", "花呗负额度分期还款后还能用吗", "能不能提前还花呗分期", "***号退款金额退到花呗里,可是已经还完款了", "借呗怎么不自动扣款的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.056, 0.143, 0.142, 0.154, 0.198, 0.074], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "商家支持花呗但我无法使用", "pos": ["为什么商家开通了花呗服务 我却无法用花呗付款", "为什么商家开启了花呗支付,到我这里不能用花呗"], "neg": ["花呗还有余额,为什么刷不出来", "花呗分期可以分别结清吗", "淘宝蚂蚁花呗怎么开通", "蚂蚁借呗额度消失两年", "我指的是蚂蚁借呗怎么还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.117, 0.134, 0.123, 0.161, 0.163], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款后仍收到催款通知", "pos": ["借呗还款了还在催我还款是怎么回事", "借呗昨天还过了,怎么还显示还款"], "neg": ["没钱还给花呗", "网商银行也不能用", "余额宝里有钱直接还蚂蚁借呗可以吗", "蚂蚁花呗分期后如果提前换该期的欠款,那还会扣利息吗", "为什么不能申请借呗", "怎么查看对方有没有花呗开通", "如果不使用手机需要关闭花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.198, 0.108, 0.092, 0.086, 0.197, 0.099, 0.095], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账单和还款金额不一致", "pos": ["我用花呗,账单与实际还款不一样", "我已还款的数额与花呗显示的数额不一致"], "neg": ["花呗可以微信发红包吗", "点解在我支付宝扣处了,花呗又扣我的支付宝", "美团订单花呗扣款了 但是美团显示未付款", "我之前用了花呗买手机,按时还完了,为什么这一年都没提升", "花呗邮箱如何登录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.175, 0.194, 0.086, 0.067], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期提前还款是否收费", "pos": ["假如现在手头有点紧,***号还不了花呗,如果选择分期这个月底前一次还清还收取额外费用吗", "上个月花呗分期,这个月全部还清应该不收取任何费用吧"], "neg": ["我花呗多还了两次", "为什么我开通不了蚂蚁借呗,花呗也是***提不了额度", "不是花呗是收钱码", "现在支付宝没有借呗了吗", "***借呗一个月忘还推迟了二天利息要多少", "为什么我花呗退回了金额,怎么未出账的金额没有变", "为什么我的芝麻借呗不开", "花呗每月充值不了话费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.153, 0.127, 0.094, 0.147, 0.123, 0.059, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗突然被关闭原因", "pos": ["为什么我的蚂蚁借呗突然封了", "为啥借呗会被封了"], "neg": ["花呗,买车险", "蚂蚁借呗自动扣款默认开通在哪里", "我的花呗无法使用了怎么办", "蚂蚁花呗额度冻结", "我花呗开通了吗", "天猫怎么不能用花呗", "想开通商家版花呗", "蚂蚁借呗怎么知道是否还清了最低"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.166, 0.169, 0.199, 0.083, 0.114, 0.199, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "如何开通借呗", "pos": ["借呗开通在哪里开通", "借呗能不能给我开通"], "neg": ["花呗充错号码了", "花呗逾期了还完了不能用了", "花呗的补充学历填错了怎么办", "借呗审核未通过怎么解决", "我花呗还清款", "蚂蚁借呗可以取现金吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.136, 0.121, 0.101, 0.117, 0.14, 0.118], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗无故扣款原因", "pos": ["未买东西怎么自动从花呗里扣款", "我没用花呗买东西咋扣费"], "neg": ["借呗 银行卡错了怎么办", "这个没有付款,但花呗叫我还款", "密付可以使用花呗支付吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.146, 0.111], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款收不到验证码怎么办", "pos": ["还款花呗,手机停机收不到短信", "花呗还款要手机验证码,我手机号码没用,收不到怎么办"], "neg": ["花呗怎么在淘宝使用", "假如说我的花呗逾期一天,会怎样,有没有负面影响", "借呗还了之后能再借吗", "到期蚂蚁花呗会自动扣款吗", "借呗可以越期吗", "蚂蚁借呗逾期***月有多大影响", "为什么我的花呗还了,然后我我买东西要申请退款,钱原来退回花呗,有什么看不到***块多的,那哪里去了", "花呗没有临时额度为什么可以为负"], "pos_scores": [1.0, 0.95], "neg_scores": [0.059, 0.054, 0.116, 0.069, 0.148, 0.074, 0.077, 0.189], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收款码失效", "pos": ["蚂蚁花呗收款码不能收钱了", "怎么现在不能花呗收款"], "neg": ["花呗冻结了怎么办", "借呗说已经到账。银行卡并没有收到短信", "***月*** 有没有临时花呗额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.072, 0.071], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期利息计算方式", "pos": ["花呗逾期一两天会影响什么?利息是多少", "花呗逾期***天利息怎么算"], "neg": ["花呗多还钱了,怎么搞的", "借呗可以推迟一期还款吗", "进入花呗总是显示服务繁忙怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.058, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗借款期限选项", "pos": ["借呗借款期限只有***个月和***个月", "借呗怎么只有六个月十二个月的周期选项"], "neg": ["为什么确认收货了 花呗没账单", "蚂蚁借呗逾期几天罚息", "我的是花呗,是极速退款的", "如何改蚂蚁花呗的还款日期", "借呗完善不了信息", "小黄单车押金用花呗支付 退款到哪里"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.073, 0.101, 0.167, 0.093, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗借呗消失原因", "pos": ["花呗借呗都不见了", "我的蚂蚁花呗和借呗怎么不见"], "neg": ["商家收款 支付方用花呗支付 现在没收到款", "我用了花呗付款怎么换钱", "我为什么没有花呗呀", "蚂蚁借呗分期了之后还能分期吗", "借呗借钱了,影响以后在银行贷款吗", "成年当天开通花呗可以吗", "借呗礼包在那里看", "我想把花呗里面的钱,转到支付宝里面能吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.099, 0.088, 0.19, 0.196, 0.063, 0.168, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭花呗后能否重新开通", "pos": ["花呗关了还能开吗", "花呗关闭之后。还可以再次开通吗"], "neg": ["为什么我的借呗借不借了", "怎么样申请花呗的临时额度", "花呗还款金额错误怎么办", "这一次借呗的红包有效期", "一句话,不耽误你时间,我取消这次交易,永远关闭花呗", "我的花呗逾期了,我现在还款了,花呗还可以使用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.084, 0.125, 0.118, 0.178, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无法开通原因", "pos": ["蚂蚁借唄怎么补补不能开通", "蚂蚁借呗 我为什么开通不了"], "neg": ["是不是我这种手机号不是自己的不能开通花呗", "带有花唄标示的是什么样的", "花呗分期有哪几种", "借呗全款还", "花呗是怎么付款的,是不是直接从支付宝付钱", "蚂蚁借呗开启和信用分数有关系吗", "花呗十号还款怎么十号没过就显示我逾期了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.092, 0.058, 0.116, 0.131, 0.099, 0.072, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账单周期说明", "pos": ["花呗记账周期", "花呗账单是***号到***号为一周期吗"], "neg": ["芝麻信用和和花呗怎么重新授权", "花呗收款吗开通之后什么时候能支持花呗信用卡收款", "借呗手动还款不是还分期的", "十二点以过为什么借呗还没扣款,银行卡己开通网银"], "pos_scores": [1.0, 0.95], "neg_scores": [0.183, 0.177, 0.071, 0.111], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查询花呗开通状态", "pos": ["我的账号花呗开通了吗", "现在我的账户没有开通花呗吗"], "neg": ["蚂蚁借呗超过还款,日中午***点,有影响吗", "花呗上个月有分期,这个月要取消分期付款,还全款可以吗", "现在花呗提示我不够权限,下个月就又可以开通了,这些反复的这怎么样", "身份证过期对借呗开通有影响吗", "app怎么退到花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.067, 0.17, 0.061, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗能否用于外卖支付", "pos": ["花呗可以在支付宝里点外卖嘛", "外卖商家可以使用花呗嘛"], "neg": ["花呗临时额度的有效期在哪", "花呗可用数目是怎么算的", "借呗利息分期是算一共借多少吗", "为什么我的借呗不能循环使用", "蚂蚁借呗没通过啥意思"], "pos_scores": [1.0, 0.95], "neg_scores": [0.156, 0.192, 0.089, 0.058, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "网商贷切换回借呗方法", "pos": ["不想要网商贷怎么变回借呗", "怎样把网商贷换回借呗"], "neg": ["借呗还款日都没有短信通知", "怎么改借呗借款限额", "花呗没有花过***多这么多钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.108, 0.135, 0.126], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗延期还款政策", "pos": ["蚂蚁借呗能延长还款时间吗", "借呗可以延后还款吗"], "neg": ["用花呗买的东西没到没账单怎么还款", "为什么花呗还不能用", "花呗***号可以还款", "花呗每月按时还款,收息吗", "借呗可以申请吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.191, 0.059, 0.131, 0.148, 0.124], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款日次日还款影响", "pos": ["我的蚂蚁借呗可以明天还吗 今天是还款日", "借呗还款日,今天到期,我明天还,会有影响吗"], "neg": ["我的花呗现在不可以用了,那么以后还可以用吗", "蚂蚁借呗首月免息吗", "我原来的花呗怎么登不上去", "我蚂蚁花呗的借呗都不可以用", "花呗刷脸通不过"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.135, 0.162, 0.186, 0.077], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗消费还款日期查询", "pos": ["一号花的花呗什么时候还款", "本月一号使用花呗,次月几号还款"], "neg": ["怎么更换蚂蚁借呗银行卡", "为借呗申诉", "花呗临时怎么提额", "我的借呗什么时候可以恢复", "我借呗借了五万可以分期付款吗", "昨天在借呗借了一百元!今天怎么还没到卡里", "添加店员可以花呗收款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.106, 0.178, 0.08, 0.189, 0.173, 0.08, 0.052], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "美团无法使用花呗支付", "pos": ["花呗用不了美团是怎么回事", "我美团点好餐,花唄怎么支付不了"], "neg": ["不记得是不是花呗付的款", "我是商家想查询花呗服务是否开通", "为什么我的花呗额度是***,但是刷不了", "经常使用支付宝,可是为什么开通不了花呗", "蚂蚁借呗还款日期能改吗", "借呗还款没到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.123, 0.133, 0.099, 0.17, 0.134], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查看花呗还款明细", "pos": ["如何查看花呗还款的具体", "怎么查看还清花呗"], "neg": ["花呗可以先下使用吗", "为什么我的借呗利息是万五呀,我朋友的就是万三", "我想开通下花呗给申请下呗", "借呗什么时候恢复评估", "花呗分期还款可以提前还", "合同违约花呗", "蚂蚁借呗的钱能买房"], "pos_scores": [1.0, 0.95], "neg_scores": [0.139, 0.106, 0.182, 0.179, 0.052, 0.179, 0.132], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "隐藏借呗入口方法", "pos": ["可以把借呗这个标栏隐藏起来,别人看不见", "借呗工具栏能隐藏起来吗"], "neg": ["提前还款还是算上分期付款那些账单的利息吗", "为什么我的花呗额度能为负", "花呗还款分期了可以改吗", "为什么花呗交电费了马上就退款", "蚂蚁借呗还款日期是怎么计算的", "借呗今日借款哪天还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.153, 0.094, 0.082, 0.139, 0.077], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗分期还款金额计算", "pos": ["蚂蚁借呗借***一个月还多少", "蚂蚁借呗借******个月还。一个月还多少"], "neg": ["为什么我的花呗当时借了我就还了 怎么现在还要我还", "蚂蚁借呗提示当前存在逾期", "为什么我的花呗可用度只有***", "今天借呗还款,怎么余额宝里没有扣款", "开通花呗显示手机验证未通过", "如何删除花呗里面以出账的记录", "支付宝小黄车押金用不了花呗付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.102, 0.188, 0.084, 0.119, 0.178, 0.125, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "永久关闭借呗流程", "pos": ["我想永久关闭蚂蚁借呗", "永久关闭借呗"], "neg": ["完善蚂蚁借呗信息", "二唯码可以收款花呗吗", "借呗最多能上升到多少", "借呗可以借几期", "怎么可以再开蚂蚁借呗", "为什么商家支持花呗分期!我确不能分期", "花呗显示为负数,还可以使用么", "花呗红包用手机支付宝付款时为什么没有自动抵扣"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.069, 0.151, 0.179, 0.077, 0.135, 0.135, 0.053], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款操作指南", "pos": ["要还花呗的钱,怎么操作,昨晚用的", "我要花呗还款,怎么操作"], "neg": ["为什么我的花呗额度被调低了那么多", "借呗无法一次还清还能分期吗", "能查查我欠不欠花呗里面的钱吗", "花呗可以透支分期吗", "借呗每月几号评估一次额度", "怎么查看我的花呗账单", "花呗还款日期每个人都一样吗", "如何查询还款借呗帐单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.087, 0.067, 0.181, 0.063, 0.115, 0.126, 0.059, 0.095], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗关闭后退款处理", "pos": ["花呗里有退款,但是后来我关闭了,退款怎么办", "花呗已经关闭,如果有退款怎么办"], "neg": ["为啥我信用那么高,我的花呗和借呗那么点", "花呗怎么转移手机号", "之前不需要钱用就有借呗", "为什么花呗未还显示为***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.157, 0.16, 0.064, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗免息券使用规则", "pos": ["我有一个***天免息的借呗券", "借呗***天免息券"], "neg": ["我蚂蚁借呗按期还完之后额度怎么没有回复", "怎么找回另外一个有花呗的账号", "我的钱怎么去了花呗?怎么弄的", "我想知道花呗的未出账怎么查详单", "为什么在支付宝上设置了花呗在淘宝上用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.14, 0.159, 0.192, 0.154], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗缴纳电费限制", "pos": ["缴费电费不能用花呗", "花呗不可以付电费了吗"], "neg": ["花呗为什么没有体验额", "花呗账单金额是什么", "我的蚂蚁借呗利息去哪里看", "我的花呗分期付款的,,为什么不管来借呗", "我用花呗分期付了", "戒备怎么关闭"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.181, 0.087, 0.081, 0.122, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "更改花呗绑定支付宝账号", "pos": ["花呗绑定的支付宝能改吗", "花呗换绑定的支付宝"], "neg": ["显示,钱已经退到花呗里了", "花呗消费的账单被删掉了,怎么找回来", "借呗要时逾期一天怎么办", "借呗额度还清整笔之后恢复吗", "花呗不能使用美团了么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.119, 0.076, 0.068, 0.085], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗充值加油卡支持", "pos": ["花呗充加油卡吗", "可以使用蚂蚁花呗充值加油卡吗"], "neg": ["为什么花呗还款日还没到 就显示我逾期了", "怎么找不到花呗签约", "电脑端怎么设置花呗还款顺序"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.082, 0.184], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款去向查询", "pos": ["拿花呗支付的退款后钱是退到哪儿", "花呗买的东西,退款就退到哪了"], "neg": ["为什么我点击花呗应还款这里没有反应,点击可用额度有反应", "为什么显示我没有花呗", "为什么我不可以办花呗,什么原因", "我的蚂蚁借呗刚才领了一个什么礼包。额度就涨到***了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.059, 0.067, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "芝麻信用分不提升原因", "pos": ["芝麻信用分怎么提升不了", "为啥不加芝麻分了"], "neg": ["换新手机号码怎么授权花呗", "花呗开通不用有额外费用吗", "我在淘宝上买东西用花呗钱了然后我申请退款是不是退花呗里面", "我的花呗还能开通了么", "信用卡花呗收取服务费是什么", "没有人工客服吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.139, 0.05, 0.195, 0.098, 0.131, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度与芝麻分关系", "pos": ["花呗额度与芝麻分有关系吗", "信用分跟花呗有关吗"], "neg": ["花呗如何一次还清分期付款", "我要开通花呗 买手机", "我欠款已经还清了,淘宝上退货的款直接退到花呗,这个款在哪看到"], "pos_scores": [1.0, 0.95], "neg_scores": [0.111, 0.178, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗不显示解决方法", "pos": ["借呗也不显示了", "我为什么不显示蚂蚁借呗"], "neg": ["花呗自主还款最晚几点", "点了分***期 钱从花呗一次性扣的", "我从借呗借的钱,光显示放款中,什么时候能到账", "花呗分期能修改呗", "帮我查询一下 花呗什么时候可以用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.194, 0.124, 0.087, 0.131], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无法提现原因", "pos": ["嗨为什么不借呗提现了", "为什么我支付宝借呗不能取款了"], "neg": ["花呗申请后一次性还清", "缴费电费不能用花呗", "还有花呗不涨", "怎么看以前的借呗记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.198, 0.181, 0.15, 0.145], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗转账到其他账户", "pos": ["我想把花呗转到这个支付宝账号上,怎么操作", "花呗怎么转到另个账户"], "neg": ["另一个不常用的支付宝以前开通了花呗,我刚刚关闭了,我想开通在现在用的多的支付宝号上", "花呗未出账单不可以分期", "花呗关闭后,再开启里面的余额还是一样吗", "我用花呗可以用多少钱", "为什么我买的时候在余额里扣钱了,现在花呗还让我还钱", "如何主动关闭蚂蚁借呗", "如何清除所有花呗交易记录", "星力超市可以用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.097, 0.114, 0.053, 0.148, 0.089, 0.09, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "退款退回花呗原因", "pos": ["为什么把钱退到花呗", "给我退到花呗上了"], "neg": ["花呗额度不是最低***吗", "我强,,蚂蚁借呗和蚂蚁微贷是不是一个产品", "花呗手机号多少", "临时额度用完后会不会对以后使用花呗有影响", "多少额度才能开蚂蚁借呗", "借呗额度提升流程", "有借呗入口", "花呗分期购提前还款收不收手续费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.162, 0.148, 0.147, 0.071, 0.068, 0.068, 0.19, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "登录指定支付宝使用花呗", "pos": ["您可登录支付宝号使用花呗", "您可登录某支付宝使用花呗"], "neg": ["我的花呗到期后会从余额宝自动扣款吗", "怎么样才能得到花呗临时额度", "比如我今天月末最后一天用的花呗下个月月初才发货那我的花呗是下个月还还是下下个月还", "我花呗还清款", "花呗还款可以选择用银行卡吗", "花呗买单限额"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.095, 0.191, 0.182, 0.169, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "快速提升花呗额度方法", "pos": ["怎么快速提高花呗额度", "蚂蚁花呗可用额度怎么样才能提高"], "neg": ["收款花呗有没有限制", "我的借呗还款日是***月***日", "我的花唄額度", "借呗分期还可以更改时间么", "花呗钱已还清,退回花呗的钱为什么不还给我", "花呗授权在哪"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.158, 0.119, 0.125, 0.064, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭网商贷流程", "pos": ["我把网商贷取消", "我把网商关掉了"], "neg": ["花呗临时额度为什么用", "为什么我现在的支付宝用花呗要用以前的号码", "显示,钱已经退到花呗里了", "借呗借来的钱可以还花呗吗", "为什么同一个手机号登陆的花呗还显示要手机号登陆", "开通了花呗,收款收不了款是什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.058, 0.14, 0.052, 0.165, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗最新活动查询", "pos": ["今天借呗有活动", "借呗里面的活动"], "neg": ["我***点左右在借呗借的钱怎么现在还没到账", "账单都还清楚了!怎么花呗不能用了", "花呗抽了一个天猫券 不知道从哪里再找出来", "商店或超市购物可以用花呗支付吗", "我现在怎么开通花呗", "我的花呗是昨天还的,今天可以使用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.182, 0.172, 0.18, 0.139, 0.117], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "无法登录花呗账户解决", "pos": ["您说的是您登录不上您开通花呗的 支付宝账户", "您说的是您登录不上您开通花呗的"], "neg": ["手机丢失导致借呗逾期了。现在怎么办", "借呗最低还款还了吗", "借呗里面还有三百", "花呗分期免息卷在哪里兑换", "为什么花呗不能滴滴打车和美团了", "蚂蚁借呗进不去怎么还", "借呗这么变成网商货了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.188, 0.111, 0.068, 0.163, 0.16, 0.085], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗提前还款操作", "pos": ["蚂蚁借呗怎么每期提前还", "借呗能提前还清吗"], "neg": ["所以请把花呗关掉", "我这个是花呗红包", "蚂蚁借呗三个月了不提额", "为什么我绑定银行卡了 不能使用蚂蚁花呗", "花呗有额度但是淘宝显示花呗分期暂不可用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.082, 0.103, 0.095, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗利息计算公式", "pos": ["蚂蚁借呗利率怎么计算", "详细说算下蚂蚁借呗利息"], "neg": ["我的蚂蚁借呗还没通过", "我为什么不可以蚂蚁借呗", "蚂蚁花呗有问题,我明明已经还款了,为什么还显示我没有还", "为什么总让我用另一个支付宝用花呗", "我的借呗可以通过吗", "为什么我花呗只能扫***", "花呗提示我可以提额是怎么回事", "为什么我的花呗的额度半年都没有升过了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.124, 0.067, 0.151, 0.089, 0.11, 0.184, 0.179, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "提高借呗额度方法", "pos": ["我怎么才能提高借呗金额", "借呗额度怎样提升"], "neg": ["花呗分期和不分期不一样吗", "我是用建行卡付的款,他却退到了花呗上", "刚才一家店里吃了拉面 我用花呗付过款了 为什么人家没收到钱", "怎么再次用借呗", "花呗最低还款的利息咋算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.123, 0.092, 0.18, 0.099], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭新开通的花呗", "pos": ["怎么关闭刚开的花呗", "花呗怎么关掉我不用花呗了"], "neg": ["我的蚂蚁借呗最迟到几号还款", "银行留的手机号已经改了,但是为什么支付宝借呗的验证码还是原来的手机号", "我的花呗还款日期可以改为***号吗", "我用我的微信付款了咋就成了还花呗的钱那我的钱去那里了", "我的花呗之前逾期了,现在什么时候能用", "为什么我蚂蚁借呗给我取消啦", "我购物刷的都是银行卡,为什么花呗还有账单?我购物刷的都是银行卡,为什么花呗还有账单?我购物刷的都是银行卡,为什么花呗还有账单?我购物刷的都是银行卡,为什么花呗还有账单", "调花呗快捷支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.198, 0.194, 0.149, 0.068, 0.107, 0.054, 0.071], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "重新开通花呗流程", "pos": ["开通下花呗", "从新开通蚂蚁花呗"], "neg": ["现在蚂蚁借呗还可以开通吗", "刚刚提示了关于借呗提额了", "花呗是到期自动从银行卡或余额宝扣款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.072, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗开通条件不满足", "pos": ["为什么说不满足花呗开通条件", "开花呗,显示不符合条件"], "neg": ["花呗分期低一点", "请把蚂蚁借呗发过来", "借呗能分期还嘛", "蚂蚁借呗可以申请", "花呗能帮我开一下吗", "花呗已经激活"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.056, 0.155, 0.161, 0.078, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "符合条件但无法开通借呗", "pos": ["借呗开通的条件都满足了 为啥开不了", "为什么我不能开通借呗呀"], "neg": ["网贷商和蚂蚁借呗", "***元花呗分***期还款每个月还多少钱", "借呗还款日能否更改", "我在蚂蚁借呗上借的钱,然后转到银行卡里了,但是我的卡挂失了,我的钱去哪了", "以前能用花呗交电费,现在不能了", "蚂蚁借呗多久放款", "除了蚂蚁借呗,还有可以借钱", "为什么我的花呗显示暂时无法使用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.086, 0.159, 0.094, 0.148, 0.065, 0.081, 0.114], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗和花呗区别", "pos": ["借呗和花呗有冲突吗", "借呗和花呗都一样的吗"], "neg": ["积分要怎样才可以兑换蚂蚁花呗额度", "额度怎么会降定借呗", "修改花呗手机怎么吗", "花呗如何返现花呗如何返现花呗如何返现", "借呗什么时候可以分期还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.19, 0.19, 0.106, 0.129], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "网商贷与借呗关系", "pos": ["网商贷和借呗是一个平台吗", "网商银行和借呗是一起的吗"], "neg": ["为啥不能花呗账单分期", "我可以用微信还花呗钱嘛", "花呗消费会有什么提醒", "能不能不要网商贷要借呗", "花呗分期后一次还清需要手续费么", "花呗分期可以提前还款完吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.086, 0.083, 0.174, 0.107, 0.098, 0.113], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗支付抽奖活动", "pos": ["花呗支付成功,怎么抽奖", "花呗还能抽奖"], "neg": ["商家支持花呗我不能用花呗", "然后还款的时候我忘了是用花呗退款的", "花呗临时额度怎么是自动领取的", "为什么我花呗用不了了", "蚂蚁借呗***月份我还了吗", "为什么我的借呗申请人太多了,叫明天在来"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.17, 0.185, 0.197, 0.087, 0.084], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "滴滴支持花呗支付吗", "pos": ["花呗能否滴滴", "滴滴花呗支付"], "neg": ["借呗中可以分十二期还吗", "淘宝退款退回花呗钱怎么看不到", "收钱码开通信用卡花呗付款", "但我付款时,却让我开通花呗,否则付不了款", "借呗的借款记录怎么删除"], "pos_scores": [1.0, 0.95], "neg_scores": [0.191, 0.102, 0.142, 0.088, 0.064], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "退款转入花呗原因", "pos": ["为什么我的退款***元会转入花呗", "小蜜,退款怎么会转花呗里"], "neg": ["我没有逾期还款,也一直用花呗,信用也很好,为什么额度那么少", "我是还蚂蚁借呗", "怎么我还不上借呗呀", "收款码收花呗", "我再办一张银行卡可以开通花呗吗", "借呗逾期***天有事情吗", "借呗提现到卡怎么查"], "pos_scores": [1.0, 0.95], "neg_scores": [0.127, 0.097, 0.058, 0.126, 0.146, 0.067, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "提升花呗额度方法", "pos": ["我的花呗额度能调高吗", "花呗额度怎么提升"], "neg": ["昨天晚上还花呗 银行卡已经扣款 但是花呗没有显示成功", "花呗可不可以还款", "我的蚂蚁花呗已经还最低还款 为什么还自动扣款", "花呗可以调临时额度吗", "为什么花呗不能滴滴打车和美团了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.165, 0.094, 0.127, 0.132], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查询借呗账单", "pos": ["查一下蚂蚁借呗的账单", "我的借呗账单在哪里找"], "neg": ["为什么我的收钱码不能用花呗", "其他付款方式怎么转到花呗", "预期的钱还完了还能使用借呗吗", "花呗分期如何取消", "开通花呗后不使用会产生费用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.138, 0.138, 0.132, 0.119, 0.186], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "个人花呗收款资格", "pos": ["我现在可以向客户花呗收款吗", "个体可以花呗收款吗"], "neg": ["***月一号用的花呗,什么时候还", "如何完善借呗额度申请资料", "意思说支付宝里的花呗,我不可以用现在的支付宝还款吗", "花呗消费返还是什么", "使用购物津贴,还能用花呗吗", "全款还了蚂蚁借呗能恢复吗", "有了网商贷还能有借呗么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.114, 0.199, 0.188, 0.187, 0.087, 0.183], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "恢复借呗额度方法", "pos": ["我的借呗什么时候恢复额度", "怎么样才能恢复借呗额度"], "neg": ["花呗冻结什么时候才能使用", "蚂蚁借呗还款日主动还款怎么还", "不是本人身份证可以开通花呗嘛", "我,修改借呗还款日", "花呗能设置还款提醒吗", "我能不能撤消开通花呗", "广西农村信用卡可以开通花呗吗", "花呗可以***号还钱吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.132, 0.131, 0.197, 0.157, 0.127, 0.095, 0.162], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗短期借款选项", "pos": ["借呗没出现短期借款", "借呗短期在哪选"], "neg": ["小蓝单车以前用花呗交的押金,花呗还了,退款退回哪", "花呗,如果到期无法还款可以延期还款吗", "花呗返现***元"], "pos_scores": [1.0, 0.95], "neg_scores": [0.109, 0.08, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "有贷款资格但无法开通借呗", "pos": ["为什么我有现金贷款资格 借呗开通不了", "为什么我开不了借呗呀"], "neg": ["我以开通花呗收款,为什么客户还是用花呗付不了款", "借呗上月***日借的,什么时候还款", "借呗为什么不能借一年了", "已经关闭了花呗怎么另外一个账户还让我到这个账户来看", "花呗可以接到哪些银行卡机"], "pos_scores": [1.0, 0.95], "neg_scores": [0.128, 0.107, 0.152, 0.148, 0.156], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账单不能分期原因", "pos": ["花呗账单为什么不能分期还款", "花呗账单不可分期了吗"], "neg": ["登录了支付宝,还是登录不了花呗", "我的花呗为什么没有最低还款", "花呗最低还款剩余部分可以提前结清吗", "我花呗里的钱充话费充错了能退回吗", "怎么开了借呗", "实名认证过期了蚂蚁借呗就不能用了吗", "如何申请能用花呗的收钱码"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.148, 0.05, 0.089, 0.175, 0.076, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无法借款原因", "pos": ["借呗***为什么不可以借", "天天用支付宝 为什么借呗用不了"], "neg": ["什么时候给我恢复花呗嘛,我真的很需要这个功能", "我明明付钱了,为什么确认收货后花呗上还要多次付款", "花呗,余额可以用吗", "我的借呗怎么最低分期***个月", "借呗怎么改自动还款账号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.163, 0.16, 0.169, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗额度评估周期", "pos": ["蚂蚁借呗多久重新可以评估额度", "蚂蚁借呗能提额多久能评估"], "neg": ["也就是我不用花呗了", "支付宝交电费怎么不支持花呗了", "花呗添加银行卡"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.077, 0.143], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款日说明", "pos": ["花呗还款日为***号或***号,有可能是***号,请以页面提示为准", "花呗还款日 一般是每月的***号或者***号哈 具体"], "neg": ["花呗换前结清", "花呗用不了,多久能使用", "为什么我已经推过了蚂蚁花呗上面还是有金额产生", "蚂蚁借呗没有去还款", "花呗密码和支付密码能不一样吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.064, 0.188, 0.19, 0.082, 0.191], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款后仍需还款", "pos": ["都已经退款了,为什么花呗上还是有这个要钱要还", "已退款到花呗,怎么每个月还要还"], "neg": ["花呗支付快,还是余额宝支付快", "花呗的钱怎么还不上去", "为什么我的支付宝账号没法开通花呗收钱", "蚂蚁花呗不是免费体验的吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.109, 0.115, 0.136, 0.185], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账单金额不符", "pos": ["为什么我花呗数对不上", "为什么我的花呗三月有***要还"], "neg": ["如何查花呗账单看还款情况", "我可以使用支付宝花呗吗", "使用花呗分期付款,退款后依然需要付手续费", "借呗最最低还多少", "如果蚂蚁花呗脱期了分期还有用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.131, 0.063, 0.117, 0.165, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期额度不足", "pos": ["为什么我的花呗分期额度显示余额不足", "花呗分期总是提示额度不足是什么意思"], "neg": ["为什么付款了之后花呗还要收我的钱", "花呗分期购物的花呗额度需要大于物品金额么", "已经付款到花呗,为什么还有", "用花呗充值苹果商店", "商家花呗额度怎么提高", "花呗做了分期可以提前还吗", "怎么显示不了借呗金额了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.072, 0.16, 0.06, 0.121, 0.094, 0.101, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗放款时间", "pos": ["借呗的钱什么时候到", "为什么在借呗借钱要那么久还不到账"], "neg": ["花呗帮还,我能有积分吗", "如果上次花呗一次性还完,还可以在分期吗", "借呗要满多少分可以借"], "pos_scores": [1.0, 0.95], "neg_scores": [0.108, 0.148, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款提醒设置", "pos": ["借呗有还款信息提醒吗", "借呗还款不提醒吗"], "neg": ["花呗交电费付不了款", "花呗的钱可以还款", "退款了花呗还款不变", "怎么突然花呗少了这么多"], "pos_scores": [1.0, 0.95], "neg_scores": [0.194, 0.08, 0.2, 0.15], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗借款期限变更", "pos": ["借呗的期限怎么变短了", "借呗可以借***期,现在怎么只可以借***期了"], "neg": ["花呗逾期一天咋办", "花呗有支付宝有分期付款的车吗", "花呗每日限额使用么", "我是商家能能用花呗收款吗", "花呗怎么能调额", "怎样用支付宝花呗买淘宝东西", "借呗还款 我可以一起还吗", "卖东西花呗图标是什么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.084, 0.155, 0.179, 0.1, 0.147, 0.111, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "降低借呗利息方法", "pos": ["蚂蚁借呗怎么把利息降低", "借呗怎么降低利息"], "neg": ["几号自动扣除花呗", "蚂蚁借呗也能提前还款", "我的花呗怎么降低额度了", "为什么我的花呗不能用了?还有借呗", "花呗用户立减***元", "也弄删除花呗账单", "我的蚂蚁借呗现在能不能提额", "蚂蚁花呗可以找朋友代付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.123, 0.103, 0.129, 0.172, 0.119, 0.063, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗分期选项消失", "pos": ["为什么借呗不能十二个月分期还款了", "蚂蚁借呗***期还款怎么不能选了"], "neg": ["从借呗借款影响从银行贷款吗", "我刚刚是不是用花呗消费了", "我的花呗提醒还了,但又没扣银行卡钱", "使用借呗会影响银行贷款吗", "蚂蚁借呗的记录了", "我申请了退款怎么花呗额度没有恢复", "我的花呗怎么显示的以前的手机号", "我的花呗到期了没还进去怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.085, 0.092, 0.146, 0.18, 0.093, 0.151, 0.09, 0.157], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "解冻花呗账户方法", "pos": ["花呗冻结了怎么把花呗恢复", "花呗冻结了 给我开开"], "neg": ["蚂蚁花呗账户怎么解冻", "电费怎么不可以用花呗了", "而且现在淘宝买东西也不可以用花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.079, 0.178], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "删除借呗还款记录", "pos": ["如何删除支付宝借呗的还款记录", "我想你给我蚂蚁借呗记录删除"], "neg": ["花呗分期可以一次还请吗", "花呗已还款。怎么再次开通", "我的花呗提示还款失败,怎么回事", "我是我是花呗,我是我,是花呗花呗分期分期付款付款,商品", "商家开通花呗啦,我就买啦一件", "我咨询下 为什么我的退款金额已经退回到花呗了 花呗的还款金额还是有", "开通花呗后未使用会产生费用吗", "实|店可以用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.07, 0.121, 0.173, 0.092, 0.106, 0.187, 0.089, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗被关闭原因", "pos": ["我不明白为什么停了我的借呗", "我的借呗给我关了干嘛"], "neg": ["蚂蚁借呗蚂蚁花呗可以一块吗", "蚂蚁借呗还款算法", "花呗的账户出错,怎么改回来", "花呗已还款,仍提示我需要还款", "蚂蚁借呗能提***吗", "到时候我把蚂蚁借呗给关了", "花呗未出账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.111, 0.148, 0.097, 0.107, 0.098, 0.145], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗功能异常", "pos": ["为什么我的借呗用不了", "为什么我的借呗不能接了"], "neg": ["借呗关闭了!还可以再开通吗", "蚂蚁花呗的二维码", "花呗逾期还了过了快一年还有影响吗", "借呗临时提额", "为什么我花呗和借呗那么少", "借呗最后最后还款日"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.093, 0.151, 0.145, 0.12, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "查询花呗开通状态", "pos": ["我的花呗开没开", "我这个支付宝还没开通花呗吗"], "neg": ["借呗审核不通过用什么原因", "花呗可以付款腾讯会员吗", "花呗提额为什么不能提了", "借呗上个月***号借的款为什么到下个月***号就得还款了,问题是还没有到一个月"], "pos_scores": [1.0, 0.95], "neg_scores": [0.118, 0.08, 0.19, 0.19], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗自动扣款失败", "pos": ["借呗 没有自动扣款", "为什么借呗到现在还没有扣"], "neg": ["借呗还款不对", "花呗还款,一定要设置银行卡吗", "我花呗预期还款都还完了怎么不能用了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.132, 0.092, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "新账号开通花呗限制", "pos": ["怎么可以再另一个账户开通花呗", "我的这个账号花呗关掉了,我得新账号怎么不能开通花呗"], "neg": ["花呗退了款 下个月还用还吗", "退回花呗的钱我有收到信息,但是退款金额里没有看到", "你能看出来我花呗有没有分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.177, 0.19, 0.108], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款日期查询", "pos": ["花呗还款详情时间", "花呗每月几时还款"], "neg": ["花呗还款找怎么分期", "借呗还款会通知么", "花呗的可用额度什么意思", "花呗dut额度怎么提高", "花呗分期,优惠卷", "花呗账户怎么切换手机号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.073, 0.079, 0.108, 0.167, 0.116], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "商家花呗收款失败", "pos": ["安安,我申请的花呗收款怎么不能用", "我的商家收款二维码为什么花呗支付不了"], "neg": ["为什么借呗之前可以借***个月的,现在最多只有***个月的", "怎样申请花呗临时额度", "花呗分期有手续费没", "花呗支付怎么用", "花呗还款能不能不从余额里面扣"], "pos_scores": [1.0, 0.95], "neg_scores": [0.151, 0.149, 0.19, 0.117, 0.117], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗额度消失原因", "pos": ["我的蚂蚁借呗怎么沒有了,我要借钱", "我的借呗以前有两万多的,现在为什么没有了"], "neg": ["花呗付款的限制", "什么的商品可以用花呗", "花呗分期如何提前"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.15, 0.19], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期服务费发票", "pos": ["花呗分期服务开票", "花呗分期服务费开票类型"], "neg": ["花呗购买有短信吗", "商家花呗服务怎么开通", "提高芝麻借呗的额度", "为啥我用不了蚂蚁借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.072, 0.18, 0.127, 0.102], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗当月账单分期", "pos": ["借呗本月的款可以分期吗", "借呗分期可以"], "neg": ["借呗额度停用和身份证过期有关吗", "我看我的蚂蚁借呗可以转两万元到我卡上是真的吗", "花呗啥时候还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.069, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗支付电费限制", "pos": ["不可以用花呗支付电费了么", "用花呗可以叫电费嘛"], "neg": ["花呗到了时间就自动扣款吗", "我这个为什么开不成花呗", "为什么我美团都使用不到花呗", "我的花呗***月账单上个月就还了", "花呗逾期后还了款,无法使用,该如何才能恢复"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.101, 0.19, 0.121, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗入口消失", "pos": ["找不到我的花呗了", "为什么我的蚂蚁花呗找不到了"], "neg": ["花呗怎么转到另一个账号", "为什么我花呗没有临时额度了", "使用花呗有哪些好处", "花呗还款用微信红包可以吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.177, 0.102, 0.136, 0.134], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "网商贷与借呗额度对比", "pos": ["蚂蚁借呗升级为网商贷后额度是一样的吗", "网商贷是否跟原有的借呗一样的额度吗"], "neg": ["连锁超市可以用花呗吗", "花呗账单怎样查看", "没有营业执照怎么办花呗二维码收款", "为什么我又有花呗账单", "为什么我的收钱码不能用花呗", "花呗怎么向支付宝付款", "我把钱放到了余额宝咋样才能转到蚂蚁借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.13, 0.095, 0.068, 0.09, 0.134, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期提前结清", "pos": ["花呗已经办理分期还款,现在怎么样才可以提前一次性能还清", "花呗分期提前还清 在哪里"], "neg": ["可以用支付宝花呗还信用卡吗", "我的蚂蚁借呗怎么才可以继续使用", "借呗余额减少了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.198, 0.133], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "淘宝花呗与支付宝花呗区别", "pos": ["我没有支付宝花呗,只有淘宝花呗", "淘宝上有花呗为什么支付宝上没有"], "neg": ["花呗的余付款怎么能还上", "支付宝花呗昨天交天然气费交易成功没到账", "卖家开通了花呗付款,买家为什么不能使用花呗付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.194, 0.083, 0.169], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "美团支持花呗支付吗", "pos": ["我用花呗在美团订餐饮", "花呗可以在美团用嘛"], "neg": ["我想取消花呗怎么取消的", "什么是花呗标志", "请帮我关闭***的花呗", "怎么取消借呗的自动还款", "dc花呗分期修改", "我用了花呗的钱怎么还上", "花呗用积分兑换的免费提现额度是什么意思", "我要停花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.185, 0.16, 0.097, 0.172, 0.13, 0.155, 0.085, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "取消借呗身份证验证", "pos": ["我刚才注册了蚂蚁借呗我想取消我扫的身份证", "蚂蚁借呗怎么取消验证身份证"], "neg": ["花呗分期,提前还清下个月,没找到提前还款", "借呗转账可以取消吗", "我的花呗账单已还清,为何还自动扣我钱", "我想还四月的花呗账单", "借呗现在的还款日是几号", "什么时候能设置花呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.116, 0.194, 0.144, 0.122, 0.072, 0.062], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗火车票退款差额", "pos": ["火车票花呗退款金额差额很大", "支付宝火车票退款与花呗不符"], "neg": ["花呗逾期就不能申请分期还款了吗", "邀请朋友开借呗", "借呗还款逾期一天怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.102, 0.145, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗自动还款失败怎么办", "pos": ["【花呗】你的支付宝******@qq.com花呗***月账单自动还款失败,***元未还,请登录“支付宝-花呗”核对并还款,已还忽略", "【花呗】你的支付宝******@qq.com花呗***月账单自动还款失败,***元未还,请登录“支付宝-花呗”核对并还款,已还忽略"], "neg": ["怎么样能把借呗关闭", "花呗分期期间提升可用额度吗", "我没有银行卡怎么还花呗款", "我今天在淘宝下了个单。交易关闭但花呗又扣款,怎么办", "我用花呗买东西,但是我的钱被扣了,为什么花呗还要还钱", "花呗付款的押金退回后钱在哪"], "pos_scores": [1.0, 0.95], "neg_scores": [0.181, 0.181, 0.189, 0.061, 0.089, 0.107], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗刷脸验证不成功解决方法", "pos": ["花呗,刷脸没成功", "多次刷脸花呗不成功,咋办"], "neg": ["可以重新給我開通花唄嗎", "蚂蚁借呗始终是放款中", "借呗还能借吗", "信用卡和花呗收不了款了", "花呗集盒子活动在哪里报名", "花呗如何绑定到别外的手机号上"], "pos_scores": [1.0, 0.95], "neg_scores": [0.084, 0.146, 0.194, 0.173, 0.106, 0.113], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗登录问题解决方法", "pos": ["花呗登不进去了", "蚂蚁花呗,怎么登不进去"], "neg": ["为什么我的功能里没有花呗", "现在不能借***个月了吗蚂蚁借呗", "申请借呗是上传身份证,正反面吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.195, 0.167, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度提升方法", "pos": ["花呗怎么提升到几千", "花呗怎么提额"], "neg": ["花呗开通条件有什么为什么说我不符合", "蚂蚁借呗提前还利息i", "蚂蚁借呗只能选择***个月的期限", "花呗分期可以一次多还嘛", "花呗已经还款 为什么发信息说没有还", "花呗临时的额度也能分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.07, 0.113, 0.095, 0.198, 0.115, 0.138], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "商家开通花呗的条件", "pos": ["我是商家,店铺目前无法开通花呗", "我的店铺刚开一个月 想开通花呗 但是好像不能开 对店铺有什么具体要求 才能开通花呗"], "neg": ["我就是蚂蚁花呗,这个需要利息不", "蚂蚁借呗分期了,提前还款一期到期了还会扣钱吗", "花呗按揭退款还会扣手续费嘛", "申请蚂蚁借呗的钱,显示已通过申请,为什么没有看到钱,在哪里看", "蚂蚁借呗降额度了", "借呗最长能分几期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.135, 0.124, 0.186, 0.119, 0.143, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "忘记花呗账号怎么找回", "pos": ["我之前的蚂蚁花呗忘记账号了", "花呗使用的账户忘记是什么号码"], "neg": ["我的退款到哪去了?我是提前还款花呗,但是收到货产生退货退款,卖家退回花呗了,那我的钱到哪去了", "我蚂蚁花呗还款晚了一个小时有影响吗", "借呗没有进度查询", "为什么花呗不能充值华费", "尾款支付使用花呗", "借呗会影响我房贷吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.147, 0.087, 0.175, 0.137, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "如何快速开通蚂蚁借呗", "pos": ["怎样可以快速使用蚂蚁借呗", "我想使用蚂蚁借吧"], "neg": ["花呗分期付款利息那么高", "我的借呗怎么还用了吗", "我借呗和花呗都不能用了", "不可以花呗什么情况"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.083, 0.056, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "未开通花呗但误用如何还款", "pos": ["我还没有开通花呗但是淘宝付款使用了要怎么还款", "我并没有开通花呗,付款按错用了***次,怎么样还款"], "neg": ["蚂蚁花呗双十二有临时额度吗", "我记得蚂蚁借呗可以分***个月还款的", "我在淘宝消费 花呗的上限会涨吗", "花呗扣款的具体时间是什么时候"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.113, 0.17, 0.099], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "如何关闭花呗", "pos": ["我得花呗怎么停用了", "怎么样把花呗关了"], "neg": ["花呗总欠款怎么看", "花呗为啥不能开", "花呗分期付款可以一起还吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.16, 0.175, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期后能否一次性还清", "pos": ["开启花呗账单分期后可以一次还清吗", "花呗分期了以后还可以一次性还上么"], "neg": ["充话费为什么现在不能用花呗了", "借呗还款余额中有钱为什么从卡里扣金额", "为什么付款时用不了花呗", "怎么我的花呗不可以充话费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.189, 0.144, 0.055, 0.139], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期涉及的小贷公司", "pos": ["还有重庆阿里小额小贷、重庆蚂蚁小额小贷 这两个公司也收取了花呗分期", "***重庆阿里小额小贷,***重庆蚂蚁小额小贷 这两个也是花呗分期"], "neg": ["花呗分期免息卷在哪里兑换", "商家开通了花呗支付,为什么不能付款", "我从借呗借的钱转到银行卡里去了现在用不了", "花呗收款已经开通了但是不能用", "蚂蚁借呗还款只能本人还款吗", "关了花呗,什么时候可以重新开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.052, 0.166, 0.198, 0.061, 0.157, 0.121], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "如何查询借呗还款是否成功", "pos": ["我得借呗还款成功了吗", "我九月这期蚂蚁借呗是否还款成功"], "neg": ["该怎么弄!我现在着急用,额度不够淘宝不能花呗分期", "我为什么开不起花呗", "花呗开通后一定要使用吗", "我为什么花呗人脸识别不通过", "借呗能可以把六个月变成***个月还款吗", "花呗,可以提额"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.102, 0.083, 0.103, 0.059, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "如何退出花呗", "pos": ["我还清了花呗怎样退出", "我怎样才能退出花呗"], "neg": ["借呗还款后额度没有恢复还显示借款", "后面几个月的花呗账单怎么查询", "花呗额度也降低了,咋回事", "借呗到期会不会自动还款吗", "花呗与信用卡收钱码", "我别的支付宝账号花呗额度能转移到这个支付宝上吗", "蚂蚁借呗提示暂时不能使用刷脸为什么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.198, 0.065, 0.182, 0.16, 0.198, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度恢复方法", "pos": ["怎么样恢复调整前的花呗额度", "花呗额度怎么恢复原来的花呗钱数"], "neg": ["提前还花呗咋样操作", "花呗余额为什么是负***", "我不是花呗分期吗!我需要全款支付吗", "花呗能分期不", "借呗短期借款不是免息的吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.092, 0.118, 0.058, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗开通资格查询", "pos": ["借呗怎么没资格申请", "借呗怎么还没开通"], "neg": ["花呗分期付款后可以一次付清", "***花呗试用", "花呗付款申请售后怎么退款", "打开花呗的时候显示的系统繁忙是咋回事", "推迟借呗还款", "申请借呗学历要怎么填写", "怎么解冻结的花呗", "在支付宝支付多少次才可以开通花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.095, 0.132, 0.195, 0.196, 0.101, 0.11, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "如何使用借呗", "pos": ["我想用借呗怎么做", "怎么样才可以用借呗"], "neg": ["借呗 预留号码变了 怎么办", "借呗***个月还上还是每个月都还点", "蚂蚁借呗额度关闭了怎么能快速恢复"], "pos_scores": [1.0, 0.95], "neg_scores": [0.128, 0.141, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗付款抽奖活动", "pos": ["之前用完花呗有一次抽奖的机会现在还有吗", "之前花呗付款有抽奖机会的"], "neg": ["花呗返现金额", "这个花呗是以前的号码", "选择了花呗分期的预售商品,尾款可以用红包支付吗", "使用花呗需要额外收服务费吗", "可以取消花呗支付,用别的支付方式吗", "有蚂蚁借呗怎么还要申请", "为什么邦定银行卡,还使用不了花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.067, 0.165, 0.194, 0.162, 0.185, 0.162], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗使用限制解除方法", "pos": ["我的花呗什么时候才可以y", "怎么才可以用回花呗"], "neg": ["求花呗额度要买电话", "花呗怎么算还款的", "花呗付了最低款,剩余的怎么算利息", "关联账户能使用花呗或借呗吗", "为什么我没有借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.162, 0.134, 0.181, 0.097, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "重新开通花呗的步骤", "pos": ["蚂蚁花呗重新怎么开通", "我请求再次开通蚂蚁花呗"], "neg": ["我花呗分了几期?每期多少钱", "从新开通蚂蚁花呗", "借呗怎么取消身份证号", "我的花呗收款码怎么制作"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.116, 0.129, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期后恢复使用条件", "pos": ["花呗逾期款返清后花呗可以使用吗", "逾期一天还可以用花呗不"], "neg": ["花呗可以扫别人的支付宝二维码吗", "花呗逾期会利滚利算利息么", "怎样退花呗,不用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.115, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期不可用原因", "pos": ["花呗分期暂不可用", "到提交订单的时候显示花呗分期不可用"], "neg": ["怎么才能开通花呗", "我要暂时花呗的可用额可以吗", "如何临时提升蚂蚁花呗额度", "蚂蚁借呗自动降低还会提升吗", "一分钱花呗体验活动", "借呗和网商贷不能共存吗", "如何取消花呗的零时额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.196, 0.141, 0.094, 0.184, 0.052, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通花呗的芝麻信用分要求", "pos": ["要多少芝麻信用分,花呗才可以用", "花呗要多少分可以开通"], "neg": ["为什么我用花呗帮朋友代付不行", "蚂蚁借呗,怎么完善信息", "花呗额度在哪边调", "信用住支持花呗付款吗", "已经有的借呗为什么突然变成网商贷了", "支付宝怎么更改花呗登录", "借呗关闭了怎么才能在开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.193, 0.122, 0.082, 0.104, 0.185, 0.173, 0.065], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗最低还款设置方法", "pos": ["花呗借呗设计最低还款哪里设计", "花呗怎么自动按最低额还款"], "neg": ["花呗扣款是从哪里扣,是支付宝绑定", "为什么我的没有蚂蚁借呗功能", "为什么我的花呗有***分也用不了借呗", "花呗为啥不能充游戏点券", "我还不上花呗", "境外商家申请花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.105, 0.065, 0.085, 0.07, 0.197, 0.103], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "绑定银行卡后花呗无法使用原因", "pos": ["绑定了银行卡为什么还用不了花呗", "为什么绑定银行卡了,还是使用不了花呗支付"], "neg": ["借呗转给我的款,都超过***个小时了,还是没有收到", "网商贷如何降回借呗", "怎么绑定不到花呗的", "蚂蚁借呗不见了,怎么还钱", "花呗 退款", "蚂蚁借呗怎么分***个月", "蚂蚁花呗忘记了以前的手机号账号用户名怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.173, 0.189, 0.121, 0.178, 0.137, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "如何重新授权花呗", "pos": ["怎么样可以授权花呗", "重新授权蚂蚁花呗"], "neg": ["花呗人脉关系授权", "怎么显示不了借呗金额了", "帮我查询一下 花呗什么时候可以用", "花呗还款怎么查询详单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.164, 0.156, 0.106, 0.13], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期后恢复使用步骤", "pos": ["花呗预期,如何可以恢复使用", "我的花呗有逾期,怎样才能恢复正常使用"], "neg": ["如果停用借呗", "借呗可以当天还款,当天借出来,再当天还款吗", "花呗分期后可以第三个月把其他的钱还完吗", "我支付宝帐户是手机号,咋叫我登入花呗时显示支付宝", "蚂蚁花呗还款延期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.184, 0.133, 0.162, 0.145, 0.094], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度更新时间", "pos": ["花呗额度用完什么时候可以更新", "花呗的零时额度,什么时候到期"], "neg": ["花呗分期已提前还清,为什么花呗还不能用", "我凌晨用借呗借的钱怎么到现在还没到账", "蚂蚁花呗多久提一次额度", "怎么能重新开通花呗", "我没有用花呗一分钱为什么让我还款", "花呗未收获之前怎么分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.086, 0.168, 0.123, 0.1, 0.185], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款使用余额宝设置", "pos": ["花呗还款如何选择用余额宝", "从余额宝怎么设置代扣到花呗还款"], "neg": ["借呗的还款日可不可以修改", "借呗自动关闭是什么情况", "电脑端如花使用花呗支付", "借呗能延伸还款时间吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.169, 0.121, 0.076, 0.064], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗支持加油充值吗", "pos": ["花呗是否支持加油充值", "可以用花呗充加油卡"], "neg": ["网商银行跟蚂蚁借呗", "手机上没有支付宝了", "网商袋怎么换借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.08, 0.172, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款方法", "pos": ["花呗有余额,但不知道怎么还款", "我就花呗,怎么还款"], "neg": ["我刚还清蚂蚁借呗,什么时候能再借呗", "开通花呗,不使用会收费用吗", "怎麼关掉借呗", "如何递增花呗金额", "借呗要不按期还,推迟两天还会怎么样"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.103, 0.061, 0.184, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款金额不符解决方法", "pos": ["花呗的还快款金额不对", "还的钱跟花呗对不上"], "neg": ["借呗提前还款后,额度为何不恢复", "为什么借呗申请的时候说账户有风险", "花呗有消费限制"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.147, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "双十一花呗付款支持情况", "pos": ["双十一可以用蚂蚁花呗付款吗", "双十一抢购的东西 可以用花呗付款吗"], "neg": ["借呗是券怎么领", "借呗提示提现额度", "退款后花呗分期暂不可用", "但是我分了期 分了三个月", "花呗在哪 我这没有", "花呗的运费险退回哪", "申请退款后花呗支付金额什么时间到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.077, 0.108, 0.178, 0.112, 0.086, 0.137, 0.2], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度下降原因", "pos": ["花呗额度下降了", "花呗降额,有病"], "neg": ["我要花呗退款", "从花呗还借款", "怎么样才能把花呗的钱转到我账户上", "我想知道我的花呗有多少可用额度", "香港花呗购物返现", "花呗有逾期 怎么办", "借呗***期怎么还款", "我想注销的帐户开通了花呗,另一个开通不了,那我想使用的花呗该怎样开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.138, 0.182, 0.165, 0.098, 0.096, 0.091, 0.104, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗无法使用解决方法", "pos": ["我的花贝怎么用不了", "我的花呗怎么没法使用"], "neg": ["支付宝花呗只能一个号码使用吗", "密付可以扣花呗的吗", "花呗收钱码开通不了了,怎么开通", "花呗***次月还是什么意思"], "pos_scores": [1.0, 0.95], "neg_scores": [0.121, 0.052, 0.179, 0.172], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款后账单问题", "pos": ["花呗退款了为什么在还花呗账单里面还有这个账单", "这单淘宝已显示退款,这里也显示退回花呗,为什么帐单还款里还要再还"], "neg": ["加油可用花呗付款吗", "我借呗还有额度,为什么借不出来了", "可以消除借呗还款记录么", "如何开通余额宝自动还花呗", "花呗怎么返现"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.06, 0.192, 0.138, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "关闭花呗后能否再次开通", "pos": ["主动关闭花呗,是否以后都不能再开启", "花呗关闭后还能再次开通嘛"], "neg": ["我的蚂蚁借呗一年了怎么还没提额", "能把网商贷转为借呗吗", "我的花呗已经花完了,怎么还让我还款", "蚂蚁借呗快到还款日有短信提醒不", "花呗消费限制额度吗", "我想咨询一下我的花呗逾期一两天可以吗", "怎么把花呗什么的设置到我的页面"], "pos_scores": [1.0, 0.95], "neg_scores": [0.161, 0.106, 0.094, 0.191, 0.076, 0.164, 0.162], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗用余额宝还款设置", "pos": ["余额宝设置还花呗", "如何设置,花呗用余额宝还款"], "neg": ["今天借呗没有自动还款", "我的卡怎么绑定花呗绑定不了", "花呗怎么不支持", "花呗分期付款的事", "我花呗逾期了,可以晚两天还么", "花呗退款,问什么还款额没变,余额也没有收到", "花呗退款手续费怎么收取", "不要在页面昱显示花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.162, 0.116, 0.071, 0.14, 0.172, 0.063, 0.136, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗额度冻结原因", "pos": ["花呗也被冻结了", "蚂蚁花呗额度冻结"], "neg": ["怎么取消花呗的授权", "我想花呗自动在余额宝里扣款,怎样操作", "我的花呗双***之前有临时额度", "蚂蚁花呗收款时在哪里上传营业执照", "关闭花呗自动交话费", "查询花呗未还款款余", "借呗金额全部还清可以再借么", "我现在用花呗买电瓶车,额度不够怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.062, 0.19, 0.135, 0.127, 0.143, 0.132, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "收钱码支持花呗付款吗", "pos": ["收钱码付款可以用花呗吗", "收钱码花呗能用吗"], "neg": ["如何用花呗付款", "花呗我为什么又履约了", "我想永久的关掉借呗", "我有蚂蚁借呗,现在给我停了", "全部还请花呗账单", "为什么借呗提前还了没有额度", "提醒花呗额度", "花呗在有些商家用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.124, 0.056, 0.132, 0.074, 0.193, 0.109, 0.087, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "蚂蚁借呗所属公司", "pos": ["蚂蚁借呗是哪个公司的产品", "蚂蚁借呗贷款公司是谁"], "neg": ["花呗晚一天还款怎么", "花呗收钱码哪里查看", "在蚂蚁借呗里我怎么没有信誉额度", "如何取消花呗代扣电话费", "我的商品退货了,淘宝显示款项已退回,为什么我的花呗还要还钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.135, 0.103, 0.073, 0.067], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗借款期限选项", "pos": ["借呗只有***个月和***个月", "蚂蚁借呗***个月"], "neg": ["花呗怎么冲qb", "但是我银行卡支付的为什么退回花呗", "开通花呗,如果不用会产生费用吗", "我在花呗里面充值话费到现在没到账", "退货了,可是还有***元运费没有退到花呗怎么处理"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.123, 0.165, 0.159, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账户转移方法", "pos": ["另一个账户开通了花呗,怎么转到这个账户", "我之前的手机号账户花呗已经不记得了,怎么转到现在这个手机号账户上"], "neg": ["积分可换花呗额度吗", "我想把这个花呗给注销了", "花呗的账期", "蚂蚁借呗怎么不能借一个月", "借呗超过***笔然后什么时候可以恢复", "***花呗,用了***,已经还了,现在怎么还是***", "是不是借呗和网商贷两个只能开一个", "花呗不是一个手机号码怎么改"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.131, 0.075, 0.159, 0.181, 0.133, 0.189, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗被取消原因", "pos": ["为什么我蚂蚁借呗给我取消啦", "我又没有逾期过,怎么取消我的蚂蚁借贝"], "neg": ["花呗关闭还能再使用吗", "借呗申请提高额度,但是审核进度不能查询", "蚂蚁花呗能绑定到我的另一个支付宝账号吗", "借呗还款日期多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.168, 0.2, 0.077, 0.06], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗游戏充值限制", "pos": ["现在游戏内不可以用花呗充值点卷吗", "花呗不能用于游戏充值吗"], "neg": ["怎么才能支持花呗充话费", "我老公有营业执照,我可以开通花呗收款吗", "我找不到我的花呗是哪个号码了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.183, 0.088, 0.06], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗分期利息计算", "pos": ["花呗分期还款怎么算利息", "花呗分期付款 利息多少"], "neg": ["花呗分期客服", "借呗要求信用额度是什么", "我想知道为什么花呗分期暂不可用", "花呗怎么计息的", "借呗还款 银行卡只能转入一万", "明天是还款日,可是今天我的钱转不进余额宝?明天怎么还蚂蚁借呗的钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.088, 0.115, 0.132, 0.057, 0.105, 0.178], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "取消花呗分期方法", "pos": ["取消花呗分期吗", "我的花呗不小心点到分期了可以帮我取消分期了吗"], "neg": ["蚂蚁借呗提前还款可以得到蚂蚁庄园的饲料", "找不到花呗在哪里", "花呗立减十元", "ofo共享单车为什么不能用花呗", "花呗可用额度多少才可以分期", "为什么花呗自动还款失败,什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.068, 0.054, 0.158, 0.147, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款方式", "pos": ["花呗怎样还款", "花呗支付是怎么还款的"], "neg": ["为什么申请蚂蚁借呗要刷脸", "借呗借钱需要绑定银行卡嘛", "花呗是自动扣款是从哪里扣款的", "我的蚂蚁借呗为什么会被取消", "花呗还了最低还款,还需要分期吗", "借呗还钱在余额宝里吗扣", "这个花呗还款日可以调一下吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.123, 0.175, 0.195, 0.135, 0.167, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗分期还款选项", "pos": ["借呗分期还款可以", "我的借呗借到钱了,可以分期还不"], "neg": ["我的花呗可以提升吗", "我用花呗付的款,要什么时候还", "我的花呗还款日期可以改一下。推迟几天吗", "借呗 借***个月 什么时候开始还", "怎么我的支付宝没有花贝"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.087, 0.1, 0.128, 0.084], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "芝麻信用解除授权对花呗影响", "pos": ["我的芝麻信用解除了授权,对芝麻分和花呗有影响吗", "花呗和芝麻信用"], "neg": ["我的蚂蚁借呗咋不能用了", "我的花呗怎么别另一个手机号绑定了", "花呗分期的能提前还款吗", "花呗买飞机票"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.092, 0.174, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗额度调整选项消失", "pos": ["借呗右上角找不到额度调整", "借呗的额度调控不见了"], "neg": ["花呗分期可以提前还清吗", "为什么我的花呗***月是未出账,***月为什么还要还款", "为什么借呗不能超过***笔", "花呗分期签约", "花呗如何取消自动缴费", "蚂蚁借呗利息这么算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.189, 0.09, 0.186, 0.104, 0.106, 0.07], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗开通时间查询", "pos": ["什么时候能使用花呗", "我的花呗什么时候可以开阿"], "neg": ["借呗已还清,还停我花呗", "网商贷换借呗", "没有app 花呗怎么还款", "蚂蚁借呗 借***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.189, 0.099, 0.124], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗短期借款方法", "pos": ["蚂蚁借呗如何短期借款", "蚂蚁借呗怎么选择短期的,比如借半月或者一个月"], "neg": ["就 我之前有个支付宝开了花呗但是我号不用了", "蚂蚁借呗用不了怎么办", "花呗会寄账单过来吗", "怎么解除花呗贷"], "pos_scores": [1.0, 0.95], "neg_scores": [0.051, 0.182, 0.168, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "退款后花呗未恢复原因", "pos": ["我在小米商城买了东西,已经退款了,花呗的的钱怎么还没回来", "已经退款了 花呗怎么还欠款"], "neg": ["支持花呗付款我为什么付不了", "我借呗分两次借的账单怎么分开查", "借呗的借款日和还款日相差多久", "花呗多少信用度可以开通", "花呗返还卷", "我的花呗为什不能用了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.142, 0.063, 0.09, 0.126, 0.074, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款未更新问题", "pos": ["我银行卡已经被花呗扣款为什么打开支付宝花呗还显示没还款", "我的花呗还款银行卡钱已经扣了 花呗还是现实没有还"], "neg": ["我的花呗是否是开通状态", "花呗的额度包括手续费吗", "借呗还钱的日期是什么时候"], "pos_scores": [1.0, 0.95], "neg_scores": [0.05, 0.186, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "支付宝没有花呗选项", "pos": ["点开右下角我的,没有花呗", "我手机上没有花呗"], "neg": ["我还了借呗的一部分钱,为什么额度没恢复", "蚂蚁花呗最低还款会降低信用", "我为什么花呗不能还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.161, 0.153, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款操作步骤", "pos": ["怎么可以把蚂蚁花呗还了,我现在换不了呀", "我刚才点错了,我要还款给花呗,怎么还"], "neg": ["蚂蚁借呗如何取消最低还款", "花呗支付成功但订单显示未支付", "我怎样才能使用借呗贷款", "最近花呗无法提前还款", "你看我开通花呗了吗", "蚂蚁信用***分借呗额度多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.081, 0.146, 0.087, 0.121, 0.099], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗逾期还款后果", "pos": ["花呗没有按期偿还", "花呗没有正常还款"], "neg": ["我现在是开通了花呗吗", "花呗分期可以提前还吗?还需要手续费吗", "为什么花呗我没花那么多钱,花呗让我还那么多那", "蚂蚁借呗可以在***个月后一次还清吗", "花呗付的话便宜吗", "如果将网商贷改为借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.182, 0.14, 0.145, 0.127, 0.158, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗免分期券领取方法", "pos": ["花呗免分期券在哪可以领", "花呗***期分期券哪里领"], "neg": ["之前我看了可以用花呗买机票", "下线门店怎么设置花呗交易额度", "还差多少才能开借呗", "花呗买电影票退票", "为什么商品支持花呗但是不能用花呗付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.137, 0.068, 0.141, 0.07], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款验证码发送到旧手机号", "pos": ["花呗还款验证码还发在旧手机号上怎么办", "修改玩手机号为什么用花呗付款还是发短信到旧手机号码"], "neg": ["我想知道花呗提前还,怎么额度还降低了", "怎么开通商家花呗", "花呗吗。为什么提前不能还 还是等到***月份还", "我想提高花呗额度买东西", "我花呗账号忘记啦"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.153, 0.079, 0.075, 0.144], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还清后无法使用原因", "pos": ["为什么还钱完了花呗却用不了", "所有账单已还清,花呗为什么还不能用"], "neg": ["花呗关闭了可以再开启吗", "支付宝里面查不到花呗", "花呗怎么样分期", "我的花呗双***之前有临时额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.096, 0.165, 0.158], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗自动还款设置", "pos": ["花呗支付后。还款日自动扣款吗", "花呗还l款会自动扣除还款钱吗"], "neg": ["我要开通花呗 买手机", "花呗开通需要多少芝麻信用", "蚂蚁借呗分期还款可不可以提前还", "花呗最多可以透支多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.151, 0.181, 0.15, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗购买火车票支持情况", "pos": ["花呗支持飞天猪买火车票吗", "支付宝上买火车票支持花呗吗"], "neg": ["领取借呗提升额度为什么来输入密码", "花呗咋样才能还款", "我的花呗该还款了,但是我不会还,应该怎么还款", "花呗一般什么时候评估额度", "我想降低借呗额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.069, 0.081, 0.178, 0.066, 0.128], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "收钱码接收花呗付款", "pos": ["收钱码收不收对方的花呗", "收钱码可以收花呗点钱吗"], "neg": ["怎么样才可以换掉之前的花呗手机号", "展开 花呗付款的时候验证码怎么发到旧手机号了?这怎么改", "我现在还不了分期花呗吗", "花呗还清 额度不对", "花呗客服多少", "花呗未还清还可以用借呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.189, 0.198, 0.135, 0.077, 0.08, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗账单金额不符", "pos": ["花呗的数额与实际不符", "我花呗与还款不符合"], "neg": ["我现在想分期买苹果***花呗不够怎么办", "对多花呗可以消费多少", "用余额宝自动向花贝借呗还款要收费吗", "为什么我的支付宝不显示借呗", "我现在想在蚂蚁借呗上借钱。但是我的电话收不到短信的验证码", "蚂蚁借呗怎么看还款记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.121, 0.084, 0.112, 0.086, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗欠款还款方法", "pos": ["怎样还花呗欠款", "花呗还钱?怎么换"], "neg": ["我刚买的东西还没收到货就要显示交易成功要还花呗", "共享单车的押金为什么不能用花呗付了", "花呗这个月用下个月还", "我要还蚂蚁花呗的款,怎么办,我弄不好怎么办", "花呗未还款可以申请借呗嘛", "花呗只能在淘宝购物了,其它地方不能用是怎么回事", "我想成为商家,花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.185, 0.062, 0.163, 0.061, 0.097, 0.106, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗被关闭原因", "pos": ["借呗正在用怎么给我关闭了", "信用记录一直挺好 为什么现在借呗给我取消了"], "neg": ["使用花呗支付花呗红包没有自动扣除", "花呗可用于付帐么", "蚂蚁借呗最高额多是多少", "开通花呗刷脸一直不通过", "花呗和信用卡收款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.182, 0.112, 0.16, 0.156, 0.176], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗提前还款操作", "pos": ["花呗申请提前还款", "我花呗要提前还款"], "neg": ["退款***快多,为什么我的花呗没有显示", "余额宝可以直接扣蚂蚁借呗吗", "本身积分发放的功能是受到花呗限制影响的这个无法恢复如果您是对花呗的限制有疑问可以帮您zhuanj", "借呗还款日是每个月的几号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.058, 0.163, 0.146], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "取消花呗方法", "pos": ["是我点错了,怎么取消花呗", "花呗怎么取消,我取消嘞"], "neg": ["摇一摇花呗临时额度", "怎么兑换花呗免息券", "过了***点怎么借呗还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.18, 0.099, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款未到账问题", "pos": ["账单已还清,但是退款到花呗的钱没收到", "退款退回花呗,可是花呗并没有收到那退款的钱"], "neg": ["借呗一千一个月还多少钱", "我的花呗返现", "借呗当天还最迟几点", "花呗可以定餐吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.168, 0.133, 0.055, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗多扣款解决方法", "pos": ["花呗多收了我的钱", "花呗为什么多让我还了***"], "neg": ["多少芝麻信用开通花呗", "借呗还款日未自动还款", "淘宝买手机朋友代付可以用花呗吗", "我已经还了花呗,那*** 退额度是怎么算的", "花呗下月还款日期", "蚂蚁借呗多少分才能借", "花呗免息优惠券在哪里", "怎么取消朋友带还花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.145, 0.165, 0.195, 0.157, 0.148, 0.144, 0.05, 0.07], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "蚂蚁借呗入口位置", "pos": ["借呗点哪里", "从哪找蚂蚁借呗"], "neg": ["我房贷需要借呗的结清证明怎么弄", "网商银行就是借呗吗", "可以还下下个月的花呗吗", "代收可以使用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.088, 0.19, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "开通借呗方法", "pos": ["我要怎么才能使用借呗", "怎么能用借呗"], "neg": ["我的花呗账号怎么跟我登陆的账号不一样", "花呗逾期利息和罚款怎么算", "为什么我游戏充值时无法使用花呗支付", "借呗分期能分几个月", "花呗分期还有多少期", "你妈的逼我为什么没有花呗", "花呗使用临时额度后,还款会优先还临时额度吗", "飞猪买机票为什么不能用花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.187, 0.187, 0.191, 0.175, 0.174, 0.18, 0.164, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗账单消失原因", "pos": ["为什么蚂蚁借呗看不到账单了,变成重新申请了", "我的借呗账单看不到了"], "neg": ["在花呗哪一个里面", "为什么打开花呗,没有分期还款", "我要怎样才能用蚂蚁借呗", "我之前可以用借呗现在为什么又不能用了", "我有花呗为什么查账,都显沒有邦定银行卡,支付宝 上的银行卡不就是的吗", "我支付宝是现在的手机号,怎么花呗还是原来的手机号?怎么修改", "我想看一下我的花呗之前的分期还有多少期", "为什么蚂蚁借呗是暂无额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.052, 0.071, 0.195, 0.156, 0.133, 0.187, 0.158, 0.089], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "商户开通花呗收款步骤", "pos": ["怎么开通商户的花呗收钱", "商户如何开通花呗支付"], "neg": ["手机花呗 如何分期", "花呗临时额度下个月必须还清吗", "还花呗从哪扣款", "为什么借呗不能借钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.137, 0.154, 0.195], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款手机号更换方法", "pos": ["我的花呗,我想还钱,但是预留手机号换了,怎么办", "我的花呗还款手机号不用怎么改"], "neg": ["借呗活动网址", "使用花呗和来分期可以增加芝麻分", "如何减低借呗额度", "蚂蚁借呗首月免息"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.064, 0.088, 0.058], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗放款时间", "pos": ["申请借呗后多久放款", "借呗借后多久到账"], "neg": ["借呗冻结存在那些风险", "为什么这么久还没帮我提", "要用多少信用额度才能开通花呗", "我***月花呗还***,这部分钱怎么算的", "我不用花呗行关掉", "花呗怎么买qb"], "pos_scores": [1.0, 0.95], "neg_scores": [0.067, 0.097, 0.111, 0.151, 0.073, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗登录淘宝方法", "pos": ["花呗怎么登录淘宝", "淘宝的蚂蚁花呗"], "neg": ["我用着来分期 可以用蚂蚁借呗吗", "借呗还款日在", "商铺可以用花呗付,但我拍下后花呗不能用", "为什么我已经完成大学生认证但花呗分三期还是要手续费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.175, 0.069, 0.121, 0.066], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗身份信息完善问题", "pos": ["借呗完善不了信息", "蚂蚁借呗不能完善身份信息"], "neg": ["怎样能够恢复蚂蚁借呗的功能", "借呗入口不显示额度", "使用购物津贴,还能用花呗吗", "花呗支付退款如何处理"], "pos_scores": [1.0, 0.95], "neg_scores": [0.187, 0.117, 0.069, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗购买手机物流查询", "pos": ["花呗上买的手机去哪里查有没有发货", "花呗上买的手机怎么查快递"], "neg": ["什么是花呗冻结额度", "用了购物津贴蚂蚁花呗用不了", "我的花呗总了消费了多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.067, 0.137], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗退款后仍需还款原因", "pos": ["为什么退款回花呗了,还显示要我还款", "花呗已经提示退款成功,可为什么还是显示帐单"], "neg": ["花呗晚两天还款,会对信用有影响吗", "花呗上个月可以交电费,这个月为什么不行", "蚂蚁借呗最多一次还多少", "借呗最多可分多少期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.083, 0.14, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款日期查询", "pos": ["花呗使用什么时候还款", "我要是今天用花呗啥时候还款"], "neg": ["蚂蚁借呗还款提醒设置", "不想用花呗付款,又怎么换付款方式", "我能帮助朋友还花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.059, 0.099, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗无法借款原因", "pos": ["现在不能借呗了吗", "借呗下不了单"], "neg": ["花呗的余付款怎么能还上", "信用额度怎样才能蚂蚁借呗的条件", "蚂蚁借呗入口怎么找不到", "花呗可以在京东上购物吗", "我是商家,如何关闭客户用花呗付款", "怎样主动申请蚂蚁借呗提额", "怎么挂失花呗账号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.151, 0.121, 0.087, 0.198, 0.167, 0.062, 0.129], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗和花呗区别", "pos": ["蚂蚁借呗,和花呗不样吗", "借呗花呗有冲突吗"], "neg": ["蚂蚁借呗可用花呗还吗", "我已经还完花呗的款了,为什么还显示未还款", "花呗你查不到次笔交易", "花呗被冻结咋办", "花呗也还,为什么额度不更新", "今天是借呗还款日,可以超过中午***点"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.05, 0.109, 0.146, 0.18, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收款码设置方法", "pos": ["怎么样能花呗收款", "怎么设置花呗收钱码"], "neg": ["花呗还款日是***号,***号当天还是逾期吗", "借呗分期还款账单", "取消蚂蚁借呗提现", "给我恢复借呗", "借呗跟花呗的联系", "花呗还款为什么额外扣了手续费", "有没有直接转账给花呗的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.077, 0.186, 0.136, 0.157, 0.079, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款困难解决方法", "pos": ["花呗还不了整么办", "花呗还不到这么多怎么办"], "neg": ["更改电话号码之后,借呗显示没有额度,是不是关闭了", "不是可以花呗分期吗", "我的花呗为什么不能用了 什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.085, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗有额度但无法使用", "pos": ["我什么我不能用花呗", "为什么我的花呗不能用,显示我的花呗还有***可用余额"], "neg": ["我想看看我***月的借呗还款记录", "借呗活动页面", "我的花呗之前退回了***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.068, 0.154, 0.104], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗还款记录查询", "pos": ["如果查询借呗还款", "看蚂蚁借呗还款记录"], "neg": ["借呗没有临时额度吗", "花呗临时额度查得到吗", "蚂蚁会员和蚂蚁借呗的区别", "我花呗可以先提前还一部分吗", "花呗还款扣了银行卡的钱但显示没有这笔账单", "退款成功怎么花呗没入帐"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.115, 0.17, 0.095, 0.167, 0.055], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "重新开通花呗方法", "pos": ["之前把花呗关闭了 现在怎么开通", "不小心关闭花呗啦,怎么开通"], "neg": ["花呗额度***千元为什么不能用分期付", "花呗额度分期是要商品的全额吗", "我说怎么删除花呗消费记录", "去饭店吃饭可以用花呗那吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.091, 0.087, 0.156], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗人脸识别失败解决方法", "pos": ["借呗无法人脸识别", "借呗密码输了,人脸识别没用怎么办"], "neg": ["蚂蚁花呗不到下个月为什么不能提前还款", "借呗关了还能开么", "借呗还款日期不是固定每月四号吗", "借呗的钱没有到银行卡怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.14, 0.129, 0.109, 0.061], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "借呗分期期限选项", "pos": ["借呗!可以分***个月还吗", "蚂蚁借呗不是可以分***期么"], "neg": ["分期商品占用花呗额度吗", "借呗还款日利息", "我的花呗这个月的还款退货退款了,为什么还显示让我还款", "花呗已还款,银行卡里的钱都扣了,为什么还显示没还", "花呗里面怎么一次还请", "借呗停用了,什么时候能恢复", "能帮我关闭花呗吗", "花呗被关闭了怎么开启"], "pos_scores": [1.0, 0.95], "neg_scores": [0.071, 0.166, 0.185, 0.186, 0.057, 0.062, 0.191, 0.116], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗登录问题解决方法", "pos": ["花呗显示可登陆支付宝账号使用花呗,为什么我登了还是用不了", "登陆支付宝账号都对,可是花呗登不上"], "neg": ["不开通花呗看不到可用额度", "蚂蚁借呗还款时间没***个月了", "退款回复花呗额度怎么做", "解除花呗关联", "怎么主动申请蚂蚁借呗的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.178, 0.165, 0.134, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗开通方法", "pos": ["如何开通“花呗”服务", "花呗哪里开通"], "neg": ["花呗是哪个银行给的额度", "花呗授权什么用", "花呗分期以后 提前还款", "透支花呗额度", "本月还款借呗还能开通吗", "为何蚂蚁借呗还完后没有信用额度", "花呗最低还款的利息咋算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.091, 0.064, 0.071, 0.194, 0.149, 0.05, 0.153], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗还款日修改方法", "pos": ["花呗还款日能不能改到***号", "花呗还款日期可以改到每月***号吗"], "neg": ["另一个账号用花呗,这个账号可以还吗", "我的花呗额度原来是***的", "还能重新开通花呗吗", "花呗还款用银行卡号支付密码在哪设置", "收钱码花呗怎么不能收大额的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.092, 0.151, 0.051, 0.168], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗ETC缴费支持", "pos": ["花呗是否可以开道etc高速免费通道", "用花呗可以走高速etc缴费吗"], "neg": ["我刚用了借呗", "怎么样能提升蚂蚁借呗额度", "如何取消蚂蚁借呗这一项,不是关闭,是完全的取消"], "pos_scores": [1.0, 0.95], "neg_scores": [0.182, 0.1, 0.167], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "更换手机号后花呗无法使用", "pos": ["手机号码更换,无法使用花呗", "换了手机号花呗用不了"], "neg": ["换手机号码 对花呗有没有影响", "我关闭了花呗,想用另一个号开通怎么办", "花呗还最低还款额会影响信誉吗", "我要查询是哪次花呗产生的罚款", "我刚刚已经还了花呗了可是怎么还是显示没还,银行卡也扣钱了", "花呗已经还款了,怎么还有未出账", "我的花呗开通没有成功怎么办", "蚂蚁借呗之前能用为什么现在说没额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.123, 0.146, 0.142, 0.162, 0.123, 0.082, 0.109, 0.185], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗临时额度获取方法", "pos": ["花呗的临时额度我怎么没有", "临时花呗没有"], "neg": ["为什么不能用花呗支付hello单车押金", "花呗逾期用不了 多久才能在用", "当初用蚂蚁花呗付的款,现在花呗已经还完了,要退货,钱退在哪里", "怎么查花呗临时额度有效期", "借呗还款跟借款记录怎么删掉", "蚂蚁借呗最快到账", "借呗逾期了自动扣款", "花呗可以生活交费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.142, 0.177, 0.081, 0.121, 0.134, 0.191, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗购买机票支持情况", "pos": ["我上个月用蚂蚁花呗买了机票", "花呗支持买机票吗"], "neg": ["借呗钱以还清?额度什么时候恢复", "花呗冻结了,怎样才可以开", "花呗号码和支付宝号码不一致怎么可以统一", "花呗支付了怎么没到账", "怎么取消消费优先选花呗", "什么时候才能再用花呗", "借呗额度自行消失怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.19, 0.17, 0.115, 0.087, 0.075, 0.094, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗最低还款影响信用吗", "pos": ["最低还花呗会影响信用度吗", "花呗最低还款后会不会影响信誉"], "neg": ["花呗付款成功却显示没有付款", "蚂蚁借呗没有去还款", "我得借呗怎么给停了", "借呗到还款日为什么不自动还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.074, 0.183, 0.089, 0.139], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "花呗收款码办理方法", "pos": ["怎么样来办理花呗收款码", "怎么蚂蚁花呗收款"], "neg": ["蚂蚁借呗可以按天借款吗", "我想退款说什么双***红包如果被使用就扣花呗什么意思", "可以花呗付款 怎么显示不支持花呗付款", "花呗能支付宝购物吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.142, 0.084, 0.087, 0.186], "prompt": "为此查询生成表示:", "type": "normal"}
{"query": "商家花呗收款手续费", "pos": ["花呗 像商家收取手续费么", "我是商家,用花呗收款,需要收费吗"], "neg": ["借呗提前还款,本金的利息会相应减少吗", "主动申请蚂蚁借呗后又占无使用资格", "借呗不能一年还款么", "怎么花呗也扣了,银行卡里也出了的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.065, 0.103, 0.179], "prompt": "为此查询生成表示:", "type": "normal"}

View File

@ -0,0 +1,5 @@
{"query":"什么是深度学习?","pos":["深度学习是机器学习的一个子领域,使用多层神经网络模拟人脑处理信息的方式","深度学习Deep Learning是人工智能和机器学习的重要分支通过构建深层神经网络来学习数据表示"],"neg":["机器学习包含多种算法,如决策树、支持向量机、神经网络等","人工智能是计算机科学的一个分支,目标是创建能够执行智能任务的系统","数据挖掘是从大型数据集中提取模式和知识的过程"],"pos_scores":[1.0,0.95],"neg_scores":[0.6,0.4,0.3],"prompt":"为此查询生成表示:","type":"normal"}
{"query":"如何制作红烧肉?","pos":["红烧肉制作步骤选五花肉切块焯水去腥热锅炒糖色下肉翻炒加生抽老抽料酒小火炖煮30-40分钟至软烂","红烧肉是经典上海菜,用五花肉、冰糖、生抽、老抽、料酒等,关键在于炒糖色和火候控制"],"neg":["红烧肉是中国传统菜肴,口感香甜软糯,深受大众喜爱","五花肉含有丰富的蛋白质和脂肪,营养价值较高","上海菜以红烧见长,口味偏甜,适合南方人饮食习惯"],"pos_scores":[1.0,0.9],"neg_scores":[0.5,0.3,0.2],"prompt":"为此查询生成表示:","type":"recipe"}
{"query":"Python编程入门应该学习哪些内容","pos":["Python编程入门需要掌握基本语法、数据类型(字符串、列表、字典)、控制结构(if/for/while)、函数定义、面向对象编程基础","Python新手学习路线先学变量和数据类型然后是条件语句和循环接着学习函数和模块最后是类和对象"],"neg":["Python是一种简单易学的编程语言语法清晰适合初学者","编程思维比语法更重要,需要培养逻辑思维能力","计算机科学包含算法、数据结构、操作系统等多个分支"],"pos_scores":[1.0,0.95],"neg_scores":[0.4,0.3,0.2],"prompt":"为此查询生成表示:","type":"programming"}
{"query":"北京有哪些必去的旅游景点?","pos":["北京必去景点推荐:故宫(紫禁城)、天安门广场、万里长城、颐和园、天坛、圆明园、北海公园、什刹海等","故宫是明清两代皇宫占地72万平方米是世界最大的古代宫殿建筑群必须提前预约参观"],"neg":["北京是中华人民共和国首都有3000多年建城史","中国拥有众多世界文化遗产,旅游资源丰富","春秋季节是北京旅游的最佳时间,气候宜人"],"pos_scores":[1.0,0.9],"neg_scores":[0.3,0.25,0.4],"prompt":"为此查询生成表示:","type":"travel"}
{"query":"机器学习和深度学习有什么区别?","pos":["机器学习是更广义的概念,深度学习是机器学习的一个子集,主要区别在于深度学习使用深层神经网络进行特征自动提取","深度学习使用多层神经网络自动学习特征,而传统机器学习通常需要人工设计特征,深度学习在图像和语音识别方面表现更优"],"neg":["机器学习算法包括监督学习、无监督学习和强化学习三大类","人工智能技术发展迅速,在各行各业都有广泛应用","神经网络是深度学习的基础,模拟人脑神经元连接"],"pos_scores":[1.0,0.95],"neg_scores":[0.4,0.3,0.5],"prompt":"为此查询生成表示:","type":"tech_comparison"}

View File

@ -0,0 +1,580 @@
#!/usr/bin/env python3
"""
Comprehensive usage examples for the refactored BGE data pipeline
Demonstrates all four supported formats: triplets, pairs, candidates, and legacy nested
Data Formats Supported:
1. Triplets: Query-triplets format (query, pos, neg) for BGE-M3 embedding training
2. Pairs: Flat pairs format (query, passage, label) for BGE-reranker training
3. Candidates: Unified format (query, candidates) with threshold-based label assignment
4. Legacy Nested: Legacy nested format with automatic conversion to flat pairs
"""
import os
import sys
import json
import tempfile
from transformers import AutoTokenizer
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import BGEDataset, BGEM3Dataset, BGERerankerDataset
from data.optimization import optimize_data
def create_sample_data():
"""Create sample data files for all four formats"""
# Create temporary directory
temp_dir = tempfile.mkdtemp(prefix="bge_examples_")
print(f"📁 Creating sample data in: {temp_dir}")
# 1. Triplets format (for BGE-M3 embedding training)
triplets_file = os.path.join(temp_dir, "triplets_data.jsonl")
triplets_data = [
{
"query": "What is machine learning?",
"pos": [
"Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
"ML algorithms use statistical techniques to learn patterns from data and make predictions."
],
"neg": [
"Weather forecasting uses meteorological data to predict atmospheric conditions.",
"Cooking recipes require specific ingredients and step-by-step instructions."
],
"pos_scores": [0.95, 0.88],
"neg_scores": [0.15, 0.08],
"prompt": "为此查询生成表示:",
"type": "definition"
},
{
"query": "How does deep learning work?",
"pos": [
"Deep learning uses neural networks with multiple layers to learn complex patterns.",
"Deep neural networks automatically extract features from raw data through backpropagation."
],
"neg": [
"Baking bread requires yeast, flour, and proper kneading techniques.",
"Solar panels convert sunlight into electrical energy through photovoltaic cells."
],
"pos_scores": [0.92, 0.89],
"neg_scores": [0.12, 0.18]
}
]
with open(triplets_file, 'w', encoding='utf-8') as f:
for item in triplets_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 2. Pairs format (for BGE-reranker training)
pairs_file = os.path.join(temp_dir, "pairs_data.jsonl")
pairs_data = [
{
"query": "What is machine learning?",
"passage": "Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
"label": 1,
"score": 0.95,
"qid": "q1",
"pid": "p1"
},
{
"query": "What is machine learning?",
"passage": "Weather forecasting uses meteorological data to predict atmospheric conditions.",
"label": 0,
"score": 0.15,
"qid": "q1",
"pid": "p2"
},
{
"query": "How does deep learning work?",
"passage": "Deep learning uses neural networks with multiple layers to learn complex patterns.",
"label": 1,
"score": 0.92,
"qid": "q2",
"pid": "p3"
},
{
"query": "How does deep learning work?",
"passage": "Baking bread requires yeast, flour, and proper kneading techniques.",
"label": 0,
"score": 0.12,
"qid": "q2",
"pid": "p4"
}
]
with open(pairs_file, 'w', encoding='utf-8') as f:
for item in pairs_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 3. Candidates format (unified format with threshold-based labels)
candidates_file = os.path.join(temp_dir, "candidates_data.jsonl")
candidates_data = [
{
"query": "What is artificial intelligence?",
"qid": "q3",
"candidates": [
{
"text": "Artificial intelligence is the simulation of human intelligence in machines.",
"score": 0.94,
"pid": "c1"
},
{
"text": "AI systems can perform tasks that typically require human intelligence.",
"score": 0.87,
"pid": "c2"
},
{
"text": "Mountain climbing requires proper equipment and training.",
"score": 0.23,
"pid": "c3"
},
{
"text": "Chess is a strategic board game played between two players.",
"score": 0.31,
"pid": "c4"
}
]
},
{
"query": "Explain neural networks",
"qid": "q4",
"candidates": [
{
"text": "Neural networks are computing systems inspired by biological neural networks.",
"score": 0.91,
"pid": "c5"
},
{
"text": "They consist of interconnected nodes that process information in layers.",
"score": 0.85,
"pid": "c6"
},
{
"text": "Gardening involves planting seeds and watering plants regularly.",
"score": 0.19,
"pid": "c7"
}
]
}
]
with open(candidates_file, 'w', encoding='utf-8') as f:
for item in candidates_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 4. Legacy nested format (for backward compatibility)
legacy_file = os.path.join(temp_dir, "legacy_nested_data.jsonl")
legacy_data = [
{
"query": "What is natural language processing?",
"pos": [
"Natural language processing is a field of AI that helps computers understand human language.",
"NLP combines computational linguistics with machine learning to process text and speech."
],
"neg": [
"Automobile manufacturing involves assembly lines and quality control processes.",
"Photography captures light through camera lenses to create images."
],
"pos_scores": [0.93, 0.86],
"neg_scores": [0.14, 0.22]
},
{
"query": "How do transformers work in NLP?",
"pos": [
"Transformers use self-attention mechanisms to process sequential data in parallel.",
"The transformer architecture revolutionized NLP with attention-based models like BERT and GPT."
],
"neg": [
"Electric vehicles use batteries to power electric motors for transportation.",
"Coffee brewing requires proper water temperature and grinding consistency."
],
"pos_scores": [0.89, 0.91],
"neg_scores": [0.16, 0.13]
}
]
with open(legacy_file, 'w', encoding='utf-8') as f:
for item in legacy_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
return {
'triplets': triplets_file,
'pairs': pairs_file,
'candidates': candidates_file,
'legacy_nested': legacy_file,
'temp_dir': temp_dir
}
def demonstrate_fine_tuning_loader():
"""Demonstrate fine-tuning data loader with all formats"""
print("\n" + "="*80)
print("🎯 FINE-TUNING DATA LOADER EXAMPLES")
print("="*80)
# Create sample data
data_files = create_sample_data()
# Initialize tokenizer (use a simple one for demo)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# 1. Triplets format with BGE-M3 Dataset
print("\n📊 1. TRIPLETS FORMAT (BGE-M3 Embedding Training)")
print("-" * 60)
try:
triplets_dataset = BGEM3Dataset(
data_path=data_files['triplets'],
tokenizer=tokenizer,
query_max_length=64,
passage_max_length=256,
is_train=True
)
print(f"✅ Loaded {len(triplets_dataset)} triplet examples")
print(f" Detected format: {triplets_dataset.detected_format}")
# Get a sample
sample = triplets_dataset[0]
print(f" Sample keys: {list(sample.keys())}")
print(f" Query shape: {sample['query_input_ids'].shape}")
print(f" Passages shape: {sample['passage_input_ids'].shape}")
print(f" Labels: {sample['labels']}")
except Exception as e:
print(f"❌ Error: {e}")
# 2. Pairs format with BGE-Reranker Dataset
print("\n📊 2. PAIRS FORMAT (BGE-Reranker Training)")
print("-" * 60)
try:
pairs_dataset = BGERerankerDataset(
data_path=data_files['pairs'],
tokenizer=tokenizer,
query_max_length=64,
passage_max_length=256,
is_train=True,
legacy_support=False
)
print(f"✅ Loaded {len(pairs_dataset)} pair examples")
print(f" Detected format: {pairs_dataset.detected_format}")
# Get a sample
sample = pairs_dataset[0]
print(f" Sample keys: {list(sample.keys())}")
print(f" Input shape: {sample['input_ids'].shape}")
print(f" Label: {sample['labels']}")
print(f" Query: {sample['query_text'][:50]}...")
except Exception as e:
print(f"❌ Error: {e}")
# 3. Candidates format with automatic conversion
print("\n📊 3. CANDIDATES FORMAT (Threshold-based Conversion)")
print("-" * 60)
try:
candidates_dataset = BGEDataset(
data_path=data_files['candidates'],
tokenizer=tokenizer,
query_max_length=64,
passage_max_length=256,
is_train=True,
format_type="candidates",
candidates_threshold=0.5 # Scores >= 0.5 become label=1
)
print(f"✅ Loaded {len(candidates_dataset)} exploded examples")
print(f" Detected format: {candidates_dataset.detected_format}")
print(f" Note: Each query with multiple candidates exploded into separate pairs")
# Get a sample
sample = candidates_dataset[0]
print(f" Sample keys: {list(sample.keys())}")
print(f" Format after processing: {sample['format']}")
print(f" Source: {sample.get('source', 'N/A')}")
except Exception as e:
print(f"❌ Error: {e}")
# 4. Legacy nested format with automatic conversion
print("\n📊 4. LEGACY NESTED FORMAT (Automatic Conversion)")
print("-" * 60)
try:
legacy_dataset = BGERerankerDataset(
data_path=data_files['legacy_nested'],
tokenizer=tokenizer,
query_max_length=64,
passage_max_length=256,
is_train=True,
legacy_support=True
)
print(f"✅ Loaded {len(legacy_dataset)} converted examples")
print(f" Detected format: {legacy_dataset.detected_format}")
print(f" Note: Legacy nested format automatically converted to flat pairs")
# Get a sample
sample = legacy_dataset[0]
print(f" Sample keys: {list(sample.keys())}")
print(f" Format after processing: {sample['format']}")
print(f" Source: {sample.get('source', 'N/A')}")
except Exception as e:
print(f"❌ Error: {e}")
return data_files
def demonstrate_optimization_pipeline(data_files):
"""Demonstrate optimization pipeline with all formats"""
print("\n" + "="*80)
print("⚡ OPTIMIZATION PIPELINE EXAMPLES")
print("="*80)
cache_dir = os.path.join(data_files['temp_dir'], 'cache')
os.makedirs(cache_dir, exist_ok=True)
# 1. BGE-M3 optimization with triplets format
print("\n🚀 1. BGE-M3 OPTIMIZATION (Triplets Format)")
print("-" * 60)
try:
cached_files = optimize_data(
model_type='bge-m3',
input_files=[data_files['triplets']],
output_dir=cache_dir,
tokenizer_path='bert-base-uncased',
force=True,
hard_negative_ratio=0.7,
max_triplets_per_query=8,
difficulty_threshold=0.1
)
print(f"✅ M3 optimization complete: {len(cached_files)} files")
for file in cached_files:
print(f" 📦 {file}")
except Exception as e:
print(f"❌ M3 optimization error: {e}")
# 2. BGE-Reranker optimization with pairs format
print("\n🚀 2. BGE-RERANKER OPTIMIZATION (Pairs Format)")
print("-" * 60)
try:
cached_files = optimize_data(
model_type='bge-reranker',
input_files=[data_files['pairs']],
output_dir=cache_dir,
tokenizer_path='bert-base-uncased',
force=True,
output_format='flat'
)
print(f"✅ Reranker optimization complete: {len(cached_files)} files")
for file in cached_files:
print(f" 📦 {file}")
except Exception as e:
print(f"❌ Reranker optimization error: {e}")
# 3. Candidates format optimization (converts to pairs for reranker)
print("\n🚀 3. CANDIDATES FORMAT OPTIMIZATION")
print("-" * 60)
try:
cached_files = optimize_data(
model_type='bge-reranker',
input_files=[data_files['candidates']],
output_dir=cache_dir,
tokenizer_path='bert-base-uncased',
force=True,
output_format='flat'
)
print(f"✅ Candidates optimization complete: {len(cached_files)} files")
print(" Note: Candidates exploded into pairs with threshold-based labels")
for file in cached_files:
print(f" 📦 {file}")
except Exception as e:
print(f"❌ Candidates optimization error: {e}")
# 4. Candidates to M3 triplets conversion
print("\n🚀 4. CANDIDATES TO M3 TRIPLETS CONVERSION")
print("-" * 60)
try:
cached_files = optimize_data(
model_type='bge-m3',
input_files=[data_files['candidates']],
output_dir=cache_dir,
tokenizer_path='bert-base-uncased',
force=True,
hard_negative_ratio=0.8
)
print(f"✅ Candidates-to-M3 optimization complete: {len(cached_files)} files")
print(" Note: Candidates split by threshold into pos/neg for triplets")
for file in cached_files:
print(f" 📦 {file}")
except Exception as e:
print(f"❌ Candidates-to-M3 optimization error: {e}")
# 5. Legacy nested format optimization
print("\n🚀 5. LEGACY NESTED FORMAT OPTIMIZATION")
print("-" * 60)
try:
cached_files = optimize_data(
model_type='bge-reranker',
input_files=[data_files['legacy_nested']],
output_dir=cache_dir,
tokenizer_path='bert-base-uncased',
force=True,
output_format='nested' # Keep nested format
)
print(f"✅ Legacy optimization complete: {len(cached_files)} files")
print(" Note: Legacy nested format preserved for compatibility")
for file in cached_files:
print(f" 📦 {file}")
except Exception as e:
print(f"❌ Legacy optimization error: {e}")
def show_data_format_schemas():
"""Show clean data format schemas without inline comments"""
print("\n" + "="*80)
print("📋 DATA FORMAT SCHEMAS")
print("="*80)
print("\n1. 🎯 TRIPLETS FORMAT (BGE-M3 Embedding Training)")
print("-" * 60)
triplets_schema = {
"query": "What is machine learning?",
"pos": [
"Machine learning is a subset of artificial intelligence...",
"ML algorithms learn patterns from data..."
],
"neg": [
"Weather forecasting uses meteorological data...",
"Cooking recipes require specific ingredients..."
],
"pos_scores": [0.95, 0.88],
"neg_scores": [0.15, 0.08],
"prompt": "为此查询生成表示:",
"type": "definition"
}
print(json.dumps(triplets_schema, indent=2, ensure_ascii=False))
print("\n2. 🔍 PAIRS FORMAT (BGE-Reranker Training)")
print("-" * 60)
pairs_schema = {
"query": "What is machine learning?",
"passage": "Machine learning is a subset of artificial intelligence...",
"label": 1,
"score": 0.95,
"qid": "q1",
"pid": "p1"
}
print(json.dumps(pairs_schema, indent=2, ensure_ascii=False))
print("\n3. 🎲 CANDIDATES FORMAT (Threshold-based Conversion)")
print("-" * 60)
candidates_schema = {
"query": "What is artificial intelligence?",
"qid": "q1",
"candidates": [
{
"text": "AI is the simulation of human intelligence in machines.",
"score": 0.94,
"pid": "c1"
},
{
"text": "Mountain climbing requires proper equipment.",
"score": 0.23,
"pid": "c2"
}
]
}
print(json.dumps(candidates_schema, indent=2, ensure_ascii=False))
print("📝 Note: Scores >= threshold become label=1, others become label=0")
print("\n4. 🔄 LEGACY NESTED FORMAT (Backward Compatibility)")
print("-" * 60)
legacy_schema = {
"query": "What is natural language processing?",
"pos": [
"NLP is a field of AI that helps computers understand human language.",
"NLP combines computational linguistics with machine learning..."
],
"neg": [
"Automobile manufacturing involves assembly lines...",
"Photography captures light through camera lenses..."
],
"pos_scores": [0.93, 0.86],
"neg_scores": [0.14, 0.22]
}
print(json.dumps(legacy_schema, indent=2, ensure_ascii=False))
print("📝 Note: Automatically converted to flat pairs for reranker training")
def main():
"""Main demonstration function"""
print("🚀 BGE Data Pipeline Format Examples")
print("="*80)
print("Demonstrating all four supported formats:")
print("1. Triplets (BGE-M3 embedding training)")
print("2. Pairs (BGE-reranker training)")
print("3. Candidates (unified format with threshold conversion)")
print("4. Legacy nested (backward compatibility)")
print("="*80)
try:
# Show data schemas
show_data_format_schemas()
# Demonstrate fine-tuning loader
data_files = demonstrate_fine_tuning_loader()
# Demonstrate optimization pipeline
demonstrate_optimization_pipeline(data_files)
print("\n" + "="*80)
print("🎉 ALL EXAMPLES COMPLETED SUCCESSFULLY!")
print("="*80)
print("\n📋 SUMMARY:")
print("✅ Format Detection: Automatic detection of all four formats")
print("✅ Conversion Logic: Seamless conversion between formats")
print("✅ Threshold Assignment: Candidates → pairs with configurable threshold")
print("✅ Legacy Support: Automatic migration of nested → flat pairs")
print("✅ Optimization: Pre-tokenization and caching for 10-15x speedup")
print("✅ Backward Compatibility: Full support for existing datasets")
print(f"\n📁 Sample data and cache files created in: {data_files['temp_dir']}")
print(" You can examine the generated files to see the format conversions")
except Exception as e:
print(f"\n❌ Error during demonstration: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*80)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,20 @@
{"query":"什么是深度学习?","passage":"深度学习是机器学习的一个子领域,使用多层神经网络模拟人脑处理信息的方式","label":1,"score":1.0,"qid":"Q001","pid":"P001"}
{"query":"什么是深度学习?","passage":"深度学习Deep Learning是人工智能和机器学习的重要分支通过构建深层神经网络来学习数据表示","label":1,"score":0.95,"qid":"Q001","pid":"P002"}
{"query":"什么是深度学习?","passage":"机器学习包含多种算法,如决策树、支持向量机、神经网络等","label":0,"score":0.6,"qid":"Q001","pid":"N001"}
{"query":"什么是深度学习?","passage":"人工智能是计算机科学的一个分支,目标是创建能够执行智能任务的系统","label":0,"score":0.4,"qid":"Q001","pid":"N002"}
{"query":"什么是深度学习?","passage":"数据挖掘是从大型数据集中提取模式和知识的过程","label":0,"score":0.3,"qid":"Q001","pid":"N003"}
{"query":"如何制作红烧肉?","passage":"红烧肉制作步骤选五花肉切块焯水去腥热锅炒糖色下肉翻炒加生抽老抽料酒小火炖煮30-40分钟至软烂","label":1,"score":1.0,"qid":"Q002","pid":"P003"}
{"query":"如何制作红烧肉?","passage":"红烧肉是经典上海菜,用五花肉、冰糖、生抽、老抽、料酒等,关键在于炒糖色和火候控制","label":1,"score":0.9,"qid":"Q002","pid":"P004"}
{"query":"如何制作红烧肉?","passage":"红烧肉是中国传统菜肴,口感香甜软糯,深受大众喜爱","label":0,"score":0.5,"qid":"Q002","pid":"N004"}
{"query":"如何制作红烧肉?","passage":"五花肉含有丰富的蛋白质和脂肪,营养价值较高","label":0,"score":0.3,"qid":"Q002","pid":"N005"}
{"query":"如何制作红烧肉?","passage":"上海菜以红烧见长,口味偏甜,适合南方人饮食习惯","label":0,"score":0.2,"qid":"Q002","pid":"N006"}
{"query":"Python编程入门应该学习哪些内容","passage":"Python编程入门需要掌握基本语法、数据类型(字符串、列表、字典)、控制结构(if/for/while)、函数定义、面向对象编程基础","label":1,"score":1.0,"qid":"Q003","pid":"P005"}
{"query":"Python编程入门应该学习哪些内容","passage":"Python新手学习路线先学变量和数据类型然后是条件语句和循环接着学习函数和模块最后是类和对象","label":1,"score":0.95,"qid":"Q003","pid":"P006"}
{"query":"Python编程入门应该学习哪些内容","passage":"Python是一种简单易学的编程语言语法清晰适合初学者","label":0,"score":0.4,"qid":"Q003","pid":"N007"}
{"query":"Python编程入门应该学习哪些内容","passage":"编程思维比语法更重要,需要培养逻辑思维能力","label":0,"score":0.3,"qid":"Q003","pid":"N008"}
{"query":"Python编程入门应该学习哪些内容","passage":"计算机科学包含算法、数据结构、操作系统等多个分支","label":0,"score":0.2,"qid":"Q003","pid":"N009"}
{"query":"北京有哪些必去的旅游景点?","passage":"北京必去景点推荐:故宫(紫禁城)、天安门广场、万里长城、颐和园、天坛、圆明园、北海公园、什刹海等","label":1,"score":1.0,"qid":"Q004","pid":"P007"}
{"query":"北京有哪些必去的旅游景点?","passage":"故宫是明清两代皇宫占地72万平方米是世界最大的古代宫殿建筑群必须提前预约参观","label":1,"score":0.9,"qid":"Q004","pid":"P008"}
{"query":"北京有哪些必去的旅游景点?","passage":"北京是中华人民共和国首都有3000多年建城史","label":0,"score":0.3,"qid":"Q004","pid":"N010"}
{"query":"北京有哪些必去的旅游景点?","passage":"中国拥有众多世界文化遗产,旅游资源丰富","label":0,"score":0.25,"qid":"Q004","pid":"N011"}
{"query":"北京有哪些必去的旅游景点?","passage":"春秋季节是北京旅游的最佳时间,气候宜人","label":0,"score":0.4,"qid":"Q004","pid":"N012"}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,534 @@
"""
ANCE-style hard negative mining for BGE models
"""
import os
import json
import logging
import threading
import queue
from typing import List, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import faiss
# Optional imports for text processing
try:
import jieba
JIEBA_AVAILABLE = True
except ImportError:
JIEBA_AVAILABLE = False
jieba = None
try:
from rank_bm25 import BM25Okapi
BM25_AVAILABLE = True
except ImportError:
BM25_AVAILABLE = False
BM25Okapi = None
from collections import defaultdict
logger = logging.getLogger(__name__)
class ANCEHardNegativeMiner:
"""
Approximate Nearest Neighbor Negative Contrastive Learning
Dynamically mines hard negatives using evolving model
Notes:
- All parameters (embedding_dim, num_hard_negatives, negative_range, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
embedding_dim: int = 768,
num_hard_negatives: int = 15,
negative_range: Tuple[int, int] = (10, 100),
margin: float = 0.1,
index_refresh_interval: int = 1000,
index_type: str = "IVF",
nlist: int = 100,
use_gpu: bool = True,
max_index_size: int = 1000000
):
self.embedding_dim = embedding_dim
self.num_hard_negatives = num_hard_negatives
self.negative_range = negative_range
self.margin = margin
self.index_refresh_interval = index_refresh_interval
self.use_gpu = use_gpu and torch.cuda.is_available()
self.max_index_size = max_index_size
# Initialize FAISS index
self.index = self._create_index(index_type, nlist)
self.passage_map: Dict[int, str] = {} # Maps index ID to passage text
self.query_map: Dict[str, List[str]] = defaultdict(list) # Maps query to positive passages
# For asynchronous index updates
self.update_queue = queue.Queue()
self.update_thread = None
self.stop_update_thread = threading.Event()
self.current_step = 0
# Statistics
self.stats = {
'total_queries': 0,
'total_negatives_mined': 0,
'index_updates': 0,
'cache_hits': 0,
'cache_misses': 0
}
def _create_index(self, index_type: str, nlist: int) -> faiss.Index:
"""Create FAISS index, warn if falling back to CPU."""
if index_type == "IVF":
quantizer = faiss.IndexFlatIP(self.embedding_dim)
index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT)
elif index_type == "HNSW":
index = faiss.IndexHNSWFlat(self.embedding_dim, 32, faiss.METRIC_INNER_PRODUCT)
else:
index = faiss.IndexFlatIP(self.embedding_dim)
if self.use_gpu:
try:
res = faiss.StandardGpuResources() # type: ignore[attr-defined]
index = faiss.index_cpu_to_gpu(res, 0, index) # type: ignore[attr-defined]
logger.info("Using GPU for FAISS index")
except Exception as e:
logger.warning(f"Failed to use GPU for FAISS: {e}. Using CPU.")
self.use_gpu = False
else:
logger.warning("FAISS is running on CPU. This may be slower than GPU.")
return index
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
return True
return False
def start_async_updates(self):
"""Start background thread for asynchronous index updates"""
if self.update_thread is None or not self.update_thread.is_alive():
self.stop_update_thread.clear()
self.update_thread = threading.Thread(target=self._update_worker)
self.update_thread.daemon = True
self.update_thread.start()
logger.info("Started asynchronous index update thread")
def stop_async_updates(self):
"""Stop background update thread and ensure clean shutdown."""
if self.update_thread and self.update_thread.is_alive():
self.stop_update_thread.set()
self.update_thread.join(timeout=5)
logger.info("Stopped asynchronous index update thread")
def _update_worker(self):
"""Background worker for index updates"""
while not self.stop_update_thread.is_set():
try:
# Get update task with timeout
task = self.update_queue.get(timeout=1)
if task['type'] == 'add_embeddings':
self._add_to_index(task['embeddings'], task['passages'])
elif task['type'] == 'rebuild_index':
self._rebuild_index(task['model'])
self.stats['index_updates'] += 1
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error in update worker: {e}")
def _add_to_index(self, embeddings: np.ndarray, passages: List[str]):
"""Add embeddings to index"""
try:
# Normalize embeddings for inner product search
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Add to index
start_idx = len(self.passage_map)
# Debug log for embeddings
logger.debug(f"embeddings shape: {getattr(embeddings, 'shape', None)}, type: {type(embeddings)}")
if hasattr(self.index, 'is_trained') and not self.index.is_trained:
if embeddings is not None and len(embeddings) > 0:
if hasattr(self.index, 'nlist') and len(embeddings) >= getattr(self.index, 'nlist', 0):
# Safe to access self.index.nlist and call train
logger.info("Training FAISS index...")
self.index.train(embeddings) # type: ignore[call-arg]
elif hasattr(self.index, 'nlist'):
logger.warning(f"Not enough data to train index. Need {getattr(self.index, 'nlist', 0)}, got {len(embeddings)}")
return
else:
# For index types without nlist/train, skip training
pass
else:
logger.warning("Embeddings is None or empty, skipping training and addition to index.")
return
if embeddings is not None and len(embeddings) > 0:
self.index.add(embeddings) # type: ignore[call-arg]
else:
logger.warning("Embeddings is None or empty, skipping addition to index.")
# Update passage map
for i, passage in enumerate(passages):
self.passage_map[start_idx + i] = passage
logger.info(f"Added {len(passages)} passages to index. Total: {self.index.ntotal}")
except Exception as e:
logger.error(f"Error adding embeddings to index: {e}")
def _rebuild_index(self, model):
"""Rebuild entire index with current model"""
try:
logger.info("Rebuilding entire index...")
# Clear current index
self.index.reset()
old_passage_map = self.passage_map.copy()
self.passage_map.clear()
# Re-encode all passages
all_passages = list(set(sum(self.query_map.values(), [])))
if not all_passages:
logger.warning("No passages to rebuild index with")
return
embeddings = []
batch_size = 32
for i in tqdm(range(0, len(all_passages), batch_size), desc="Encoding passages"):
batch = all_passages[i:i + batch_size]
try:
with torch.no_grad():
batch_embeddings = model.encode(batch, convert_to_numpy=True, batch_size=len(batch))
embeddings.append(batch_embeddings)
except Exception as e:
logger.error(f"Error encoding batch {i // batch_size}: {e}")
continue
if embeddings:
embeddings = np.vstack(embeddings)
self._add_to_index(embeddings, all_passages)
logger.info(f"Rebuilt index with {len(all_passages)} passages")
else:
# Restore old passage map if rebuild failed
self.passage_map = old_passage_map
logger.error("Failed to rebuild index - no embeddings generated")
except Exception as e:
logger.error(f"Error rebuilding index: {e}")
def mine_hard_negatives(
self,
queries: List[str],
query_embeddings: np.ndarray,
positive_passages: List[List[str]],
corpus_embeddings: Optional[np.ndarray] = None,
corpus_passages: Optional[List[str]] = None
) -> List[List[str]]:
"""
Mine hard negatives for queries using ANCE approach
Args:
queries: List of query texts
query_embeddings: Query embeddings [num_queries, embedding_dim]
positive_passages: List of positive passages for each query
corpus_embeddings: Optional corpus embeddings to add to index
corpus_passages: Optional corpus passages
Returns:
List of hard negative passages for each query
"""
# Add corpus to index if provided
if corpus_embeddings is not None and corpus_passages is not None:
self.update_queue.put({
'type': 'add_embeddings',
'embeddings': corpus_embeddings,
'passages': corpus_passages
})
# Update query map
for query, positives in zip(queries, positive_passages):
self.query_map[query].extend(positives)
# Normalize query embeddings
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
# Search for hard negatives
hard_negatives = []
k = min(self.negative_range[1], self.index.ntotal)
if k == 0:
logger.warning("Index is empty, returning empty hard negatives")
return [[] for _ in queries]
similarities, indices = self.index.search(query_embeddings, k) # type: ignore[call-arg]
for i, (query, positives) in enumerate(zip(queries, positive_passages)):
query_hard_negs = []
candidates = []
# Collect candidates within negative range
for j in range(self.negative_range[0], min(k, self.negative_range[1])):
if indices[i, j] >= 0: # Valid index
passage_idx = indices[i, j]
if passage_idx in self.passage_map:
passage = self.passage_map[passage_idx]
score = similarities[i, j]
# Skip if it's a positive passage
if passage not in positives:
candidates.append((passage, score))
# Select hard negatives with margin
if candidates:
# Sort by score (descending)
candidates.sort(key=lambda x: x[1], reverse=True)
# Select with margin constraint
selected = []
last_score = float('inf')
for passage, score in candidates:
if last_score - score >= self.margin or len(selected) == 0:
selected.append(passage)
last_score = score
if len(selected) >= self.num_hard_negatives:
break
query_hard_negs = selected
# Pad with random negatives if needed
if len(query_hard_negs) < self.num_hard_negatives:
available = [p for p in self.passage_map.values()
if p not in positives and p not in query_hard_negs]
if available:
additional_count = min(len(available), self.num_hard_negatives - len(query_hard_negs))
additional = np.random.choice(available, size=additional_count, replace=False).tolist()
query_hard_negs.extend(additional)
hard_negatives.append(query_hard_negs)
self.stats['total_queries'] += len(queries)
self.stats['total_negatives_mined'] += sum(len(negs) for negs in hard_negatives)
return hard_negatives
def update_index_if_needed(self, step: int, model=None):
"""Check if index needs update and trigger if necessary"""
self.current_step = step
if step > 0 and step % self.index_refresh_interval == 0 and model is not None:
logger.info(f"Triggering index rebuild at step {step}")
self.update_queue.put({
'type': 'rebuild_index',
'model': model
})
def save_mined_data(self, output_path: str, original_data_path: str):
"""Save data with mined hard negatives"""
logger.info(f"Saving mined data to {output_path}")
# Load original data
original_data = []
with open(original_data_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
original_data.append(json.loads(line))
# This is a simplified implementation - in practice you'd batch mine negatives
with open(output_path, 'w', encoding='utf-8') as f:
for item in tqdm(original_data, desc="Saving mined data"):
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logger.info(f"Saved {len(original_data)} examples")
def get_statistics(self) -> Dict[str, int]:
"""Get mining statistics"""
self.stats['index_size'] = self.index.ntotal
self.stats['unique_passages'] = len(self.passage_map)
return self.stats.copy()
class HardNegativeDataAugmenter:
"""
Augment training data with various hard negative strategies
"""
def __init__(self, config: Dict):
self.config = config
self.strategies = config.get('strategies', ['ance', 'bm25', 'semantic'])
self.bm25_index = None
self.corpus_passages = None
def build_bm25_index(self, corpus: List[str]):
"""Build BM25 index for hard negative mining"""
try:
logger.info("Building BM25 index...")
self.corpus_passages = corpus
# Tokenize corpus
tokenized_corpus = []
for doc in tqdm(corpus, desc="Tokenizing corpus"):
if self._is_chinese(doc) and JIEBA_AVAILABLE:
tokens = list(jieba.cut(doc))
else:
tokens = doc.lower().split()
tokenized_corpus.append(tokens)
if BM25_AVAILABLE:
self.bm25_index = BM25Okapi(tokenized_corpus)
logger.info(f"Built BM25 index with {len(corpus)} documents")
else:
logger.warning("BM25Okapi not available, skipping BM25 index")
self.bm25_index = None
except Exception as e:
logger.error(f"Error building BM25 index: {e}")
self.bm25_index = None
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
return True
return False
def augment_data(
self,
data: List[Dict],
model=None,
corpus: Optional[List[str]] = None
) -> List[Dict]:
"""Apply multiple hard negative augmentation strategies"""
# Build BM25 index if corpus provided
if corpus and 'bm25' in self.strategies:
self.build_bm25_index(corpus)
augmented_data = []
for item in tqdm(data, desc="Augmenting data"):
query = item['query']
positives = item.get('pos', [])
negatives = item.get('neg', [])
# Apply different strategies
if 'bm25' in self.strategies and self.bm25_index:
bm25_negs = self._mine_bm25_negatives(query, positives)
negatives.extend(bm25_negs[:5]) # Add top 5 BM25 negatives
if 'semantic' in self.strategies and model:
semantic_negs = self._mine_semantic_negatives(query, positives, model)
negatives.extend(semantic_negs[:5])
if 'random_swap' in self.strategies:
swap_negs = self._create_swap_negatives(positives)
negatives.extend(swap_negs[:3])
# Remove duplicates while preserving order
seen = set()
unique_negatives = []
for neg in negatives:
if neg not in seen and neg not in positives:
seen.add(neg)
unique_negatives.append(neg)
augmented_item = item.copy()
augmented_item['neg'] = unique_negatives[:self.config.get('max_negatives', 30)]
augmented_data.append(augmented_item)
return augmented_data
def _mine_bm25_negatives(self, query: str, positives: List[str]) -> List[str]:
"""Mine hard negatives using BM25"""
if not self.bm25_index or not self.corpus_passages or not BM25_AVAILABLE:
return []
# Tokenize query
if self._is_chinese(query) and JIEBA_AVAILABLE:
query_tokens = list(jieba.cut(query))
else:
query_tokens = query.lower().split()
# Get BM25 scores
scores = self.bm25_index.get_scores(query_tokens)
# Get top passages excluding positives
top_indices = np.argsort(scores)[::-1]
negatives = []
for idx in top_indices:
if len(negatives) >= 10:
break
passage = self.corpus_passages[idx]
if passage not in positives:
negatives.append(passage)
return negatives
def _mine_semantic_negatives(self, query: str, positives: List[str], model) -> List[str]:
"""
Mine semantically similar but irrelevant passages using the model's encode method.
Returns top-k most similar passages (excluding positives).
"""
try:
if not self.corpus_passages or not model:
logger.warning("No corpus or model provided for semantic negative mining.")
return []
batch_size = self.config.get('semantic_batch_size', 64)
num_negatives = self.config.get('semantic_num_negatives', 10)
device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
# Encode query
query_embedding = model.encode([query], convert_to_numpy=True, device=device)
# Encode all corpus passages in batches
passage_embeddings = []
for i in range(0, len(self.corpus_passages), batch_size):
batch = self.corpus_passages[i:i+batch_size]
batch_emb = model.encode(batch, convert_to_numpy=True, device=device)
passage_embeddings.append(batch_emb)
passage_embeddings = np.vstack(passage_embeddings)
# Compute similarity
scores = np.dot(passage_embeddings, query_embedding[0])
# Exclude positives
candidates = [(idx, score) for idx, (score, passage) in enumerate(zip(scores, self.corpus_passages)) if passage not in positives]
# Sort and select top-k
candidates.sort(key=lambda x: x[1], reverse=True)
top_indices = [idx for idx, _ in candidates[:num_negatives]]
negatives = [self.corpus_passages[idx] for idx in top_indices]
logger.info(f"Mined {len(negatives)} semantic negatives for query: {query}")
return negatives
except Exception as e:
logger.error(f"Error mining semantic negatives: {e}")
return []
def _create_swap_negatives(self, positives: List[str]) -> List[str]:
"""Create negatives by swapping words in positive passages"""
negatives = []
for pos in positives[:2]: # Only use first 2 positives
words = pos.split()
if len(words) > 5:
# Swap random words
import random
swapped = words.copy()
for _ in range(min(3, len(words) // 3)):
i, j = random.sample(range(len(words)), 2)
swapped[i], swapped[j] = swapped[j], swapped[i]
negatives.append(' '.join(swapped))
return negatives

687
data/preprocessing.py Normal file
View File

@ -0,0 +1,687 @@
import json
import os
import random
from typing import List, Dict, Optional, Tuple, Union
import pandas as pd
import numpy as np
from tqdm import tqdm
import logging
from transformers import AutoTokenizer
import torch
from collections import defaultdict
import re
import xml.etree.ElementTree as ET
import csv
logger = logging.getLogger(__name__)
class DataPreprocessor:
"""
Comprehensive data preprocessor for various input formats
Supports JSON, JSONL, CSV, TSV, XML, and custom formats
Notes:
- All parameters (tokenizer, max_query_length, max_passage_length, min_passage_length, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
tokenizer: Optional[AutoTokenizer] = None,
max_query_length: int = 512,
max_passage_length: int = 8192,
# For production, set min_passage_length to 10
min_passage_length: int = 3
):
self.tokenizer = tokenizer
self.max_query_length = max_query_length
self.max_passage_length = max_passage_length
self.min_passage_length = min_passage_length
def preprocess_file(
self,
input_path: str,
output_path: str,
file_format: str = "auto",
validation_split: float = 0.1
) -> Tuple[str, Optional[str]]:
"""
Preprocess input file and convert to standard JSONL format.
Args:
input_path: Path to input file
output_path: Path to output JSONL file
file_format: Input file format (auto, json, jsonl, csv, tsv, xml, dureader, msmarco)
validation_split: Fraction of data to use for validation
Returns:
Tuple of (train_path, val_path)
"""
# Auto-detect format
if file_format == "auto":
file_format = self._detect_format(input_path)
logger.info(f"Processing {input_path} as {file_format} format")
# Load data based on format
if file_format == "jsonl":
data = self._load_jsonl(input_path)
elif file_format == "json":
data = self._load_json(input_path)
elif file_format in ["csv", "tsv"]:
data = self._load_csv(input_path, delimiter="," if file_format == "csv" else "\t")
elif file_format == "xml":
data = self._load_xml(input_path)
elif file_format == "dureader":
data = self._load_dureader(input_path)
elif file_format == "msmarco":
data = self._load_msmarco(input_path)
elif file_format == "beir":
data = self._load_beir(input_path)
else:
raise ValueError(f"Unsupported format: {file_format}")
# Validate and clean data
data = self._validate_and_clean(data)
# Split into train/val
if validation_split > 0:
train_data, val_data = self._split_data(data, validation_split)
# Save training data
train_path = output_path.replace(".jsonl", "_train.jsonl")
self._save_jsonl(train_data, train_path)
# Save validation data
val_path = output_path.replace(".jsonl", "_val.jsonl")
self._save_jsonl(val_data, val_path)
logger.info(f"Saved {len(train_data)} training and {len(val_data)} validation examples")
return train_path, val_path
else:
self._save_jsonl(data, output_path)
logger.info(f"Saved {len(data)} examples to {output_path}")
return output_path, None
def _detect_format(self, file_path: str) -> str:
"""Auto-detect file format based on extension and content"""
ext = os.path.splitext(file_path)[1].lower()
# First try by extension
if ext == ".jsonl":
return "jsonl"
elif ext == ".json":
return "json"
elif ext == ".csv":
return "csv"
elif ext == ".tsv":
return "tsv"
elif ext == ".xml":
return "xml"
# Try to detect by content
try:
with open(file_path, 'r', encoding='utf-8') as f:
first_lines = [f.readline().strip() for _ in range(5)]
first_content = '\n'.join(first_lines)
# Check for JSON Lines
if all(line.startswith("{") and line.endswith("}") for line in first_lines if line):
return "jsonl"
# Check for JSON
if first_content.strip().startswith("[") or first_content.strip().startswith("{"):
return "json"
# Check for XML
if first_content.strip().startswith("<?xml") or first_content.strip().startswith("<"):
return "xml"
# Check for TSV (tabs)
if any("\t" in line for line in first_lines):
return "tsv"
# Check for CSV (commas)
if any("," in line for line in first_lines):
return "csv"
except Exception as e:
logger.warning(f"Could not auto-detect format for {file_path}: {e}")
# Default fallback
return "jsonl"
def _load_jsonl(self, file_path: str) -> List[Dict]:
"""Load JSONL file with robust error handling and detailed logging."""
data = []
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading JSONL"), 1):
line = line.strip()
if line:
total_lines += 1
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e} | Content: {line[:100]}")
except Exception as e:
logger.warning(f"Skipping malformed line at {file_path}:{line_num}: {e} | Content: {line[:100]}")
except Exception as e:
logger.error(f"Error loading JSONL file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_json(self, file_path: str) -> List[Dict]:
"""Load JSON file"""
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Convert to list format if needed
if isinstance(data, dict):
if 'data' in data:
data = data['data']
else:
data = [data]
elif not isinstance(data, list):
data = [data]
total_lines = len(data)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in file {file_path}: {e}")
raise
except Exception as e:
logger.error(f"Error loading JSON file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_csv(self, file_path: str, delimiter: str = ",") -> List[Dict]:
"""Load CSV/TSV file with robust parsing"""
data = []
total_lines = 0
try:
# Try different encodings
encodings = ['utf-8', 'gb2312', 'gbk', 'big5', 'latin-1']
df = None
for encoding in encodings:
try:
df = pd.read_csv(file_path, delimiter=delimiter, encoding=encoding)
logger.info(f"Successfully loaded CSV with {encoding} encoding")
break
except UnicodeDecodeError:
continue
except Exception as e:
logger.warning(f"Error with encoding {encoding} in {file_path}: {e}")
continue
if df is None:
raise ValueError(f"Could not load CSV file {file_path} with any encoding")
# Process each row
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CSV/TSV"):
total_lines += 1
try:
item = {}
# Map common column names to standard format
row_dict = {str(k).lower().strip(): v for k, v in row.items() if pd.notna(v)}
# Extract query
query_fields = ['query', 'question', 'q', 'text', 'sentence']
for field in query_fields:
if field in row_dict:
item['query'] = str(row_dict[field]).strip()
break
# Extract positive passages
pos_fields = ['positive', 'pos', 'passage', 'document', 'answer', 'response']
positives = []
for field in pos_fields:
if field in row_dict:
pos_text = str(row_dict[field]).strip()
if '|||' in pos_text:
positives.extend([p.strip() for p in pos_text.split('|||') if p.strip()])
else:
positives.append(pos_text)
break
if 'passage' in row_dict and 'label' in row_dict:
try:
label = int(float(row_dict['label']))
passage = str(row_dict['passage']).strip()
if label == 1 and passage:
positives.append(passage)
except (ValueError, TypeError):
pass
if positives:
item['pos'] = positives
# Extract negative passages
neg_fields = ['negative', 'neg', 'hard_negative']
negatives = []
for field in neg_fields:
if field in row_dict:
neg_text = str(row_dict[field]).strip()
if '|||' in neg_text:
negatives.extend([n.strip() for n in neg_text.split('|||') if n.strip()])
else:
negatives.append(neg_text)
break
if 'passage' in row_dict and 'label' in row_dict:
try:
label = int(float(row_dict['label']))
passage = str(row_dict['passage']).strip()
if label == 0 and passage:
negatives.append(passage)
except (ValueError, TypeError):
pass
if negatives:
item['neg'] = negatives
# Extract scores if available
score_fields = ['score', 'relevance', 'rating']
for field in score_fields:
if field in row_dict:
try:
score = float(row_dict[field])
if 'pos' in item:
item['pos_scores'] = [score] * len(item['pos'])
except (ValueError, TypeError):
pass
# Only add if we have minimum required fields
if 'query' in item and ('pos' in item or 'neg' in item):
data.append(item)
except Exception as e:
logger.warning(f"Skipping malformed row in {file_path} at index {idx}: {e} | Row: {row}")
except Exception as e:
logger.error(f"Error loading CSV file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_tsv(self, file_path: str) -> List[Dict]:
"""Load TSV file (wrapper around _load_csv)."""
return self._load_csv(file_path, delimiter="\t")
def _load_xml(self, file_path: str) -> List[Dict]:
"""Load XML file (e.g., TREC format)"""
data = []
total_lines = 0
try:
tree = ET.parse(file_path)
root = tree.getroot()
# Handle different XML structures
if root.tag == 'questions' or root.tag == 'queries':
# TREC-style format
for question in tqdm(root.findall('.//question') or root.findall('.//query'), desc="Loading XML"):
item = {}
# Extract query
query_elem = question.find('text') or question.find('title') or question.find('query')
if query_elem is not None and query_elem.text is not None:
item['query'] = query_elem.text.strip()
# Extract passages/documents
passages = []
for doc in question.findall('.//document') or question.findall('.//passage'):
if doc.text:
passages.append(doc.text.strip())
if passages:
item['pos'] = passages
if 'query' in item:
data.append(item)
else:
# Generic XML - try to extract text content
for elem in tqdm(root.iter(), desc="Processing XML"):
if elem.text and len(elem.text.strip()) > 20: # Minimum length check
item = {
'query': elem.text.strip()[:100], # First part as query
'pos': [elem.text.strip()] # Full text as passage
}
data.append(item)
except Exception as e:
logger.error(f"Error loading XML file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_dureader(self, file_path: str) -> List[Dict]:
"""Load DuReader format data"""
data = []
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading DuReader"), 1):
if line.strip():
total_lines += 1
try:
item = json.loads(line)
# Convert DuReader format
processed = {
'query': item.get('question', '').strip(),
'pos': [],
'neg': []
}
# Extract positive passages from answers
if 'answers' in item:
for answer in item['answers']:
if isinstance(answer, dict) and 'text' in answer:
processed['pos'].append(answer['text'].strip())
elif isinstance(answer, str):
processed['pos'].append(answer.strip())
# Extract documents as potential negatives
if 'documents' in item:
for doc in item['documents']:
if isinstance(doc, dict):
if 'paragraphs' in doc:
for para in doc['paragraphs']:
para_text = para.strip() if isinstance(para, str) else str(para).strip()
if para_text and para_text not in processed['pos']:
processed['neg'].append(para_text)
elif 'content' in doc:
content = doc['content'].strip()
if content and content not in processed['pos']:
processed['neg'].append(content)
if processed['query'] and processed['pos']:
data.append(processed)
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON line in DuReader: {e}")
except Exception as e:
logger.error(f"Error loading DuReader file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_msmarco(self, file_path: str) -> List[Dict]:
"""Load MS MARCO format data"""
data = []
total_lines = 0
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading MS MARCO"), 1):
line = line.strip()
if line:
total_lines += 1
parts = line.split('\t')
if len(parts) >= 2:
item = {
'query': parts[0].strip(),
'pos': [parts[1].strip()],
'neg': [p.strip() for p in parts[2:] if p.strip()]
}
data.append(item)
except Exception as e:
logger.error(f"Error loading MS MARCO file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _load_beir(self, file_path: str) -> List[Dict]:
"""Load BEIR format data"""
data = []
total_lines = 0
try:
# BEIR format is typically JSONL with specific structure
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc="Loading BEIR"), 1):
if line.strip():
total_lines += 1
try:
item = json.loads(line)
# BEIR format usually has 'query', 'positive_passages', 'negative_passages'
processed = {}
if 'query' in item:
processed['query'] = item['query'].strip()
elif 'text' in item:
processed['query'] = item['text'].strip()
if 'positive_passages' in item:
processed['pos'] = [p.strip() for p in item['positive_passages'] if p.strip()]
elif 'pos' in item:
processed['pos'] = [p.strip() for p in item['pos'] if p.strip()]
if 'negative_passages' in item:
processed['neg'] = [n.strip() for n in item['negative_passages'] if n.strip()]
elif 'neg' in item:
processed['neg'] = [n.strip() for n in item['neg'] if n.strip()]
if 'query' in processed and 'pos' in processed:
data.append(processed)
except json.JSONDecodeError as e:
logger.warning(f"Skipping invalid JSON line in BEIR: {e}")
except Exception as e:
logger.error(f"Error loading BEIR file {file_path}: {e}")
raise
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
logger.warning(
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
)
return data
def _validate_and_clean(self, data: List[Dict]) -> List[Dict]:
"""Validate and clean data"""
cleaned = []
for item in tqdm(data, desc="Validating data"):
try:
# Ensure required fields
if 'query' not in item or not item['query'].strip():
continue
# Clean query
if self.tokenizer:
encoding = self.tokenizer(item['query'], add_special_tokens=False) # type: ignore[operator]
input_ids = encoding['input_ids']
if len(input_ids) > self.max_query_length:
input_ids = input_ids[:self.max_query_length]
item['query'] = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
else:
if len(item['query']) > self.max_query_length:
item['query'] = item['query'][: self.max_query_length]
# Ensure at least one positive
if 'pos' not in item or not item['pos']:
continue
# Clean positives
item['pos'] = self._filter_passages_by_length(item['pos'])
if not item['pos']:
continue
# Clean negatives
if 'neg' in item:
item['neg'] = self._filter_passages_by_length(item['neg'])
else:
item['neg'] = []
# Filter by length if tokenizer available
if self.tokenizer:
# Check query length
# Check passage lengths
item['pos'] = self._filter_passages_by_length(item['pos'])
item['neg'] = self._filter_passages_by_length(item['neg'])
# Ensure we still have data after filtering
if item['pos']:
cleaned.append(item)
except Exception as e:
logger.warning(f"Error processing item: {e}")
continue
logger.info(f"Cleaned {len(cleaned)} examples from {len(data)} original")
return cleaned
def _filter_passages_by_length(self, passages: List[str]) -> List[str]:
"""Filter passages by token length"""
if not self.tokenizer:
return passages
filtered = []
for passage in passages:
try:
encoding = self.tokenizer(passage, add_special_tokens=False) # type: ignore[operator]
input_ids = encoding['input_ids']
if len(input_ids) < self.min_passage_length:
continue
if len(input_ids) > self.max_passage_length:
# Truncate to max length
input_ids = input_ids[:self.max_passage_length]
passage = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
filtered.append(passage)
except Exception as e:
logger.warning(f"Error processing passage: {e}")
continue
return filtered
def _split_data(
self,
data: List[Dict],
validation_split: float
) -> Tuple[List[Dict], List[Dict]]:
"""Split data into train/validation sets"""
# Shuffle data
shuffled_data = data.copy()
random.shuffle(shuffled_data)
# Calculate split point
split_idx = int(len(shuffled_data) * (1 - validation_split))
train_data = shuffled_data[:split_idx]
val_data = shuffled_data[split_idx:]
return train_data, val_data
def _save_jsonl(self, data: List[Dict], output_path: str):
"""Save data in JSONL format"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
with open(output_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
except Exception as e:
logger.error(f"Error saving JSONL file {output_path}: {e}")
raise
class DataAugmenter:
"""
Data augmentation strategies for retrieval training
"""
def __init__(self, tokenizer: Optional[AutoTokenizer] = None):
self.tokenizer = tokenizer
def augment_with_paraphrases(
self,
data: List[Dict],
num_paraphrases: int = 2
) -> List[Dict]:
"""
Augment queries with paraphrases using rule-based approach
"""
augmented = []
for item in tqdm(data, desc="Augmenting with paraphrases"):
# Add original
augmented.append(item)
# Generate paraphrases
for i in range(num_paraphrases):
para_item = item.copy()
para_query = self._generate_paraphrase(item['query'], i)
if para_query != item['query']:
para_item['query'] = para_query
para_item['is_augmented'] = True
para_item['augmentation_type'] = 'paraphrase'
augmented.append(para_item)
return augmented
def _generate_paraphrase(self, query: str, variant: int) -> str:
"""Generate paraphrase using rule-based patterns"""
if self._is_chinese(query):
# Chinese paraphrase patterns
patterns = [
("什么是", "请解释一下"),
("如何", "怎样才能"),
("为什么", "什么原因导致"),
("在哪里", "什么地方"),
("是什么", "指的是什么"),
("有什么", "存在哪些"),
("怎么办", "如何处理"),
("多少", "数量是多少")
]
# Apply patterns based on variant
if variant == 0:
for original, replacement in patterns:
if original in query:
return query.replace(original, replacement, 1)
else:
# Add question prefixes
prefixes = ["请问", "我想知道", "能否告诉我", "想了解"]
return f"{prefixes[variant % len(prefixes)]}{query}"
else:
# English paraphrase patterns
patterns = [
("what is", "what does"),
("how to", "how can I"),
("why does", "what causes"),
("where is", "what location"),
("when did", "at what time")
]
query_lower = query.lower()
for original, replacement in patterns:
if original in query_lower:
return query_lower.replace(original, replacement, 1)
# Add question prefixes
prefixes = ["Please explain", "Could you tell me", "I need to know", "Help me understand"]
return f"{prefixes[variant % len(prefixes)]} {query.lower()}"
return query
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters"""
return bool(re.search(r'[\u4e00-\u9fff]', text))

254
docs/data_formats.md Normal file
View File

@ -0,0 +1,254 @@
# BGE Data Pipeline Format Specifications
The refactored BGE data pipeline supports four distinct input formats with automatic detection and conversion:
## 📊 Format Overview
| Format | Use Case | Detection Key | Output |
|--------|----------|---------------|---------|
| **Triplets** | BGE-M3 embedding training | `pos` + `neg` | Contrastive learning batches |
| **Pairs** | BGE-reranker training | `passage` + `label` | Cross-encoder pairs |
| **Candidates** | Unified threshold conversion | `candidates` | Auto-exploded pairs |
| **Legacy Nested** | Backward compatibility | `pos` only | Converted to pairs |
---
## 1. 🎯 Triplets Format (BGE-M3 Embedding)
For contrastive learning with multiple positives and negatives per query.
```json
{
"query": "What is machine learning?",
"pos": [
"Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
"ML algorithms use statistical techniques to learn patterns from data and make predictions."
],
"neg": [
"Weather forecasting uses meteorological data to predict atmospheric conditions.",
"Cooking recipes require specific ingredients and step-by-step instructions."
],
"pos_scores": [0.95, 0.88],
"neg_scores": [0.15, 0.08],
"prompt": "为此查询生成表示:",
"type": "definition"
}
```
### Required Fields
- `query`: Query text
- `pos`: List of positive passages
### Optional Fields
- `neg`: List of negative passages
- `pos_scores`: Relevance scores for positives
- `neg_scores`: Relevance scores for negatives
- `prompt`: Query instruction prefix
- `type`: Query type classification
---
## 2. 🔍 Pairs Format (BGE-Reranker)
For cross-encoder training with individual query-passage pairs.
```json
{
"query": "What is machine learning?",
"passage": "Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
"label": 1,
"score": 0.95,
"qid": "q1",
"pid": "p1"
}
```
### Required Fields
- `query`: Query text
- `passage`: Passage text
- `label`: Relevance label (0 or 1)
### Optional Fields
- `score`: Relevance score (0.0-1.0)
- `qid`: Query identifier
- `pid`: Passage identifier
---
## 3. 🎲 Candidates Format (Unified)
For datasets with multiple candidates per query. Each candidate is automatically exploded into a separate pair based on score threshold.
```json
{
"query": "What is artificial intelligence?",
"qid": "q1",
"candidates": [
{
"text": "Artificial intelligence is the simulation of human intelligence in machines.",
"score": 0.94,
"pid": "c1"
},
{
"text": "AI systems can perform tasks that typically require human intelligence.",
"score": 0.87,
"pid": "c2"
},
{
"text": "Mountain climbing requires proper equipment and training.",
"score": 0.23,
"pid": "c3"
}
]
}
```
### Required Fields
- `query`: Query text
- `candidates`: List of candidate objects
### Candidate Object Fields
- `text`: Candidate passage text
- `score`: Relevance score (optional, defaults to 1.0)
- `pid`: Passage identifier (optional, auto-generated)
### Processing Logic
- **Threshold Assignment**: Candidates with `score >= threshold` become `label=1`, others become `label=0`
- **Default Threshold**: 0.5 (configurable)
- **Output**: Multiple pairs, one per candidate
---
## 4. 🔄 Legacy Nested Format (Backward Compatibility)
Legacy format automatically converted to flat pairs for reranker training.
```json
{
"query": "What is natural language processing?",
"pos": [
"Natural language processing is a field of AI that helps computers understand human language.",
"NLP combines computational linguistics with machine learning to process text and speech."
],
"neg": [
"Automobile manufacturing involves assembly lines and quality control processes.",
"Photography captures light through camera lenses to create images."
],
"pos_scores": [0.93, 0.86],
"neg_scores": [0.14, 0.22]
}
```
### Required Fields
- `query`: Query text
- `pos`: List of positive passages
### Optional Fields
- `neg`: List of negative passages
- `pos_scores`: Scores for positives
- `neg_scores`: Scores for negatives
### Processing Logic
- Each `pos[i]` becomes `(query, passage, label=1)`
- Each `neg[j]` becomes `(query, passage, label=0)`
- Maintains original scores
- Tracks source as `legacy_nested`
---
## 🔧 Usage Examples
### Fine-tuning Data Loader
```python
from data.dataset import BGEM3Dataset, BGERerankerDataset, BGEDataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Triplets for BGE-M3
triplets_dataset = BGEM3Dataset(
data_path="triplets_data.jsonl",
tokenizer=tokenizer,
format_type="triplets"
)
# Pairs for BGE-reranker
pairs_dataset = BGERerankerDataset(
data_path="pairs_data.jsonl",
tokenizer=tokenizer,
format_type="pairs"
)
# Candidates with threshold conversion
candidates_dataset = BGEDataset(
data_path="candidates_data.jsonl",
tokenizer=tokenizer,
format_type="candidates",
candidates_threshold=0.6 # Custom threshold
)
# Legacy with automatic conversion
legacy_dataset = BGERerankerDataset(
data_path="legacy_data.jsonl",
tokenizer=tokenizer,
legacy_support=True
)
```
### Optimization Pipeline
```bash
# BGE-M3 optimization (triplets → cached triplets)
python -m data.optimization bge-m3 triplets_data.jsonl \
--output_dir ./cache/data \
--hard_negative_ratio 0.8 \
--difficulty_threshold 0.2
# BGE-reranker optimization (pairs → cached pairs)
python -m data.optimization bge-reranker pairs_data.jsonl \
--output_dir ./cache/data \
--output_format flat
# Candidates optimization (candidates → cached pairs)
python -m data.optimization bge-reranker candidates_data.jsonl \
--output_dir ./cache/data
# Legacy optimization (nested → cached pairs)
python -m data.optimization bge-reranker legacy_data.jsonl \
--output_dir ./cache/data \
--output_format nested
```
---
## ⚡ Performance Features
### Enhanced Processing
- **10-15x faster loading** through pre-tokenization
- **Automatic format detection** based on key presence
- **Score-weighted sampling** for better training quality
- **Hard negative mining** with difficulty thresholds
### Conversion Utilities
- **Candidates → Pairs**: Threshold-based label assignment
- **Legacy → Pairs**: Automatic nested-to-flat conversion
- **Pairs → Triplets**: Grouping by query for embedding training
- **Format validation**: Comprehensive error checking
### Backward Compatibility
- **Legacy support**: Full compatibility with existing datasets
- **Gradual migration**: Convert formats incrementally
- **Source tracking**: Know the original format of converted data
- **Flexible thresholds**: Configurable conversion parameters
---
## 🚀 Key Benefits
1. **Unified Pipeline**: Single codebase handles all formats
2. **Automatic Detection**: No manual format specification needed
3. **Seamless Conversion**: Transparent format transformations
4. **Performance Optimized**: Significant speedup through caching
5. **Backward Compatible**: Works with existing datasets
6. **Flexible Configuration**: Customizable thresholds and parameters

159
docs/directory_structure.md Normal file
View File

@ -0,0 +1,159 @@
# BGE Fine-tuning Directory Structure
This document clarifies the directory structure and cache usage in the BGE fine-tuning project.
## 📂 Project Directory Structure
```
bge_finetune/
├── 📁 cache/ # Cache directories
│ ├── 📁 data/ # Preprocessed dataset cache
│ │ ├── 📄 *.jsonl # Preprocessed JSONL files
│ │ └── 📄 *.json # Preprocessed JSON files
│ └── 📁 models/ # Downloaded model cache (HuggingFace/ModelScope)
│ ├── 📁 tokenizers/ # Cached tokenizer files
│ ├── 📁 config/ # Cached model config files
│ └── 📁 downloads/ # Temporary download files
├── 📁 models/ # Local model storage
│ ├── 📁 bge-m3/ # Local BGE-M3 model files
│ └── 📁 bge-reranker-base/ # Local BGE-reranker model files
├── 📁 data/ # Raw training data
│ ├── 📄 train.jsonl # Input training data
│ ├── 📄 eval.jsonl # Input evaluation data
│ └── 📁 processed/ # Preprocessed (but not cached) data
├── 📁 output/ # Training outputs
│ ├── 📁 bge-m3/ # M3 training results
│ ├── 📁 bge-reranker/ # Reranker training results
│ └── 📁 checkpoints/ # Model checkpoints
└── 📁 logs/ # Training logs
├── 📁 tensorboard/ # TensorBoard logs
└── 📄 training.log # Text logs
```
## 🎯 Directory Purposes
### 1. **`./cache/models/`** - Model Downloads Cache
- **Purpose**: HuggingFace/ModelScope model download cache
- **Contents**: Tokenizers, config files, temporary downloads
- **Usage**: Automatic caching of remote model downloads
- **Config**: `model_paths.cache_dir = "./cache/models"`
### 2. **`./cache/data/`** - Processed Dataset Cache
- **Purpose**: Preprocessed and cleaned dataset files
- **Contents**: Validated JSONL/JSON files
- **Usage**: Store cleaned and validated training data
- **Config**: `data.cache_dir = "./cache/data"`
### 3. **`./models/`** - Local Model Storage
- **Purpose**: Complete local model files (when available)
- **Contents**: Full model directories with all files
- **Usage**: Direct loading without downloads
- **Config**: `model_paths.bge_m3 = "./models/bge-m3"`
### 4. **`./data/`** - Raw Input Data
- **Purpose**: Original training/evaluation datasets
- **Contents**: `.jsonl` files in various formats
- **Usage**: Input to preprocessing and optimization
- **Config**: User-specified paths in training scripts
### 5. **`./output/`** - Training Results
- **Purpose**: Fine-tuned models and training artifacts
- **Contents**: Saved models, checkpoints, metrics
- **Usage**: Results of training scripts
- **Config**: `--output_dir` in training scripts
## ⚙️ Configuration Mapping
### Config File Settings
```toml
[model_paths]
cache_dir = "./cache/models" # Model download cache
bge_m3 = "./models/bge-m3" # Local BGE-M3 path
bge_reranker = "./models/bge-reranker-base" # Local reranker path
[data]
cache_dir = "./cache/data" # Dataset cache
optimization_output_dir = "./cache/data" # Optimization default
```
### Script Defaults
```bash
# Data optimization (processed datasets)
python -m data.optimization bge-m3 data/train.jsonl --output_dir ./cache/data
# Model training (output models)
python scripts/train_m3.py --output_dir ./output/bge-m3 --cache_dir ./cache/models
# Model evaluation (results)
python scripts/evaluate.py --output_dir ./evaluation_results
```
## 🔄 Data Flow
```
Raw Data → Preprocessing → Optimization → Training → Output
data/ → - → cache/data/ → models → output/
↓ ↓ ↓
Input Pre-tokenized Model Fine-tuned
JSONL → .pkl files → Loading → Models
```
## 🛠️ Usage Examples
### Model Cache (Download Cache)
```python
# Automatic model caching
tokenizer = AutoTokenizer.from_pretrained(
"BAAI/bge-m3",
cache_dir="./cache/models" # Downloads cached here
)
```
### Data Cache (Optimization Cache)
```bash
# Create optimized datasets
python -m data.optimization bge-m3 data/train.jsonl \
--output_dir ./cache/data # .pkl files created here
# Use cached datasets in training
python scripts/train_m3.py \
--train_data ./cache/data/train_m3_cached.pkl
```
### Local Models (No Downloads)
```python
# Use local model (no downloads)
model = BGEM3Model.from_pretrained("./models/bge-m3")
```
## 📋 Best Practices
1. **Separate Concerns**: Keep model cache, data cache, and outputs separate
2. **Clear Naming**: Use descriptive suffixes (`.pkl` for cache, `_cached` for optimized)
3. **Consistent Paths**: Use config.toml for centralized path management
4. **Cache Management**:
- Model cache: Can be shared across projects
- Data cache: Project-specific optimized datasets
- Output: Training results and checkpoints
## 🧹 Cleanup Commands
```bash
# Clean data cache (re-optimization needed)
rm -rf ./cache/data/
# Clean model cache (re-download needed)
rm -rf ./cache/models/
# Clean training outputs
rm -rf ./output/
# Clean all caches
rm -rf ./cache/
```
This structure ensures clear separation of concerns and efficient caching for both model downloads and processed datasets.

298
docs/usage_guide.md Normal file
View File

@ -0,0 +1,298 @@
# BGE Fine-tuning Usage Guide
## Overview
This guide provides detailed instructions on fine-tuning BGE (BAAI General Embedding) models using our enhanced training framework. The framework supports both BGE-M3 (embedding) and BGE-reranker (cross-encoder) models with state-of-the-art training techniques.
## Table of Contents
1. [Installation](#installation)
2. [Quick Start](#quick-start)
3. [Data Preparation](#data-preparation)
4. [Training Scripts](#training-scripts)
5. [Advanced Usage](#advanced-usage)
6. [Troubleshooting](#troubleshooting)
## Installation
```bash
# Clone the repository
git clone https://github.com/yourusername/bge-finetune.git
cd bge-finetune
# Install dependencies
pip install -r requirements.txt
# Optional: Install Ascend NPU support
# pip install torch-npu # If using Huawei Ascend NPUs
```
## Quick Start
### 1. Prepare Your Data
The framework supports multiple data formats:
- **JSONL** (recommended): One JSON object per line
- **CSV/TSV**: Tabular data with headers
- **JSON**: Single JSON file with array of samples
Example JSONL format for embedding model:
```json
{"query": "What is machine learning?", "pos": ["ML is a subset of AI..."], "neg": ["The weather today..."]}
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning uses..."]}
```
### 2. Fine-tune BGE-M3 (Embedding Model)
```bash
python scripts/train_m3.py \
--model_name_or_path BAAI/bge-m3 \
--train_data data/train.jsonl \
--output_dir ./output/bge-m3-finetuned \
--num_train_epochs 3 \
--per_device_train_batch_size 16 \
--learning_rate 1e-5
```
### 3. Fine-tune BGE-Reranker
```bash
python scripts/train_reranker.py \
--model_name_or_path BAAI/bge-reranker-base \
--train_data data/train_pairs.jsonl \
--output_dir ./output/bge-reranker-finetuned \
--num_train_epochs 3 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5
```
## Data Preparation
### Data Formats
#### 1. Embedding Model (Triplets Format)
```json
{
"query": "query text",
"pos": ["positive passage 1", "positive passage 2"],
"neg": ["negative passage 1", "negative passage 2"],
"pos_scores": [1.0, 0.9], // Optional: relevance scores
"neg_scores": [0.1, 0.2] // Optional: relevance scores
}
```
#### 2. Reranker Model (Pairs Format)
```json
{
"query": "query text",
"passage": "passage text",
"label": 1, // 1 for relevant, 0 for not relevant
"score": 0.95 // Optional: relevance score
}
```
### Performance Optimization
For small datasets (< 100k samples), enable in-memory caching:
```python
# In your custom script
dataset = BGEM3Dataset(
data_path="train.jsonl",
tokenizer=tokenizer,
cache_in_memory=True, # Pre-tokenize and cache in memory
max_cache_size=100000 # Maximum samples to cache
)
```
## Training Scripts
### BGE-M3 Training Options
```bash
python scripts/train_m3.py \
--model_name_or_path BAAI/bge-m3 \
--train_data data/train.jsonl \
--eval_data data/eval.jsonl \
--output_dir ./output/bge-m3-finetuned \
--num_train_epochs 3 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--learning_rate 1e-5 \
--warmup_ratio 0.1 \
--gradient_accumulation_steps 2 \
--fp16 \ # Enable mixed precision training
--train_group_size 8 \ # Number of passages per query
--use_hard_negatives \ # Enable hard negative mining
--temperature 0.02 \ # Contrastive loss temperature
--save_steps 500 \
--eval_steps 500 \
--logging_steps 100
```
### BGE-Reranker Training Options
```bash
python scripts/train_reranker.py \
--model_name_or_path BAAI/bge-reranker-base \
--train_data data/train_pairs.jsonl \
--eval_data data/eval_pairs.jsonl \
--output_dir ./output/bge-reranker-finetuned \
--num_train_epochs 3 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--max_length 512 \
--train_group_size 16 \ # Pairs per query
--gradient_checkpointing \ # Save memory
--logging_steps 50
```
### Joint Training (RocketQAv2 Approach)
Train retriever and reranker together:
```bash
python scripts/train_joint.py \
--retriever_model BAAI/bge-m3 \
--reranker_model BAAI/bge-reranker-base \
--train_data data/train.jsonl \
--output_dir ./output/joint-training \
--num_train_epochs 3 \
--alternate_steps 100 # Switch between models every N steps
```
## Advanced Usage
### Configuration Files
Use TOML configuration for complex setups:
```toml
# config.toml
[training]
default_batch_size = 16
default_num_epochs = 3
gradient_accumulation_steps = 2
[m3]
model_name_or_path = "BAAI/bge-m3"
train_group_size = 8
temperature = 0.02
[hardware]
device_type = "cuda" # or "npu" for Ascend
```
Run with config:
```bash
python scripts/train_m3.py --config_path config.toml --train_data data/train.jsonl
```
### Multi-GPU Training
```bash
# Single node, multiple GPUs
torchrun --nproc_per_node=4 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-m3-multigpu
# Multiple nodes
torchrun --nnodes=2 --nproc_per_node=4 \
--rdzv_id=100 --rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
scripts/train_m3.py --train_data data/train.jsonl
```
### Custom Data Preprocessing
```python
from data.preprocessing import DataPreprocessor
# Preprocess custom format data
preprocessor = DataPreprocessor(
tokenizer=tokenizer,
max_query_length=64,
max_passage_length=512
)
# Convert and clean data
output_path, val_path = preprocessor.preprocess_file(
input_path="raw_data.csv",
output_path="processed_data.jsonl",
file_format="csv",
validation_split=0.1
)
```
### Model Evaluation
```bash
python scripts/evaluate.py \
--retriever_model ./output/bge-m3-finetuned \
--reranker_model ./output/bge-reranker-finetuned \
--eval_data data/test.jsonl \
--output_dir ./evaluation_results \
--metrics ndcg mrr recall \
--top_k 10 20 50
```
## Troubleshooting
### Common Issues
1. **Out of Memory**
- Reduce batch size
- Enable gradient checkpointing: `--gradient_checkpointing`
- Use fp16 training: `--fp16`
- Enable in-memory caching for small datasets
2. **Slow Training**
- Increase number of data loader workers: `--dataloader_num_workers 8`
- Enable in-memory caching for datasets < 100k samples
- Use SSD for data storage
3. **Poor Performance**
- Check data quality and format
- Adjust learning rate and warmup
- Use hard negative mining: `--use_hard_negatives`
- Increase training epochs
### Debugging Tips
```bash
# Enable debug logging
export LOG_LEVEL=DEBUG
python scripts/train_m3.py ...
# Profile training
python scripts/benchmark.py \
--model_type retriever \
--model_path ./output/bge-m3-finetuned \
--data_path data/test.jsonl
```
## Best Practices
1. **Data Quality**
- Ensure balanced positive/negative samples
- Use relevance scores when available
- Clean and normalize text data
2. **Training Strategy**
- Start with small learning rate (1e-5 to 2e-5)
- Use warmup (10% of steps)
- Monitor evaluation metrics
3. **Resource Optimization**
- Use gradient accumulation for large batches
- Enable mixed precision training
- Consider in-memory caching for small datasets
## Additional Resources
- [Data Format Examples](data_formats.md)
- [Model Architecture Details](../models/README.md)
- [Evaluation Metrics Guide](../evaluation/README.md)
- [API Reference](api_reference.md)

432
docs/validation_guide.md Normal file
View File

@ -0,0 +1,432 @@
# BGE Fine-tuning Validation Guide
This guide provides comprehensive instructions for validating your fine-tuned BGE models to ensure they actually perform better than the baseline models.
## 🎯 Overview
After fine-tuning BGE models, it's crucial to validate that your models have actually improved. This guide covers multiple validation approaches:
1. **Quick Validation** - Fast sanity checks
2. **Comprehensive Validation** - Detailed performance analysis
3. **Comparison Benchmarks** - Head-to-head baseline comparisons
4. **Pipeline Validation** - End-to-end retrieval + reranking tests
5. **Statistical Analysis** - Significance testing and detailed metrics
## 🚀 Quick Start
### Option 1: Run Full Validation Suite (Recommended)
The easiest way to validate your models is using the validation suite:
```bash
# Auto-discover trained models and run complete validation
python scripts/run_validation_suite.py --auto-discover
# Or specify models explicitly
python scripts/run_validation_suite.py \
--retriever_model ./output/bge-m3-enhanced/final_model \
--reranker_model ./output/bge-reranker/final_model
```
This runs all validation tests and generates a comprehensive report.
### Option 2: Quick Validation Only
For a fast check if your models are working correctly:
```bash
python scripts/quick_validation.py \
--retriever_model ./output/bge-m3-enhanced/final_model \
--reranker_model ./output/bge-reranker/final_model \
--test_pipeline
```
## 📊 Validation Tools
### 1. Quick Validation (`quick_validation.py`)
**Purpose**: Fast sanity checks to verify models are working and show basic improvements.
**Features**:
- Tests model loading and basic functionality
- Uses predefined test queries in multiple languages
- Compares relevance scoring between baseline and fine-tuned models
- Tests pipeline performance (retrieval + reranking)
**Usage**:
```bash
# Test both models
python scripts/quick_validation.py \
--retriever_model ./output/bge-m3-enhanced/final_model \
--reranker_model ./output/bge-reranker/final_model \
--test_pipeline
# Test only retriever
python scripts/quick_validation.py \
--retriever_model ./output/bge-m3-enhanced/final_model
# Save results to file
python scripts/quick_validation.py \
--retriever_model ./output/bge-m3-enhanced/final_model \
--output_file ./results.json
```
**Output**: Console summary + optional JSON results file
### 2. Comprehensive Validation (`comprehensive_validation.py`)
**Purpose**: Detailed validation across multiple datasets with statistical analysis.
**Features**:
- Tests on all available datasets automatically
- Computes comprehensive metrics (Recall@K, Precision@K, MAP, MRR, NDCG)
- Statistical significance testing
- Performance degradation detection
- Generates HTML reports with visualizations
- Cross-dataset performance analysis
**Usage**:
```bash
# Comprehensive validation with auto-discovered datasets
python scripts/comprehensive_validation.py \
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
--reranker_finetuned ./output/bge-reranker/final_model
# Use specific datasets
python scripts/comprehensive_validation.py \
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
--test_datasets data/datasets/examples/embedding_data.jsonl \
data/datasets/Reranker_AFQMC/dev.json \
--output_dir ./validation_results
# Custom evaluation settings
python scripts/comprehensive_validation.py \
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
--batch_size 16 \
--k_values 1 3 5 10 20 \
--retrieval_top_k 100 \
--rerank_top_k 10
```
**Output**:
- JSON results file
- HTML report with detailed analysis
- Performance comparison tables
- Statistical significance analysis
### 3. Model Comparison Scripts
**Purpose**: Head-to-head performance comparisons with baseline models.
#### Retriever Comparison (`compare_retriever.py`)
```bash
python scripts/compare_retriever.py \
--finetuned_model_path ./output/bge-m3-enhanced/final_model \
--baseline_model_path BAAI/bge-m3 \
--data_path data/datasets/examples/embedding_data.jsonl \
--batch_size 16 \
--max_samples 1000 \
--output ./retriever_comparison.txt
```
#### Reranker Comparison (`compare_reranker.py`)
```bash
python scripts/compare_reranker.py \
--finetuned_model_path ./output/bge-reranker/final_model \
--baseline_model_path BAAI/bge-reranker-base \
--data_path data/datasets/examples/reranker_data.jsonl \
--batch_size 16 \
--max_samples 1000 \
--output ./reranker_comparison.txt
```
**Output**: Tab-separated comparison tables with metrics and performance deltas.
### 4. Validation Suite Runner (`run_validation_suite.py`)
**Purpose**: Orchestrates all validation tests in a systematic way.
**Features**:
- Auto-discovers trained models
- Runs multiple validation approaches
- Generates comprehensive reports
- Provides final verdict on model quality
**Usage**:
```bash
# Full auto-discovery and validation
python scripts/run_validation_suite.py --auto-discover
# Specify models and validation types
python scripts/run_validation_suite.py \
--retriever_model ./output/bge-m3-enhanced/final_model \
--reranker_model ./output/bge-reranker/final_model \
--validation_types quick comprehensive comparison
# Custom settings
python scripts/run_validation_suite.py \
--retriever_model ./output/bge-m3-enhanced/final_model \
--batch_size 32 \
--max_samples 500 \
--output_dir ./my_validation_results
```
**Output**:
- Multiple result files from each validation type
- HTML summary report
- Final verdict and recommendations
### 5. Validation Utilities (`validation_utils.py`)
**Purpose**: Helper utilities for validation preparation and analysis.
**Features**:
- Discover trained models in workspace
- Analyze available test datasets
- Create sample test data
- Analyze validation results
**Usage**:
```bash
# Discover trained models
python scripts/validation_utils.py --discover-models
# Analyze available test datasets
python scripts/validation_utils.py --analyze-datasets
# Create sample test data
python scripts/validation_utils.py --create-sample-data --output-dir ./test_data
# Analyze validation results
python scripts/validation_utils.py --analyze-results ./validation_results
```
## 📁 Understanding Results
### Result Files
After running validation, you'll get several result files:
1. **`validation_suite_report.html`** - Main HTML report with visual summary
2. **`validation_results.json`** - Complete detailed results in JSON format
3. **`quick_validation_results.json`** - Quick validation results
4. **`*_comparison_*.txt`** - Model comparison tables
5. **Comprehensive validation folder** with detailed metrics
### Interpreting Results
#### Overall Verdict
- **🌟 EXCELLENT** - Significant improvements across most metrics
- **✅ GOOD** - Clear improvements with minor degradations
- **👌 FAIR** - Mixed results, some improvements
- **⚠️ PARTIAL** - Some validation tests failed
- **❌ POOR** - More degradations than improvements
#### Key Metrics to Watch
**For Retrievers**:
- **Recall@K** - How many relevant documents are in top-K results
- **MAP (Mean Average Precision)** - Overall ranking quality
- **MRR (Mean Reciprocal Rank)** - Position of first relevant result
- **NDCG@K** - Normalized ranking quality with graded relevance
**For Rerankers**:
- **Accuracy** - Correct classification of relevant/irrelevant pairs
- **MRR@K** - Mean reciprocal rank in top-K results
- **NDCG@K** - Ranking quality with position discounting
**For Pipeline**:
- **End-to-End Recall@K** - Overall system performance
- **Pipeline Accuracy** - Correct top results after retrieval + reranking
#### Statistical Significance
The comprehensive validation includes statistical tests to determine if improvements are significant:
- **Improvement %** - Percentage improvement over baseline
- **Statistical Significance** - Whether improvements are statistically meaningful
- **Effect Size** - Magnitude of the improvement
## 🔧 Troubleshooting
### Common Issues
#### 1. Models Not Found
```
❌ Retriever model not found: ./output/bge-m3-enhanced/final_model
```
**Solution**:
- Check that training completed successfully
- Verify the model path exists
- Use `--discover-models` to find available models
#### 2. No Test Datasets Found
```
❌ No test datasets found!
```
**Solution**:
- Place test data in `data/datasets/` directory
- Use `--create-sample-data` to generate sample datasets
- Specify datasets explicitly with `--test_datasets`
#### 3. Validation Failures
```
❌ Comprehensive validation failed
```
**Solution**:
- Check model compatibility (tokenizer issues)
- Reduce batch size if running out of memory
- Verify dataset format is correct
- Check logs for detailed error messages
#### 4. Poor Performance Results
```
⚠️ More degradations than improvements
```
**Analysis Steps**:
1. Check training data quality and format
2. Verify hyperparameters (learning rate might be too high)
3. Ensure sufficient training data and epochs
4. Check for data leakage or format mismatches
5. Consider different base models
### Validation Best Practices
#### 1. Before Running Validation
- ✅ Ensure training completed without errors
- ✅ Verify model files exist and are complete
- ✅ Check that test datasets are in correct format
- ✅ Have baseline models available for comparison
#### 2. During Validation
- 📊 Start with quick validation for fast feedback
- 🔬 Run comprehensive validation on representative datasets
- 📈 Use multiple metrics to get complete picture
- ⏱️ Be patient - comprehensive validation takes time
#### 3. Analyzing Results
- 🎯 Focus on metrics most relevant to your use case
- 📊 Look at trends across multiple datasets
- ⚖️ Consider statistical significance, not just raw improvements
- 🔍 Investigate datasets where performance degraded
## 🎯 Validation Checklist
Use this checklist to ensure thorough validation:
### Pre-Validation Setup
- [ ] Models trained and saved successfully
- [ ] Test datasets available and formatted correctly
- [ ] Baseline models accessible (BAAI/bge-m3, BAAI/bge-reranker-base)
- [ ] Sufficient compute resources for validation
### Quick Validation (5-10 minutes)
- [ ] Run quick validation script
- [ ] Verify models load correctly
- [ ] Check basic performance improvements
- [ ] Test pipeline functionality (if both models available)
### Comprehensive Validation (30-60 minutes)
- [ ] Run comprehensive validation script
- [ ] Test on multiple datasets
- [ ] Analyze statistical significance
- [ ] Review HTML report for detailed insights
### Comparison Benchmarks (15-30 minutes)
- [ ] Run retriever comparison (if applicable)
- [ ] Run reranker comparison (if applicable)
- [ ] Analyze performance deltas
- [ ] Check throughput and memory usage
### Final Analysis
- [ ] Review overall verdict and recommendations
- [ ] Identify best and worst performing datasets
- [ ] Check for any critical issues or degradations
- [ ] Document findings and next steps
## 🚀 Advanced Usage
### Custom Test Datasets
Create your own test datasets for domain-specific validation:
```jsonl
// For retriever testing (embedding_format)
{"query": "Your query", "pos": ["relevant doc 1", "relevant doc 2"], "neg": ["irrelevant doc 1"]}
// For reranker testing (pairs format)
{"query": "Your query", "passage": "Document text", "label": 1}
{"query": "Your query", "passage": "Irrelevant text", "label": 0}
```
### Cross-Domain Validation
Test model generalization across different domains:
```bash
python scripts/comprehensive_validation.py \
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
--test_datasets data/medical_qa.jsonl data/legal_docs.jsonl data/tech_support.jsonl
```
### Performance Profiling
Monitor resource usage during validation:
```bash
python scripts/comprehensive_validation.py \
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
--batch_size 8 \ # Reduce if memory issues
--max_samples 100 # Limit samples for speed
```
### A/B Testing Multiple Models
Compare different fine-tuning approaches:
```bash
# Test Model A
python scripts/comprehensive_validation.py \
--retriever_finetuned ./output/model_A/final_model \
--output_dir ./validation_A
# Test Model B
python scripts/comprehensive_validation.py \
--retriever_finetuned ./output/model_B/final_model \
--output_dir ./validation_B
# Compare results
python scripts/validation_utils.py --analyze-results ./validation_A
python scripts/validation_utils.py --analyze-results ./validation_B
```
## 📚 Additional Resources
- **BGE Model Documentation**: [BAAI/bge](https://github.com/FlagOpen/FlagEmbedding)
- **Evaluation Metrics Guide**: See `evaluation/metrics.py` for detailed metric implementations
- **Dataset Format Guide**: See `data/datasets/examples/format_usage_examples.py`
- **Training Guide**: See main README for training instructions
## 💡 Tips for Better Validation
1. **Use Multiple Datasets**: Don't rely on a single test set
2. **Check Statistical Significance**: Small improvements might not be meaningful
3. **Monitor Resource Usage**: Ensure models are efficient enough for production
4. **Validate Domain Generalization**: Test on data similar to production use case
5. **Document Results**: Keep validation reports for future reference
6. **Iterate Based on Results**: Use validation insights to improve training
## 🤝 Getting Help
If you encounter issues with validation:
1. Check the troubleshooting section above
2. Review the logs for detailed error messages
3. Use `--discover-models` and `--analyze-datasets` to debug setup issues
4. Start with quick validation to isolate problems
5. Check model and dataset formats are correct
Remember: Good validation is key to successful fine-tuning. Take time to thoroughly test your models before deploying them!

20
evaluation/__init__.py Normal file
View File

@ -0,0 +1,20 @@
"""
Evaluation modules and metrics
"""
from .evaluator import RetrievalEvaluator, InteractiveEvaluator
from .metrics import (
compute_retrieval_metrics,
compute_reranker_metrics,
evaluate_cross_lingual_retrieval,
MetricsTracker
)
__all__ = [
'RetrievalEvaluator',
'InteractiveEvaluator',
'compute_retrieval_metrics',
'compute_reranker_metrics',
'evaluate_cross_lingual_retrieval',
'MetricsTracker'
]

511
evaluation/evaluator.py Normal file
View File

@ -0,0 +1,511 @@
import torch
import numpy as np
from typing import List, Dict, Optional, Tuple, Union
from tqdm import tqdm
import logging
from torch.utils.data import DataLoader
import json
import os
from collections import defaultdict
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from .metrics import (
compute_retrieval_metrics,
compute_reranker_metrics,
evaluate_cross_lingual_retrieval,
MetricsTracker
)
logger = logging.getLogger(__name__)
class RetrievalEvaluator:
"""
Comprehensive evaluator for retrieval models
Notes:
- All parameters (model, tokenizer, device, metrics, k_values, etc.) should be passed from config-driven scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model: Union[BGEM3Model, BGERerankerModel],
tokenizer,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
metrics: List[str] = ['recall', 'precision', 'map', 'mrr', 'ndcg'],
k_values: List[int] = [1, 5, 10, 20, 100]
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.metrics = metrics
self.k_values = k_values
# Move model to device
self.model.to(device)
self.model.eval()
# Metrics tracker
self.tracker = MetricsTracker()
def evaluate(
self,
dataloader: DataLoader,
desc: str = "Evaluating",
return_embeddings: bool = False
) -> Dict[str, Union[float, list, np.ndarray]]:
"""
Evaluate model on dataset
Args:
dataloader: DataLoader with evaluation data
desc: Description for progress bar
return_embeddings: Whether to return computed embeddings
Returns:
Dictionary of evaluation metrics
"""
self.model.eval()
all_query_embeddings = []
all_passage_embeddings = []
all_labels = []
all_queries = []
all_passages = []
with torch.no_grad():
for batch in tqdm(dataloader, desc=desc):
# Move batch to device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Get embeddings based on model type
if isinstance(self.model, BGEM3Model):
# Bi-encoder evaluation
query_embeddings, passage_embeddings = self._evaluate_biencoder_batch(batch)
all_query_embeddings.append(query_embeddings.cpu())
all_passage_embeddings.append(passage_embeddings.cpu())
else:
# Cross-encoder evaluation
scores = self._evaluate_crossencoder_batch(batch)
# For cross-encoder, we store scores instead of embeddings
all_query_embeddings.append(scores.cpu())
# Collect labels
if 'labels' in batch:
all_labels.append(batch['labels'].cpu())
# Collect text if available
if 'query' in batch:
all_queries.extend(batch['query'])
if 'passages' in batch:
all_passages.extend(batch['passages'])
# Concatenate results
if isinstance(self.model, BGEM3Model):
all_query_embeddings = torch.cat(all_query_embeddings, dim=0).numpy()
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
# Compute retrieval metrics
if all_labels is not None:
metrics = compute_retrieval_metrics(
all_query_embeddings,
all_passage_embeddings,
all_labels,
k_values=self.k_values
)
else:
metrics = {}
else:
# Cross-encoder metrics
all_scores = torch.cat(all_query_embeddings, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
if all_labels is not None:
# Determine if we have groups (listwise) or pairs
if hasattr(dataloader.dataset, 'train_group_size'):
groups = [getattr(dataloader.dataset, 'train_group_size')] * len(dataloader.dataset) # type: ignore[arg-type]
else:
groups = None
metrics = compute_reranker_metrics(
all_scores,
all_labels,
groups=groups,
k_values=self.k_values
)
else:
metrics = {}
# Update tracker
self.tracker.update(metrics, step=0)
# Return embeddings if requested
if return_embeddings and isinstance(self.model, BGEM3Model):
metrics['query_embeddings'] = all_query_embeddings # type: ignore[assignment]
metrics['passage_embeddings'] = all_passage_embeddings # type: ignore[assignment]
metrics['queries'] = all_queries # type: ignore[assignment]
metrics['passages'] = all_passages # type: ignore[assignment]
return metrics # type: ignore[return-value]
def _evaluate_biencoder_batch(self, batch: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
"""Evaluate a batch for bi-encoder model"""
# Get query embeddings
query_outputs = self.model(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
query_embeddings = query_outputs['dense']
# Get passage embeddings
# Handle multiple passages per query
if batch['passage_input_ids'].dim() == 3:
# Flatten batch dimension
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
passage_input_ids = batch['passage_input_ids'].view(-1, seq_len)
passage_attention_mask = batch['passage_attention_mask'].view(-1, seq_len)
else:
passage_input_ids = batch['passage_input_ids']
passage_attention_mask = batch['passage_attention_mask']
passage_outputs = self.model(
input_ids=passage_input_ids,
attention_mask=passage_attention_mask,
return_dense=True
)
passage_embeddings = passage_outputs['dense']
return query_embeddings, passage_embeddings
def _evaluate_crossencoder_batch(self, batch: Dict) -> torch.Tensor:
"""Evaluate a batch for cross-encoder model"""
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
return outputs['logits']
def evaluate_retrieval_pipeline(
self,
queries: List[str],
corpus: List[str],
relevant_docs: Dict[str, List[int]],
batch_size: int = 32
) -> Dict[str, float]:
"""
Evaluate full retrieval pipeline
Args:
queries: List of queries
corpus: List of all documents
relevant_docs: Dict mapping query to relevant doc indices
batch_size: Batch size for encoding
Returns:
Evaluation metrics
"""
logger.info(f"Evaluating retrieval pipeline with {len(queries)} queries over {len(corpus)} documents")
# Encode all documents
logger.info("Encoding corpus...")
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
# Encode queries and evaluate
all_metrics = defaultdict(list)
for i in tqdm(range(0, len(queries), batch_size), desc="Evaluating queries"):
batch_queries = queries[i:i + batch_size]
batch_relevant = [relevant_docs.get(q, []) for q in batch_queries]
# Encode queries
query_embeddings = self._encode_texts(batch_queries, batch_size, show_progress=False)
# Create labels matrix
labels = np.zeros((len(batch_queries), len(corpus)))
for j, relevant_indices in enumerate(batch_relevant):
for idx in relevant_indices:
if idx < len(corpus):
labels[j, idx] = 1
# Compute metrics
batch_metrics = compute_retrieval_metrics(
query_embeddings,
corpus_embeddings,
labels,
k_values=self.k_values
)
# Accumulate metrics
for metric, value in batch_metrics.items():
if metric != 'scores':
all_metrics[metric].append(value)
# Average metrics
final_metrics = {
metric: np.mean(values) for metric, values in all_metrics.items()
}
return final_metrics # type: ignore[return-value]
def _encode_texts(
self,
texts: List[str],
batch_size: int,
desc: str = "Encoding",
show_progress: bool = True
) -> np.ndarray:
"""Encode texts to embeddings"""
if isinstance(self.model, BGEM3Model):
# Use model's encode method
embeddings = self.model.encode(
texts,
batch_size=batch_size,
device=self.device
)
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
else:
# For cross-encoder, this doesn't make sense
raise ValueError("Cannot encode texts with cross-encoder model")
return embeddings # type: ignore[return-value]
def evaluate_cross_lingual(
self,
queries: Dict[str, List[str]],
corpus: List[str],
relevant_docs: Dict[str, Dict[str, List[int]]],
batch_size: int = 32
) -> Dict[str, Dict[str, float]]:
"""
Evaluate cross-lingual retrieval
Args:
queries: Dict mapping language to queries
corpus: List of documents (assumed to be in source language)
relevant_docs: Nested dict of language -> query -> relevant indices
batch_size: Batch size
Returns:
Metrics per language
"""
# Encode corpus once
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
# Evaluate each language
results = {}
for lang, lang_queries in queries.items():
logger.info(f"Evaluating {lang} queries...")
# Encode queries
query_embeddings = self._encode_texts(lang_queries, batch_size, desc=f"Encoding {lang} queries")
# Create labels
labels = np.zeros((len(lang_queries), len(corpus)))
for i, query in enumerate(lang_queries):
if query in relevant_docs.get(lang, {}):
for idx in relevant_docs[lang][query]:
if idx < len(corpus):
labels[i, idx] = 1
# Compute metrics
metrics = compute_retrieval_metrics(
query_embeddings,
corpus_embeddings,
labels,
k_values=self.k_values
)
results[lang] = metrics
# Compute average across languages
if len(results) > 1:
avg_metrics = {}
for metric in results[list(results.keys())[0]]:
if metric != 'scores':
values = [results[lang][metric] for lang in results]
avg_metrics[metric] = np.mean(values)
results['average'] = avg_metrics
return results
def save_evaluation_results(
self,
metrics: Dict[str, float],
output_dir: str,
prefix: str = "eval"
):
"""Save evaluation results"""
os.makedirs(output_dir, exist_ok=True)
# Save metrics
metrics_file = os.path.join(output_dir, f"{prefix}_metrics.json")
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=2)
# Save summary
summary_file = os.path.join(output_dir, f"{prefix}_summary.txt")
with open(summary_file, 'w') as f:
f.write(f"Evaluation Results - {prefix}\n")
f.write("=" * 50 + "\n\n")
for metric, value in sorted(metrics.items()):
if isinstance(value, float):
f.write(f"{metric:.<30} {value:.4f}\n")
logger.info(f"Saved evaluation results to {output_dir}")
class InteractiveEvaluator:
"""
Interactive evaluation for qualitative analysis
"""
def __init__(
self,
retriever: BGEM3Model,
reranker: Optional[BGERerankerModel],
tokenizer,
corpus: List[str],
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
):
self.retriever = retriever
self.reranker = reranker
self.tokenizer = tokenizer
self.corpus = corpus
self.device = device
# Pre-encode corpus
logger.info("Pre-encoding corpus...")
self.corpus_embeddings = self._encode_corpus()
logger.info(f"Encoded {len(self.corpus)} documents")
def _encode_corpus(self) -> np.ndarray:
"""Pre-encode all corpus documents"""
embeddings = self.retriever.encode(
self.corpus,
batch_size=32,
device=self.device
)
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
return embeddings # type: ignore[return-value]
def search(
self,
query: str,
top_k: int = 10,
rerank: bool = True
) -> List[Tuple[int, float, str]]:
"""
Search for query in corpus
Args:
query: Search query
top_k: Number of results to return
rerank: Whether to use reranker
Returns:
List of (index, score, text) tuples
"""
# Encode query
query_embedding = self.retriever.encode(
[query],
device=self.device
)
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.cpu().numpy()
# Compute similarities
similarities = np.dot(query_embedding, self.corpus_embeddings.T).squeeze() # type: ignore[arg-type]
# Get top-k
top_indices = np.argsort(similarities)[::-1][:top_k * 2 if rerank else top_k]
results = []
for idx in top_indices:
results.append((int(idx), float(similarities[idx]), self.corpus[idx]))
# Rerank if requested
if rerank and self.reranker:
reranked_indices = self.reranker.rerank(
query,
[self.corpus[idx] for idx in top_indices],
return_scores=False
)
# Reorder results
reranked_results = []
for i, rerank_idx in enumerate(reranked_indices[:top_k]):
original_idx = top_indices[rerank_idx] # type: ignore[index]
reranked_results.append((
int(original_idx),
float(i + 1), # Rank as score
self.corpus[original_idx]
))
results = reranked_results
return results[:top_k]
def evaluate_query_interactively(
self,
query: str,
relevant_indices: List[int],
top_k: int = 10
) -> Dict[str, float]:
"""
Evaluate a single query interactively
Args:
query: Query to evaluate
relevant_indices: Ground truth relevant document indices
top_k: Number of results to consider
Returns:
Metrics for this query
"""
# Get search results
results = self.search(query, top_k=top_k, rerank=True)
retrieved_indices = [r[0] for r in results]
# Calculate metrics
metrics = {}
# Precision@k
relevant_retrieved = sum(1 for idx in retrieved_indices if idx in relevant_indices)
metrics[f'precision@{top_k}'] = relevant_retrieved / top_k
# Recall@k
if relevant_indices:
metrics[f'recall@{top_k}'] = relevant_retrieved / len(relevant_indices)
else:
metrics[f'recall@{top_k}'] = 0.0
# MRR
for i, idx in enumerate(retrieved_indices):
if idx in relevant_indices:
metrics['mrr'] = 1.0 / (i + 1)
break
else:
metrics['mrr'] = 0.0
# Print results for inspection
print(f"\nQuery: {query}")
print(f"Relevant docs: {relevant_indices}")
print(f"\nTop {top_k} results:")
for i, (idx, score, text) in enumerate(results):
relevance = "" if idx in relevant_indices else ""
print(f"{i + 1}. [{relevance}] (idx={idx}, score={score:.3f}) {text[:100]}...")
print(f"\nMetrics: {metrics}")
return metrics

542
evaluation/metrics.py Normal file
View File

@ -0,0 +1,542 @@
import numpy as np
from typing import List, Dict, Tuple, Optional, Union
import torch
from sklearn.metrics import average_precision_score, ndcg_score
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
def compute_retrieval_metrics(
query_embeddings: np.ndarray,
passage_embeddings: np.ndarray,
labels: np.ndarray,
k_values: List[int] = [1, 5, 10, 20, 100],
return_all_scores: bool = False
) -> Dict[str, float]:
"""
Compute comprehensive retrieval metrics
Args:
query_embeddings: Query embeddings [num_queries, embedding_dim]
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
k_values: List of k values for metrics@k
return_all_scores: Whether to return all similarity scores
Returns:
Dictionary of metrics
"""
metrics = {}
# Compute similarity scores
scores = np.matmul(query_embeddings, passage_embeddings.T)
# Get number of queries
num_queries = query_embeddings.shape[0]
# Convert labels to binary matrix if needed
if labels.ndim == 1:
# Assume labels contains relevant passage indices
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
for i, idx in enumerate(labels):
if idx < passage_embeddings.shape[0]:
binary_labels[i, idx] = 1
labels = binary_labels
# Compute metrics for each k
for k in k_values:
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
# Overall metrics
metrics['map'] = compute_mean_average_precision(scores, labels)
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
if return_all_scores:
metrics['scores'] = scores
return metrics
def compute_retrieval_metrics_comprehensive(
query_embeddings: np.ndarray,
passage_embeddings: np.ndarray,
labels: np.ndarray,
k_values: List[int] = [1, 5, 10, 20, 100],
return_all_scores: bool = False,
compute_additional_metrics: bool = True
) -> Dict[str, float]:
"""
Compute comprehensive retrieval metrics with additional statistics
Args:
query_embeddings: Query embeddings [num_queries, embedding_dim]
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
k_values: List of k values for metrics@k
return_all_scores: Whether to return all similarity scores
compute_additional_metrics: Whether to compute additional metrics like success rate
Returns:
Dictionary of comprehensive metrics
"""
metrics = {}
# Validate inputs
if query_embeddings.shape[0] == 0 or passage_embeddings.shape[0] == 0:
raise ValueError("Empty embeddings provided")
# Normalize embeddings for cosine similarity
query_embeddings = query_embeddings / (np.linalg.norm(query_embeddings, axis=1, keepdims=True) + 1e-8)
passage_embeddings = passage_embeddings / (np.linalg.norm(passage_embeddings, axis=1, keepdims=True) + 1e-8)
# Compute similarity scores
scores = np.matmul(query_embeddings, passage_embeddings.T)
# Get number of queries
num_queries = query_embeddings.shape[0]
# Convert labels to binary matrix if needed
if labels.ndim == 1:
# Assume labels contains relevant passage indices
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
for i, idx in enumerate(labels):
if idx < passage_embeddings.shape[0]:
binary_labels[i, idx] = 1
labels = binary_labels
# Compute metrics for each k
for k in k_values:
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
if compute_additional_metrics:
# Success rate (queries with at least one relevant doc in top-k)
success_count = 0
for i in range(num_queries):
top_k_indices = np.argsort(scores[i])[::-1][:k]
if np.any(labels[i][top_k_indices] > 0):
success_count += 1
metrics[f'success@{k}'] = success_count / num_queries if num_queries > 0 else 0.0
# Overall metrics
metrics['map'] = compute_mean_average_precision(scores, labels)
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
if compute_additional_metrics:
# Average number of relevant documents per query
metrics['avg_relevant_per_query'] = np.mean(np.sum(labels > 0, axis=1))
# Coverage: fraction of relevant documents that appear in any top-k list
max_k = max(k_values)
covered_relevant = set()
for i in range(num_queries):
top_k_indices = np.argsort(scores[i])[::-1][:max_k]
relevant_indices = np.where(labels[i] > 0)[0]
for idx in relevant_indices:
if idx in top_k_indices:
covered_relevant.add((i, idx))
total_relevant = np.sum(labels > 0)
metrics['coverage'] = len(covered_relevant) / total_relevant if total_relevant > 0 else 0.0
if return_all_scores:
metrics['scores'] = scores
return metrics
def compute_recall_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Recall@K
Args:
scores: Similarity scores [num_queries, num_passages]
labels: Binary relevance labels
k: Top-k to consider
Returns:
Average recall@k
"""
num_queries = scores.shape[0]
recall_scores = []
for i in range(num_queries):
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Count relevant documents
num_relevant = labels[i].sum()
if num_relevant == 0:
continue
# Count relevant in top-k
num_relevant_in_topk = labels[i][top_k_indices].sum()
recall_scores.append(num_relevant_in_topk / num_relevant)
return float(np.mean(recall_scores)) if recall_scores else 0.0
def compute_precision_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Precision@K
"""
num_queries = scores.shape[0]
precision_scores = []
for i in range(num_queries):
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Count relevant in top-k
num_relevant_in_topk = labels[i][top_k_indices].sum()
precision_scores.append(num_relevant_in_topk / k)
return float(np.mean(precision_scores))
def compute_map_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Mean Average Precision@K
"""
num_queries = scores.shape[0]
ap_scores = []
for i in range(num_queries):
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Compute average precision
num_relevant = 0
sum_precision = 0
for j, idx in enumerate(top_k_indices):
if labels[i][idx] == 1:
num_relevant += 1
sum_precision += num_relevant / (j + 1)
if labels[i].sum() > 0:
ap_scores.append(sum_precision / min(labels[i].sum(), k))
return float(np.mean(ap_scores)) if ap_scores else 0.0
def compute_mrr_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Mean Reciprocal Rank@K with proper error handling
Args:
scores: Similarity scores [num_queries, num_passages]
labels: Binary relevance labels
k: Top-k to consider
Returns:
MRR@k score
"""
if scores.shape != labels.shape:
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
num_queries = scores.shape[0]
rr_scores = []
for i in range(num_queries):
# Skip queries with no relevant documents
if labels[i].sum() == 0:
continue
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Find first relevant document
for j, idx in enumerate(top_k_indices):
if labels[i][idx] > 0: # Support graded relevance
rr_scores.append(1.0 / (j + 1))
break
else:
# No relevant document in top-k
rr_scores.append(0.0)
return float(np.mean(rr_scores)) if rr_scores else 0.0
def compute_ndcg_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Normalized Discounted Cumulative Gain@K with robust error handling
Args:
scores: Similarity scores [num_queries, num_passages]
labels: Binary or graded relevance labels
k: Top-k to consider
Returns:
NDCG@k score
"""
if scores.shape != labels.shape:
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
num_queries = scores.shape[0]
ndcg_scores = []
for i in range(num_queries):
# Skip queries with no relevant documents
if labels[i].sum() == 0:
continue
# Get scores and labels for this query
query_scores = scores[i]
query_labels = labels[i]
# Get top-k indices by score
top_k_indices = np.argsort(query_scores)[::-1][:k]
# Compute DCG@k
dcg = 0.0
for j, idx in enumerate(top_k_indices):
if query_labels[idx] > 0:
# Use log2(j+2) as denominator (j+1 for rank, +1 to avoid log(1)=0)
dcg += query_labels[idx] / np.log2(j + 2)
# Compute ideal DCG@k
ideal_labels = np.sort(query_labels)[::-1][:k]
idcg = 0.0
for j, label in enumerate(ideal_labels):
if label > 0:
idcg += label / np.log2(j + 2)
# Compute NDCG
if idcg > 0:
ndcg_scores.append(dcg / idcg)
else:
ndcg_scores.append(0.0)
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
def compute_mean_average_precision(scores: np.ndarray, labels: np.ndarray) -> float:
"""
Compute Mean Average Precision (MAP) over all positions
"""
num_queries = scores.shape[0]
ap_scores = []
for i in range(num_queries):
if labels[i].sum() > 0:
ap = average_precision_score(labels[i], scores[i])
ap_scores.append(ap)
return float(np.mean(ap_scores)) if ap_scores else 0.0
def compute_mean_reciprocal_rank(scores: np.ndarray, labels: np.ndarray) -> float:
"""
Compute Mean Reciprocal Rank (MRR) over all positions
"""
num_queries = scores.shape[0]
rr_scores = []
for i in range(num_queries):
# Get ranking
ranked_indices = np.argsort(scores[i])[::-1]
# Find first relevant document
for j, idx in enumerate(ranked_indices):
if labels[i][idx] == 1:
rr_scores.append(1.0 / (j + 1))
break
else:
rr_scores.append(0.0)
return float(np.mean(rr_scores))
def compute_reranker_metrics(
scores: Union[np.ndarray, torch.Tensor],
labels: Union[np.ndarray, torch.Tensor],
groups: Optional[List[int]] = None,
k_values: List[int] = [1, 5, 10]
) -> Dict[str, float]:
"""
Compute metrics specifically for reranker evaluation
Args:
scores: Relevance scores from reranker
labels: Binary relevance labels
groups: Group sizes for each query (for listwise evaluation)
k_values: List of k values
Returns:
Dictionary of reranker-specific metrics
"""
# Convert to numpy if needed
if isinstance(scores, torch.Tensor):
scores = scores.cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.cpu().numpy()
metrics = {}
if groups is None:
# Assume pairwise evaluation
# Binary classification metrics
predictions = (scores > 0.5).astype(int) # type: ignore[attr-defined]
# Accuracy
metrics['accuracy'] = np.mean(predictions == labels)
# Compute AUC if we have both positive and negative examples
if len(np.unique(labels)) > 1:
from sklearn.metrics import roc_auc_score
metrics['auc'] = roc_auc_score(labels, scores)
else:
# Listwise evaluation
start_idx = 0
mrr_scores = []
ndcg_scores = defaultdict(list)
for group_size in groups:
# Get scores and labels for this group
group_scores = scores[start_idx:start_idx + group_size]
group_labels = labels[start_idx:start_idx + group_size]
# Sort by scores
sorted_indices = np.argsort(group_scores)[::-1]
# MRR
for i, idx in enumerate(sorted_indices):
if group_labels[idx] == 1:
mrr_scores.append(1.0 / (i + 1))
break
else:
mrr_scores.append(0.0)
# NDCG@k
for k in k_values:
if group_size >= k:
try:
ndcg = ndcg_score(
group_labels.reshape(1, -1),
group_scores.reshape(1, -1),
k=k
)
ndcg_scores[k].append(ndcg)
except:
pass
start_idx += group_size
# Aggregate metrics
metrics['mrr'] = np.mean(mrr_scores)
for k in k_values:
if ndcg_scores[k]:
metrics[f'ndcg@{k}'] = np.mean(ndcg_scores[k])
return metrics
def evaluate_cross_lingual_retrieval(
query_embeddings: Dict[str, np.ndarray],
passage_embeddings: np.ndarray,
labels: Dict[str, np.ndarray],
languages: List[str] = ['en', 'zh']
) -> Dict[str, Dict[str, float]]:
"""
Evaluate cross-lingual retrieval performance
Args:
query_embeddings: Dict mapping language to query embeddings
passage_embeddings: Passage embeddings
labels: Dict mapping language to relevance labels
languages: List of languages to evaluate
Returns:
Nested dict of metrics per language
"""
results = {}
for lang in languages:
if lang in query_embeddings and lang in labels:
logger.info(f"Evaluating {lang} queries")
results[lang] = compute_retrieval_metrics(
query_embeddings[lang],
passage_embeddings,
labels[lang]
)
# Compute cross-lingual metrics if we have multiple languages
if len(languages) > 1 and all(lang in results for lang in languages):
# Average performance across languages
avg_metrics = {}
for metric in results[languages[0]].keys():
if metric != 'scores':
values = [results[lang][metric] for lang in languages]
avg_metrics[metric] = np.mean(values)
results['average'] = avg_metrics
return results
class MetricsTracker:
"""
Track and aggregate metrics during training
"""
def __init__(self):
self.metrics = defaultdict(list)
self.best_metrics = {}
self.best_steps = {}
def update(self, metrics: Dict[str, float], step: int):
"""Update metrics for current step"""
for key, value in metrics.items():
self.metrics[key].append((step, value))
# Track best metrics
if key not in self.best_metrics or value > self.best_metrics[key]:
self.best_metrics[key] = value
self.best_steps[key] = step
def get_current(self, metric: str) -> Optional[float]:
"""Get most recent value for a metric"""
if metric in self.metrics and self.metrics[metric]:
return self.metrics[metric][-1][1]
return None
def get_best(self, metric: str) -> Optional[Tuple[int, float]]:
"""Get best value and step for a metric"""
if metric in self.best_metrics:
return self.best_steps[metric], self.best_metrics[metric]
return None
def get_history(self, metric: str) -> List[Tuple[int, float]]:
"""Get full history for a metric"""
return self.metrics.get(metric, [])
def summary(self) -> Dict[str, Dict[str, float]]:
"""Get summary of all metrics"""
summary = {
'current': {},
'best': {},
'best_steps': {}
}
for metric in self.metrics:
summary['current'][metric] = self.get_current(metric)
summary['best'][metric] = self.best_metrics.get(metric)
summary['best_steps'][metric] = self.best_steps.get(metric)
return summary

25
models/__init__.py Normal file
View File

@ -0,0 +1,25 @@
"""
Model implementations for BGE fine-tuning
"""
from .bge_m3 import BGEM3Model
from .bge_reranker import BGERerankerModel
from .losses import (
InfoNCELoss,
M3KnowledgeDistillationLoss,
ListwiseCrossEntropyLoss,
PairwiseMarginLoss,
DynamicListwiseDistillationLoss,
MultiTaskLoss
)
__all__ = [
'BGEM3Model',
'BGERerankerModel',
'InfoNCELoss',
'M3KnowledgeDistillationLoss',
'ListwiseCrossEntropyLoss',
'PairwiseMarginLoss',
'DynamicListwiseDistillationLoss',
'MultiTaskLoss'
]

768
models/bge_m3.py Normal file
View File

@ -0,0 +1,768 @@
"""
BGE-M3 model wrapper for multi-functionality embedding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import logging
from utils.config_loader import get_config
import os
logger = logging.getLogger(__name__)
def load_model_and_tokenizer(model_name_or_path, cache_dir=None):
"""
Load model and tokenizer with comprehensive error handling and fallback strategies
Args:
model_name_or_path: Model identifier or local path
cache_dir: Cache directory for downloaded models
Returns:
Tuple of (model, tokenizer, config)
Raises:
RuntimeError: If model loading fails after all attempts
"""
config = get_config()
model_source = config.get('model_paths', 'source', 'huggingface')
# Track loading attempts for debugging
attempts = []
# Try primary loading method
try:
if model_source == 'huggingface':
return _load_from_huggingface(model_name_or_path, cache_dir)
elif model_source == 'modelscope':
return _load_from_modelscope(model_name_or_path, cache_dir)
else:
raise ValueError(f"Unknown model source: {model_source}")
except Exception as e:
attempts.append(f"{model_source}: {str(e)}")
logger.warning(f"Primary loading from {model_source} failed: {e}")
# Fallback strategies
fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope']
for fallback in fallback_sources:
try:
logger.info(f"Attempting fallback loading from {fallback}")
if fallback == 'huggingface':
return _load_from_huggingface(model_name_or_path, cache_dir)
elif fallback == 'modelscope':
return _load_from_modelscope(model_name_or_path, cache_dir)
except Exception as e:
attempts.append(f"{fallback}: {str(e)}")
logger.warning(f"Fallback loading from {fallback} failed: {e}")
# Try loading from local path if it exists
if os.path.exists(model_name_or_path):
try:
logger.info(f"Attempting to load from local path: {model_name_or_path}")
return _load_from_local(model_name_or_path, cache_dir)
except Exception as e:
attempts.append(f"local: {str(e)}")
logger.warning(f"Local loading failed: {e}")
# All attempts failed
error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts)
logger.error(error_msg)
raise RuntimeError(error_msg)
def _load_from_huggingface(model_name_or_path, cache_dir=None):
"""Load model from HuggingFace with error handling"""
try:
from transformers import AutoModel, AutoTokenizer, AutoConfig
# Load config first to validate model
model_config = AutoConfig.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
# Load model
model = AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}")
return model, tokenizer, model_config
except ImportError as e:
raise ImportError(f"Transformers library not available: {e}")
except Exception as e:
raise RuntimeError(f"HuggingFace loading error: {e}")
def _load_from_modelscope(model_name_or_path, cache_dir=None):
"""Load model from ModelScope with error handling"""
try:
from modelscope.models import Model as MSModel
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
# Download model if needed
local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
# Load model
model = MSModel.from_pretrained(local_path)
# Try to load tokenizer
tokenizer = None
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(local_path)
except Exception as e:
logger.warning(f"Could not load tokenizer from ModelScope: {e}")
# Try to get config
config = None
config_path = os.path.join(local_path, 'config.json')
if os.path.exists(config_path):
import json
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}")
return model, tokenizer, config
except ImportError as e:
raise ImportError(f"ModelScope library not available: {e}")
except Exception as e:
raise RuntimeError(f"ModelScope loading error: {e}")
def _load_from_local(model_path, cache_dir=None):
"""Load model from local directory with error handling"""
try:
from transformers import AutoModel, AutoTokenizer, AutoConfig
if not os.path.isdir(model_path):
raise ValueError(f"Local path is not a directory: {model_path}")
# Check for required files
required_files = ['config.json']
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing_files:
raise FileNotFoundError(f"Missing required files: {missing_files}")
# Load components
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
logger.info(f"Successfully loaded model from local path: {model_path}")
return model, tokenizer, model_config
except Exception as e:
raise RuntimeError(f"Local loading error: {e}")
class BGEM3Model(nn.Module):
"""
BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval
Notes:
- All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model_name_or_path: str = "BAAI/bge-m3",
use_dense: bool = True,
use_sparse: bool = True,
use_colbert: bool = True,
pooling_method: str = "cls",
normalize_embeddings: bool = True,
sentence_pooling_method: str = "cls",
colbert_dim: int = -1,
sparse_top_k: int = 100,
cache_dir: Optional[str] = None,
device: Optional[str] = None,
gradient_checkpointing: bool = False
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.use_dense = use_dense
self.use_sparse = use_sparse
self.use_colbert = use_colbert
self.pooling_method = pooling_method
self.normalize_embeddings = normalize_embeddings
self.colbert_dim = colbert_dim
self.sparse_top_k = sparse_top_k
self.gradient_checkpointing = gradient_checkpointing
# Initialize model with comprehensive error handling
try:
logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...")
self.model, self.tokenizer, self.config = load_model_and_tokenizer(
model_name_or_path,
cache_dir=cache_dir
)
logger.info("BGE-M3 model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BGE-M3 model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
# Enable gradient checkpointing if requested
if self.gradient_checkpointing:
self.enable_gradient_checkpointing()
# Ensure we have a valid config
if self.config is None:
logger.warning("Model config is None, creating a basic config")
# Create a basic config object
class BasicConfig:
def __init__(self):
self.hidden_size = 768 # Default BERT hidden size
self.config = BasicConfig()
# Initialize heads for different functionalities
hidden_size = getattr(self.config, 'hidden_size', 768)
if self.use_dense:
# Dense retrieval doesn't need additional layers (uses CLS token)
self.dense_linear = nn.Identity()
if self.use_sparse:
# Sparse retrieval head (SPLADE-style)
self.sparse_linear = nn.Linear(hidden_size, 1)
self.sparse_activation = nn.ReLU()
if self.use_colbert:
# ColBERT projection layer
colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size
self.colbert_linear = nn.Linear(hidden_size, colbert_dim)
# Set device with enhanced detection
if device:
self.device = torch.device(device)
else:
self.device = self._detect_device()
try:
self.to(self.device)
logger.info(f"BGE-M3 model moved to device: {self.device}")
except Exception as e:
logger.warning(f"Failed to move model to {self.device}: {e}, using CPU")
self.device = torch.device('cpu')
self.to(self.device)
# Cache for tokenizer
self._tokenizer = None
def enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency"""
if hasattr(self.model, 'gradient_checkpointing_enable'):
self.model.gradient_checkpointing_enable()
logger.info("Gradient checkpointing enabled for BGE-M3")
else:
logger.warning("Model does not support gradient checkpointing")
def disable_gradient_checkpointing(self):
"""Disable gradient checkpointing"""
if hasattr(self.model, 'gradient_checkpointing_disable'):
self.model.gradient_checkpointing_disable()
logger.info("Gradient checkpointing disabled for BGE-M3")
def _detect_device(self) -> torch.device:
"""Enhanced device detection with proper error handling"""
# Check Ascend NPU first
try:
import torch_npu
if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined]
logger.info("Using Ascend NPU")
return torch.device("npu")
except ImportError:
pass
except Exception as e:
logger.debug(f"NPU detection failed: {e}")
pass
# Check CUDA
if torch.cuda.is_available():
logger.info("Using CUDA GPU")
return torch.device("cuda")
# Check MPS (Apple Silicon)
try:
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
logger.info("Using Apple MPS")
return torch.device("mps")
except Exception as e:
logger.debug(f"MPS detection failed: {e}")
pass
# Default to CPU
logger.info("Using CPU")
return torch.device("cpu")
def _get_tokenizer(self):
"""Lazy load tokenizer with error handling"""
if self._tokenizer is None:
try:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path,
trust_remote_code=True
)
logger.info("Tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
raise RuntimeError(f"Tokenizer loading failed: {e}")
return self._tokenizer
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
return_dict: bool = True
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads.
Args:
input_ids: Token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
return_dense: Return dense embeddings
return_sparse: Return sparse embeddings
return_colbert: Return ColBERT embeddings
return_dict: Return dictionary of outputs
Returns:
Embeddings or dictionary of embeddings
"""
# Validate inputs
if input_ids is None or attention_mask is None:
raise ValueError("input_ids and attention_mask are required")
if input_ids.shape != attention_mask.shape:
raise ValueError("input_ids and attention_mask must have the same shape")
# Determine what to return based on model configuration and parameters
return_dense = return_dense and self.use_dense
return_sparse = return_sparse and self.use_sparse
return_colbert = return_colbert and self.use_colbert
if not (return_dense or return_sparse or return_colbert):
logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.")
return_dense = True # Default to dense if nothing selected
try:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
except Exception as e:
logger.error(f"Model forward pass failed: {e}")
raise RuntimeError(f"Forward pass failed: {e}")
results = {}
# Dense embeddings (CLS or mean pooling)
if return_dense:
if self.pooling_method == "cls" or self.sentence_pooling_method == "cls":
dense_vecs = last_hidden_state[:, 0]
elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean":
expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape)
masked_embeddings = last_hidden_state * expanded_mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
dense_vecs = sum_embeddings / sum_mask
else:
raise ValueError(f"Unknown pooling method: {self.pooling_method}")
dense_vecs = self.dense_linear(dense_vecs)
# Ensure embeddings are float type
dense_vecs = dense_vecs.float()
if self.normalize_embeddings:
dense_vecs = F.normalize(dense_vecs, p=2, dim=-1)
results['dense'] = dense_vecs
# SPLADE+ sparse head
if return_sparse:
# SPLADE+ uses log1p(ReLU(.))
sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len]
sparse_activations = torch.log1p(F.relu(sparse_logits))
# Mask out padding tokens
sparse_activations = sparse_activations * attention_mask.float()
# Optionally normalize sparse activations
if self.normalize_embeddings:
sparse_activations = F.normalize(sparse_activations, p=2, dim=-1)
results['sparse'] = sparse_activations
# FLOPS regularization (for loss, not forward, but add as attribute)
self.splade_flops = torch.sum(sparse_activations)
# ColBERT head with late interaction (maxsim)
if return_colbert:
colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim]
# Mask out padding tokens
expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape)
colbert_vecs = colbert_vecs * expanded_mask
if self.normalize_embeddings:
colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1)
results['colbert'] = colbert_vecs
if return_dict:
return results
# If only one head requested, return that
if return_dense and len(results) == 1:
return results['dense']
if return_sparse and len(results) == 1:
return results['sparse']
if return_colbert and len(results) == 1:
return results['colbert']
return results
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
max_length: int = 512,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
show_progress_bar: bool = False,
normalize_embeddings: bool = True,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
device: Optional[str] = None
) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]:
"""
Complete encode implementation with comprehensive error handling
Args:
sentences: Input sentences
batch_size: Batch size for encoding
max_length: Maximum sequence length
convert_to_numpy: Convert to numpy arrays
convert_to_tensor: Keep as tensors
show_progress_bar: Show progress bar
normalize_embeddings: Normalize embeddings
return_dense/sparse/colbert: Which embeddings to return
device: Device to use
Returns:
Embeddings
"""
# Input validation
if isinstance(sentences, str):
sentences = [sentences]
if not sentences:
raise ValueError("Empty sentences list provided")
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if max_length <= 0:
raise ValueError("max_length must be positive")
# Get tokenizer
try:
tokenizer = self._get_tokenizer()
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer: {e}")
# Determine target device
target_device = device if device else self.device
# Initialize storage for embeddings
all_embeddings = {
'dense': [] if return_dense else None,
'sparse': [] if return_sparse else None,
'colbert': [] if return_colbert else None
}
# Set model to evaluation mode
self.eval()
try:
# Import tqdm if show_progress_bar is True
if show_progress_bar:
try:
from tqdm import tqdm
iterator = tqdm(
range(0, len(sentences), batch_size),
desc="Encoding",
total=(len(sentences) + batch_size - 1) // batch_size
)
except ImportError:
logger.warning("tqdm not available, disabling progress bar")
iterator = range(0, len(sentences), batch_size)
else:
iterator = range(0, len(sentences), batch_size)
# Process in batches
for i in iterator:
batch_sentences = sentences[i:i + batch_size]
# Tokenize with proper error handling
try:
encoded = tokenizer(
batch_sentences,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
)
except Exception as e:
raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}")
# Move to device
try:
encoded = {k: v.to(target_device) for k, v in encoded.items()}
except Exception as e:
raise RuntimeError(f"Failed to move batch to device {target_device}: {e}")
# Get embeddings
with torch.no_grad():
try:
outputs = self.forward(
**encoded,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert=return_colbert,
return_dict=True
)
except Exception as e:
raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}")
# Collect embeddings
for key in ['dense', 'sparse', 'colbert']:
if all_embeddings[key] is not None and key in outputs:
all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type]
# Concatenate and convert results
final_embeddings = {}
for key in ['dense', 'sparse', 'colbert']:
if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type]
try:
embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type]
if normalize_embeddings and key == 'dense':
embeddings = F.normalize(embeddings, p=2, dim=-1)
if convert_to_numpy:
embeddings = embeddings.numpy()
elif not convert_to_tensor:
embeddings = embeddings
final_embeddings[key] = embeddings
except Exception as e:
logger.error(f"Failed to process {key} embeddings: {e}")
raise
# Return format
if len(final_embeddings) == 1:
return list(final_embeddings.values())[0]
if not final_embeddings:
raise RuntimeError("No embeddings were generated")
return final_embeddings
except Exception as e:
logger.error(f"Encoding failed: {e}")
raise
def save_pretrained(self, save_path: str):
"""Save model to directory with error handling, including custom heads and config."""
try:
import os
import json
os.makedirs(save_path, exist_ok=True)
# Save the base model
self.model.save_pretrained(save_path)
# Save additional configuration (including [m3] section)
from utils.config_loader import get_config
config_dict = get_config().get_section('m3')
config = {
'use_dense': self.use_dense,
'use_sparse': self.use_sparse,
'use_colbert': self.use_colbert,
'pooling_method': self.pooling_method,
'normalize_embeddings': self.normalize_embeddings,
'colbert_dim': self.colbert_dim,
'sparse_top_k': self.sparse_top_k,
'model_name_or_path': self.model_name_or_path,
'm3_config': config_dict
}
with open(f"{save_path}/m3_config.json", 'w') as f:
json.dump(config, f, indent=2)
# Save custom heads if needed
if self.use_sparse:
torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt'))
if self.use_colbert:
torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt'))
# Save tokenizer if available
if self._tokenizer is not None:
self._tokenizer.save_pretrained(save_path)
logger.info(f"Model and config saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save model: {e}")
raise
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from directory, restoring custom heads and config."""
import os
import json
try:
# Load config
config_path = os.path.join(model_path, "m3_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Instantiate model
model = cls(
model_name_or_path=config.get('model_name_or_path', model_path),
use_dense=config.get('use_dense', True),
use_sparse=config.get('use_sparse', True),
use_colbert=config.get('use_colbert', True),
pooling_method=config.get('pooling_method', 'cls'),
normalize_embeddings=config.get('normalize_embeddings', True),
colbert_dim=config.get('colbert_dim', -1),
sparse_top_k=config.get('sparse_top_k', 100),
**kwargs
)
# Load base model weights
model.model = model.model.from_pretrained(model_path)
# Load custom heads if present
if model.use_sparse:
sparse_path = os.path.join(model_path, 'sparse_linear.pt')
if os.path.exists(sparse_path):
model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu'))
if model.use_colbert:
colbert_path = os.path.join(model_path, 'colbert_linear.pt')
if os.path.exists(colbert_path):
model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu'))
logger.info(f"Loaded BGE-M3 model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load model from {model_path}: {e}")
raise
def compute_similarity(
self,
queries: Union[torch.Tensor, Dict[str, torch.Tensor]],
passages: Union[torch.Tensor, Dict[str, torch.Tensor]],
mode: str = "dense"
) -> torch.Tensor:
"""
Compute similarity scores between queries and passages with error handling
Args:
queries: Query embeddings
passages: Passage embeddings
mode: Similarity mode ('dense', 'sparse', 'colbert')
Returns:
Similarity scores [num_queries, num_passages]
"""
try:
# Extract embeddings based on mode
if isinstance(queries, dict):
if mode not in queries:
raise ValueError(f"Mode '{mode}' not found in query embeddings")
query_emb = queries[mode]
else:
query_emb = queries
if isinstance(passages, dict):
if mode not in passages:
raise ValueError(f"Mode '{mode}' not found in passage embeddings")
passage_emb = passages[mode]
else:
passage_emb = passages
# Validate tensor shapes
if query_emb.dim() < 2 or passage_emb.dim() < 2:
raise ValueError("Embeddings must be at least 2-dimensional")
if mode == "dense":
# Simple dot product for dense
return torch.matmul(query_emb, passage_emb.t())
elif mode == "sparse":
# Dot product for sparse (simplified)
return torch.matmul(query_emb, passage_emb.t())
elif mode == "colbert":
# MaxSim for ColBERT
if query_emb.dim() != 3 or passage_emb.dim() != 3:
raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]")
scores = []
for q in query_emb:
# q: [query_len, dim]
# passage_emb: [num_passages, passage_len, dim]
passage_scores = []
for p in passage_emb:
# Compute similarity matrix
sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len]
# Max over passage tokens for each query token
max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len]
# Sum over query tokens
score = max_sim_scores.sum()
passage_scores.append(score)
scores.append(torch.stack(passage_scores))
return torch.stack(scores)
else:
raise ValueError(f"Unknown similarity mode: {mode}")
except Exception as e:
logger.error(f"Similarity computation failed: {e}")
raise
def get_sentence_embedding_dimension(self) -> int:
"""Get the dimension of sentence embeddings"""
return self.config.hidden_size # type: ignore[attr-defined]
def get_word_embedding_dimension(self) -> int:
"""Get the dimension of word embeddings"""
return self.config.hidden_size # type: ignore[attr-defined]

483
models/bge_reranker.py Normal file
View File

@ -0,0 +1,483 @@
"""
BGE-Reranker cross-encoder model wrapper
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import logging
from dataclasses import dataclass
from utils.config_loader import get_config
logger = logging.getLogger(__name__)
@dataclass
class RerankerOutput:
"""Output class for reranker predictions"""
scores: torch.Tensor
logits: torch.Tensor
embeddings: Optional[torch.Tensor] = None
def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1):
"""Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml"""
config = get_config()
# Try to get model source from config, with fallback to huggingface
try:
model_source = config.get('model_paths', 'source', 'huggingface')
except Exception:
# Fallback if config loading fails
logger.warning("Failed to load model source from config, defaulting to 'huggingface'")
model_source = 'huggingface'
logger.info(f"Loading reranker model from source: {model_source}")
if model_source == 'huggingface':
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
model_config = AutoConfig.from_pretrained(
model_name_or_path,
num_labels=num_labels,
cache_dir=cache_dir
)
try:
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir
)
use_base_model = False
except:
model = AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir
)
use_base_model = True
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir
)
return model, tokenizer, model_config, use_base_model
elif model_source == 'modelscope':
try:
from modelscope.models import Model as MSModel
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir)
tokenizer = None # ModelScope tokenizer loading can be added if needed
model_config = None
use_base_model = False
return model, tokenizer, model_config, use_base_model
except ImportError:
logger.error("modelscope not available. Install it first: pip install modelscope")
raise
except Exception as e:
logger.error(f"Failed to load reranker model from ModelScope: {e}")
raise
else:
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
raise RuntimeError(f"Unknown model source: {model_source}")
class BGERerankerModel(nn.Module):
"""
BGE-Reranker cross-encoder model for passage reranking
Notes:
- All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model_name_or_path: str = "BAAI/bge-reranker-base",
num_labels: int = 1,
cache_dir: Optional[str] = None,
use_fp16: bool = False,
gradient_checkpointing: bool = False,
device: Optional[str] = None
):
"""
Initialize BGE-Reranker model.
Args:
model_name_or_path: Path or identifier for pretrained model
num_labels: Number of output labels
cache_dir: Directory for model cache
use_fp16: Use mixed precision
gradient_checkpointing: Enable gradient checkpointing
device: Device to use
"""
super().__init__()
self.model_name_or_path = model_name_or_path
self.num_labels = num_labels
self.use_fp16 = use_fp16
# Load configuration and model
try:
logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...")
self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer(
model_name_or_path,
cache_dir=cache_dir,
num_labels=num_labels
)
logger.info("BGE-Reranker model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BGE-Reranker model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
# Add classification head
if self.use_base_model:
self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined]
self.dropout = nn.Dropout(
self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined]
# Enable gradient checkpointing
if gradient_checkpointing:
try:
self.model.gradient_checkpointing_enable()
except Exception:
logger.warning("Gradient checkpointing requested but not supported by this model.")
# Set device
if device:
try:
self.device = torch.device(device)
self.to(self.device)
except Exception as e:
logger.warning(f"Failed to move model to {device}: {e}, using CPU")
self.device = torch.device('cpu')
self.to(self.device)
else:
self.device = next(self.parameters()).device
# Mixed precision
if self.use_fp16:
try:
self.half()
except Exception as e:
logger.warning(f"Failed to set model to half precision: {e}")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True
) -> Union[torch.Tensor, RerankerOutput]:
"""
Forward pass for cross-encoder reranking
Args:
input_ids: Token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
labels: Labels for training (optional)
return_dict: Return RerankerOutput object
Returns:
Scores or RerankerOutput
"""
if self.use_base_model:
# Use base model + custom head
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
# Get pooled output (CLS token)
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
pooled_output = outputs.pooler_output
else:
# Manual pooling if pooler not available
last_hidden_state = outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0]
# Apply dropout and classifier
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
embeddings = pooled_output
else:
# Use sequence classification model
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=labels,
return_dict=True
)
logits = outputs.logits
embeddings = None
# Convert logits to scores (sigmoid for binary classification)
if self.num_labels == 1:
scores = torch.sigmoid(logits.squeeze(-1))
else:
scores = F.softmax(logits, dim=-1)
if not return_dict:
return scores
return RerankerOutput(
scores=scores,
logits=logits,
embeddings=embeddings
)
def predict(
self,
queries: Union[str, List[str]],
passages: Union[str, List[str], List[List[str]]],
batch_size: int = 32,
max_length: int = 512,
show_progress_bar: bool = False,
convert_to_numpy: bool = True,
normalize_scores: bool = False
) -> Union[np.ndarray, torch.Tensor, List[float]]:
"""
Predict reranking scores for query-passage pairs
Args:
queries: Query or list of queries
passages: Passage(s) or list of passages per query
batch_size: Batch size for inference
max_length: Maximum sequence length
show_progress_bar: Show progress bar
convert_to_numpy: Convert to numpy array
normalize_scores: Normalize scores to [0, 1]
Returns:
Reranking scores
"""
# Handle input formats
if isinstance(queries, str):
queries = [queries]
if isinstance(passages, str):
passages = [[passages]]
elif isinstance(passages[0], str):
passages = [passages] # type: ignore[assignment]
# Ensure queries and passages match
if len(queries) == 1 and len(passages) > 1:
queries = queries * len(passages)
# Create pairs
pairs = []
for query, passage_list in zip(queries, passages):
if isinstance(passage_list, str):
passage_list = [passage_list]
for passage in passage_list:
pairs.append((query, passage))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
# Score in batches
all_scores = []
if show_progress_bar:
from tqdm import tqdm
pbar = tqdm(total=len(pairs), desc="Scoring")
for i in range(0, len(pairs), batch_size):
batch_pairs = pairs[i:i + batch_size]
# Tokenize
batch_queries = [pair[0] for pair in batch_pairs]
batch_passages = [pair[1] for pair in batch_pairs]
encoded = tokenizer(
batch_queries,
batch_passages,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
)
# Move to device
encoded = {k: v.to(self.device) for k, v in encoded.items()}
# Get scores
with torch.no_grad():
outputs = self.forward(**encoded, return_dict=True)
scores = outputs.scores # type: ignore[attr-defined]
all_scores.append(scores)
if show_progress_bar:
pbar.update(len(batch_pairs))
if show_progress_bar:
pbar.close()
# Concatenate scores
all_scores = torch.cat(all_scores, dim=0)
# Normalize if requested
if normalize_scores:
all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8)
# Convert format
if convert_to_numpy:
all_scores = all_scores.cpu().numpy()
elif not isinstance(all_scores, torch.Tensor):
all_scores = all_scores.cpu()
# Reshape if single query
if len(queries) == 1 and len(passages) == 1:
return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined]
return all_scores
def rerank(
self,
query: str,
passages: List[str],
top_k: Optional[int] = None,
return_scores: bool = False,
batch_size: int = 32,
max_length: int = 512
) -> Union[List[str], Tuple[List[str], List[float]]]:
"""
Rerank passages for a query
Args:
query: Query string
passages: List of passages to rerank
top_k: Return only top-k passages
return_scores: Also return scores
batch_size: Batch size for scoring
max_length: Maximum sequence length
Returns:
Reranked passages (and optionally scores)
"""
# Get scores
scores = self.predict(
query,
passages,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=True
)
# Sort by scores (descending)
sorted_indices = np.argsort(scores)[::-1]
# Apply top_k if specified
if top_k is not None:
sorted_indices = sorted_indices[:top_k]
# Get reranked passages
reranked_passages = [passages[i] for i in sorted_indices]
if return_scores:
reranked_scores = [float(scores[i]) for i in sorted_indices]
return reranked_passages, reranked_scores
return reranked_passages
def save_pretrained(self, save_path: str):
"""Save model to directory, including custom heads and config."""
try:
import os
import json
os.makedirs(save_path, exist_ok=True)
# Save the base model
self.model.save_pretrained(save_path)
# Save additional configuration (including [reranker] section)
from utils.config_loader import get_config
config_dict = get_config().get_section('reranker')
config = {
'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False,
'num_labels': getattr(self, 'num_labels', 1),
'reranker_config': config_dict
}
with open(f"{save_path}/reranker_config.json", 'w') as f:
json.dump(config, f, indent=2)
# Save custom heads if present
if hasattr(self, 'classifier'):
torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt'))
if hasattr(self, 'dropout'):
torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt'))
# Save tokenizer if available
if hasattr(self, '_tokenizer') and self._tokenizer is not None:
self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
logger.info(f"Reranker model and config saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save reranker model: {e}")
raise
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from directory, restoring custom heads and config."""
import os
import json
try:
# Load config
config_path = os.path.join(model_path, "reranker_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Instantiate model
model = cls(
model_name_or_path=model_path,
num_labels=config.get('num_labels', 1),
use_fp16=config.get('use_fp16', False),
**kwargs
)
# Load base model weights
model.model = model.model.from_pretrained(model_path)
# Load custom heads if present
classifier_path = os.path.join(model_path, 'classifier.pt')
if hasattr(model, 'classifier') and os.path.exists(classifier_path):
model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu'))
dropout_path = os.path.join(model_path, 'dropout.pt')
if hasattr(model, 'dropout') and os.path.exists(dropout_path):
model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu'))
logger.info(f"Loaded BGE-Reranker model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load reranker model from {model_path}: {e}")
raise
def compute_loss(
self,
scores: torch.Tensor,
labels: torch.Tensor,
loss_type: str = "binary_cross_entropy"
) -> torch.Tensor:
"""
Compute loss for training
Args:
scores: Model scores
labels: Ground truth labels
loss_type: Type of loss function
Returns:
Loss value
"""
if loss_type == "binary_cross_entropy":
return F.binary_cross_entropy(scores, labels.float())
elif loss_type == "cross_entropy":
return F.cross_entropy(scores, labels.long())
elif loss_type == "mse":
return F.mse_loss(scores, labels.float())
else:
raise ValueError(f"Unknown loss type: {loss_type}")

507
models/losses.py Normal file
View File

@ -0,0 +1,507 @@
"""
Loss functions for BGE fine-tuning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import numpy as np
class InfoNCELoss(nn.Module):
"""
InfoNCE loss for contrastive learning
As used in BGE-M3 paper with temperature τ=0.02
Notes:
- All parameters (temperature, reduction, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 0.02, reduction: str = 'mean'):
super().__init__()
self.temperature = temperature
self.reduction = reduction
self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction)
def forward(
self,
query_embeddings: torch.Tensor,
positive_embeddings: torch.Tensor,
negative_embeddings: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute InfoNCE loss with proper normalization and numerical stability.
Args:
query_embeddings: [batch_size, embedding_dim]
positive_embeddings: [batch_size, embedding_dim]
negative_embeddings: [batch_size, num_negatives, embedding_dim] or None
"""
batch_size = query_embeddings.size(0)
device = query_embeddings.device
# Normalize embeddings (L2 normalization)
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1)
# Compute positive scores (query-positive similarity)
# Shape: [batch_size]
pos_scores = torch.sum(query_embeddings * positive_embeddings, dim=1)
pos_logits = pos_scores / self.temperature
# Compute negative scores if provided
if negative_embeddings is not None and negative_embeddings.numel() > 0:
# Normalize negative embeddings
negative_embeddings = F.normalize(negative_embeddings, p=2, dim=-1)
# Compute negative scores
# Shape: [batch_size, num_negatives]
neg_scores = torch.matmul(
query_embeddings.unsqueeze(1),
negative_embeddings.transpose(1, 2)
).squeeze(1)
neg_logits = neg_scores / self.temperature
# Concatenate positive and negative logits
# Shape: [batch_size, 1 + num_negatives]
logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
# Labels: positive is always at index 0
labels = torch.zeros(batch_size, device=device, dtype=torch.long)
else:
# Use in-batch negatives as fallback
# Compute similarity matrix for in-batch negatives
similarity_matrix = torch.matmul(query_embeddings, positive_embeddings.t())
logits = similarity_matrix / self.temperature
# Labels: positive pairs are on the diagonal
labels = torch.arange(batch_size, device=device, dtype=torch.long)
# Compute cross-entropy loss
loss = self.cross_entropy(logits, labels)
return loss
class M3KnowledgeDistillationLoss(nn.Module):
"""
Knowledge distillation loss for BGE-M3's multi-functionality training
Distills knowledge from dense, sparse, and multi-vector representations
Notes:
- All parameters (temperature, alpha, distill_types, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.5,
distill_types: List[str] = ['dense', 'sparse', 'colbert']
):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.distill_types = distill_types
self.kl_div = nn.KLDivLoss(reduction='batchmean')
def forward(
self,
student_outputs: Dict[str, torch.Tensor],
teacher_outputs: Dict[str, torch.Tensor],
labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
student_outputs: Dictionary with 'dense', 'sparse', 'colbert' outputs
teacher_outputs: Dictionary with teacher model outputs
labels: Optional labels for supervised loss
Returns:
total_loss: Combined distillation loss
loss_dict: Dictionary of individual losses
"""
total_loss = torch.tensor(0.0, device=next(iter(student_outputs.values())).device if student_outputs else 'cpu')
loss_dict = {}
# Distillation loss for each representation type
for repr_type in self.distill_types:
if repr_type in student_outputs and repr_type in teacher_outputs:
student_logits = student_outputs[repr_type]
teacher_logits = teacher_outputs[repr_type]
# Apply temperature scaling
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
# KL divergence loss
kl_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)
total_loss = total_loss + self.alpha * kl_loss
loss_dict[f'{repr_type}_distill_loss'] = kl_loss.detach().cpu().item()
# Add supervised loss if labels provided
if labels is not None and 'dense' in student_outputs:
ce_loss = F.cross_entropy(student_outputs['dense'], labels)
total_loss = total_loss + (1 - self.alpha) * ce_loss
loss_dict['supervised_loss'] = ce_loss.detach().cpu().item()
loss_dict['total_loss'] = total_loss.detach().cpu().item()
return total_loss, loss_dict
class ListwiseCrossEntropyLoss(nn.Module):
"""
Listwise cross-entropy loss for reranker training
Notes:
- All parameters (temperature, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 1.0):
super().__init__()
self.temperature = temperature
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
"""
Compute listwise cross-entropy loss with proper normalization.
Args:
scores: Relevance scores [total_passages]
labels: Binary relevance labels [total_passages]
groups: List of group sizes (passages per query)
"""
total_loss = torch.tensor(0.0, device=scores.device, dtype=scores.dtype)
offset = 0
for group_size in groups:
if group_size == 0:
continue
# Get scores and labels for this group
group_scores = scores[offset:offset + group_size]
group_labels = labels[offset:offset + group_size]
# Skip if no positive labels
if group_labels.sum() == 0:
offset += group_size
continue
# Apply temperature scaling
scaled_scores = group_scores / self.temperature
# Apply log_softmax for numerical stability
log_probs = F.log_softmax(scaled_scores, dim=0)
# Normalize labels to create a probability distribution
label_probs = group_labels.float() / group_labels.sum()
# Compute cross-entropy: -sum(p_true * log(p_pred))
loss = -(label_probs * log_probs).sum()
total_loss = total_loss + loss
offset += group_size
# Average over number of groups
num_valid_groups = sum(1 for g in groups if g > 0)
if num_valid_groups > 0:
return total_loss / num_valid_groups
else:
return total_loss
class PairwiseMarginLoss(nn.Module):
"""
Pairwise margin ranking loss for reranker
Notes:
- All parameters (margin, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, margin: float = 1.0):
super().__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
"""
Args:
scores: Relevance scores [total_passages]
labels: Binary relevance labels [total_passages]
groups: List of group sizes
"""
total_loss = torch.tensor(0.0, device=scores.device)
num_pairs = 0
offset = 0
for group_size in groups:
group_scores = scores[offset:offset + group_size]
group_labels = labels[offset:offset + group_size]
# Find positive and negative indices
pos_indices = torch.where(group_labels == 1)[0]
neg_indices = torch.where(group_labels == 0)[0]
if len(pos_indices) > 0 and len(neg_indices) > 0:
# Create all positive-negative pairs
for pos_idx in pos_indices:
for neg_idx in neg_indices:
pos_score = group_scores[pos_idx].unsqueeze(0)
neg_score = group_scores[neg_idx].unsqueeze(0)
target = torch.ones(1, device=scores.device)
loss = self.ranking_loss(pos_score, neg_score, target)
total_loss = total_loss + loss
num_pairs += 1
offset += group_size
return total_loss / max(num_pairs, 1)
class DynamicListwiseDistillationLoss(nn.Module):
"""
Dynamic listwise distillation loss for RocketQAv2-style joint training
Notes:
- All parameters (temperature, dynamic_weight, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 3.0, dynamic_weight: bool = True):
super().__init__()
self.temperature = temperature
self.dynamic_weight = dynamic_weight
self.kl_div = nn.KLDivLoss(reduction='none')
def forward(
self,
student_scores: torch.Tensor,
teacher_scores: torch.Tensor,
groups: List[int],
confidence_scores: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
student_scores: Student model scores [total_passages]
teacher_scores: Teacher model scores [total_passages]
groups: List of group sizes
confidence_scores: Optional confidence for dynamic weighting
"""
total_loss = torch.tensor(0.0, device=student_scores.device)
offset = 0
for i, group_size in enumerate(groups):
student_group = student_scores[offset:offset + group_size] / self.temperature
teacher_group = teacher_scores[offset:offset + group_size] / self.temperature
# Convert to distributions
student_probs = F.log_softmax(student_group, dim=0)
teacher_probs = F.softmax(teacher_group, dim=0)
# Compute KL divergence
kl_loss = self.kl_div(student_probs, teacher_probs).sum()
# Apply dynamic weighting based on confidence
if self.dynamic_weight and confidence_scores is not None:
weight = confidence_scores[i]
kl_loss = kl_loss * weight
total_loss = total_loss + kl_loss * (self.temperature ** 2)
offset += group_size
return total_loss / len(groups)
class MultiTaskLoss(nn.Module):
"""
Multi-task loss for combining multiple objectives
Notes:
- All parameters (loss_weights, adaptive_weights, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
loss_weights: Optional[Dict[str, float]] = None,
adaptive_weights: bool = False
):
super().__init__()
self.loss_weights = loss_weights or {
'dense': 1.0,
'sparse': 0.3,
'colbert': 0.5,
'distillation': 1.0
}
self.adaptive_weights = adaptive_weights
if adaptive_weights:
# Learnable loss weights (uncertainty weighting)
self.log_vars = nn.Parameter(torch.zeros(len(self.loss_weights)))
def forward(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
losses: Dictionary of individual losses
Returns:
total_loss: Combined loss
loss_dict: Dictionary of individual loss values
"""
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device if losses else 'cpu')
loss_dict = {}
if self.adaptive_weights:
# Uncertainty-based weighting
for i, (loss_name, loss_value) in enumerate(losses.items()):
if loss_name in self.loss_weights:
precision = torch.exp(-self.log_vars[i])
weighted_loss = precision * loss_value + self.log_vars[i]
total_loss = total_loss + weighted_loss
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
loss_dict[f'{loss_name}_weight'] = precision.detach().cpu().item() if isinstance(precision, torch.Tensor) else float(precision)
else:
# Fixed weighting
for loss_name, loss_value in losses.items():
if loss_name in self.loss_weights:
weight = self.loss_weights[loss_name]
weighted_loss = weight * loss_value
total_loss = total_loss + weighted_loss
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
loss_dict[f'{loss_name}_weight'] = weight
loss_dict['total_loss'] = total_loss.detach().cpu().item() if isinstance(total_loss, torch.Tensor) else float(total_loss)
return total_loss, loss_dict
class SparseLoss(nn.Module):
"""
SPLADE+ sparse loss with log1p(ReLU(.)) and FLOPS regularization.
For BGE-M3's token-position based sparse representations.
Args:
flops_loss_weight: Weight for FLOPS regularization
sparse_reg_weight: Weight for sparsity regularization
temperature: Temperature for similarity computation
"""
def __init__(
self,
flops_loss_weight: float = 1e-3,
sparse_reg_weight: float = 1e-4,
temperature: float = 0.02
):
super().__init__()
self.flops_loss_weight = flops_loss_weight
self.sparse_reg_weight = sparse_reg_weight
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss()
def forward(
self,
query_sparse: torch.Tensor,
doc_sparse: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
query_sparse: [batch, query_seq_len] - Query sparse embeddings (token-position based)
doc_sparse: [batch, doc_seq_len] - Document sparse embeddings (token-position based)
labels: [batch] - Labels for contrastive learning
Returns:
Tuple of (loss, loss_dict)
"""
batch_size = query_sparse.size(0)
device = query_sparse.device
# SPLADE+ uses log1p(ReLU(.))
query_splade = torch.log1p(F.relu(query_sparse))
doc_splade = torch.log1p(F.relu(doc_sparse))
# For BGE-M3's token-position sparse embeddings, compute similarity using aggregated scores
# Since we have different sequence lengths, we need to aggregate to comparable representations
# Sum pooling to get document-level sparse scores
query_scores = query_splade.sum(dim=1, keepdim=True) # [batch, 1]
doc_scores = doc_splade.sum(dim=1, keepdim=True) # [batch, 1]
# Compute similarity matrix for contrastive learning
# Use broadcasting to create [batch, batch] similarity matrix
similarity_matrix = query_scores * doc_scores.t() # [batch, batch]
# Alternative: use cosine similarity approach
# Normalize the aggregated scores
query_norm = F.normalize(query_scores, p=2, dim=1)
doc_norm = F.normalize(doc_scores, p=2, dim=1)
similarity_matrix = torch.matmul(query_norm, doc_norm.t())
# Scale by temperature for contrastive loss
logits = similarity_matrix / self.temperature
# Labels for in-batch contrastive learning (diagonal elements are positive pairs)
contrastive_labels = torch.arange(batch_size, device=device, dtype=torch.long)
# Contrastive loss
main_loss = self.cross_entropy(logits, contrastive_labels)
# FLOPS regularization (encourage sparsity)
flops = (query_splade.sum(dim=-1) + doc_splade.sum(dim=-1)).mean()
# L1 sparsity regularization
sparsity_loss = (query_splade.sum() + doc_splade.sum()) / (query_splade.numel() + doc_splade.numel())
# Combine losses
total_loss = main_loss + self.flops_loss_weight * flops + self.sparse_reg_weight * sparsity_loss
# Ensure all losses are tensors
if not isinstance(total_loss, torch.Tensor):
total_loss = torch.tensor(total_loss, device=device)
if not isinstance(main_loss, torch.Tensor):
main_loss = torch.tensor(main_loss, device=device)
if not isinstance(flops, torch.Tensor):
flops = torch.tensor(flops, device=device)
if not isinstance(sparsity_loss, torch.Tensor):
sparsity_loss = torch.tensor(sparsity_loss, device=device)
loss_dict = {
'main_loss': main_loss.detach().cpu().item(),
'flops': flops.detach().cpu().item(),
'sparsity_loss': sparsity_loss.detach().cpu().item(),
'total_loss': total_loss.detach().cpu().item()
}
return total_loss, loss_dict
class ColBERTLoss(nn.Module):
"""
ColBERT late interaction (maxsim) loss.
Args:
temperature: Softmax temperature
"""
def __init__(self, temperature: float = 0.02):
super().__init__()
self.temperature = temperature
def forward(
self,
query_embeds: torch.Tensor,
doc_embeds: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
Args:
query_embeds: [batch, seq_len, dim]
doc_embeds: [batch, seq_len, dim]
labels: [batch]
Returns:
Loss value
"""
# Late interaction: maxsim over doc tokens for each query token
# [batch, query_len, doc_len]
sim_matrix = torch.einsum('bqd,bkd->bqk', query_embeds, doc_embeds) / self.temperature
maxsim = sim_matrix.max(dim=2)[0] # [batch, query_len]
# Aggregate over query tokens (mean)
scores = maxsim.mean(dim=1) # [batch]
# Binary cross-entropy loss
loss = F.binary_cross_entropy_with_logits(scores, labels.float())
return loss

1627
readme.md Normal file

File diff suppressed because it is too large Load Diff

154
readme_new.md Normal file
View File

@ -0,0 +1,154 @@
# BGE Fine-tuning for Chinese Book Retrieval
Robust, production-ready tooling to fine-tune the BAAI **BGE-M3** embedding model and **BGE-Reranker** cross-encoder for high-performance retrieval of Chinese book content (with English query support).
---
## Table of Contents
1. Overview
2. Requirements & Compatibility
3. Core Features
4. Installation
5. Quick Start
6. Configuration System
7. Training & Evaluation
8. Distributed & Ascend NPU Support
9. Troubleshooting & FAQ
10. References
11. Roadmap
---
## 1. Overview
* Unified repository for dense, sparse (SPLADE-style) and multi-vector (ColBERT-style) retrieval.
* Joint training (RocketQAv2) and ANCE-style hard-negative mining included.
* Built-in evaluation, benchmarking and statistical significance testing.
---
## 2. Requirements & Compatibility
* **Python ≥ 3.10**
* **PyTorch ≥ 2.2.0** (nightly recommended for CUDA 12.8+)
* **CUDA ≥ 12.8** *(RTX 5090 / A100 / H100)* or **Torch-NPU** *(Ascend 910 / 910B)*
See `requirements.txt` for the full dependency list.
---
## 3. Core Features
* **InfoNCE (τ = 0.02) & Listwise CE** losses with self-distillation.
* **ANCE-style Hard-Negative Mining** with FAISS index refresh.
* **Atomic Checkpointing** & robust failure recovery.
* **Automatic Mixed Precision** (fp16 / bf16) & gradient checkpointing.
* **Comprehensive Metrics**: Recall@k, MRR@k, NDCG@k, MAP, Success-Rate.
* **Cross-hardware Support**: CUDA (NCCL), Ascend NPU (HCCL), CPU (Gloo).
---
## 4. Installation
```bash
# 1. Install PyTorch nightly (CUDA 12.8)
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
# 2. Install project dependencies
pip install -r requirements.txt
# 3. (Optional) Ascend NPU backend
pip install torch_npu # then edit config.toml [ascend]
```
Models auto-download on first use, or fetch manually:
```bash
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base
```
---
## 5. Quick Start
### Fine-tune BGE-M3 (Embedding)
```bash
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir output/bge-m3 \
--per_device_train_batch_size 8 \
--learning_rate 1e-5
```
### Fine-tune BGE-Reranker (Cross-encoder)
```bash
python scripts/train_reranker.py \
--train_data data/train.jsonl \
--output_dir output/bge-reranker \
--learning_rate 6e-5 \
--train_group_size 16
```
### Joint Training (RocketQAv2-style)
```bash
python scripts/train_joint.py \
--train_data data/train.jsonl \
--output_dir output/joint-training \
--bf16 --gradient_checkpointing
```
---
## 6. Configuration System
Configuration is **TOML-driven** with clear precedence:
1. Command-line → 2. Model (`[m3]` / `[reranker]`) → 3. Hardware (`[hardware]` / `[ascend]`) → 4. Training defaults (`[training]`) → 5. Dataclass field defaults.
Load & validate:
```python
from config.base_config import BaseConfig
cfg = BaseConfig.from_toml("config.toml")
```
Key sections: `[training]`, `[hardware]`, `[distributed]`, `[m3]`, `[reranker]`, `[augmentation]`, `[checkpoint]`, `[logging]`.
---
## 7. Training & Evaluation
* **Scripts**: `train_m3.py`, `train_reranker.py`, `train_joint.py`
* **Evaluation**: `scripts/evaluate.py` supports retriever, reranker or full pipeline.
* **Benchmarking**: `scripts/benchmark.py` for latency / throughput; `compare_*` for baseline comparisons.
Example evaluation:
```bash
python scripts/evaluate.py \
--retriever_model output/bge-m3 \
--eval_data data/test.jsonl \
--k_values 1 5 10
```
---
## 8. Distributed & Ascend NPU Support
### Multi-GPU (NVIDIA)
```bash
torchrun --nproc_per_node=4 scripts/train_m3.py --train_data data/train.jsonl --output_dir output/distributed
```
Backend auto-selects **NCCL**. Use `[distributed]` in `config.toml` for fine-tuning.
### Huawei Ascend NPU
```bash
pip install torch_npu
python scripts/train_m3.py --device npu --train_data data/train.jsonl --output_dir output/ascend
```
Backend auto-selects **HCCL**.
---
## 9. Troubleshooting & FAQ
* **CUDA mismatch**: install the matching PyTorch + CUDA build (`nvidia-smi` to verify).
* **Ascend setup**: ensure `torch_npu` and `[ascend]` config.
* **NCCL errors**: set `NCCL_DEBUG=INFO` and disable IB if needed (`NCCL_IB_DISABLE=1`).
* **Memory issues**: enable `--gradient_checkpointing` & mixed precision.
---
## 10. References
1. BGE-M3 Embeddings [arXiv:2402.03216](https://arxiv.org/abs/2402.03216)
2. RocketQAv2 ACL 2021
3. ANCE [arXiv:2007.00808](https://arxiv.org/abs/2007.00808)
4. ColBERT SIGIR 2020
5. SPLADE [arXiv:2109.10086](https://arxiv.org/abs/2109.10086)

49
requirements.txt Normal file
View File

@ -0,0 +1,49 @@
# Core dependencies
# IMPORTANT: Install PyTorch with CUDA 12.8 support first using:
# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
# Then install remaining dependencies with: pip install -r requirements.txt
# PyTorch - version will be satisfied by pre-installed nightly build
# Minimum version for CUDA 12.8 support (will use nightly/2.3.0+ in practice)
torch>=2.2.0
torchvision>=0.17.0
torchaudio>=2.2.0
# HuggingFace Transformers
transformers>=4.38.0
# HuggingFace Datasets
datasets>=2.14.0
# HuggingFace Tokenizers
tokenizers>=0.15.0
numpy>=1.23.0
scipy>=1.10.0
scikit-learn>=1.3.0
pandas>=1.5.0
FlagEmbedding>=1.2.0
# For CPU use faiss-cpu
faiss-gpu>=1.7.4
psutil>=5.9.0
omegaconf>=2.3.0
toml>=0.10.2
tqdm>=4.64.0
pyyaml>=6.0
jsonlines>=3.1.0
huggingface-hub>=0.19.0
jieba>=0.42.1
requests>=2.28.0
# Optional/dev dependencies
# For Huawei Ascend NPU support (install only if needed)
# torch-npu>=1.11.0.post1; platform_system=="Linux" and platform_machine=="aarch64"
# ascend-compiler; platform_system=="Linux" and platform_machine=="aarch64"
# For experiment tracking (optional)
wandb # optional, for Weights & Biases logging
# For visualization (optional)
tensorboard # optional, for TensorBoard logging
# For testing and development
pytest # optional, for running tests
# For ModelScope support (optional)
modelscope # optional, for ModelScope model hub
sentence-transformers>=2.2.2
accelerate>=0.25.0
rank-bm25>=0.2.2
# For logging and distributed
loguru>=0.7.0

144
scripts/benchmark.py Normal file
View File

@ -0,0 +1,144 @@
import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("benchmark")
def parse_args():
"""Parse command line arguments for benchmarking."""
parser = argparse.ArgumentParser(description="Benchmark BGE fine-tuned models.")
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
return parser.parse_args()
def measure_memory():
"""Measure the current memory usage of the process."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def benchmark_retriever(model_path, data_path, batch_size, max_samples, device):
"""Benchmark the BGE retriever model."""
logger.info(f"Loading retriever model from {model_path}")
model = BGEM3Model(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
# Only dense embedding for speed
_ = model(
batch['query_input_ids'].to(device),
batch['query_attention_mask'].to(device),
return_dense=True, return_sparse=False, return_colbert=False
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n_samples += batch['query_input_ids'].shape[0]
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
logger.info(f"Retriever throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
return throughput, latency, mem_mb
def benchmark_reranker(model_path, data_path, batch_size, max_samples, device):
"""Benchmark the BGE reranker model."""
logger.info(f"Loading reranker model from {model_path}")
model = BGERerankerModel(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
_ = model.model(
input_ids=batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device)
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n_samples += batch['input_ids'].shape[0]
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
logger.info(f"Reranker throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
return throughput, latency, mem_mb
def main():
"""Run benchmark with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Benchmark BGE models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
if args.model_type == 'retriever':
throughput, latency, mem_mb = benchmark_retriever(
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
else:
throughput, latency, mem_mb = benchmark_reranker(
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
# Save results
with open(args.output, 'w') as f:
f.write(f"Model: {args.model_path}\n")
f.write(f"Data: {args.data_path}\n")
f.write(f"Batch size: {args.batch_size}\n")
f.write(f"Throughput: {throughput:.2f} samples/sec\n")
f.write(f"Latency: {latency*1000:.2f} ms/sample\n")
f.write(f"Peak memory: {mem_mb:.2f} MB\n")
logger.info(f"Benchmark results saved to {args.output}")
logging.info(f"Benchmark completed on device: {device}")
except Exception as e:
logging.error(f"Benchmark failed: {e}")
raise
if __name__ == '__main__':
main()

138
scripts/compare_reranker.py Normal file
View File

@ -0,0 +1,138 @@
import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_reranker import BGERerankerModel
from data.dataset import BGERerankerDataset
from transformers import AutoTokenizer
from evaluation.metrics import compute_reranker_metrics
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("compare_reranker")
def parse_args():
"""Parse command-line arguments for reranker comparison."""
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline reranker models.")
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
return parser.parse_args()
def measure_memory():
"""Measure the current memory usage of the process."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def run_reranker(model_path, data_path, batch_size, max_samples, device):
"""
Run inference on a reranker model and measure throughput, latency, and memory.
Args:
model_path (str): Path to the reranker model.
data_path (str): Path to the evaluation data.
batch_size (int): Batch size for inference.
max_samples (int): Maximum number of samples to process.
device (str): Device to use (cuda, cpu, npu).
Returns:
tuple: (throughput, latency, mem_mb, all_scores, all_labels)
"""
model = BGERerankerModel(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
all_scores = []
all_labels = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
outputs = model.model(
input_ids=batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device)
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n = batch['input_ids'].shape[0]
n_samples += n
# For metrics: collect scores and labels
scores = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
all_scores.append(scores.cpu().numpy())
all_labels.append(batch['labels'].cpu().numpy())
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
# Prepare for metrics
all_scores = np.concatenate(all_scores, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
return throughput, latency, mem_mb, all_scores, all_labels
def main():
"""Run reranker comparison with robust error handling and config-driven defaults."""
parser = argparse.ArgumentParser(description="Compare reranker models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
args = parser.parse_args()
# Load config and set defaults
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
# Fine-tuned model
logger.info("Running fine-tuned reranker...")
ft_throughput, ft_latency, ft_mem, ft_scores, ft_labels = run_reranker(
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
ft_metrics = compute_reranker_metrics(ft_scores, ft_labels, k_values=args.k_values)
# Baseline model
logger.info("Running baseline reranker...")
bl_throughput, bl_latency, bl_mem, bl_scores, bl_labels = run_reranker(
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
bl_metrics = compute_reranker_metrics(bl_scores, bl_labels, k_values=args.k_values)
# Output comparison
with open(args.output, 'w') as f:
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
for k in args.k_values:
for metric in [f'accuracy@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
for metric in ['map', 'mrr']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
logger.info(f"Comparison results saved to {args.output}")
print(f"\nComparison complete. See {args.output} for details.")
except Exception as e:
logging.error(f"Reranker comparison failed: {e}")
raise
if __name__ == '__main__':
main()

View File

@ -0,0 +1,144 @@
import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset
from transformers import AutoTokenizer
from evaluation.metrics import compute_retrieval_metrics
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("compare_retriever")
def parse_args():
"""Parse command line arguments for retriever comparison."""
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline retriever models.")
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
return parser.parse_args()
def measure_memory():
"""Measure the current process memory usage in MB."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def run_retriever(model_path, data_path, batch_size, max_samples, device):
"""
Run inference on a retriever model and measure throughput, latency, and memory.
Args:
model_path (str): Path to the retriever model.
data_path (str): Path to the evaluation data.
batch_size (int): Batch size for inference.
max_samples (int): Maximum number of samples to process.
device (str): Device to use (cuda, cpu, npu).
Returns:
tuple: throughput (samples/sec), latency (ms/sample), peak memory (MB),
query_embeds (np.ndarray), passage_embeds (np.ndarray), labels (np.ndarray).
"""
model = BGEM3Model(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
query_embeds = []
passage_embeds = []
labels = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
out = model(
batch['query_input_ids'].to(device),
batch['query_attention_mask'].to(device),
return_dense=True, return_sparse=False, return_colbert=False
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n = batch['query_input_ids'].shape[0]
n_samples += n
# For metrics: collect embeddings and labels
query_embeds.append(out['query_embeds'].cpu().numpy())
passage_embeds.append(out['passage_embeds'].cpu().numpy())
labels.append(batch['labels'].cpu().numpy())
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
# Prepare for metrics
query_embeds = np.concatenate(query_embeds, axis=0)
passage_embeds = np.concatenate(passage_embeds, axis=0)
labels = np.concatenate(labels, axis=0)
return throughput, latency, mem_mb, query_embeds, passage_embeds, labels
def main():
"""Run retriever comparison with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Compare retriever models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
# Fine-tuned model
logger.info("Running fine-tuned retriever...")
ft_throughput, ft_latency, ft_mem, ft_qe, ft_pe, ft_labels = run_retriever(
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
ft_metrics = compute_retrieval_metrics(ft_qe, ft_pe, ft_labels, k_values=args.k_values)
# Baseline model
logger.info("Running baseline retriever...")
bl_throughput, bl_latency, bl_mem, bl_qe, bl_pe, bl_labels = run_retriever(
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
bl_metrics = compute_retrieval_metrics(bl_qe, bl_pe, bl_labels, k_values=args.k_values)
# Output comparison
with open(args.output, 'w') as f:
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
for k in args.k_values:
for metric in [f'recall@{k}', f'precision@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
for metric in ['map', 'mrr']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
logger.info(f"Comparison results saved to {args.output}")
print(f"\nComparison complete. See {args.output} for details.")
except Exception as e:
logging.error(f"Retriever comparison failed: {e}")
raise
if __name__ == '__main__':
main()

View File

@ -0,0 +1,990 @@
#!/usr/bin/env python3
"""
Comprehensive Validation Script for BGE Fine-tuned Models
This script provides a complete validation framework to determine if fine-tuned models
actually perform better than baseline models across multiple datasets and metrics.
Features:
- Tests both M3 retriever and reranker models
- Compares against baseline models on multiple datasets
- Provides statistical significance testing
- Tests pipeline performance (retrieval + reranking)
- Generates comprehensive reports
- Includes performance degradation detection
- Supports cross-dataset validation
Usage:
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--reranker_finetuned ./output/bge-reranker/final_model \\
--output_dir ./validation_results
"""
import os
import sys
import json
import logging
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from scipy import stats
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset
from evaluation.metrics import (
compute_retrieval_metrics, compute_reranker_metrics,
compute_retrieval_metrics_comprehensive
)
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
from data.preprocessing import DataPreprocessor
from utils.logging import setup_logging, get_logger
from utils.config_loader import get_config
@dataclass
class ValidationConfig:
"""Configuration for comprehensive validation"""
# Model paths
retriever_baseline: str = "BAAI/bge-m3"
reranker_baseline: str = "BAAI/bge-reranker-base"
retriever_finetuned: Optional[str] = None
reranker_finetuned: Optional[str] = None
# Test datasets
test_datasets: List[str] = None
# Evaluation settings
batch_size: int = 32
max_length: int = 512
k_values: List[int] = None
significance_level: float = 0.05
min_samples_for_significance: int = 30
# Pipeline settings
retrieval_top_k: int = 100
rerank_top_k: int = 10
# Output settings
output_dir: str = "./validation_results"
save_embeddings: bool = False
create_report: bool = True
def __post_init__(self):
if self.test_datasets is None:
self.test_datasets = [
"data/datasets/examples/embedding_data.jsonl",
"data/datasets/examples/reranker_data.jsonl",
"data/datasets/Reranker_AFQMC/dev.json",
"data/datasets/三国演义/reranker_train.jsonl"
]
if self.k_values is None:
self.k_values = [1, 3, 5, 10, 20]
class ComprehensiveValidator:
"""Comprehensive validation framework for BGE models"""
def __init__(self, config: ValidationConfig):
self.config = config
self.logger = get_logger(__name__)
self.results = {}
self.device = self._get_device()
# Create output directory
os.makedirs(config.output_dir, exist_ok=True)
# Setup logging
setup_logging(
output_dir=config.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
def _get_device(self) -> torch.device:
"""Get optimal device for evaluation"""
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
device = torch.device("npu:0")
self.logger.info("Using NPU device")
return device
except ImportError:
pass
if torch.cuda.is_available():
device = torch.device("cuda")
self.logger.info("Using CUDA device")
return device
except Exception as e:
self.logger.warning(f"Error detecting optimal device: {e}")
device = torch.device("cpu")
self.logger.info("Using CPU device")
return device
def validate_all(self) -> Dict[str, Any]:
"""Run comprehensive validation on all models and datasets"""
self.logger.info("🚀 Starting Comprehensive BGE Model Validation")
self.logger.info("=" * 80)
start_time = datetime.now()
try:
# 1. Validate retriever models
if self.config.retriever_finetuned:
self.logger.info("📊 Validating Retriever Models...")
retriever_results = self._validate_retrievers()
self.results['retriever'] = retriever_results
# 2. Validate reranker models
if self.config.reranker_finetuned:
self.logger.info("📊 Validating Reranker Models...")
reranker_results = self._validate_rerankers()
self.results['reranker'] = reranker_results
# 3. Validate pipeline (if both models available)
if self.config.retriever_finetuned and self.config.reranker_finetuned:
self.logger.info("📊 Validating Pipeline Performance...")
pipeline_results = self._validate_pipeline()
self.results['pipeline'] = pipeline_results
# 4. Generate comprehensive report
if self.config.create_report:
self._generate_report()
# 5. Statistical analysis
self._perform_statistical_analysis()
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
self.logger.info("✅ Comprehensive validation completed!")
self.logger.info(f"⏱️ Total validation time: {duration:.2f} seconds")
return self.results
except Exception as e:
self.logger.error(f"❌ Validation failed: {e}", exc_info=True)
raise
def _validate_retrievers(self) -> Dict[str, Any]:
"""Validate retriever models on all suitable datasets"""
results = {
'baseline_metrics': {},
'finetuned_metrics': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
self.logger.warning(f"Dataset not found: {dataset_path}")
continue
# Skip reranker-only datasets for retriever testing
if 'reranker' in dataset_path.lower() and 'AFQMC' in dataset_path:
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing retriever on dataset: {dataset_name}")
try:
# Test baseline model
baseline_metrics = self._test_retriever_on_dataset(
model_path=self.config.retriever_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned model
finetuned_metrics = self._test_retriever_on_dataset(
model_path=self.config.retriever_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "retriever")
results['baseline_metrics'][dataset_name] = baseline_metrics
results['finetuned_metrics'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Retriever validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test retriever on {dataset_name}: {e}")
continue
return results
def _validate_rerankers(self) -> Dict[str, Any]:
"""Validate reranker models on all suitable datasets"""
results = {
'baseline_metrics': {},
'finetuned_metrics': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
self.logger.warning(f"Dataset not found: {dataset_path}")
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing reranker on dataset: {dataset_name}")
try:
# Test baseline model
baseline_metrics = self._test_reranker_on_dataset(
model_path=self.config.reranker_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned model
finetuned_metrics = self._test_reranker_on_dataset(
model_path=self.config.reranker_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "reranker")
results['baseline_metrics'][dataset_name] = baseline_metrics
results['finetuned_metrics'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Reranker validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test reranker on {dataset_name}: {e}")
continue
return results
def _validate_pipeline(self) -> Dict[str, Any]:
"""Validate end-to-end retrieval + reranking pipeline"""
results = {
'baseline_pipeline': {},
'finetuned_pipeline': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
continue
# Only test on datasets suitable for pipeline evaluation
if 'embedding_data' not in dataset_path:
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing pipeline on dataset: {dataset_name}")
try:
# Test baseline pipeline
baseline_metrics = self._test_pipeline_on_dataset(
retriever_path=self.config.retriever_baseline,
reranker_path=self.config.reranker_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned pipeline
finetuned_metrics = self._test_pipeline_on_dataset(
retriever_path=self.config.retriever_finetuned,
reranker_path=self.config.reranker_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "pipeline")
results['baseline_pipeline'][dataset_name] = baseline_metrics
results['finetuned_pipeline'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Pipeline validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test pipeline on {dataset_name}: {e}")
continue
return results
def _test_retriever_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test a retriever model on a specific dataset"""
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Create dataset
dataset = BGEM3Dataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=self.config.max_length,
passage_max_length=self.config.max_length,
is_train=False
)
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=False
)
# Create evaluator
evaluator = RetrievalEvaluator(
model=model,
tokenizer=tokenizer,
device=self.device,
k_values=self.config.k_values
)
# Run evaluation
metrics = evaluator.evaluate(
dataloader=dataloader,
desc=f"{'Baseline' if is_baseline else 'Fine-tuned'} Retriever"
)
return metrics
except Exception as e:
self.logger.error(f"Error testing retriever: {e}")
return {}
def _test_reranker_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test a reranker model on a specific dataset"""
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Create dataset
dataset = BGERerankerDataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=self.config.max_length,
passage_max_length=self.config.max_length,
is_train=False
)
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=False
)
# Collect predictions and labels
all_scores = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
# Move to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].cpu().numpy()
# Get predictions
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
scores = outputs.logits.cpu().numpy()
all_scores.append(scores)
all_labels.append(labels)
# Concatenate results
all_scores = np.concatenate(all_scores, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Compute metrics
metrics = compute_reranker_metrics(
scores=all_scores,
labels=all_labels,
k_values=self.config.k_values
)
return metrics
except Exception as e:
self.logger.error(f"Error testing reranker: {e}")
return {}
def _test_pipeline_on_dataset(self, retriever_path: str, reranker_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test end-to-end pipeline on a dataset"""
try:
# Load models
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
retriever = BGEM3Model.from_pretrained(retriever_path)
retriever.to(self.device)
retriever.eval()
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
reranker = BGERerankerModel.from_pretrained(reranker_path)
reranker.to(self.device)
reranker.eval()
# Load and preprocess data
data = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
try:
item = json.loads(line.strip())
if 'query' in item and 'pos' in item and 'neg' in item:
data.append(item)
except:
continue
if not data:
return {}
# Test pipeline performance
total_queries = 0
pipeline_metrics = {f'recall@{k}': [] for k in self.config.k_values}
pipeline_metrics.update({f'mrr@{k}': [] for k in self.config.k_values})
for item in data[:100]: # Limit to 100 queries for speed
query = item['query']
pos_docs = item['pos'][:2] # Take first 2 positive docs
neg_docs = item['neg'][:8] # Take first 8 negative docs
# Create candidate pool
candidates = pos_docs + neg_docs
labels = [1] * len(pos_docs) + [0] * len(neg_docs)
if len(candidates) < 3: # Need minimum candidates
continue
try:
# Step 1: Retrieval (encode query and candidates)
query_inputs = retriever_tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
query_outputs = retriever(**query_inputs, return_dense=True)
query_emb = query_outputs['dense']
# Encode candidates
candidate_embs = []
for candidate in candidates:
cand_inputs = retriever_tokenizer(
candidate,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
cand_outputs = retriever(**cand_inputs, return_dense=True)
candidate_embs.append(cand_outputs['dense'])
candidate_embs = torch.cat(candidate_embs, dim=0)
# Compute similarities
similarities = torch.cosine_similarity(query_emb, candidate_embs, dim=1)
# Get top-k for retrieval
top_k_retrieval = min(self.config.retrieval_top_k, len(candidates))
retrieval_indices = torch.argsort(similarities, descending=True)[:top_k_retrieval]
# Step 2: Reranking
rerank_candidates = [candidates[i] for i in retrieval_indices]
rerank_labels = [labels[i] for i in retrieval_indices]
# Score pairs with reranker
rerank_scores = []
for candidate in rerank_candidates:
pair_text = f"{query} [SEP] {candidate}"
pair_inputs = reranker_tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
pair_outputs = reranker(**pair_inputs)
score = pair_outputs.logits.item()
rerank_scores.append(score)
# Final ranking
final_indices = np.argsort(rerank_scores)[::-1]
final_labels = [rerank_labels[i] for i in final_indices]
# Compute metrics
for k in self.config.k_values:
if k <= len(final_labels):
# Recall@k
relevant_in_topk = sum(final_labels[:k])
total_relevant = sum(rerank_labels)
recall = relevant_in_topk / total_relevant if total_relevant > 0 else 0
pipeline_metrics[f'recall@{k}'].append(recall)
# MRR@k
for i, label in enumerate(final_labels[:k]):
if label == 1:
pipeline_metrics[f'mrr@{k}'].append(1.0 / (i + 1))
break
else:
pipeline_metrics[f'mrr@{k}'].append(0.0)
total_queries += 1
except Exception as e:
self.logger.warning(f"Error processing query: {e}")
continue
# Average metrics
final_metrics = {}
for metric, values in pipeline_metrics.items():
if values:
final_metrics[metric] = np.mean(values)
else:
final_metrics[metric] = 0.0
final_metrics['total_queries'] = total_queries
return final_metrics
except Exception as e:
self.logger.error(f"Error testing pipeline: {e}")
return {}
def _compare_metrics(self, baseline: Dict[str, float], finetuned: Dict[str, float], model_type: str) -> Dict[str, Any]:
"""Compare baseline vs fine-tuned metrics with statistical analysis"""
comparison = {
'improvements': {},
'degradations': {},
'statistical_significance': {},
'summary': {}
}
# Compare common metrics
common_metrics = set(baseline.keys()) & set(finetuned.keys())
improvements = 0
degradations = 0
significant_improvements = 0
for metric in common_metrics:
if isinstance(baseline[metric], (int, float)) and isinstance(finetuned[metric], (int, float)):
baseline_val = baseline[metric]
finetuned_val = finetuned[metric]
# Calculate improvement
if baseline_val != 0:
improvement_pct = ((finetuned_val - baseline_val) / baseline_val) * 100
else:
improvement_pct = float('inf') if finetuned_val > 0 else 0
# Classify improvement/degradation
if finetuned_val > baseline_val:
comparison['improvements'][metric] = {
'baseline': baseline_val,
'finetuned': finetuned_val,
'improvement_pct': improvement_pct,
'absolute_improvement': finetuned_val - baseline_val
}
improvements += 1
elif finetuned_val < baseline_val:
comparison['degradations'][metric] = {
'baseline': baseline_val,
'finetuned': finetuned_val,
'degradation_pct': -improvement_pct,
'absolute_degradation': baseline_val - finetuned_val
}
degradations += 1
# Summary statistics
comparison['summary'] = {
'total_metrics': len(common_metrics),
'improvements': improvements,
'degradations': degradations,
'unchanged': len(common_metrics) - improvements - degradations,
'improvement_ratio': improvements / len(common_metrics) if common_metrics else 0,
'net_improvement': improvements - degradations
}
return comparison
def _perform_statistical_analysis(self):
"""Perform statistical analysis on validation results"""
self.logger.info("📈 Performing Statistical Analysis...")
analysis = {
'overall_summary': {},
'significance_tests': {},
'confidence_intervals': {},
'effect_sizes': {}
}
# Analyze each model type
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type not in self.results:
continue
model_results = self.results[model_type]
# Collect all improvements across datasets
all_improvements = []
significant_improvements = 0
for dataset, comparison in model_results.get('comparisons', {}).items():
for metric, improvement_data in comparison.get('improvements', {}).items():
improvement_pct = improvement_data.get('improvement_pct', 0)
if improvement_pct > 0:
all_improvements.append(improvement_pct)
# Simple significance test (can be enhanced)
if improvement_pct > 5.0: # 5% improvement threshold
significant_improvements += 1
if all_improvements:
analysis[model_type] = {
'mean_improvement_pct': np.mean(all_improvements),
'median_improvement_pct': np.median(all_improvements),
'std_improvement_pct': np.std(all_improvements),
'min_improvement_pct': np.min(all_improvements),
'max_improvement_pct': np.max(all_improvements),
'significant_improvements': significant_improvements,
'total_improvements': len(all_improvements)
}
self.results['statistical_analysis'] = analysis
def _generate_report(self):
"""Generate comprehensive validation report"""
self.logger.info("📝 Generating Comprehensive Report...")
report_path = os.path.join(self.config.output_dir, "validation_report.html")
html_content = self._create_html_report()
with open(report_path, 'w', encoding='utf-8') as f:
f.write(html_content)
# Also save JSON results
json_path = os.path.join(self.config.output_dir, "validation_results.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(self.results, f, indent=2, default=str)
self.logger.info(f"📊 Report saved to: {report_path}")
self.logger.info(f"📊 JSON results saved to: {json_path}")
def _create_html_report(self) -> str:
"""Create HTML report with comprehensive validation results"""
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>BGE Model Validation Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.header {{ background: #f0f8ff; padding: 20px; border-radius: 10px; }}
.section {{ margin: 20px 0; }}
.metrics-table {{ border-collapse: collapse; width: 100%; }}
.metrics-table th, .metrics-table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
.metrics-table th {{ background-color: #f2f2f2; }}
.improvement {{ color: green; font-weight: bold; }}
.degradation {{ color: red; font-weight: bold; }}
.summary-box {{ background: #f9f9f9; padding: 15px; border-left: 4px solid #4CAF50; margin: 10px 0; }}
</style>
</head>
<body>
<div class="header">
<h1>🚀 BGE Model Comprehensive Validation Report</h1>
<p><strong>Generated:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
<p><strong>Configuration:</strong></p>
<ul>
<li>Retriever Baseline: {self.config.retriever_baseline}</li>
<li>Retriever Fine-tuned: {self.config.retriever_finetuned or 'Not tested'}</li>
<li>Reranker Baseline: {self.config.reranker_baseline}</li>
<li>Reranker Fine-tuned: {self.config.reranker_finetuned or 'Not tested'}</li>
</ul>
</div>
"""
# Add executive summary
html += self._create_executive_summary()
# Add detailed results for each model type
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in self.results:
html += self._create_model_section(model_type)
html += """
</body>
</html>
"""
return html
def _create_executive_summary(self) -> str:
"""Create executive summary section"""
html = """
<div class="section">
<h2>📊 Executive Summary</h2>
"""
# Overall verdict
total_improvements = 0
total_degradations = 0
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in self.results and 'comparisons' in self.results[model_type]:
for dataset, comparison in self.results[model_type]['comparisons'].items():
total_improvements += len(comparison.get('improvements', {}))
total_degradations += len(comparison.get('degradations', {}))
if total_improvements > total_degradations:
verdict = "✅ <strong>FINE-TUNING SUCCESSFUL</strong> - Models show overall improvement"
verdict_color = "green"
elif total_improvements == total_degradations:
verdict = "⚠️ <strong>MIXED RESULTS</strong> - Equal improvements and degradations"
verdict_color = "orange"
else:
verdict = "❌ <strong>FINE-TUNING ISSUES</strong> - More degradations than improvements"
verdict_color = "red"
html += f"""
<div class="summary-box" style="border-left-color: {verdict_color};">
<h3>{verdict}</h3>
<ul>
<li>Total Metric Improvements: {total_improvements}</li>
<li>Total Metric Degradations: {total_degradations}</li>
<li>Net Improvement: {total_improvements - total_degradations}</li>
</ul>
</div>
"""
html += """
</div>
"""
return html
def _create_model_section(self, model_type: str) -> str:
"""Create detailed section for a model type"""
html = f"""
<div class="section">
<h2>📈 {model_type.title()} Model Results</h2>
"""
model_results = self.results[model_type]
for dataset in model_results.get('datasets_tested', []):
if dataset in model_results['comparisons']:
comparison = model_results['comparisons'][dataset]
html += f"""
<h3>Dataset: {dataset}</h3>
<table class="metrics-table">
<thead>
<tr>
<th>Metric</th>
<th>Baseline</th>
<th>Fine-tuned</th>
<th>Change</th>
<th>Status</th>
</tr>
</thead>
<tbody>
"""
# Add improvements
for metric, data in comparison.get('improvements', {}).items():
html += f"""
<tr>
<td>{metric}</td>
<td>{data['baseline']:.4f}</td>
<td>{data['finetuned']:.4f}</td>
<td class="improvement">+{data['improvement_pct']:.2f}%</td>
<td class="improvement"> Improved</td>
</tr>
"""
# Add degradations
for metric, data in comparison.get('degradations', {}).items():
html += f"""
<tr>
<td>{metric}</td>
<td>{data['baseline']:.4f}</td>
<td>{data['finetuned']:.4f}</td>
<td class="degradation">-{data['degradation_pct']:.2f}%</td>
<td class="degradation"> Degraded</td>
</tr>
"""
html += """
</tbody>
</table>
"""
html += """
</div>
"""
return html
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Comprehensive validation of BGE fine-tuned models",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Validate both retriever and reranker
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--reranker_finetuned ./output/bge-reranker/final_model
# Validate only retriever
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model
# Custom datasets and output
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--test_datasets data/my_test1.jsonl data/my_test2.jsonl \\
--output_dir ./my_validation_results
"""
)
# Model paths
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Path to baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Path to baseline reranker model")
parser.add_argument("--retriever_finetuned", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_finetuned", type=str, default=None,
help="Path to fine-tuned reranker model")
# Test data
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
help="Paths to test datasets (JSONL format)")
# Evaluation settings
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 3, 5, 10, 20],
help="K values for evaluation metrics")
# Output settings
parser.add_argument("--output_dir", type=str, default="./validation_results",
help="Directory to save validation results")
parser.add_argument("--save_embeddings", action="store_true",
help="Save embeddings for further analysis")
parser.add_argument("--no_report", action="store_true",
help="Skip HTML report generation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for retrieval stage")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking stage")
return parser.parse_args()
def main():
"""Main validation function"""
args = parse_args()
# Validate arguments
if not args.retriever_finetuned and not args.reranker_finetuned:
print("❌ Error: At least one fine-tuned model must be specified")
print(" Use --retriever_finetuned or --reranker_finetuned")
return 1
# Create configuration
config = ValidationConfig(
retriever_baseline=args.retriever_baseline,
reranker_baseline=args.reranker_baseline,
retriever_finetuned=args.retriever_finetuned,
reranker_finetuned=args.reranker_finetuned,
test_datasets=args.test_datasets,
batch_size=args.batch_size,
max_length=args.max_length,
k_values=args.k_values,
output_dir=args.output_dir,
save_embeddings=args.save_embeddings,
create_report=not args.no_report,
retrieval_top_k=args.retrieval_top_k,
rerank_top_k=args.rerank_top_k
)
# Run validation
try:
validator = ComprehensiveValidator(config)
results = validator.validate_all()
print("\n" + "=" * 80)
print("🎉 COMPREHENSIVE VALIDATION COMPLETED!")
print("=" * 80)
print(f"📊 Results saved to: {config.output_dir}")
print("\n📈 Quick Summary:")
# Print quick summary
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in results:
model_results = results[model_type]
total_datasets = len(model_results.get('datasets_tested', []))
total_improvements = sum(len(comp.get('improvements', {}))
for comp in model_results.get('comparisons', {}).values())
total_degradations = sum(len(comp.get('degradations', {}))
for comp in model_results.get('comparisons', {}).values())
status = "" if total_improvements > total_degradations else "⚠️" if total_improvements == total_degradations else ""
print(f" {status} {model_type.title()}: {total_improvements} improvements, {total_degradations} degradations across {total_datasets} datasets")
return 0
except KeyboardInterrupt:
print("\n❌ Validation interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Validation failed: {e}")
return 1
if __name__ == "__main__":
exit(main())

View File

@ -0,0 +1,291 @@
#!/usr/bin/env python3
"""
BGE-M3 Dataset Format Converter
Converts the nested dataset format where training data is stored inside a "data" field
to the flat format expected by the BGE-M3 training pipeline.
Input format:
{
"query": "...",
"data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}"
}
Output format:
{
"query": "...",
"pos": [...],
"neg": [...],
"pos_scores": [...],
"neg_scores": [...],
"prompt": "...",
"type": "..."
}
"""
import json
import argparse
import logging
import os
from tqdm import tqdm
from typing import Dict, Any, Optional
import re
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]:
"""
Extract JSON data from the "data" field which contains a JSON string
wrapped in markdown code blocks.
Args:
data_field: String containing JSON data, possibly wrapped in markdown
Returns:
Parsed JSON dictionary or None if parsing fails
"""
try:
# Remove markdown code block markers if present
cleaned_data = data_field.strip()
# Remove ```json and ``` markers
if cleaned_data.startswith('```json'):
cleaned_data = cleaned_data[7:] # Remove ```json
elif cleaned_data.startswith('```'):
cleaned_data = cleaned_data[3:] # Remove ```
if cleaned_data.endswith('```'):
cleaned_data = cleaned_data[:-3] # Remove trailing ```
# Clean up any remaining whitespace
cleaned_data = cleaned_data.strip()
# Parse the JSON
return json.loads(cleaned_data)
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(f"Failed to parse JSON from data field: {e}")
logger.warning(f"Data content (first 200 chars): {data_field[:200]}")
return None
def convert_dataset_format(input_file: str, output_file: str) -> None:
"""
Convert dataset from nested format to flat format.
Args:
input_file: Path to input JSONL file with nested format
output_file: Path to output JSONL file with flat format
"""
logger.info(f"Converting dataset from {input_file} to {output_file}")
# Count total lines for progress bar
total_lines = 0
with open(input_file, 'r', encoding='utf-8') as f:
for _ in f:
total_lines += 1
converted_count = 0
error_count = 0
with open(input_file, 'r', encoding='utf-8') as infile, \
open(output_file, 'w', encoding='utf-8') as outfile:
for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1):
line = line.strip()
if not line:
continue
try:
# Parse the input JSON line
item = json.loads(line)
# Validate required fields
if 'query' not in item:
logger.warning(f"Line {line_num}: Missing 'query' field")
error_count += 1
continue
if 'data' not in item:
logger.warning(f"Line {line_num}: Missing 'data' field")
error_count += 1
continue
# Extract data from the nested field
data_content = extract_json_from_data_field(item['data'])
if data_content is None:
logger.warning(f"Line {line_num}: Failed to parse data field")
error_count += 1
continue
# Create the flattened format
flattened_item = {
'query': item['query']
}
# Add all fields from the data content
expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type']
for field in expected_fields:
if field in data_content:
flattened_item[field] = data_content[field]
else:
# Provide reasonable defaults
if field == 'pos':
flattened_item[field] = []
elif field == 'neg':
flattened_item[field] = []
elif field == 'pos_scores':
flattened_item[field] = [1.0] * len(data_content.get('pos', []))
elif field == 'neg_scores':
flattened_item[field] = [0.0] * len(data_content.get('neg', []))
elif field == 'prompt':
flattened_item[field] = '为此查询生成表示:'
elif field == 'type':
flattened_item[field] = 'normal'
# Validate the converted item
if not flattened_item['pos'] and not flattened_item['neg']:
logger.warning(f"Line {line_num}: No positive or negative passages found")
error_count += 1
continue
# Ensure scores arrays match passage arrays
if len(flattened_item['pos_scores']) != len(flattened_item['pos']):
flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos'])
if len(flattened_item['neg_scores']) != len(flattened_item['neg']):
flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg'])
# Write the flattened item
outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n')
converted_count += 1
except json.JSONDecodeError as e:
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
error_count += 1
continue
except Exception as e:
logger.warning(f"Line {line_num}: Unexpected error - {e}")
error_count += 1
continue
logger.info(f"Conversion complete!")
logger.info(f"Successfully converted: {converted_count} items")
logger.info(f"Errors encountered: {error_count} items")
logger.info(f"Output saved to: {output_file}")
def validate_converted_file(file_path: str, sample_size: int = 5) -> None:
"""
Validate the converted file by showing a few sample entries.
Args:
file_path: Path to the converted JSONL file
sample_size: Number of sample entries to display
"""
logger.info(f"Validating converted file: {file_path}")
try:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
total_lines = len(lines)
logger.info(f"Total converted entries: {total_lines}")
if total_lines == 0:
logger.warning("No entries found in converted file!")
return
# Show sample entries
sample_indices = list(range(min(sample_size, total_lines)))
for i, idx in enumerate(sample_indices):
try:
item = json.loads(lines[idx])
logger.info(f"\nSample {i+1} (line {idx+1}):")
logger.info(f" Query: {item['query'][:100]}...")
logger.info(f" Positives: {len(item.get('pos', []))}")
logger.info(f" Negatives: {len(item.get('neg', []))}")
logger.info(f" Type: {item.get('type', 'N/A')}")
logger.info(f" Prompt: {item.get('prompt', 'N/A')}")
# Show first positive and negative if available
if item.get('pos'):
logger.info(f" First positive: {item['pos'][0][:150]}...")
if item.get('neg'):
logger.info(f" First negative: {item['neg'][0][:150]}...")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {idx+1}: {e}")
except Exception as e:
logger.error(f"Error validating file: {e}")
def main():
parser = argparse.ArgumentParser(
description="Convert BGE-M3 dataset from nested to flat format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl
python convert_m3_dataset.py input.jsonl output.jsonl --validate
"""
)
parser.add_argument(
'input_file',
help='Path to input JSONL file with nested format'
)
parser.add_argument(
'output_file',
help='Path to output JSONL file with flat format'
)
parser.add_argument(
'--validate',
action='store_true',
help='Validate the converted file and show samples'
)
parser.add_argument(
'--sample-size',
type=int,
default=5,
help='Number of sample entries to show during validation (default: 5)'
)
args = parser.parse_args()
# Check if input file exists
if not os.path.exists(args.input_file):
logger.error(f"Input file not found: {args.input_file}")
return 1
# Create output directory if it doesn't exist
output_dir = os.path.dirname(args.output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info(f"Created output directory: {output_dir}")
try:
# Convert the dataset
convert_dataset_format(args.input_file, args.output_file)
# Validate if requested
if args.validate:
validate_converted_file(args.output_file, args.sample_size)
logger.info("Dataset conversion completed successfully!")
return 0
except Exception as e:
logger.error(f"Dataset conversion failed: {e}")
return 1
if __name__ == '__main__':
exit(main())

584
scripts/evaluate.py Normal file
View File

@ -0,0 +1,584 @@
"""
Evaluation script for fine-tuned BGE models
"""
import os
import sys
import argparse
import logging
import json
import torch
from transformers import AutoTokenizer
import numpy as np
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
from data.preprocessing import DataPreprocessor
from utils.logging import setup_logging, get_logger
logger = logging.getLogger(__name__)
def parse_args():
"""Parse command line arguments for evaluation."""
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
# Model arguments
parser.add_argument("--retriever_model", type=str, required=True,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
help="Base retriever model for comparison")
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
help="Base reranker model for comparison")
# Data arguments
parser.add_argument("--eval_data", type=str, required=True,
help="Path to evaluation data file")
parser.add_argument("--corpus_data", type=str, default=None,
help="Path to corpus data for retrieval evaluation")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Evaluation arguments
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
help="Directory to save evaluation results")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
help="K values for evaluation metrics")
# Evaluation modes
parser.add_argument("--eval_retriever", action="store_true",
help="Evaluate retriever model")
parser.add_argument("--eval_reranker", action="store_true",
help="Evaluate reranker model")
parser.add_argument("--eval_pipeline", action="store_true",
help="Evaluate full retrieval pipeline")
parser.add_argument("--compare_baseline", action="store_true",
help="Compare with baseline models")
parser.add_argument("--interactive", action="store_true",
help="Run interactive evaluation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for initial retrieval")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking")
# Other
parser.add_argument("--device", type=str, default="auto",
help="Device to use (auto, cpu, cuda)")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
# Validation
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
args.eval_pipeline = True # Default to pipeline evaluation
return args
def load_evaluation_data(data_path: str, data_format: str):
"""Load evaluation data from a file."""
preprocessor = DataPreprocessor()
if data_format == "auto":
data_format = preprocessor._detect_format(data_path)
if data_format == "jsonl":
data = preprocessor._load_jsonl(data_path)
elif data_format == "json":
data = preprocessor._load_json(data_path)
elif data_format in ["csv", "tsv"]:
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
elif data_format == "dureader":
data = preprocessor._load_dureader(data_path)
elif data_format == "msmarco":
data = preprocessor._load_msmarco(data_path)
else:
raise ValueError(f"Unsupported format: {data_format}")
return data
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
"""Evaluate the retriever model."""
logger.info("Evaluating retriever model...")
evaluator = RetrievalEvaluator(
model=retriever_model,
tokenizer=tokenizer,
device=args.device,
k_values=args.k_values
)
# Prepare data for evaluation
queries = [item['query'] for item in eval_data]
# If we have passage data, create passage corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
# Create relevance mapping
relevant_docs = {}
for i, item in enumerate(eval_data):
query = item['query']
if 'pos' in item:
relevant_indices = []
for pos in item['pos']:
if pos in passages:
relevant_indices.append(passages.index(pos))
if relevant_indices:
relevant_docs[query] = relevant_indices
# Evaluate
if passages and relevant_docs:
metrics = evaluator.evaluate_retrieval_pipeline(
queries=queries,
corpus=passages,
relevant_docs=relevant_docs,
batch_size=args.batch_size
)
else:
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
# Create dummy evaluation
metrics = {"note": "Limited evaluation due to data format"}
return metrics
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
"""Evaluate the reranker model."""
logger.info("Evaluating reranker model...")
evaluator = RetrievalEvaluator(
model=reranker_model,
tokenizer=tokenizer,
device=args.device,
k_values=args.k_values
)
# Prepare data for reranker evaluation
total_queries = 0
total_correct = 0
total_mrr = 0
for item in eval_data:
if 'pos' not in item or 'neg' not in item:
continue
query = item['query']
positives = item['pos']
negatives = item['neg']
if not positives or not negatives:
continue
# Create candidates (1 positive + negatives)
candidates = positives[:1] + negatives
# Score all candidates
pairs = [(query, candidate) for candidate in candidates]
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
# Check if positive is ranked first
best_idx = np.argmax(scores)
if best_idx == 0: # First is positive
total_correct += 1
total_mrr += 1.0
else:
# Find rank of positive (it's always at index 0)
sorted_indices = np.argsort(scores)[::-1]
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
total_mrr += 1.0 / positive_rank
total_queries += 1
if total_queries > 0:
metrics = {
'accuracy@1': total_correct / total_queries,
'mrr': total_mrr / total_queries,
'total_queries': total_queries
}
else:
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
return metrics
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
"""Evaluate the full retrieval + reranking pipeline."""
logger.info("Evaluating full pipeline...")
# Prepare corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
if not passages:
return {"note": "No passages found for pipeline evaluation"}
# Pre-encode corpus with retriever
logger.info("Encoding corpus...")
corpus_embeddings = retriever_model.encode(
passages,
batch_size=args.batch_size,
return_numpy=True,
device=args.device
)
total_queries = 0
pipeline_metrics = {
'retrieval_recall@10': [],
'rerank_accuracy@1': [],
'rerank_mrr': []
}
for item in eval_data:
if 'pos' not in item or not item['pos']:
continue
query = item['query']
relevant_passages = item['pos']
# Step 1: Retrieval
query_embedding = retriever_model.encode(
[query],
return_numpy=True,
device=args.device
)
# Compute similarities
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
# Get top-k for retrieval
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
retrieved_passages = [passages[i] for i in top_k_indices]
# Check retrieval recall
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
# Step 2: Reranking (if reranker available)
if reranker_model is not None:
# Rerank top candidates
rerank_candidates = retrieved_passages[:args.rerank_top_k]
pairs = [(query, candidate) for candidate in rerank_candidates]
if pairs:
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
reranked_indices = np.argsort(rerank_scores)[::-1]
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
# Check reranking performance
if reranked_passages[0] in relevant_passages:
pipeline_metrics['rerank_accuracy@1'].append(1.0)
pipeline_metrics['rerank_mrr'].append(1.0)
else:
pipeline_metrics['rerank_accuracy@1'].append(0.0)
# Find MRR
for rank, passage in enumerate(reranked_passages, 1):
if passage in relevant_passages:
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
break
else:
pipeline_metrics['rerank_mrr'].append(0.0)
total_queries += 1
# Average metrics
final_metrics = {}
for metric, values in pipeline_metrics.items():
if values:
final_metrics[metric] = np.mean(values)
final_metrics['total_queries'] = total_queries
return final_metrics
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
"""Run interactive evaluation."""
logger.info("Starting interactive evaluation...")
# Prepare corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
if not passages:
print("No passages found for interactive evaluation")
return
# Create interactive evaluator
evaluator = InteractiveEvaluator(
retriever=retriever_model,
reranker=reranker_model,
tokenizer=tokenizer,
corpus=passages,
device=args.device
)
print("\n=== Interactive Evaluation ===")
print("Enter queries to test the retrieval system.")
print("Type 'quit' to exit, 'sample' to test with sample queries.")
# Load some sample queries
sample_queries = [item['query'] for item in eval_data[:5]]
while True:
try:
user_input = input("\nEnter query: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
break
elif user_input.lower() == 'sample':
print("\nSample queries:")
for i, query in enumerate(sample_queries):
print(f"{i + 1}. {query}")
continue
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
query = sample_queries[int(user_input) - 1]
else:
query = user_input
if not query:
continue
# Search
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
print(f"\nResults for: {query}")
print("-" * 50)
for i, (idx, score, text) in enumerate(results):
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
except KeyboardInterrupt:
print("\nExiting interactive evaluation...")
break
except Exception as e:
print(f"Error: {e}")
def main():
"""Run evaluation with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
# Model arguments
parser.add_argument("--retriever_model", type=str, required=True,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
help="Base retriever model for comparison")
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
help="Base reranker model for comparison")
# Data arguments
parser.add_argument("--eval_data", type=str, required=True,
help="Path to evaluation data file")
parser.add_argument("--corpus_data", type=str, default=None,
help="Path to corpus data for retrieval evaluation")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Evaluation arguments
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
help="Directory to save evaluation results")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
help="K values for evaluation metrics")
# Evaluation modes
parser.add_argument("--eval_retriever", action="store_true",
help="Evaluate retriever model")
parser.add_argument("--eval_reranker", action="store_true",
help="Evaluate reranker model")
parser.add_argument("--eval_pipeline", action="store_true",
help="Evaluate full retrieval pipeline")
parser.add_argument("--compare_baseline", action="store_true",
help="Compare with baseline models")
parser.add_argument("--interactive", action="store_true",
help="Run interactive evaluation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for initial retrieval")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("Starting model evaluation")
logger.info(f"Arguments: {args}")
logger.info(f"Using device: {device}")
# Set device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "npu":
npu = getattr(torch, "npu", None)
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
device = "npu"
else:
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
device = "cpu"
else:
device = device
logger.info(f"Using device: {device}")
# Load models
logger.info("Loading models...")
from utils.config_loader import get_config
model_source = get_config().get('model_paths.source', 'huggingface')
# Load retriever
try:
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
retriever_model.to(device)
retriever_model.eval()
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
except Exception as e:
logger.error(f"Failed to load retriever: {e}")
return
# Load reranker (optional)
reranker_model = None
reranker_tokenizer = None
if args.reranker_model:
try:
reranker_model = BGERerankerModel(args.reranker_model)
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
reranker_model.to(device)
reranker_model.eval()
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
except Exception as e:
logger.warning(f"Failed to load reranker: {e}")
# Load evaluation data
logger.info("Loading evaluation data...")
eval_data = load_evaluation_data(args.eval_data, args.data_format)
logger.info(f"Loaded {len(eval_data)} evaluation examples")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Run evaluations
all_results = {}
if args.eval_retriever:
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
all_results['retriever'] = retriever_metrics
logger.info(f"Retriever metrics: {retriever_metrics}")
if args.eval_reranker and reranker_model:
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
all_results['reranker'] = reranker_metrics
logger.info(f"Reranker metrics: {reranker_metrics}")
if args.eval_pipeline:
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
all_results['pipeline'] = pipeline_metrics
logger.info(f"Pipeline metrics: {pipeline_metrics}")
# Compare with baseline if requested
if args.compare_baseline:
logger.info("Loading baseline models for comparison...")
try:
base_retriever = BGEM3Model(args.base_retriever)
base_retriever.to(device)
base_retriever.eval()
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
all_results['baseline_retriever'] = base_retriever_metrics
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
if reranker_model:
base_reranker = BGERerankerModel(args.base_reranker)
base_reranker.to(device)
base_reranker.eval()
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
all_results['baseline_reranker'] = base_reranker_metrics
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
except Exception as e:
logger.warning(f"Failed to load baseline models: {e}")
# Save results
results_file = os.path.join(args.output_dir, "evaluation_results.json")
with open(results_file, 'w') as f:
json.dump(all_results, f, indent=2)
logger.info(f"Evaluation results saved to {results_file}")
# Print summary
print("\n" + "=" * 50)
print("EVALUATION SUMMARY")
print("=" * 50)
for model_type, metrics in all_results.items():
print(f"\n{model_type.upper()}:")
for metric, value in metrics.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.4f}")
else:
print(f" {metric}: {value}")
# Interactive evaluation
if args.interactive:
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
if __name__ == "__main__":
main()

624
scripts/quick_validation.py Normal file
View File

@ -0,0 +1,624 @@
#!/usr/bin/env python3
"""
Quick Validation Script for BGE Fine-tuned Models
This script provides fast validation to quickly check if fine-tuned models
are working correctly and performing better than baselines.
Usage:
# Quick test both models
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Test only retriever
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model
"""
import os
import sys
import json
import argparse
import logging
import torch
import numpy as np
from transformers import AutoTokenizer
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class QuickValidator:
"""Quick validation for BGE models"""
def __init__(self):
self.device = self._get_device()
self.test_queries = self._get_test_queries()
def _get_device(self):
"""Get optimal device"""
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
return torch.device("npu:0")
except ImportError:
pass
if torch.cuda.is_available():
return torch.device("cuda")
except:
pass
return torch.device("cpu")
def _get_test_queries(self):
"""Get standard test queries for validation"""
return [
{
"query": "什么是深度学习?",
"relevant": [
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
],
"irrelevant": [
"今天天气很好,阳光明媚",
"我喜欢吃苹果和香蕉"
]
},
{
"query": "如何制作红烧肉?",
"relevant": [
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
"制作红烧肉的关键是掌握火候和糖色的调制"
],
"irrelevant": [
"Python是一种编程语言",
"机器学习需要大量的数据"
]
},
{
"query": "Python编程入门",
"relevant": [
"Python是一种易学易用的编程语言适合初学者入门",
"学习Python需要掌握基本语法、数据类型和控制结构"
],
"irrelevant": [
"红烧肉是一道经典的中式菜肴",
"深度学习使用神经网络进行训练"
]
}
]
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
"""Quick validation of retriever model"""
print(f"\n🔍 Validating Retriever Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned model...")
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline model...")
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Retriever Validation Summary:")
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
if comparison['relevance_improvement'] > 0:
print(" ✅ Fine-tuned model shows improvement!")
else:
print(" ❌ Fine-tuned model shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['relevance_improvement']
}
return results
except Exception as e:
logger.error(f"Retriever validation failed: {e}")
return None
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
"""Quick validation of reranker model"""
print(f"\n🏆 Validating Reranker Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned reranker...")
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline reranker...")
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Reranker Validation Summary:")
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
if comparison['accuracy_improvement'] > 0:
print(" ✅ Fine-tuned reranker shows improvement!")
else:
print(" ❌ Fine-tuned reranker shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['accuracy_improvement']
}
return results
except Exception as e:
logger.error(f"Reranker validation failed: {e}")
return None
def _test_retriever_model(self, model_path, model_type):
"""Test retriever model with predefined queries"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} model loaded successfully")
relevant_scores = []
irrelevant_scores = []
ranking_accuracy = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Encode query
query_inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
query_outputs = model(**query_inputs, return_dense=True)
query_emb = query_outputs['dense']
# Test relevant documents
for doc in relevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
# Compute similarity
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
relevant_scores.append(similarity)
# Test irrelevant documents
for doc in irrelevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
irrelevant_scores.append(similarity)
# Check ranking accuracy (relevant should score higher than irrelevant)
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in relevant_docs])
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in irrelevant_docs])
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
return {
'avg_relevant_score': np.mean(relevant_scores),
'avg_irrelevant_score': np.mean(irrelevant_scores),
'ranking_accuracy': np.mean(ranking_accuracy),
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
}
except Exception as e:
logger.error(f"Error testing {model_type} retriever: {e}")
return None
def _test_reranker_model(self, model_path, model_type):
"""Test reranker model with predefined query-document pairs"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} reranker loaded successfully")
correct_rankings = 0
total_pairs = 0
all_scores = []
all_labels = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Test relevant documents (positive examples)
for doc in relevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(1) # Relevant
# Check if score indicates relevance (positive score)
if score > 0:
correct_rankings += 1
total_pairs += 1
# Test irrelevant documents (negative examples)
for doc in irrelevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(0) # Irrelevant
# Check if score indicates irrelevance (negative score)
if score <= 0:
correct_rankings += 1
total_pairs += 1
return {
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
'total_pairs': total_pairs
}
except Exception as e:
logger.error(f"Error testing {model_type} reranker: {e}")
return None
def _compare_retriever_results(self, baseline, finetuned):
"""Compare retriever results"""
if not baseline or not finetuned:
return {}
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
return {
'relevance_improvement': relevance_improvement,
'separation_improvement': separation_improvement,
'ranking_improvement': ranking_improvement
}
def _compare_reranker_results(self, baseline, finetuned):
"""Compare reranker results"""
if not baseline or not finetuned:
return {}
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
return {
'accuracy_improvement': accuracy_improvement
}
def validate_pipeline(self, retriever_path, reranker_path):
"""Quick validation of full pipeline"""
print(f"\n🚀 Validating Full Pipeline")
print("=" * 50)
try:
# Load models
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
retriever = BGEM3Model.from_pretrained(retriever_path)
retriever.to(self.device)
retriever.eval()
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
reranker = BGERerankerModel.from_pretrained(reranker_path)
reranker.to(self.device)
reranker.eval()
print(" ✅ Pipeline models loaded successfully")
total_queries = 0
pipeline_accuracy = 0
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Create candidate pool
all_docs = relevant_docs + irrelevant_docs
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
# Step 1: Retrieval
query_inputs = retriever_tokenizer(
query, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
query_emb = retriever(**query_inputs, return_dense=True)['dense']
# Score all documents with retriever
doc_scores = []
for doc in all_docs:
doc_inputs = retriever_tokenizer(
doc, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
doc_scores.append(score)
# Get top-k for reranking (all docs in this case)
retrieval_indices = sorted(range(len(doc_scores)),
key=lambda i: doc_scores[i], reverse=True)
# Step 2: Reranking
rerank_scores = []
rerank_labels = []
for idx in retrieval_indices:
doc = all_docs[idx]
label = doc_labels[idx]
pair_text = f"{query} [SEP] {doc}"
pair_inputs = reranker_tokenizer(
pair_text, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
rerank_output = reranker(**pair_inputs)
rerank_score = rerank_output.logits.item()
rerank_scores.append(rerank_score)
rerank_labels.append(label)
# Check if top-ranked document is relevant
best_idx = np.argmax(rerank_scores)
if rerank_labels[best_idx] == 1:
pipeline_accuracy += 1
total_queries += 1
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
print(f"\n📊 Pipeline Validation Summary:")
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
print(f" Queries Tested: {total_queries}")
if pipeline_acc > 0.5:
print(" ✅ Pipeline performing well!")
else:
print(" ❌ Pipeline needs improvement!")
return {
'accuracy': pipeline_acc,
'total_queries': total_queries,
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
}
except Exception as e:
logger.error(f"Pipeline validation failed: {e}")
return None
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Quick validation of BGE fine-tuned models",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--retriever_model", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Path to baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Path to baseline reranker model")
parser.add_argument("--test_pipeline", action="store_true",
help="Test full retrieval + reranking pipeline")
parser.add_argument("--output_file", type=str, default=None,
help="Save results to JSON file")
return parser.parse_args()
def main():
"""Main validation function"""
args = parse_args()
if not args.retriever_model and not args.reranker_model:
print("❌ Error: At least one model must be specified")
print(" Use --retriever_model or --reranker_model")
return 1
print("🚀 BGE Quick Validation")
print("=" * 60)
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
validator = QuickValidator()
results = {}
# Validate retriever
if args.retriever_model:
if os.path.exists(args.retriever_model):
retriever_results = validator.validate_retriever(
args.retriever_model, args.retriever_baseline
)
if retriever_results:
results['retriever'] = retriever_results
else:
print(f"❌ Retriever model not found: {args.retriever_model}")
# Validate reranker
if args.reranker_model:
if os.path.exists(args.reranker_model):
reranker_results = validator.validate_reranker(
args.reranker_model, args.reranker_baseline
)
if reranker_results:
results['reranker'] = reranker_results
else:
print(f"❌ Reranker model not found: {args.reranker_model}")
# Validate pipeline
if (args.test_pipeline and args.retriever_model and args.reranker_model and
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
pipeline_results = validator.validate_pipeline(
args.retriever_model, args.reranker_model
)
if pipeline_results:
results['pipeline'] = pipeline_results
# Print final summary
print(f"\n{'='*60}")
print("🎯 FINAL VALIDATION SUMMARY")
print(f"{'='*60}")
overall_status = "✅ SUCCESS"
issues = []
for model_type, model_results in results.items():
if 'summary' in model_results:
status = model_results['summary'].get('status', 'unknown')
improvement = model_results['summary'].get('improvement_pct', 0)
if status == 'improved':
print(f"{model_type.title()}: Improved by {improvement:.2f}%")
elif status == 'degraded':
print(f"{model_type.title()}: Degraded by {abs(improvement):.2f}%")
issues.append(model_type)
overall_status = "⚠️ ISSUES DETECTED"
elif model_type == 'pipeline' and 'status' in model_results:
if model_results['status'] == 'good':
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
else:
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
issues.append('pipeline')
overall_status = "⚠️ ISSUES DETECTED"
print(f"\n{overall_status}")
if issues:
print(f"⚠️ Issues detected in: {', '.join(issues)}")
print("💡 Consider running comprehensive validation for detailed analysis")
# Save results if requested
if args.output_file:
results['timestamp'] = datetime.now().isoformat()
results['overall_status'] = overall_status
results['issues'] = issues
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"📊 Results saved to: {args.output_file}")
return 0 if overall_status == "✅ SUCCESS" else 1
if __name__ == "__main__":
exit(main())

View File

@ -0,0 +1,696 @@
#!/usr/bin/env python3
"""
BGE Validation Suite Runner
This script runs a complete validation suite for BGE fine-tuned models,
including multiple validation approaches and generating comprehensive reports.
Usage:
# Auto-discover models and run full suite
python scripts/run_validation_suite.py --auto-discover
# Run with specific models
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Run only specific validation types
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--validation_types quick comprehensive comparison
"""
import os
import sys
import json
import argparse
import subprocess
import logging
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional, Any
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from scripts.validation_utils import ModelDiscovery, TestDataManager, ValidationReportAnalyzer
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class ValidationSuiteRunner:
"""Orchestrates comprehensive validation of BGE models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.results_dir = Path(config.get('output_dir', './validation_suite_results'))
self.results_dir.mkdir(parents=True, exist_ok=True)
# Initialize components
self.model_discovery = ModelDiscovery(config.get('workspace_root', '.'))
self.data_manager = TestDataManager(config.get('data_dir', 'data/datasets'))
# Suite results
self.suite_results = {
'start_time': datetime.now().isoformat(),
'config': config,
'validation_results': {},
'summary': {},
'issues': [],
'recommendations': []
}
def run_suite(self) -> Dict[str, Any]:
"""Run the complete validation suite"""
logger.info("🚀 Starting BGE Validation Suite")
logger.info("=" * 60)
try:
# Phase 1: Discovery and Setup
self._phase_discovery()
# Phase 2: Quick Validation
if 'quick' in self.config.get('validation_types', ['quick']):
self._phase_quick_validation()
# Phase 3: Comprehensive Validation
if 'comprehensive' in self.config.get('validation_types', ['comprehensive']):
self._phase_comprehensive_validation()
# Phase 4: Comparison Benchmarks
if 'comparison' in self.config.get('validation_types', ['comparison']):
self._phase_comparison_benchmarks()
# Phase 5: Analysis and Reporting
self._phase_analysis()
# Phase 6: Generate Final Report
self._generate_final_report()
self.suite_results['end_time'] = datetime.now().isoformat()
self.suite_results['status'] = 'completed'
logger.info("✅ Validation Suite Completed Successfully!")
except Exception as e:
logger.error(f"❌ Validation Suite Failed: {e}")
self.suite_results['status'] = 'failed'
self.suite_results['error'] = str(e)
raise
return self.suite_results
def _phase_discovery(self):
"""Phase 1: Discover models and datasets"""
logger.info("📡 Phase 1: Discovery and Setup")
# Auto-discover models if requested
if self.config.get('auto_discover', False):
discovered_models = self.model_discovery.find_trained_models()
# Set models from discovery if not provided
if not self.config.get('retriever_model'):
if discovered_models['retrievers']:
model_name = list(discovered_models['retrievers'].keys())[0]
self.config['retriever_model'] = discovered_models['retrievers'][model_name]['path']
logger.info(f" 📊 Auto-discovered retriever: {model_name}")
if not self.config.get('reranker_model'):
if discovered_models['rerankers']:
model_name = list(discovered_models['rerankers'].keys())[0]
self.config['reranker_model'] = discovered_models['rerankers'][model_name]['path']
logger.info(f" 🏆 Auto-discovered reranker: {model_name}")
self.suite_results['discovered_models'] = discovered_models
# Discover test datasets
discovered_datasets = self.data_manager.find_test_datasets()
self.suite_results['available_datasets'] = discovered_datasets
# Validate models exist
issues = []
if self.config.get('retriever_model'):
if not os.path.exists(self.config['retriever_model']):
issues.append(f"Retriever model not found: {self.config['retriever_model']}")
if self.config.get('reranker_model'):
if not os.path.exists(self.config['reranker_model']):
issues.append(f"Reranker model not found: {self.config['reranker_model']}")
if issues:
for issue in issues:
logger.error(f"{issue}")
self.suite_results['issues'].extend(issues)
return
logger.info(" ✅ Discovery phase completed")
def _phase_quick_validation(self):
"""Phase 2: Quick validation tests"""
logger.info("⚡ Phase 2: Quick Validation")
try:
# Run quick validation script
cmd = ['python', 'scripts/quick_validation.py']
if self.config.get('retriever_model'):
cmd.extend(['--retriever_model', self.config['retriever_model']])
if self.config.get('reranker_model'):
cmd.extend(['--reranker_model', self.config['reranker_model']])
if self.config.get('retriever_model') and self.config.get('reranker_model'):
cmd.append('--test_pipeline')
# Save results to file
quick_results_file = self.results_dir / 'quick_validation_results.json'
cmd.extend(['--output_file', str(quick_results_file)])
logger.info(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(" ✅ Quick validation completed")
# Load results
if quick_results_file.exists():
with open(quick_results_file, 'r') as f:
quick_results = json.load(f)
self.suite_results['validation_results']['quick'] = quick_results
else:
logger.error(f" ❌ Quick validation failed: {result.stderr}")
self.suite_results['issues'].append(f"Quick validation failed: {result.stderr}")
except Exception as e:
logger.error(f" ❌ Error in quick validation: {e}")
self.suite_results['issues'].append(f"Quick validation error: {e}")
def _phase_comprehensive_validation(self):
"""Phase 3: Comprehensive validation"""
logger.info("🔬 Phase 3: Comprehensive Validation")
try:
# Run comprehensive validation script
cmd = ['python', 'scripts/comprehensive_validation.py']
if self.config.get('retriever_model'):
cmd.extend(['--retriever_finetuned', self.config['retriever_model']])
if self.config.get('reranker_model'):
cmd.extend(['--reranker_finetuned', self.config['reranker_model']])
# Use specific datasets if provided
if self.config.get('test_datasets'):
cmd.extend(['--test_datasets'] + self.config['test_datasets'])
# Set output directory
comprehensive_dir = self.results_dir / 'comprehensive'
cmd.extend(['--output_dir', str(comprehensive_dir)])
logger.info(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(" ✅ Comprehensive validation completed")
# Load results
results_file = comprehensive_dir / 'validation_results.json'
if results_file.exists():
with open(results_file, 'r') as f:
comp_results = json.load(f)
self.suite_results['validation_results']['comprehensive'] = comp_results
else:
logger.error(f" ❌ Comprehensive validation failed: {result.stderr}")
self.suite_results['issues'].append(f"Comprehensive validation failed: {result.stderr}")
except Exception as e:
logger.error(f" ❌ Error in comprehensive validation: {e}")
self.suite_results['issues'].append(f"Comprehensive validation error: {e}")
def _phase_comparison_benchmarks(self):
"""Phase 4: Run comparison benchmarks"""
logger.info("📊 Phase 4: Comparison Benchmarks")
# Find suitable datasets for comparison
datasets = self.data_manager.find_test_datasets()
try:
# Run retriever comparison if model exists
if self.config.get('retriever_model'):
self._run_retriever_comparison(datasets)
# Run reranker comparison if model exists
if self.config.get('reranker_model'):
self._run_reranker_comparison(datasets)
except Exception as e:
logger.error(f" ❌ Error in comparison benchmarks: {e}")
self.suite_results['issues'].append(f"Comparison benchmarks error: {e}")
def _run_retriever_comparison(self, datasets: Dict[str, List[str]]):
"""Run retriever comparison benchmarks"""
logger.info(" 🔍 Running retriever comparison...")
# Find suitable datasets for retriever testing
test_datasets = datasets.get('retriever_datasets', []) + datasets.get('general_datasets', [])
if not test_datasets:
logger.warning(" ⚠️ No suitable datasets found for retriever comparison")
return
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
if not os.path.exists(dataset_path):
continue
try:
cmd = [
'python', 'scripts/compare_retriever.py',
'--finetuned_model_path', self.config['retriever_model'],
'--baseline_model_path', self.config.get('retriever_baseline', 'BAAI/bge-m3'),
'--data_path', dataset_path,
'--batch_size', str(self.config.get('batch_size', 16)),
'--max_samples', str(self.config.get('max_samples', 500)),
'--output', str(self.results_dir / f'retriever_comparison_{Path(dataset_path).stem}.txt')
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(f" ✅ Retriever comparison on {Path(dataset_path).name}")
else:
logger.warning(f" ⚠️ Retriever comparison failed on {Path(dataset_path).name}")
except Exception as e:
logger.warning(f" ⚠️ Error comparing retriever on {dataset_path}: {e}")
def _run_reranker_comparison(self, datasets: Dict[str, List[str]]):
"""Run reranker comparison benchmarks"""
logger.info(" 🏆 Running reranker comparison...")
# Find suitable datasets for reranker testing
test_datasets = datasets.get('reranker_datasets', []) + datasets.get('general_datasets', [])
if not test_datasets:
logger.warning(" ⚠️ No suitable datasets found for reranker comparison")
return
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
if not os.path.exists(dataset_path):
continue
try:
cmd = [
'python', 'scripts/compare_reranker.py',
'--finetuned_model_path', self.config['reranker_model'],
'--baseline_model_path', self.config.get('reranker_baseline', 'BAAI/bge-reranker-base'),
'--data_path', dataset_path,
'--batch_size', str(self.config.get('batch_size', 16)),
'--max_samples', str(self.config.get('max_samples', 500)),
'--output', str(self.results_dir / f'reranker_comparison_{Path(dataset_path).stem}.txt')
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(f" ✅ Reranker comparison on {Path(dataset_path).name}")
else:
logger.warning(f" ⚠️ Reranker comparison failed on {Path(dataset_path).name}")
except Exception as e:
logger.warning(f" ⚠️ Error comparing reranker on {dataset_path}: {e}")
def _phase_analysis(self):
"""Phase 5: Analyze all results"""
logger.info("📈 Phase 5: Results Analysis")
try:
# Analyze comprehensive results if available
comprehensive_dir = self.results_dir / 'comprehensive'
if comprehensive_dir.exists():
analyzer = ValidationReportAnalyzer(str(comprehensive_dir))
analysis = analyzer.analyze_results()
self.suite_results['analysis'] = analysis
logger.info(" ✅ Results analysis completed")
else:
logger.warning(" ⚠️ No comprehensive results found for analysis")
except Exception as e:
logger.error(f" ❌ Error in results analysis: {e}")
self.suite_results['issues'].append(f"Analysis error: {e}")
def _generate_final_report(self):
"""Generate final comprehensive report"""
logger.info("📝 Phase 6: Generating Final Report")
try:
# Create summary
self._create_suite_summary()
# Save suite results
suite_results_file = self.results_dir / 'validation_suite_results.json'
with open(suite_results_file, 'w', encoding='utf-8') as f:
json.dump(self.suite_results, f, indent=2, ensure_ascii=False)
# Generate HTML report
self._generate_html_report()
logger.info(f" 📊 Reports saved to: {self.results_dir}")
except Exception as e:
logger.error(f" ❌ Error generating final report: {e}")
self.suite_results['issues'].append(f"Report generation error: {e}")
def _create_suite_summary(self):
"""Create suite summary"""
summary = {
'total_validations': 0,
'successful_validations': 0,
'failed_validations': 0,
'overall_verdict': 'unknown',
'key_findings': [],
'critical_issues': []
}
# Count validations
for validation_type, results in self.suite_results.get('validation_results', {}).items():
summary['total_validations'] += 1
if results:
summary['successful_validations'] += 1
else:
summary['failed_validations'] += 1
# Determine overall verdict
if summary['successful_validations'] > 0 and summary['failed_validations'] == 0:
if self.suite_results.get('analysis', {}).get('overall_status') == 'excellent':
summary['overall_verdict'] = 'excellent'
elif self.suite_results.get('analysis', {}).get('overall_status') in ['good', 'mixed']:
summary['overall_verdict'] = 'good'
else:
summary['overall_verdict'] = 'fair'
elif summary['successful_validations'] > summary['failed_validations']:
summary['overall_verdict'] = 'partial'
else:
summary['overall_verdict'] = 'poor'
# Extract key findings from analysis
if 'analysis' in self.suite_results:
analysis = self.suite_results['analysis']
if 'recommendations' in analysis:
summary['key_findings'] = analysis['recommendations'][:5] # Top 5 recommendations
# Extract critical issues
summary['critical_issues'] = self.suite_results.get('issues', [])
self.suite_results['summary'] = summary
def _generate_html_report(self):
"""Generate HTML summary report"""
html_file = self.results_dir / 'validation_suite_report.html'
summary = self.suite_results.get('summary', {})
verdict_colors = {
'excellent': '#28a745',
'good': '#17a2b8',
'fair': '#ffc107',
'partial': '#fd7e14',
'poor': '#dc3545',
'unknown': '#6c757d'
}
verdict_color = verdict_colors.get(summary.get('overall_verdict', 'unknown'), '#6c757d')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>BGE Validation Suite Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; }}
.header {{ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white; padding: 30px; border-radius: 10px; text-align: center; }}
.verdict {{ font-size: 24px; font-weight: bold; color: {verdict_color};
margin: 20px 0; text-align: center; padding: 15px;
border: 3px solid {verdict_color}; border-radius: 10px; }}
.section {{ margin: 30px 0; }}
.metrics {{ display: flex; justify-content: space-around; margin: 20px 0; }}
.metric {{ text-align: center; padding: 15px; background: #f8f9fa;
border-radius: 8px; min-width: 120px; }}
.metric-value {{ font-size: 2em; font-weight: bold; color: #495057; }}
.metric-label {{ color: #6c757d; }}
.findings {{ background: #e3f2fd; padding: 20px; border-radius: 8px;
border-left: 4px solid #2196f3; }}
.issues {{ background: #ffebee; padding: 20px; border-radius: 8px;
border-left: 4px solid #f44336; }}
.timeline {{ margin: 20px 0; }}
.timeline-item {{ margin: 10px 0; padding: 10px; background: #f8f9fa;
border-left: 3px solid #007bff; border-radius: 0 5px 5px 0; }}
ul {{ padding-left: 20px; }}
li {{ margin: 8px 0; }}
</style>
</head>
<body>
<div class="header">
<h1>🚀 BGE Validation Suite Report</h1>
<p>Comprehensive Model Performance Analysis</p>
<p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
<div class="verdict">
Overall Verdict: {summary.get('overall_verdict', 'UNKNOWN').upper()}
</div>
<div class="metrics">
<div class="metric">
<div class="metric-value">{summary.get('total_validations', 0)}</div>
<div class="metric-label">Total Validations</div>
</div>
<div class="metric">
<div class="metric-value">{summary.get('successful_validations', 0)}</div>
<div class="metric-label">Successful</div>
</div>
<div class="metric">
<div class="metric-value">{summary.get('failed_validations', 0)}</div>
<div class="metric-label">Failed</div>
</div>
</div>
"""
# Add key findings
if summary.get('key_findings'):
html_content += """
<div class="section">
<h2>🔍 Key Findings</h2>
<div class="findings">
<ul>
"""
for finding in summary['key_findings']:
html_content += f" <li>{finding}</li>\n"
html_content += """
</ul>
</div>
</div>
"""
# Add critical issues if any
if summary.get('critical_issues'):
html_content += """
<div class="section">
<h2> Critical Issues</h2>
<div class="issues">
<ul>
"""
for issue in summary['critical_issues']:
html_content += f" <li>{issue}</li>\n"
html_content += """
</ul>
</div>
</div>
"""
# Add validation timeline
html_content += f"""
<div class="section">
<h2> Validation Timeline</h2>
<div class="timeline">
<div class="timeline-item">
<strong>Suite Started:</strong> {self.suite_results.get('start_time', 'Unknown')}
</div>
"""
for validation_type in self.suite_results.get('validation_results', {}).keys():
html_content += f"""
<div class="timeline-item">
<strong>{validation_type.title()} Validation:</strong> Completed
</div>
"""
html_content += f"""
<div class="timeline-item">
<strong>Suite Completed:</strong> {self.suite_results.get('end_time', 'In Progress')}
</div>
</div>
</div>
<div class="section">
<h2>📁 Detailed Results</h2>
<p>For detailed validation results, check the following files in <code>{self.results_dir}</code>:</p>
<ul>
<li><code>validation_suite_results.json</code> - Complete suite results</li>
<li><code>quick_validation_results.json</code> - Quick validation results</li>
<li><code>comprehensive/</code> - Comprehensive validation reports</li>
<li><code>*_comparison_*.txt</code> - Model comparison results</li>
</ul>
</div>
</body>
</html>
"""
with open(html_file, 'w', encoding='utf-8') as f:
f.write(html_content)
logger.info(f" 📊 HTML report: {html_file}")
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Run comprehensive BGE validation suite",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Auto-discover models and run full suite
python scripts/run_validation_suite.py --auto-discover
# Run with specific models
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Run only specific validation types
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--validation_types quick comprehensive
"""
)
# Model paths
parser.add_argument("--retriever_model", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--auto_discover", action="store_true",
help="Auto-discover trained models in workspace")
# Baseline models
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Baseline reranker model")
# Validation configuration
parser.add_argument("--validation_types", type=str, nargs="+",
default=["quick", "comprehensive", "comparison"],
choices=["quick", "comprehensive", "comparison"],
help="Types of validation to run")
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
help="Specific test datasets to use")
# Performance settings
parser.add_argument("--batch_size", type=int, default=16,
help="Batch size for evaluation")
parser.add_argument("--max_samples", type=int, default=1000,
help="Maximum samples per dataset for speed")
# Directories
parser.add_argument("--workspace_root", type=str, default=".",
help="Workspace root directory")
parser.add_argument("--data_dir", type=str, default="data/datasets",
help="Test datasets directory")
parser.add_argument("--output_dir", type=str, default="./validation_suite_results",
help="Output directory for all results")
return parser.parse_args()
def main():
"""Main function"""
args = parse_args()
# Validate input
if not args.auto_discover and not args.retriever_model and not args.reranker_model:
print("❌ Error: Must specify models or use --auto_discover")
print(" Use --retriever_model, --reranker_model, or --auto_discover")
return 1
# Create configuration
config = {
'retriever_model': args.retriever_model,
'reranker_model': args.reranker_model,
'auto_discover': args.auto_discover,
'retriever_baseline': args.retriever_baseline,
'reranker_baseline': args.reranker_baseline,
'validation_types': args.validation_types,
'test_datasets': args.test_datasets,
'batch_size': args.batch_size,
'max_samples': args.max_samples,
'workspace_root': args.workspace_root,
'data_dir': args.data_dir,
'output_dir': args.output_dir
}
# Run validation suite
try:
runner = ValidationSuiteRunner(config)
results = runner.run_suite()
# Print final summary
summary = results.get('summary', {})
verdict = summary.get('overall_verdict', 'unknown')
print("\n" + "=" * 80)
print("🎯 VALIDATION SUITE SUMMARY")
print("=" * 80)
verdict_emojis = {
'excellent': '🌟',
'good': '',
'fair': '👌',
'partial': '⚠️',
'poor': '',
'unknown': ''
}
print(f"{verdict_emojis.get(verdict, '')} Overall Verdict: {verdict.upper()}")
print(f"📊 Validations: {summary.get('successful_validations', 0)}/{summary.get('total_validations', 0)} successful")
if summary.get('key_findings'):
print(f"\n🔍 Key Findings:")
for i, finding in enumerate(summary['key_findings'][:3], 1):
print(f" {i}. {finding}")
if summary.get('critical_issues'):
print(f"\n⚠️ Critical Issues:")
for issue in summary['critical_issues'][:3]:
print(f"{issue}")
print(f"\n📁 Detailed results: {config['output_dir']}")
print("🌐 Open validation_suite_report.html in your browser for full report")
return 0 if verdict in ['excellent', 'good'] else 1
except KeyboardInterrupt:
print("\n❌ Validation suite interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Validation suite failed: {e}")
return 1
if __name__ == "__main__":
exit(main())

401
scripts/setup.py Normal file
View File

@ -0,0 +1,401 @@
"""
Setup script for BGE fine-tuning environment
"""
import os
import sys
import subprocess
import argparse
import shutil
import urllib.request
import json
from pathlib import Path
import logging
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def run_command(command, check=True):
"""Run a shell command"""
logger.info(f"Running: {command}")
result = subprocess.run(command, shell=True, capture_output=True, text=True)
if check and result.returncode != 0:
logger.error(f"Command failed: {command}")
logger.error(f"Error: {result.stderr}")
sys.exit(1)
return result
def check_python_version():
"""Check Python version"""
if sys.version_info < (3, 8):
logger.error("Python 3.8 or higher is required")
sys.exit(1)
logger.info(f"Python version: {sys.version}")
def check_gpu():
"""Check GPU availability"""
try:
result = run_command("nvidia-smi", check=False)
if result.returncode == 0:
logger.info("NVIDIA GPU detected")
return "cuda"
else:
logger.info("No NVIDIA GPU detected")
return "cpu"
except:
logger.info("nvidia-smi not available")
return "cpu"
def install_pytorch(device_type="cuda"):
"""Install PyTorch"""
logger.info("Installing PyTorch...")
if device_type == "cuda":
# Install CUDA version for RTX 5090
pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
else:
# Install CPU version
pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
run_command(pip_command)
# Verify installation
try:
import torch
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
except ImportError:
logger.error("Failed to import PyTorch")
sys.exit(1)
def install_requirements():
"""Install requirements"""
logger.info("Installing requirements...")
requirements_file = Path(__file__).parent.parent / "requirements.txt"
if requirements_file.exists():
run_command(f"pip install -r {requirements_file}")
else:
logger.warning("requirements.txt not found, installing core packages...")
# Core packages
packages = [
"transformers>=4.36.0",
"datasets>=2.14.0",
"tokenizers>=0.15.0",
"numpy>=1.24.0",
"scipy>=1.10.0",
"scikit-learn>=1.3.0",
"pandas>=2.0.0",
"FlagEmbedding>=1.2.0",
"faiss-gpu>=1.7.4",
"rank-bm25>=0.2.2",
"accelerate>=0.25.0",
"tensorboard>=2.14.0",
"tqdm>=4.66.0",
"pyyaml>=6.0",
"jsonlines>=3.1.0",
"huggingface-hub>=0.19.0",
"jieba>=0.42.1",
"toml>=0.10.2",
"requests>=2.28.0"
]
for package in packages:
run_command(f"pip install {package}")
def setup_directories():
"""Setup project directories"""
logger.info("Setting up directories...")
base_dir = Path(__file__).parent.parent
directories = [
"models",
"models/bge-m3",
"models/bge-reranker-base",
"data/raw",
"data/processed",
"output",
"logs",
"cache/data"
]
for directory in directories:
dir_path = base_dir / directory
dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created directory: {dir_path}")
def download_models(download_models_flag=False):
"""Download base models from HuggingFace or ModelScope based on config.toml"""
if not download_models_flag:
logger.info("Skipping model download. Use --download-models to download base models.")
return
logger.info("Downloading base models...")
base_dir = Path(__file__).parent.parent
config = get_config()
model_source = config.get('model_paths.source', 'huggingface')
if model_source == 'huggingface':
try:
from huggingface_hub import snapshot_download
# Download BGE-M3
logger.info("Downloading BGE-M3 from HuggingFace...")
snapshot_download(
repo_id="BAAI/bge-m3",
local_dir=base_dir / "models/bge-m3",
local_dir_use_symlinks=False
)
# Download BGE-Reranker
logger.info("Downloading BGE-Reranker from HuggingFace...")
snapshot_download(
repo_id="BAAI/bge-reranker-base",
local_dir=base_dir / "models/bge-reranker-base",
local_dir_use_symlinks=False
)
logger.info("Models downloaded successfully from HuggingFace!")
except ImportError:
logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub")
except Exception as e:
logger.error(f"Failed to download models from HuggingFace: {e}")
logger.info("You can download models manually or use the Hugging Face CLI:")
logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3")
logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base")
elif model_source == 'modelscope':
try:
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
# Download BGE-M3
logger.info("Downloading BGE-M3 from ModelScope...")
ms_snapshot_download(
model_id="damo/nlp_corom_sentence-embedding_chinese-base",
cache_dir=str(base_dir / "models/bge-m3")
)
# Download BGE-Reranker
logger.info("Downloading BGE-Reranker from ModelScope...")
ms_snapshot_download(
model_id="damo/nlp_corom_cross-encoder_chinese-base",
cache_dir=str(base_dir / "models/bge-reranker-base")
)
logger.info("Models downloaded successfully from ModelScope!")
except ImportError:
logger.error("modelscope not available. Install it first: pip install modelscope")
except Exception as e:
logger.error(f"Failed to download models from ModelScope: {e}")
logger.info("You can download models manually from ModelScope website or CLI.")
else:
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
def create_sample_config():
"""Create sample configuration files"""
logger.info("Creating sample configuration...")
base_dir = Path(__file__).parent.parent
config_file = base_dir / "config.toml"
if not config_file.exists():
# Create default config (this would match our config.toml artifact)
sample_config = """# Configuration file for BGE fine-tuning project
[api]
# OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.)
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here" # Replace with your actual API key
model = "deepseek-chat"
timeout = 30
[translation]
# Translation settings
enable_api = true
target_languages = ["en", "zh", "ja", "fr"]
max_retries = 3
fallback_to_simulation = true
[training]
# Training defaults
default_batch_size = 4
default_learning_rate = 1e-5
default_epochs = 3
default_warmup_ratio = 0.1
[model_paths]
# Default model paths
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
[data]
# Data processing settings
max_query_length = 512
max_passage_length = 8192
min_passage_length = 10
validation_split = 0.1
"""
with open(config_file, 'w') as f:
f.write(sample_config)
logger.info(f"Created sample config: {config_file}")
logger.info("Please edit config.toml to add your API keys if needed.")
def create_sample_data():
"""Create sample data files"""
logger.info("Creating sample data...")
base_dir = Path(__file__).parent.parent
sample_file = base_dir / "data/raw/sample_data.jsonl"
if not sample_file.exists():
sample_data = [
{
"query": "什么是机器学习?",
"pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"],
"neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"]
},
{
"query": "How does neural network work?",
"pos": [
"Neural networks are computing systems inspired by biological neural networks that constitute the brain."],
"neg": ["Machine learning algorithms can be supervised or unsupervised.",
"Data preprocessing is an important step in ML."]
},
{
"query": "Python编程语言的特点",
"pos": ["Python是一种高级编程语言具有简洁的语法和强大的功能广泛用于数据科学和AI开发。"],
"neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"]
}
]
with open(sample_file, 'w', encoding='utf-8') as f:
for item in sample_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logger.info(f"Created sample data: {sample_file}")
def verify_installation():
"""Verify installation"""
logger.info("Verifying installation...")
try:
# Test imports
import torch
import transformers
import numpy as np
import pandas as pd
logger.info("✓ Core packages imported successfully")
# Test GPU
if torch.cuda.is_available():
logger.info("✓ CUDA is available")
device = torch.device("cuda")
x = torch.randn(10, 10).to(device)
logger.info("✓ GPU tensor operations work")
else:
logger.info("✓ CPU-only setup")
# Test model loading (if models exist)
base_dir = Path(__file__).parent.parent
if (base_dir / "models/bge-m3").exists():
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3"))
logger.info("✓ BGE-M3 model files are accessible")
except Exception as e:
logger.warning(f"BGE-M3 model check failed: {e}")
logger.info("Installation verification completed!")
except ImportError as e:
logger.error(f"Import error: {e}")
logger.error("Installation may be incomplete")
sys.exit(1)
def print_usage_info():
"""Print usage information"""
print("\n" + "=" * 60)
print("BGE FINE-TUNING SETUP COMPLETED!")
print("=" * 60)
print("\nQuick Start:")
print("1. Edit config.toml to add your API keys (optional)")
print("2. Prepare your training data in JSONL format")
print("3. Run training:")
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
print("4. Evaluate models:")
print(" python scripts/evaluate.py --retriever_model output/model --eval_data data/raw/test.jsonl")
print("\nSample data created in: data/raw/sample_data.jsonl")
print("Configuration file: config.toml")
print("\nFor more information, see the README.md file.")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment")
parser.add_argument("--download-models", action="store_true",
help="Download base models from Hugging Face")
parser.add_argument("--cpu-only", action="store_true",
help="Install CPU-only version of PyTorch")
parser.add_argument("--skip-pytorch", action="store_true",
help="Skip PyTorch installation")
parser.add_argument("--skip-requirements", action="store_true",
help="Skip requirements installation")
args = parser.parse_args()
logger.info("Starting BGE fine-tuning environment setup...")
# Check Python version
check_python_version()
# Detect device type
if args.cpu_only:
device_type = "cpu"
else:
device_type = check_gpu()
# Install PyTorch
if not args.skip_pytorch:
install_pytorch(device_type)
# Install requirements
if not args.skip_requirements:
install_requirements()
# Setup directories
setup_directories()
# Download models
download_models(args.download_models)
# Create sample files
create_sample_config()
create_sample_data()
# Verify installation
verify_installation()
# Print usage info
print_usage_info()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""
Test script to verify the converted dataset can be loaded by BGE training system
"""
from data.dataset import BGEDataset
from transformers import AutoTokenizer
import logging
import sys
import os
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_dataset_loading(dataset_path: str):
"""Test that the dataset can be loaded properly"""
# Check if dataset file exists
if not os.path.exists(dataset_path):
logger.error(f"Dataset file not found: {dataset_path}")
return False
try:
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('models/bge-m3')
# Test loading the converted dataset
logger.info("Loading dataset...")
dataset = BGEDataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=64,
passage_max_length=512,
format_type='triplets'
)
logger.info(f"✅ Dataset loaded successfully!")
logger.info(f"📊 Total samples: {len(dataset)}")
logger.info(f"🔍 Detected format: {dataset.detected_format}")
# Test getting a sample
logger.info("Testing sample retrieval...")
sample = dataset[0]
logger.info(f"🎯 Sample keys: {list(sample.keys())}")
logger.info(f"🔤 Query shape: {sample['query_input_ids'].shape}")
logger.info(f"📄 Passages shape: {sample['passage_input_ids'].shape}")
logger.info(f"🏷️ Labels shape: {sample['labels'].shape}")
logger.info(f"📊 Scores shape: {sample['scores'].shape}")
# Show sample content
logger.info(f"🔍 Query text: {sample['query_text'][:100]}...")
logger.info(f"📝 First passage: {sample['passage_texts'][0][:100]}...")
logger.info(f"🏷️ Labels: {sample['labels']}")
logger.info(f"📊 Scores: {sample['scores']}")
# Test a few more samples
logger.info("Testing multiple samples...")
for i in [1, 10, 100]:
if i < len(dataset):
sample_i = dataset[i]
logger.info(f"Sample {i}: Query length={len(sample_i['query_text'])}, Passages={len([p for p in sample_i['passage_texts'] if p])}")
logger.info("✅ All tests passed! Dataset is ready for BGE-M3 training.")
return True
except Exception as e:
logger.error(f"❌ Dataset loading failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
dataset_path = "data/datasets/三国演义/m3_train_converted.jsonl"
logger.info("🧪 Testing converted BGE-M3 dataset...")
logger.info(f"📁 Dataset path: {dataset_path}")
success = test_dataset_loading(dataset_path)
if success:
logger.info("🎉 Dataset is ready for training!")
logger.info("💡 You can now use this dataset with:")
logger.info(" python scripts/train_m3.py --data_path data/datasets/三国演义/m3_train_converted.jsonl")
else:
logger.error("💥 Dataset testing failed!")
sys.exit(1)
if __name__ == '__main__':
main()

497
scripts/train_joint.py Normal file
View File

@ -0,0 +1,497 @@
"""
Joint training of BGE-M3 and BGE-Reranker using RocketQAv2 approach
"""
import os
import sys
import argparse
# Set tokenizer parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
import json
from datetime import datetime
import torch
from transformers import AutoTokenizer, set_seed
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset, JointDataset
from data.preprocessing import DataPreprocessor
from training.joint_trainer import JointTrainer
from utils.logging import setup_logging, get_logger
from utils.distributed import setup_distributed, is_main_process, set_random_seed
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Joint training of BGE-M3 and BGE-Reranker")
# Model arguments
parser.add_argument("--retriever_model", type=str, default="BAAI/bge-m3",
help="Path to retriever model")
parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-base",
help="Path to reranker model")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
# Data arguments
parser.add_argument("--train_data", type=str, required=True,
help="Path to training data file")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data file")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/joint-training",
help="Directory to save the models")
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
help="Number of updates steps to accumulate")
parser.add_argument("--learning_rate", type=float, default=1e-5,
help="Initial learning rate for retriever")
parser.add_argument("--reranker_learning_rate", type=float, default=6e-5,
help="Initial learning rate for reranker")
parser.add_argument("--warmup_ratio", type=float, default=0.1,
help="Ratio of warmup steps")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
# Joint training specific
parser.add_argument("--joint_loss_weight", type=float, default=0.3,
help="Weight for joint distillation loss")
parser.add_argument("--update_retriever_steps", type=int, default=100,
help="Update retriever every N steps")
parser.add_argument("--use_dynamic_distillation", action="store_true",
help="Use dynamic listwise distillation")
# Model configuration
parser.add_argument("--retriever_train_group_size", type=int, default=8,
help="Number of passages per query for retriever")
parser.add_argument("--reranker_train_group_size", type=int, default=16,
help="Number of passages per query for reranker")
parser.add_argument("--temperature", type=float, default=0.02,
help="Temperature for InfoNCE loss")
# BGE-M3 specific
parser.add_argument("--use_dense", action="store_true", default=True,
help="Use dense embeddings")
parser.add_argument("--use_sparse", action="store_true",
help="Use sparse embeddings")
parser.add_argument("--use_colbert", action="store_true",
help="Use ColBERT embeddings")
parser.add_argument("--query_max_length", type=int, default=512,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=8192,
help="Maximum passage length")
# Reranker specific
parser.add_argument("--reranker_max_length", type=int, default=512,
help="Maximum sequence length for reranker")
# Performance
parser.add_argument("--fp16", action="store_true",
help="Use mixed precision training")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="Use gradient checkpointing")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
# Checkpointing
parser.add_argument("--save_steps", type=int, default=1000,
help="Save checkpoint every N steps")
parser.add_argument("--eval_steps", type=int, default=1000,
help="Evaluate every N steps")
parser.add_argument("--logging_steps", type=int, default=100,
help="Log every N steps")
parser.add_argument("--save_total_limit", type=int, default=3,
help="Maximum number of checkpoints to keep")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Path to checkpoint to resume from")
# Distributed training
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank for distributed training")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--overwrite_output_dir", action="store_true",
help="Overwrite the output directory")
args = parser.parse_args()
# Validate arguments
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
return args
def main():
# Parse arguments
args = parse_args()
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("Starting joint training of BGE-M3 and BGE-Reranker")
logger.info(f"Arguments: {args}")
# Set random seed
set_random_seed(args.seed, deterministic=False)
# Setup distributed training if needed
device = None
if args.local_rank != -1:
setup_distributed()
from utils.config_loader import get_config
config = get_config()
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
if device_type == 'npu':
try:
import torch_npu
torch_npu.npu.set_device(args.local_rank)
device = torch.device("npu", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
else:
device = torch.device("cpu")
logger.info("Distributed training enabled: using CPU")
else:
from utils.config_loader import get_config
config = get_config()
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
if device_type == 'npu':
try:
import torch_npu
device = torch.device("npu")
logger.info("Using Ascend NPU")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
else:
device = torch.device("cpu")
logger.info("Using CPU")
# Create configurations
retriever_config = BGEM3Config(
model_name_or_path=args.retriever_model,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
output_dir=os.path.join(args.output_dir, "retriever"),
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=args.retriever_train_group_size,
temperature=args.temperature,
use_dense=args.use_dense,
use_sparse=args.use_sparse,
use_colbert=args.use_colbert,
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
device=str(device)
)
reranker_config = BGERerankerConfig(
model_name_or_path=args.reranker_model,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
max_seq_length=args.reranker_max_length,
output_dir=os.path.join(args.output_dir, "reranker"),
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.reranker_learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=args.reranker_train_group_size,
joint_training=True, # type: ignore[call-arg]
joint_loss_weight=args.joint_loss_weight, # type: ignore[call-arg]
use_dynamic_listwise_distillation=args.use_dynamic_distillation, # type: ignore[call-arg]
update_retriever_steps=args.update_retriever_steps, # type: ignore[call-arg]
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
device=str(device)
)
# Create joint config (use retriever config as base)
joint_config = retriever_config
joint_config.output_dir = args.output_dir
joint_config.joint_loss_weight = args.joint_loss_weight # type: ignore[attr-defined]
joint_config.use_dynamic_listwise_distillation = args.use_dynamic_distillation # type: ignore[attr-defined]
joint_config.update_retriever_steps = args.update_retriever_steps # type: ignore[attr-defined]
# Save configurations
if is_main_process():
os.makedirs(args.output_dir, exist_ok=True)
joint_config.save(os.path.join(args.output_dir, "joint_config.json"))
retriever_config.save(os.path.join(args.output_dir, "retriever_config.json"))
reranker_config.save(os.path.join(args.output_dir, "reranker_config.json"))
# Load tokenizers
logger.info("Loading tokenizers...")
retriever_tokenizer = AutoTokenizer.from_pretrained(
args.retriever_model,
cache_dir=args.cache_dir
)
reranker_tokenizer = AutoTokenizer.from_pretrained(
args.reranker_model,
cache_dir=args.cache_dir
)
# Prepare data
logger.info("Preparing data...")
preprocessor = DataPreprocessor(
tokenizer=retriever_tokenizer, # Use retriever tokenizer for preprocessing
max_query_length=args.query_max_length,
max_passage_length=args.passage_max_length
)
# Preprocess training data
train_path, val_path = preprocessor.preprocess_file(
input_path=args.train_data,
output_path=os.path.join(args.output_dir, "processed_train.jsonl"),
file_format=args.data_format,
validation_split=0.1 if args.eval_data is None else 0.0
)
# Use provided eval data or split validation
eval_path = args.eval_data if args.eval_data else val_path
# Create datasets
logger.info("Creating datasets...")
# Retriever dataset
retriever_train_dataset = BGEM3Dataset(
data_path=train_path,
tokenizer=retriever_tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.retriever_train_group_size,
is_train=True
)
# Reranker dataset
reranker_train_dataset = BGERerankerDataset(
data_path=train_path,
tokenizer=reranker_tokenizer,
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
train_group_size=args.reranker_train_group_size,
is_train=True
)
# Joint dataset
joint_train_dataset = JointDataset( # type: ignore[call-arg]
retriever_dataset=retriever_train_dataset, # type: ignore[call-arg]
reranker_dataset=reranker_train_dataset # type: ignore[call-arg]
)
# Evaluation datasets
joint_eval_dataset = None
if eval_path:
retriever_eval_dataset = BGEM3Dataset(
data_path=eval_path,
tokenizer=retriever_tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.retriever_train_group_size,
is_train=False
)
reranker_eval_dataset = BGERerankerDataset(
data_path=eval_path,
tokenizer=reranker_tokenizer,
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
train_group_size=args.reranker_train_group_size,
is_train=False
)
joint_eval_dataset = JointDataset( # type: ignore[call-arg]
retriever_dataset=retriever_eval_dataset, # type: ignore[call-arg]
reranker_dataset=reranker_eval_dataset # type: ignore[call-arg]
)
# Initialize models
logger.info("Loading models...")
retriever = BGEM3Model(
model_name_or_path=args.retriever_model,
use_dense=args.use_dense,
use_sparse=args.use_sparse,
use_colbert=args.use_colbert,
cache_dir=args.cache_dir
)
reranker = BGERerankerModel(
model_name_or_path=args.reranker_model,
cache_dir=args.cache_dir,
use_fp16=args.fp16
)
from utils.config_loader import get_config
logger.info("Retriever model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
logger.info("Reranker model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
# Initialize optimizers
from torch.optim import AdamW
retriever_optimizer = AdamW(
retriever.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
reranker_optimizer = AdamW(
reranker.parameters(),
lr=args.reranker_learning_rate,
weight_decay=args.weight_decay
)
# Initialize joint trainer
logger.info("Initializing joint trainer...")
trainer = JointTrainer(
config=joint_config,
retriever=retriever,
reranker=reranker,
retriever_optimizer=retriever_optimizer,
reranker_optimizer=reranker_optimizer,
device=device
)
# Create data loaders
from torch.utils.data import DataLoader
from utils.distributed import create_distributed_dataloader
if args.local_rank != -1:
train_dataloader = create_distributed_dataloader(
joint_train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
seed=args.seed
)
eval_dataloader = None
if joint_eval_dataset:
eval_dataloader = create_distributed_dataloader(
joint_eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
seed=args.seed
)
else:
train_dataloader = DataLoader(
joint_train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers
)
eval_dataloader = None
if joint_eval_dataset:
eval_dataloader = DataLoader(
joint_eval_dataset,
batch_size=args.per_device_eval_batch_size,
num_workers=args.dataloader_num_workers
)
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
trainer.load_checkpoint(args.resume_from_checkpoint)
# Train
logger.info("Starting joint training...")
try:
trainer.train(train_dataloader, eval_dataloader, args.num_train_epochs)
# Save final models
if is_main_process():
logger.info("Saving final models...")
final_dir = os.path.join(args.output_dir, "final_models")
trainer.save_models(final_dir)
# Save individual models
retriever_path = os.path.join(args.output_dir, "final_retriever")
reranker_path = os.path.join(args.output_dir, "final_reranker")
os.makedirs(retriever_path, exist_ok=True)
os.makedirs(reranker_path, exist_ok=True)
retriever.save_pretrained(retriever_path)
retriever_tokenizer.save_pretrained(retriever_path)
reranker.model.save_pretrained(reranker_path)
reranker_tokenizer.save_pretrained(reranker_path)
# Save training summary
summary = {
"training_args": vars(args),
"retriever_config": retriever_config.to_dict(),
"reranker_config": reranker_config.to_dict(),
"joint_config": joint_config.to_dict(),
"final_metrics": getattr(trainer, 'best_metric', None),
"total_steps": getattr(trainer, 'global_step', 0),
"training_completed": datetime.now().isoformat()
}
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
json.dump(summary, f, indent=2)
logger.info(f"Joint training completed! Models saved to {args.output_dir}")
except Exception as e:
logger.error(f"Joint training failed with error: {e}", exc_info=True)
raise
if __name__ == "__main__":
main()

614
scripts/train_m3.py Normal file
View File

@ -0,0 +1,614 @@
"""
Enhanced BGE-M3 embedding model training script with integrated dataset pipeline
Supports both legacy nested format and new embedding schema with automatic detection
"""
import os
import sys
import argparse
# Set tokenizer parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
import json
from datetime import datetime
import torch
from transformers import AutoTokenizer, set_seed
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from config.m3_config import BGEM3Config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset, create_embedding_dataset, collate_embedding_batch, migrate_nested_to_flat
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter, HardNegativeAugmenter
from data.hard_negative_mining import ANCEHardNegativeMiner
from training.m3_trainer import BGEM3Trainer
from evaluation.evaluator import RetrievalEvaluator
from utils.logging import setup_logging, get_logger
from utils.distributed import setup_distributed, is_main_process, set_random_seed
from utils.config_loader import get_config, load_config_for_training, resolve_model_path
except ImportError as e:
print(f"Import error: {e}")
print("Please ensure you're running from the project root directory")
sys.exit(1)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Enhanced BGE-M3 Embedding Model Training")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default=None,
help="Path to pretrained model or model identifier")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
parser.add_argument("--config_path", type=str, default="./config.toml",
help="Path to configuration file")
# Data arguments - support both legacy and new schemas
parser.add_argument("--train_data", type=str, required=True,
help="Path to training data file (supports both embedding and legacy nested formats)")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data file")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "embedding", "legacy_nested", "jsonl", "json", "csv", "tsv"],
help="Format of input data (auto-detects by default)")
parser.add_argument("--legacy_data", type=str, default=None,
help="Path to additional legacy nested data (for backward compatibility)")
parser.add_argument("--migrate_legacy", action="store_true",
help="Migrate legacy nested format to flat format")
# Sequence length arguments
parser.add_argument("--max_seq_length", type=int, default=None,
help="Maximum sequence length")
parser.add_argument("--query_max_length", type=int, default=None,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=None,
help="Maximum passage length")
parser.add_argument("--max_length", type=int, default=8192,
help="Maximum total input length (multi-granularity)")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/bge-m3-enhanced",
help="Directory to save the model")
parser.add_argument("--num_train_epochs", type=int, default=None,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=None,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=None,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
help="Number of updates steps to accumulate before performing a backward/update pass")
parser.add_argument("--learning_rate", type=float, default=None,
help="Initial learning rate")
parser.add_argument("--warmup_ratio", type=float, default=None,
help="Warmup steps as a ratio of total steps")
parser.add_argument("--weight_decay", type=float, default=None,
help="Weight decay")
# Model specific arguments - enhanced features
parser.add_argument("--train_group_size", type=int, default=8,
help="Number of passages per query in training")
parser.add_argument("--temperature", type=float, default=0.1,
help="Temperature for InfoNCE loss")
parser.add_argument("--use_hard_negatives", action="store_true",
help="Enable hard negative mining")
parser.add_argument("--use_self_distill", action="store_true",
help="Enable self-knowledge distillation")
parser.add_argument("--use_score_weighted_sampling", action="store_true", default=True,
help="Enable score-weighted positive sampling")
parser.add_argument("--query_instruction", type=str, default="",
help="Query instruction prefix")
# Enhanced sampling features
parser.add_argument("--multiple_positives", action="store_true", default=True,
help="Use multiple positives per batch for richer contrastive signals")
parser.add_argument("--hard_negative_ratio", type=float, default=0.7,
help="Ratio of hard to random negatives")
parser.add_argument("--difficulty_threshold", type=float, default=0.1,
help="Minimum difficulty for hard negatives")
# Optimization and performance
parser.add_argument("--fp16", action="store_true",
help="Use mixed precision training")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="Use gradient checkpointing to save memory")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
# Checkpointing
parser.add_argument("--save_steps", type=int, default=500,
help="Save checkpoint every N steps")
parser.add_argument("--eval_steps", type=int, default=500,
help="Evaluate every N steps")
parser.add_argument("--logging_steps", type=int, default=100,
help="Log every N steps")
parser.add_argument("--save_total_limit", type=int, default=3,
help="Maximum number of checkpoints to keep")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Path to checkpoint to resume from")
# Distributed training
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank for distributed training")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--overwrite_output_dir", action="store_true",
help="Overwrite the output directory")
parser.add_argument("--log_level", type=str, default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level")
# Backward compatibility flags
parser.add_argument("--legacy_support", action="store_true", default=True,
help="Enable legacy nested format support")
parser.add_argument("--use_new_dataset_api", action="store_true", default=True,
help="Use new integrated dataset API with format detection")
args = parser.parse_args()
# Validate arguments
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
return args
def create_config_from_args(args):
"""Create configuration from command line arguments and config file
Parameter Precedence (highest to lowest):
1. Command-line arguments (CLI)
2. Model-specific config ([m3], [reranker])
3. [hardware] section in config.toml
4. [training] section in config.toml
5. Hardcoded defaults in code
"""
# Load base configuration from TOML file
config_dict = load_config_for_training(
config_path=args.config_path,
override_params={}
)
# Get the centralized config instance
central_config = get_config()
# Create override parameters from command line arguments
override_params = {}
# Model configuration - use new resolve_model_path function
resolved_model_path, resolved_cache_dir = resolve_model_path(
model_key='bge_m3',
config_instance=central_config,
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
args.model_name_or_path = resolved_model_path
args.cache_dir = resolved_cache_dir
# Data configuration
data_config = config_dict.get('data', {})
if args.max_seq_length is None:
args.max_seq_length = data_config.get('max_seq_length', 512)
if args.query_max_length is None:
args.query_max_length = data_config.get('max_query_length', 64)
if args.passage_max_length is None:
args.passage_max_length = data_config.get('max_passage_length', 512)
# Get validation split from config
validation_split = data_config.get('validation_split', 0.1)
# Training configuration
training_config = config_dict.get('training', {})
if args.num_train_epochs is None:
args.num_train_epochs = training_config.get('default_num_epochs', 3)
if args.per_device_train_batch_size is None:
args.per_device_train_batch_size = training_config.get('default_batch_size', 4)
if args.per_device_eval_batch_size is None:
args.per_device_eval_batch_size = args.per_device_train_batch_size * 2
if args.learning_rate is None:
args.learning_rate = training_config.get('default_learning_rate', 1e-5)
if args.warmup_ratio is None:
args.warmup_ratio = training_config.get('default_warmup_ratio', 0.1)
if args.weight_decay is None:
args.weight_decay = training_config.get('default_weight_decay', 0.01)
# Model-specific (M3) configuration - highest precedence
m3_config = config_dict.get('m3', {})
train_group_size = m3_config.get('train_group_size', args.train_group_size)
temperature = m3_config.get('temperature', args.temperature)
use_hard_negatives = m3_config.get('use_hard_negatives', args.use_hard_negatives)
use_self_distill = m3_config.get('use_self_distill', args.use_self_distill)
query_instruction_for_retrieval = m3_config.get('query_instruction_for_retrieval', args.query_instruction)
# Hardware configuration
hardware_config = config_dict.get('hardware', {})
if args.per_device_train_batch_size is None:
args.per_device_train_batch_size = hardware_config.get('per_device_train_batch_size', args.per_device_train_batch_size)
if args.per_device_eval_batch_size is None:
args.per_device_eval_batch_size = hardware_config.get('per_device_eval_batch_size', args.per_device_eval_batch_size)
# Create BGE-M3 configuration
config = BGEM3Config(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
max_seq_length=args.max_seq_length,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
max_length=args.max_length,
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=train_group_size,
temperature=temperature,
use_hard_negatives=use_hard_negatives,
use_self_distill=use_self_distill,
query_instruction_for_retrieval=query_instruction_for_retrieval,
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
overwrite_output_dir=args.overwrite_output_dir
)
return config
def setup_model_and_tokenizer(args):
"""Setup BGE-M3 model and tokenizer"""
logger.info(f"Loading BGE-M3 model from {args.model_name_or_path}")
# Detect offline mode if the path is local or HF_HUB_OFFLINE is set
local_files_only = os.path.isdir(args.model_name_or_path) or bool(os.environ.get("HF_HUB_OFFLINE"))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
trust_remote_code=True,
local_files_only=local_files_only
)
# Load model
try:
model = BGEM3Model(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
return model, tokenizer
def setup_datasets(args, tokenizer):
"""Setup datasets with enhanced format detection and migration support"""
logger.info("Setting up datasets with enhanced format detection...")
# Handle legacy data migration if needed
if args.legacy_data and args.migrate_legacy:
logger.info(f"Migrating legacy data from {args.legacy_data}")
legacy_base = os.path.splitext(args.legacy_data)[0]
migrated_file = f"{legacy_base}_migrated_embedding.jsonl"
# For embedding, we don't need to migrate to flat pairs, but we can optimize the format
# Copy the legacy data to work with if needed
import shutil
shutil.copy2(args.legacy_data, migrated_file)
logger.info(f"Legacy data prepared for training: {migrated_file}")
# Setup training dataset
if args.use_new_dataset_api:
# Use new integrated dataset API with format detection
train_dataset = create_embedding_dataset(
data_path=args.train_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True
)
logger.info(f"Loaded {len(train_dataset)} training examples from {args.train_data}")
logger.info(f"Detected format: {train_dataset.detected_format}")
# Add legacy data if provided
if args.legacy_data and args.legacy_support:
logger.info(f"Loading additional legacy data from {args.legacy_data}")
legacy_dataset = create_embedding_dataset(
data_path=args.legacy_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True
)
logger.info(f"Loaded {len(legacy_dataset)} legacy examples")
# Combine datasets
train_dataset.data.extend(legacy_dataset.data)
logger.info(f"Total training examples: {len(train_dataset)}")
else:
# Use original dataset API for backward compatibility
train_dataset = BGEM3Dataset(
data_path=args.train_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.train_group_size,
is_train=True
)
logger.info(f"Loaded {len(train_dataset)} training examples (legacy API)")
# Setup evaluation dataset
eval_dataset = None
if args.eval_data:
if args.use_new_dataset_api:
eval_dataset = create_embedding_dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=False
)
else:
eval_dataset = BGEM3Dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.train_group_size,
is_train=False
)
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
return train_dataset, eval_dataset
def setup_data_loaders(args, train_dataset, eval_dataset):
"""Setup data loaders with appropriate collate functions"""
# 4. Set up data loaders with appropriate collate function
if hasattr(train_dataset, 'format') and train_dataset.format == 'triplets':
# This might be a generic BGE dataset in triplets format
collate_fn = collate_embedding_batch
logger.info("Using embedding collate function for triplets format")
else:
collate_fn = collate_embedding_batch
logger.info("Using standard embedding collate function")
# FastM3Dataset doesn't work well with multiprocessing due to pickle loading
workers = args.dataloader_num_workers
train_dataloader = DataLoader(
train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available()
)
# Evaluation data loader
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available()
)
return train_dataloader, eval_dataloader
def main():
# Parse arguments
args = parse_args()
# Create configuration
config = create_config_from_args(args)
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=getattr(logging, args.log_level),
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("🚀 Starting Enhanced BGE-M3 Fine-tuning")
logger.info(f"Arguments: {args}")
logger.info(f"Configuration: {config.to_dict()}")
# Set random seed
set_random_seed(args.seed, deterministic=False)
# Setup distributed training if needed
device = None
if args.local_rank != -1:
setup_distributed()
# Device selection from config
from utils.config_loader import get_config
config_instance = get_config()
device_type = config_instance.get('ascend.device_type', 'cuda')
if device_type == 'npu':
try:
import torch_npu # type: ignore
torch_npu.npu.set_device(args.local_rank)
device = torch.device("npu", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
else:
device = torch.device("cpu")
logger.info("Distributed training enabled: using CPU")
else:
# Non-distributed device selection
from utils.config_loader import get_config
config_instance = get_config()
device_type = config_instance.get('ascend.device_type', 'cuda')
if device_type == 'npu':
try:
import torch_npu # type: ignore
device = torch.device("npu")
logger.info("Using Ascend NPU")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
else:
device = torch.device("cpu")
logger.info("Using CPU")
# Update config with device
config.device = str(device)
# Save configuration
if is_main_process():
os.makedirs(args.output_dir, exist_ok=True)
config.save(os.path.join(args.output_dir, "config.json"))
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args)
# Setup datasets with enhanced format detection
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
# Setup data loaders
train_dataloader, eval_dataloader = setup_data_loaders(args, train_dataset, eval_dataset)
# Initialize trainer
logger.info("Initializing enhanced trainer...")
try:
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer
)
logger.info("Trainer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize trainer: {e}")
raise
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
logger.info(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
try:
trainer.load_checkpoint(args.resume_from_checkpoint)
logger.info("Checkpoint loaded successfully.")
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise
# Train
logger.info("🎯 Starting enhanced training...")
try:
# Handle type compatibility for trainer
from typing import cast
trainer_train_dataset = cast(BGEM3Dataset, train_dataset) if train_dataset else None
trainer_eval_dataset = cast(BGEM3Dataset, eval_dataset) if eval_dataset else None
if not trainer_train_dataset:
raise ValueError("Training dataset is required but not provided")
trainer.train(
train_dataset=trainer_train_dataset,
eval_dataset=trainer_eval_dataset,
resume_from_checkpoint=args.resume_from_checkpoint
)
# Save final model
if is_main_process():
logger.info("Saving final model...")
final_path = os.path.join(args.output_dir, "final_model")
os.makedirs(final_path, exist_ok=True)
# Save model
model.save_pretrained(final_path)
tokenizer.save_pretrained(final_path)
# Save training summary
summary = {
"training_args": vars(args),
"model_config": config.to_dict(),
"final_metrics": getattr(trainer, 'best_metric', None),
"total_steps": getattr(trainer, 'global_step', 0),
"training_completed": datetime.now().isoformat(),
"device_used": str(device),
"dataset_format": getattr(train_dataset, 'detected_format', 'unknown'),
"enhanced_features": {
"use_new_dataset_api": args.use_new_dataset_api,
"use_hard_negatives": args.use_hard_negatives,
"use_self_distill": args.use_self_distill,
"use_score_weighted_sampling": args.use_score_weighted_sampling,
"multiple_positives": args.multiple_positives,
"legacy_support": args.legacy_support
}
}
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
json.dump(summary, f, indent=2)
logger.info(f"✅ Enhanced training completed! Model saved to {final_path}")
# Show training summary
logger.info("📊 Training Summary:")
logger.info(f" Total examples: {len(train_dataset)}")
if hasattr(train_dataset, 'detected_format'):
logger.info(f" Format detected: {train_dataset.detected_format}") # type: ignore[attr-defined]
logger.info(f" Epochs: {args.num_train_epochs}")
logger.info(f" Learning rate: {args.learning_rate}")
logger.info(f" Batch size: {args.per_device_train_batch_size}")
logger.info(f" Temperature: {args.temperature}")
logger.info(f" Enhanced features enabled: {args.use_new_dataset_api}")
except Exception as e:
logger.error(f"Training failed with error: {e}", exc_info=True)
raise
finally:
# Cleanup
if hasattr(trainer, 'hard_negative_miner') and trainer.hard_negative_miner:
trainer.hard_negative_miner.stop_async_updates()
if args.local_rank != -1:
from utils.distributed import cleanup_distributed
cleanup_distributed()
if __name__ == "__main__":
main()

398
scripts/train_reranker.py Normal file
View File

@ -0,0 +1,398 @@
#!/usr/bin/env python3
"""
BGE-Reranker Cross-Encoder Model Training Script
Uses the refactored dataset pipeline with flat pairs schema
"""
import os
import sys
import argparse
import logging
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from transformers.optimization import AdamW
import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat
from models.bge_reranker import BGERerankerModel
from training.reranker_trainer import BGERerankerTrainer
from utils.logging import setup_logging, get_logger
logger = get_logger(__name__)
def parse_args():
"""Parse command line arguments for reranker training"""
parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base",
help="Path to pretrained BGE-Reranker model")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
# Data arguments
parser.add_argument("--reranker_data", type=str, required=True,
help="Path to reranker training data (reranker_data.jsonl)")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data")
parser.add_argument("--legacy_data", type=str, default=None,
help="Path to legacy nested data (will be migrated)")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/bge-reranker",
help="Directory to save the trained model")
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=16,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--learning_rate", type=float, default=6e-5,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
parser.add_argument("--warmup_ratio", type=float, default=0.1,
help="Warmup ratio")
# Model specific arguments
parser.add_argument("--query_max_length", type=int, default=64,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=448,
help="Maximum passage length")
parser.add_argument("--loss_type", type=str, default="cross_entropy",
choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"],
help="Loss function type")
# Enhanced features
parser.add_argument("--use_score_weighting", action="store_true",
help="Use score weighting in loss calculation")
parser.add_argument("--label_smoothing", type=float, default=0.0,
help="Label smoothing for cross entropy loss")
# System arguments
parser.add_argument("--fp16", action="store_true",
help="Use 16-bit floating point precision")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--log_level", type=str, default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level")
# Backward compatibility
parser.add_argument("--legacy_support", action="store_true", default=True,
help="Enable legacy nested format support")
parser.add_argument("--migrate_legacy", action="store_true",
help="Migrate legacy nested format to flat format")
return parser.parse_args()
def setup_model_and_tokenizer(args):
"""Setup BGE-Reranker model and tokenizer"""
logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
trust_remote_code=True
)
# Load model
model = BGERerankerModel(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
return model, tokenizer
def migrate_legacy_data(args):
"""Migrate legacy nested data to flat format if needed"""
if args.legacy_data and args.migrate_legacy:
logger.info("Migrating legacy nested data to flat format...")
# Create output file name
legacy_base = os.path.splitext(args.legacy_data)[0]
migrated_file = f"{legacy_base}_migrated_flat.jsonl"
# Migrate the data
migrate_nested_to_flat(args.legacy_data, migrated_file)
logger.info(f"✅ Migrated legacy data to {migrated_file}")
return migrated_file
return None
def setup_datasets(args, tokenizer):
"""Setup reranker datasets"""
logger.info("Setting up reranker datasets...")
# Migrate legacy data if needed
migrated_file = migrate_legacy_data(args)
# Primary reranker data
train_dataset = create_reranker_dataset(
data_path=args.reranker_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True,
legacy_support=args.legacy_support
)
logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}")
logger.info(f"Detected format: {train_dataset.format_type}")
# Add migrated legacy data if available
if migrated_file:
logger.info(f"Loading migrated legacy data from {migrated_file}")
migrated_dataset = create_reranker_dataset(
data_path=migrated_file,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True,
legacy_support=False # Migrated data is in flat format
)
logger.info(f"Loaded {len(migrated_dataset)} migrated examples")
# Combine datasets
train_dataset.data.extend(migrated_dataset.data)
logger.info(f"Total training examples: {len(train_dataset)}")
# Evaluation dataset
eval_dataset = None
if args.eval_data:
eval_dataset = create_reranker_dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=False,
legacy_support=args.legacy_support
)
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
return train_dataset, eval_dataset
def setup_loss_function(args):
"""Setup cross-encoder loss function"""
if args.loss_type == "cross_entropy":
if args.label_smoothing > 0:
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
else:
loss_fn = torch.nn.CrossEntropyLoss()
elif args.loss_type == "binary_cross_entropy":
loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.loss_type == "listwise_ce":
# Custom listwise cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss()
else:
raise ValueError(f"Unknown loss type: {args.loss_type}")
logger.info(f"Using loss function: {args.loss_type}")
return loss_fn
def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args):
"""Train for one epoch"""
model.train()
total_loss = 0
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
# Move batch to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Calculate loss
if args.loss_type == "binary_cross_entropy":
# For binary classification
loss = loss_fn(logits.squeeze(), labels.float())
else:
# For multi-class classification
loss = loss_fn(logits, labels)
# Backward pass
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss += loss.item()
num_batches += 1
# Log progress
if batch_idx % 100 == 0:
logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
return total_loss / num_batches
def evaluate(model, dataloader, loss_fn, device, args):
"""Evaluate the model"""
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Calculate loss
if args.loss_type == "binary_cross_entropy":
loss = loss_fn(logits.squeeze(), labels.float())
# Calculate accuracy for binary classification
predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long()
else:
loss = loss_fn(logits, labels)
# Calculate accuracy for multi-class classification
predictions = torch.argmax(logits, dim=-1)
total_loss += loss.item()
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = correct / total if total > 0 else 0
return total_loss / len(dataloader), accuracy
def main():
"""Main training function"""
args = parse_args()
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=args.log_level,
log_to_file=True,
log_to_console=True
)
logger.info("🚀 Starting BGE-Reranker Training")
logger.info(f"Arguments: {vars(args)}")
# Set random seed
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu")
logger.info(f"Using device: {device}")
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args)
model.to(device)
# Setup datasets
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
# Setup data loaders
train_dataloader = DataLoader(
train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
collate_fn=collate_reranker_batch,
pin_memory=torch.cuda.is_available()
)
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
collate_fn=collate_reranker_batch,
pin_memory=torch.cuda.is_available()
)
# Setup optimizer and scheduler
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
total_steps = len(train_dataloader) * args.num_train_epochs
warmup_steps = int(total_steps * args.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
# Setup loss function
loss_fn = setup_loss_function(args)
# Training loop
logger.info("🎯 Starting training...")
best_eval_loss = float('inf')
for epoch in range(args.num_train_epochs):
logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}")
# Train
train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args)
logger.info(f"Training loss: {train_loss:.4f}")
# Evaluate
if eval_dataloader:
eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args)
logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
# Save best model
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
best_model_path = os.path.join(args.output_dir, "best_model")
os.makedirs(best_model_path, exist_ok=True)
model.save_pretrained(best_model_path)
tokenizer.save_pretrained(best_model_path)
logger.info(f"Saved best model to {best_model_path}")
# Save final model
final_model_path = os.path.join(args.output_dir, "final_model")
os.makedirs(final_model_path, exist_ok=True)
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
logger.info(f"✅ Training completed! Final model saved to {final_model_path}")
# Show training summary
logger.info("📊 Training Summary:")
logger.info(f" Total examples: {len(train_dataset)}")
logger.info(f" Format detected: {train_dataset.format_type}")
logger.info(f" Epochs: {args.num_train_epochs}")
logger.info(f" Learning rate: {args.learning_rate}")
logger.info(f" Batch size: {args.per_device_train_batch_size}")
logger.info(f" Loss function: {args.loss_type}")
if __name__ == "__main__":
main()

122
scripts/validate_m3.py Normal file
View File

@ -0,0 +1,122 @@
"""
Quick validation script for the fine-tuned BGE-M3 model
"""
import os
import sys
import torch
import numpy as np
from transformers import AutoTokenizer
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
def test_model_embeddings(model_path, test_queries=None):
"""Test if the model can generate embeddings"""
if test_queries is None:
test_queries = [
"什么是深度学习?", # Original training query
"什么是机器学习?", # Related query
"今天天气怎么样?" # Unrelated query
]
print(f"🧪 Testing model from: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.eval()
print("✅ Model and tokenizer loaded successfully")
# Test embeddings generation
embeddings = []
for i, query in enumerate(test_queries):
print(f"\n📝 Testing query {i+1}: '{query}'")
# Tokenize
inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
# Get embeddings
with torch.no_grad():
outputs = model.forward(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
return_dense=True,
return_dict=True
)
if 'dense' in outputs:
embedding = outputs['dense'].cpu().numpy() # type: ignore[index]
embeddings.append(embedding)
print(f" ✅ Embedding shape: {embedding.shape}")
print(f" 📊 Embedding norm: {np.linalg.norm(embedding):.4f}")
print(f" 📈 Mean activation: {embedding.mean():.6f}")
print(f" 📉 Std activation: {embedding.std():.6f}")
else:
print(f" ❌ No dense embedding returned")
return False
# Compare embeddings
if len(embeddings) >= 2:
print(f"\n🔗 Similarity Analysis:")
for i in range(len(embeddings)):
for j in range(i+1, len(embeddings)):
sim = np.dot(embeddings[i].flatten(), embeddings[j].flatten()) / (
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
)
print(f" Query {i+1} vs Query {j+1}: {sim:.4f}")
print(f"\n🎉 Model validation successful!")
return True
except Exception as e:
print(f"❌ Model validation failed: {e}")
return False
def compare_models(original_path, finetuned_path):
"""Compare original vs fine-tuned model"""
print("🔄 Comparing original vs fine-tuned model...")
test_query = "什么是深度学习?"
# Test original model
print("\n📍 Original Model:")
orig_success = test_model_embeddings(original_path, [test_query])
# Test fine-tuned model
print("\n📍 Fine-tuned Model:")
ft_success = test_model_embeddings(finetuned_path, [test_query])
return orig_success and ft_success
if __name__ == "__main__":
print("🚀 BGE-M3 Model Validation")
print("=" * 50)
# Test fine-tuned model
finetuned_model_path = "./test_m3_output/final_model"
success = test_model_embeddings(finetuned_model_path)
# Optional: Compare with original model
original_model_path = "./models/bge-m3"
compare_models(original_model_path, finetuned_model_path)
if success:
print("\n✅ Fine-tuning validation: PASSED")
print(" The model can generate embeddings successfully!")
else:
print("\n❌ Fine-tuning validation: FAILED")
print(" The model has issues generating embeddings!")

View File

@ -0,0 +1,150 @@
"""
Quick validation script for the fine-tuned BGE Reranker model
"""
import os
import sys
import torch
import numpy as np
from transformers import AutoTokenizer
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_reranker import BGERerankerModel
def test_reranker_scoring(model_path, test_queries=None):
"""Test if the reranker can score query-passage pairs"""
if test_queries is None:
# Test query with relevant and irrelevant passages
test_queries = [
{
"query": "什么是深度学习?",
"passages": [
"深度学习是机器学习的一个子领域", # Relevant (should score high)
"机器学习包含多种算法", # Somewhat relevant
"今天天气很好", # Irrelevant (should score low)
"深度学习使用神经网络进行学习" # Very relevant
]
}
]
print(f"🧪 Testing reranker from: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.eval()
print("✅ Reranker model and tokenizer loaded successfully")
for query_idx, test_case in enumerate(test_queries):
query = test_case["query"]
passages = test_case["passages"]
print(f"\n📝 Test Case {query_idx + 1}")
print(f"Query: '{query}'")
print(f"Passages to rank:")
scores = []
for i, passage in enumerate(passages):
print(f" {i+1}. '{passage}'")
# Create query-passage pair
sep_token = tokenizer.sep_token or '[SEP]'
pair_text = f"{query} {sep_token} {passage}"
# Tokenize
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
# Get relevance score
with torch.no_grad():
outputs = model(**inputs)
# Extract score (logits)
if hasattr(outputs, 'logits'):
score = outputs.logits.item()
elif isinstance(outputs, torch.Tensor):
score = outputs.item()
else:
score = outputs[0].item()
scores.append((i, passage, score))
# Sort by relevance score (highest first)
scores.sort(key=lambda x: x[2], reverse=True)
print(f"\n🏆 Ranking Results:")
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
# Check if most relevant passages ranked highest
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
top_passages = [item[1] for item in scores[:2]]
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
irrelevant_in_top = "今天天气很好" in top_passages
if relevant_in_top and not irrelevant_in_top:
print(f" ✅ Good ranking - relevant passages ranked higher")
elif relevant_in_top:
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
else:
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
print(f"\n🎉 Reranker validation successful!")
return True
except Exception as e:
print(f"❌ Reranker validation failed: {e}")
return False
def compare_rerankers(original_path, finetuned_path):
"""Compare original vs fine-tuned reranker"""
print("🔄 Comparing original vs fine-tuned reranker...")
test_case = {
"query": "什么是深度学习?",
"passages": [
"深度学习是机器学习的一个子领域",
"今天天气很好"
]
}
# Test original model
print("\n📍 Original Reranker:")
orig_success = test_reranker_scoring(original_path, [test_case])
# Test fine-tuned model
print("\n📍 Fine-tuned Reranker:")
ft_success = test_reranker_scoring(finetuned_path, [test_case])
return orig_success and ft_success
if __name__ == "__main__":
print("🚀 BGE Reranker Model Validation")
print("=" * 50)
# Test fine-tuned model
finetuned_model_path = "./test_reranker_output/final_model"
success = test_reranker_scoring(finetuned_model_path)
# Optional: Compare with original model
original_model_path = "./models/bge-reranker-base"
compare_rerankers(original_model_path, finetuned_model_path)
if success:
print("\n✅ Reranker validation: PASSED")
print(" The reranker can score query-passage pairs successfully!")
else:
print("\n❌ Reranker validation: FAILED")
print(" The reranker has issues with scoring!")

641
scripts/validation_utils.py Normal file
View File

@ -0,0 +1,641 @@
#!/usr/bin/env python3
"""
Validation Utilities for BGE Fine-tuning
This module provides utilities for model validation including:
- Automatic model discovery
- Test data preparation and validation
- Results analysis and reporting
- Benchmark dataset creation
Usage:
python scripts/validation_utils.py --discover-models
python scripts/validation_utils.py --prepare-test-data
python scripts/validation_utils.py --analyze-results ./validation_results
"""
import os
import sys
import json
import glob
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any
from datetime import datetime
import pandas as pd
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class ModelDiscovery:
"""Discover trained models in the workspace"""
def __init__(self, workspace_root: str = "."):
self.workspace_root = Path(workspace_root)
def find_trained_models(self) -> Dict[str, Dict[str, str]]:
"""Find all trained BGE models in the workspace"""
models = {
'retrievers': {},
'rerankers': {},
'joint': {}
}
# Common output directories to search
search_patterns = [
"output/*/final_model",
"output/*/best_model",
"output/*/*/final_model",
"output/*/*/best_model",
"test_*_output/final_model",
"training_output/*/final_model"
]
for pattern in search_patterns:
for model_path in glob.glob(str(self.workspace_root / pattern)):
model_path = Path(model_path)
if self._is_valid_model_directory(model_path):
model_info = self._analyze_model_directory(model_path)
if model_info['type'] == 'retriever':
models['retrievers'][model_info['name']] = model_info
elif model_info['type'] == 'reranker':
models['rerankers'][model_info['name']] = model_info
elif model_info['type'] == 'joint':
models['joint'][model_info['name']] = model_info
return models
def _is_valid_model_directory(self, path: Path) -> bool:
"""Check if directory contains a valid trained model"""
required_files = ['config.json', 'pytorch_model.bin']
alternative_files = ['model.safetensors', 'tokenizer.json']
has_required = any((path / f).exists() for f in required_files)
has_alternative = any((path / f).exists() for f in alternative_files)
return has_required or has_alternative
def _analyze_model_directory(self, path: Path) -> Dict[str, str]:
"""Analyze model directory to determine type and metadata"""
info = {
'path': str(path),
'name': self._generate_model_name(path),
'type': 'unknown',
'created': 'unknown',
'training_info': {}
}
# Try to determine model type from path
path_str = str(path).lower()
if 'm3' in path_str or 'retriever' in path_str:
info['type'] = 'retriever'
elif 'reranker' in path_str:
info['type'] = 'reranker'
elif 'joint' in path_str:
info['type'] = 'joint'
# Get creation time
if path.exists():
info['created'] = datetime.fromtimestamp(path.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
# Try to read training info
training_summary_path = path.parent / 'training_summary.json'
if training_summary_path.exists():
try:
with open(training_summary_path, 'r') as f:
info['training_info'] = json.load(f)
except:
pass
# Try to read config
config_path = path / 'config.json'
if config_path.exists():
try:
with open(config_path, 'r') as f:
config = json.load(f)
if 'model_type' in config:
info['type'] = config['model_type']
except:
pass
return info
def _generate_model_name(self, path: Path) -> str:
"""Generate a readable name for the model"""
parts = path.parts
# Try to find meaningful parts
meaningful_parts = []
for part in parts[-3:]: # Look at last 3 parts
if part not in ['final_model', 'best_model', 'output']:
meaningful_parts.append(part)
if meaningful_parts:
return '_'.join(meaningful_parts)
else:
return f"model_{path.parent.name}"
def print_discovered_models(self):
"""Print discovered models in a formatted way"""
models = self.find_trained_models()
print("\n🔍 DISCOVERED TRAINED MODELS")
print("=" * 60)
for model_type, model_dict in models.items():
if model_dict:
print(f"\n📊 {model_type.upper()}:")
for name, info in model_dict.items():
print(f"{name}")
print(f" Path: {info['path']}")
print(f" Created: {info['created']}")
if info['training_info']:
training_info = info['training_info']
if 'total_steps' in training_info:
print(f" Training Steps: {training_info['total_steps']}")
if 'final_metrics' in training_info and training_info['final_metrics']:
print(f" Final Metrics: {training_info['final_metrics']}")
if not any(models.values()):
print("❌ No trained models found!")
print("💡 Make sure you have trained some models first.")
return models
class TestDataManager:
"""Manage test datasets for validation"""
def __init__(self, data_dir: str = "data/datasets"):
self.data_dir = Path(data_dir)
def find_test_datasets(self) -> Dict[str, List[str]]:
"""Find available test datasets"""
datasets = {
'retriever_datasets': [],
'reranker_datasets': [],
'general_datasets': []
}
# Search for JSONL files in data directory
for jsonl_file in self.data_dir.rglob("*.jsonl"):
file_path = str(jsonl_file)
# Categorize based on filename/path
if 'reranker' in file_path.lower():
datasets['reranker_datasets'].append(file_path)
elif 'embedding' in file_path.lower() or 'm3' in file_path.lower():
datasets['retriever_datasets'].append(file_path)
else:
datasets['general_datasets'].append(file_path)
# Also search for JSON files
for json_file in self.data_dir.rglob("*.json"):
file_path = str(json_file)
if 'test' in file_path.lower() or 'dev' in file_path.lower():
if 'reranker' in file_path.lower() or 'AFQMC' in file_path:
datasets['reranker_datasets'].append(file_path)
else:
datasets['general_datasets'].append(file_path)
return datasets
def validate_dataset(self, dataset_path: str) -> Dict[str, Any]:
"""Validate a dataset file and analyze its format"""
validation_result = {
'path': dataset_path,
'exists': False,
'format': 'unknown',
'sample_count': 0,
'has_queries': False,
'has_passages': False,
'has_labels': False,
'issues': [],
'sample_data': {}
}
if not os.path.exists(dataset_path):
validation_result['issues'].append(f"File does not exist: {dataset_path}")
return validation_result
validation_result['exists'] = True
try:
# Read first few lines to analyze format
samples = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= 5: # Only analyze first 5 lines
break
try:
sample = json.loads(line.strip())
samples.append(sample)
except json.JSONDecodeError:
validation_result['issues'].append(f"Invalid JSON at line {i+1}")
if samples:
validation_result['sample_count'] = len(samples)
validation_result['sample_data'] = samples[0] if samples else {}
# Analyze format
first_sample = samples[0]
# Check for common fields
if 'query' in first_sample:
validation_result['has_queries'] = True
if any(field in first_sample for field in ['passage', 'pos', 'neg', 'candidates']):
validation_result['has_passages'] = True
if any(field in first_sample for field in ['label', 'labels', 'score']):
validation_result['has_labels'] = True
# Determine format
if 'pos' in first_sample and 'neg' in first_sample:
validation_result['format'] = 'triplets'
elif 'candidates' in first_sample:
validation_result['format'] = 'candidates'
elif 'passage' in first_sample and 'label' in first_sample:
validation_result['format'] = 'pairs'
elif 'sentence1' in first_sample and 'sentence2' in first_sample:
validation_result['format'] = 'sentence_pairs'
else:
validation_result['format'] = 'custom'
except Exception as e:
validation_result['issues'].append(f"Error reading file: {e}")
return validation_result
def print_dataset_analysis(self):
"""Print analysis of available datasets"""
datasets = self.find_test_datasets()
print("\n📁 AVAILABLE TEST DATASETS")
print("=" * 60)
for category, dataset_list in datasets.items():
if dataset_list:
print(f"\n📊 {category.upper().replace('_', ' ')}:")
for dataset_path in dataset_list:
validation = self.validate_dataset(dataset_path)
status = "" if validation['exists'] and not validation['issues'] else "⚠️"
print(f" {status} {Path(dataset_path).name}")
print(f" Path: {dataset_path}")
print(f" Format: {validation['format']}")
if validation['issues']:
for issue in validation['issues']:
print(f" ⚠️ {issue}")
if not any(datasets.values()):
print("❌ No test datasets found!")
print("💡 Place your test data in the data/datasets/ directory")
return datasets
def create_sample_test_data(self, output_dir: str = "data/test_samples"):
"""Create sample test datasets for validation"""
os.makedirs(output_dir, exist_ok=True)
# Create sample retriever test data
retriever_test = [
{
"query": "什么是人工智能?",
"pos": [
"人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
"AI技术包括机器学习、深度学习、自然语言处理等多个领域"
],
"neg": [
"今天的天气非常好,阳光明媚",
"我喜欢听音乐和看电影"
]
},
{
"query": "如何学习编程?",
"pos": [
"学习编程需要选择合适的语言如Python、Java等然后通过实践项目提升技能",
"编程学习的关键是多写代码、多调试、多思考算法逻辑"
],
"neg": [
"烹饪是一门艺术,需要掌握火候和调味",
"运动对健康很重要每天应该锻炼30分钟"
]
}
]
# Create sample reranker test data
reranker_test = [
{
"query": "什么是人工智能?",
"passage": "人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
"label": 1
},
{
"query": "什么是人工智能?",
"passage": "今天的天气非常好,阳光明媚",
"label": 0
},
{
"query": "如何学习编程?",
"passage": "学习编程需要选择合适的语言如Python、Java等然后通过实践项目提升技能",
"label": 1
},
{
"query": "如何学习编程?",
"passage": "烹饪是一门艺术,需要掌握火候和调味",
"label": 0
}
]
# Save retriever test data
retriever_path = os.path.join(output_dir, "retriever_test.jsonl")
with open(retriever_path, 'w', encoding='utf-8') as f:
for item in retriever_test:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Save reranker test data
reranker_path = os.path.join(output_dir, "reranker_test.jsonl")
with open(reranker_path, 'w', encoding='utf-8') as f:
for item in reranker_test:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f"✅ Sample test datasets created in: {output_dir}")
print(f" 📁 Retriever test: {retriever_path}")
print(f" 📁 Reranker test: {reranker_path}")
return {
'retriever_test': retriever_path,
'reranker_test': reranker_path
}
class ValidationReportAnalyzer:
"""Analyze validation results and generate insights"""
def __init__(self, results_dir: str):
self.results_dir = Path(results_dir)
def analyze_results(self) -> Dict[str, Any]:
"""Analyze validation results"""
analysis = {
'overall_status': 'unknown',
'model_performance': {},
'recommendations': [],
'summary_stats': {}
}
# Look for results files
json_results = self.results_dir / "validation_results.json"
if not json_results.exists():
analysis['recommendations'].append("No validation results found. Run validation first.")
return analysis
try:
with open(json_results, 'r') as f:
results = json.load(f)
# Analyze each model type
improvements = 0
degradations = 0
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in results and 'comparisons' in results[model_type]:
model_analysis = self._analyze_model_results(results[model_type])
analysis['model_performance'][model_type] = model_analysis
improvements += model_analysis['total_improvements']
degradations += model_analysis['total_degradations']
# Overall status
if improvements > degradations * 1.5:
analysis['overall_status'] = 'excellent'
elif improvements > degradations:
analysis['overall_status'] = 'good'
elif improvements == degradations:
analysis['overall_status'] = 'mixed'
else:
analysis['overall_status'] = 'poor'
# Generate recommendations
analysis['recommendations'] = self._generate_recommendations(analysis)
# Summary statistics
analysis['summary_stats'] = {
'total_improvements': improvements,
'total_degradations': degradations,
'net_improvement': improvements - degradations,
'improvement_ratio': improvements / (improvements + degradations) if (improvements + degradations) > 0 else 0
}
except Exception as e:
analysis['recommendations'].append(f"Error analyzing results: {e}")
return analysis
def _analyze_model_results(self, model_results: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze results for a specific model type"""
analysis = {
'total_improvements': 0,
'total_degradations': 0,
'best_dataset': None,
'worst_dataset': None,
'avg_improvement': 0,
'issues': []
}
dataset_scores = {}
for dataset_name, comparison in model_results.get('comparisons', {}).items():
improvements = len(comparison.get('improvements', {}))
degradations = len(comparison.get('degradations', {}))
analysis['total_improvements'] += improvements
analysis['total_degradations'] += degradations
# Score dataset (improvements - degradations)
dataset_score = improvements - degradations
dataset_scores[dataset_name] = dataset_score
if dataset_scores:
analysis['best_dataset'] = max(dataset_scores.items(), key=lambda x: x[1])
analysis['worst_dataset'] = min(dataset_scores.items(), key=lambda x: x[1])
analysis['avg_improvement'] = sum(dataset_scores.values()) / len(dataset_scores)
# Identify issues
if analysis['total_degradations'] > analysis['total_improvements']:
analysis['issues'].append("More degradations than improvements")
if analysis['avg_improvement'] < 0:
analysis['issues'].append("Negative average improvement")
return analysis
def _generate_recommendations(self, analysis: Dict[str, Any]) -> List[str]:
"""Generate recommendations based on analysis"""
recommendations = []
overall_status = analysis['overall_status']
if overall_status == 'excellent':
recommendations.append("🎉 Great job! Your fine-tuned models show significant improvements.")
recommendations.append("Consider deploying these models to production.")
elif overall_status == 'good':
recommendations.append("✅ Good progress! Models show overall improvement.")
recommendations.append("Consider further fine-tuning on specific datasets where degradation occurred.")
elif overall_status == 'mixed':
recommendations.append("⚠️ Mixed results. Some improvements, some degradations.")
recommendations.append("Analyze which datasets/metrics are degrading and adjust training data.")
recommendations.append("Consider different hyperparameters or longer training.")
elif overall_status == 'poor':
recommendations.append("❌ Models show more degradations than improvements.")
recommendations.append("Review training data quality and format.")
recommendations.append("Check hyperparameters (learning rate might be too high).")
recommendations.append("Consider starting with different base models.")
# Specific model recommendations
for model_type, model_perf in analysis.get('model_performance', {}).items():
if model_perf['issues']:
recommendations.append(f"🔧 {model_type.title()} issues: {', '.join(model_perf['issues'])}")
if model_perf['best_dataset']:
best_dataset, best_score = model_perf['best_dataset']
recommendations.append(f"📊 {model_type.title()} performs best on: {best_dataset} (score: {best_score})")
if model_perf['worst_dataset']:
worst_dataset, worst_score = model_perf['worst_dataset']
recommendations.append(f"📉 {model_type.title()} needs improvement on: {worst_dataset} (score: {worst_score})")
return recommendations
def print_analysis(self):
"""Print comprehensive analysis of validation results"""
analysis = self.analyze_results()
print(f"\n📊 VALIDATION RESULTS ANALYSIS")
print("=" * 60)
# Overall status
status_emoji = {
'excellent': '🌟',
'good': '',
'mixed': '⚠️',
'poor': '',
'unknown': ''
}
print(f"{status_emoji.get(analysis['overall_status'], '')} Overall Status: {analysis['overall_status'].upper()}")
# Summary stats
if 'summary_stats' in analysis and analysis['summary_stats']:
stats = analysis['summary_stats']
print(f"\n📈 Summary Statistics:")
print(f" Total Improvements: {stats['total_improvements']}")
print(f" Total Degradations: {stats['total_degradations']}")
print(f" Net Improvement: {stats['net_improvement']:+d}")
print(f" Improvement Ratio: {stats['improvement_ratio']:.2%}")
# Model performance
if analysis['model_performance']:
print(f"\n📊 Model Performance:")
for model_type, perf in analysis['model_performance'].items():
print(f" {model_type.title()}:")
print(f" Improvements: {perf['total_improvements']}")
print(f" Degradations: {perf['total_degradations']}")
print(f" Average Score: {perf['avg_improvement']:.2f}")
if perf['best_dataset']:
print(f" Best Dataset: {perf['best_dataset'][0]} (score: {perf['best_dataset'][1]})")
if perf['issues']:
print(f" Issues: {', '.join(perf['issues'])}")
# Recommendations
if analysis['recommendations']:
print(f"\n💡 Recommendations:")
for i, rec in enumerate(analysis['recommendations'], 1):
print(f" {i}. {rec}")
return analysis
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Validation utilities for BGE models",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--discover-models", action="store_true",
help="Discover trained models in workspace")
parser.add_argument("--analyze-datasets", action="store_true",
help="Analyze available test datasets")
parser.add_argument("--create-sample-data", action="store_true",
help="Create sample test datasets")
parser.add_argument("--analyze-results", type=str, default=None,
help="Analyze validation results in specified directory")
parser.add_argument("--workspace-root", type=str, default=".",
help="Root directory of workspace")
parser.add_argument("--data-dir", type=str, default="data/datasets",
help="Directory containing test datasets")
parser.add_argument("--output-dir", type=str, default="data/test_samples",
help="Output directory for sample data")
return parser.parse_args()
def main():
"""Main utility function"""
args = parse_args()
print("🔧 BGE Validation Utilities")
print("=" * 50)
if args.discover_models:
print("\n🔍 Discovering trained models...")
discovery = ModelDiscovery(args.workspace_root)
models = discovery.print_discovered_models()
if any(models.values()):
print("\n💡 Use these model paths in your validation commands:")
for model_type, model_dict in models.items():
for name, info in model_dict.items():
print(f" --{model_type[:-1]}_model {info['path']}")
if args.analyze_datasets:
print("\n📁 Analyzing test datasets...")
data_manager = TestDataManager(args.data_dir)
datasets = data_manager.print_dataset_analysis()
if args.create_sample_data:
print("\n📝 Creating sample test data...")
data_manager = TestDataManager()
sample_paths = data_manager.create_sample_test_data(args.output_dir)
print("\n💡 Use these sample datasets for testing:")
print(f" --test_datasets {sample_paths['retriever_test']} {sample_paths['reranker_test']}")
if args.analyze_results:
print(f"\n📊 Analyzing results from: {args.analyze_results}")
analyzer = ValidationReportAnalyzer(args.analyze_results)
analysis = analyzer.print_analysis()
if not any([args.discover_models, args.analyze_datasets, args.create_sample_data, args.analyze_results]):
print("❓ No action specified. Use --help to see available options.")
return 1
return 0
if __name__ == "__main__":
exit(main())

View File

@ -0,0 +1,99 @@
import os
import sys
import tempfile
import pytest
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter
@pytest.fixture
def tmp_jsonl_file():
# Create a malformed JSONL file
fd, path = tempfile.mkstemp(suffix='.jsonl')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('{"query": "valid", "pos": ["a"], "neg": ["b"]}\n')
f.write('{invalid json}\n')
f.write('')
yield path
os.remove(path)
@pytest.fixture
def tmp_csv_file():
fd, path = tempfile.mkstemp(suffix='.csv')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('query,pos,neg\n')
f.write('valid,"a","b"\n')
f.write('malformed line\n')
yield path
os.remove(path)
@pytest.fixture
def tmp_tsv_file():
fd, path = tempfile.mkstemp(suffix='.tsv')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('query\tpos\tneg\n')
f.write('valid\ta\tb\n')
f.write('malformed line\n')
yield path
os.remove(path)
def test_jsonl_malformed_lines(tmp_jsonl_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_jsonl(tmp_jsonl_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_csv_malformed_lines(tmp_csv_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_csv(tmp_csv_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_tsv_malformed_lines(tmp_tsv_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_tsv(tmp_tsv_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_augmentation_methods_config():
# Should use config-driven augmentation methods
augmenter = QueryAugmenter(augmentation_methods=None, config_path='config.toml')
assert isinstance(augmenter.augmentation_methods, list)
assert 'paraphrase' in augmenter.augmentation_methods
def test_config_list_parsing():
"""Test that config loader parses comma/semicolon-separated lists correctly."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {'list_param': 'a, b; c'}}
value = loader.get('section', 'list_param', default=[])
assert isinstance(value, list)
assert set(value) == {'a', 'b', 'c'}
def test_config_param_source_logging(caplog):
"""Test that config loader logs parameter source correctly."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {'param': 42}}
with caplog.at_level('INFO'):
value = loader.get('section', 'param', default=0)
assert 'source: config.toml' in caplog.text
value = loader.get('section', 'missing', default=99)
assert 'source: default' in caplog.text
def test_config_fallback_default():
"""Test that config loader falls back to default if key is missing."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {}}
value = loader.get('section', 'not_found', default='fallback')
assert value == 'fallback'

121
tests/test_full_training.py Normal file
View File

@ -0,0 +1,121 @@
#!/usr/bin/env python
"""Test script to validate full training pipeline with learning rate scheduling"""
import torch
import json
import tempfile
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.m3_config import BGEM3Config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset, collate_embedding_batch
from training.m3_trainer import BGEM3Trainer
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
def create_test_data():
"""Create minimal test data"""
data = [
{
"query": "What is machine learning?",
"pos": ["Machine learning is a subset of AI"],
"neg": ["The weather is sunny today"]
},
{
"query": "How does Python work?",
"pos": ["Python is a programming language"],
"neg": ["Cats are cute animals"]
},
{
"query": "What is deep learning?",
"pos": ["Deep learning uses neural networks"],
"neg": ["Pizza is delicious food"]
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def main():
print("🧪 Testing BGE-M3 training with learning rate scheduling...")
# Create test data
train_file = create_test_data()
print(f"✅ Created test data: {train_file}")
# Initialize config with small dataset settings
config = BGEM3Config(
model_name_or_path="./models/bge-m3",
train_data_path=train_file,
output_dir="output/test-lr-scheduler",
num_train_epochs=2,
per_device_train_batch_size=2,
gradient_accumulation_steps=1, # Small for testing
learning_rate=1e-4, # Higher for testing
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
use_self_distill=False,
use_hard_negatives=False,
temperature=0.1 # Fixed temperature
)
try:
# Load model and tokenizer
print("📦 Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGEM3Model.from_pretrained(config.model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGEM3Dataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ Dataset size: {len(dataset)}")
# Initialize trainer
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test training for a few steps
print("\n🏋️ Starting training test...")
trainer.train(dataset)
print("✅ Training completed successfully!")
return True
except Exception as e:
print(f"❌ Training failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

712
tests/test_installation.py Normal file
View File

@ -0,0 +1,712 @@
"""
Comprehensive testing script for BGE fine-tuning project
Tests all components, configurations, and integrations
"""
import os
import sys
import traceback
import tempfile
import json
from pathlib import Path
import torch
import numpy as np
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class ComprehensiveTestSuite:
"""Comprehensive test suite for BGE fine-tuning project"""
def __init__(self):
self.test_results = []
self.failed_tests = []
def run_test(self, test_name, test_func, *args, **kwargs):
"""Run a single test and record results"""
print(f"\n{'='*60}")
print(f"Testing: {test_name}")
print('='*60)
try:
result = test_func(*args, **kwargs)
if result:
print(f"{test_name} PASSED")
self.test_results.append((test_name, "PASSED", None))
return True
else:
print(f"{test_name} FAILED")
self.test_results.append((test_name, "FAILED", "Test returned False"))
self.failed_tests.append(test_name)
return False
except Exception as e:
print(f"{test_name} FAILED with exception: {e}")
traceback.print_exc()
self.test_results.append((test_name, "FAILED", str(e)))
self.failed_tests.append(test_name)
return False
def test_imports(self):
"""Test all critical imports"""
print("Testing critical imports...")
# Core dependencies
try:
import torch
print(f"✓ PyTorch {torch.__version__}")
print(f" CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" CUDA version: {torch.version.cuda}")
print(f" GPU count: {torch.cuda.device_count()}")
except ImportError as e:
print(f"✗ PyTorch: {e}")
return False
# Test Ascend NPU support
try:
import torch_npu
if torch.npu.is_available(): # type: ignore[attr-defined]
print(f"✓ Ascend NPU available: {torch.npu.device_count()} devices") # type: ignore[attr-defined]
else:
print("! Ascend NPU library installed but no devices available")
except ImportError:
print("- Ascend NPU library not installed (optional)")
# Transformers and related
try:
import transformers
print(f"✓ Transformers {transformers.__version__}")
except ImportError as e:
print(f"✗ Transformers: {e}")
return False
# Scientific computing
try:
import numpy as np
import pandas as pd
import scipy
import sklearn
print("✓ Scientific computing libraries")
except ImportError as e:
print(f"✗ Scientific libraries: {e}")
return False
# BGE specific
try:
import faiss
print(f"✓ FAISS {faiss.__version__}")
except ImportError as e:
print(f"✗ FAISS: {e}")
return False
try:
from rank_bm25 import BM25Okapi
print("✓ BM25")
except ImportError as e:
print(f"✗ BM25: {e}")
return False
try:
import jieba
print("✓ Jieba (Chinese tokenization)")
except ImportError as e:
print(f"✗ Jieba: {e}")
return False
try:
import toml
print("✓ TOML")
except ImportError as e:
print(f"✗ TOML: {e}")
return False
return True
def test_project_imports(self):
"""Test project-specific imports"""
print("Testing project imports...")
# Configuration
try:
from utils.config_loader import get_config, load_config_for_training
from config.base_config import BaseConfig
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
print("✓ Configuration classes")
except ImportError as e:
print(f"✗ Configuration imports: {e}")
traceback.print_exc()
return False
# Models
try:
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
print("✓ Model classes")
except ImportError as e:
print(f"✗ Model imports: {e}")
traceback.print_exc()
return False
# Data processing
try:
from data.dataset import BGEM3Dataset, BGERerankerDataset
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
from data.hard_negative_mining import ANCEHardNegativeMiner
print("✓ Data processing classes")
except ImportError as e:
print(f"✗ Data imports: {e}")
traceback.print_exc()
return False
# Training
try:
from training.trainer import BaseTrainer
from training.m3_trainer import BGEM3Trainer
from training.reranker_trainer import BGERerankerTrainer
from training.joint_trainer import JointTrainer
print("✓ Training classes")
except ImportError as e:
print(f"✗ Training imports: {e}")
traceback.print_exc()
return False
# Evaluation
try:
from evaluation.evaluator import RetrievalEvaluator
from evaluation.metrics import compute_retrieval_metrics
print("✓ Evaluation classes")
except ImportError as e:
print(f"✗ Evaluation imports: {e}")
traceback.print_exc()
return False
# Utilities
try:
from utils.logging import setup_logging
from utils.checkpoint import CheckpointManager
from utils.distributed import setup_distributed, is_main_process
print("✓ Utility classes")
except ImportError as e:
print(f"✗ Utility imports: {e}")
traceback.print_exc()
return False
return True
def test_configuration_system(self):
"""Test the centralized configuration system"""
print("Testing configuration system...")
try:
from utils.config_loader import get_config, load_config_for_training
# Test singleton behavior
config1 = get_config()
config2 = get_config()
assert config1 is config2, "Config should be singleton"
print("✓ Singleton pattern working")
# Test configuration loading
test_value = config1.get('training.default_batch_size', None) # type: ignore[arg-type]
print(f"✓ Config loading: batch_size = {test_value}")
# Test section loading
training_section = config1.get_section('training')
assert isinstance(training_section, dict), "Should return dict"
print("✓ Section loading working")
# Test training configuration loading
training_config = load_config_for_training()
assert 'training' in training_config
assert 'model_paths' in training_config
print("✓ Training config loading working")
return True
except Exception as e:
print(f"✗ Configuration system error: {e}")
traceback.print_exc()
return False
def test_configuration_integration(self):
"""Test configuration integration in classes"""
print("Testing configuration integration...")
try:
from config.base_config import BaseConfig
from config.m3_config import BGEM3Config
# Test BaseConfig with config.toml integration
base_config = BaseConfig()
print(f"✓ BaseConfig created with batch_size: {base_config.per_device_train_batch_size}")
# Test M3Config
m3_config = BGEM3Config()
print(f"✓ M3Config created with learning_rate: {m3_config.learning_rate}")
# Test device detection
print(f"✓ Device detected: {base_config.device}")
return True
except Exception as e:
print(f"✗ Configuration integration error: {e}")
traceback.print_exc()
return False
def test_data_processing(self):
"""Test data processing functionality"""
print("Testing data processing...")
try:
from data.preprocessing import DataPreprocessor
# Create test data
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
test_data = [
{"query": "test query 1", "pos": ["positive passage 1"], "neg": ["negative passage 1"]},
{"query": "test query 2", "pos": ["positive passage 2"], "neg": ["negative passage 2"]}
]
for item in test_data:
f.write(json.dumps(item) + '\n')
temp_file = f.name
# Test preprocessing
preprocessor = DataPreprocessor()
data = preprocessor._load_jsonl(temp_file)
assert len(data) == 2
print("✓ JSONL loading working")
# Test validation
cleaned = preprocessor._validate_and_clean(data)
assert len(cleaned) == 2
print("✓ Data validation working")
# Cleanup
os.unlink(temp_file)
return True
except Exception as e:
print(f"✗ Data processing error: {e}")
traceback.print_exc()
return False
def test_data_augmentation(self):
"""Test data augmentation functionality"""
print("Testing data augmentation...")
try:
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
# Test QueryAugmenter
augmenter = QueryAugmenter(augmentation_methods=['paraphrase', 'synonym'])
test_data = [
{"query": "什么是机器学习?", "pos": ["机器学习是AI的分支"], "neg": ["深度学习不同"]}
]
augmented = augmenter.augment_queries(test_data, num_augmentations=1, keep_original=True)
assert len(augmented) >= len(test_data)
print("✓ Query augmentation working")
# Test PassageAugmenter
passage_aug = PassageAugmenter()
passage_augmented = passage_aug.augment_passages(test_data, augmentation_prob=1.0)
print("✓ Passage augmentation working")
# Test HardNegativeAugmenter
hn_aug = HardNegativeAugmenter()
hn_augmented = hn_aug.create_hard_negatives(test_data, num_hard_negatives=2)
print("✓ Hard negative augmentation working")
return True
except Exception as e:
print(f"✗ Data augmentation error: {e}")
traceback.print_exc()
return False
def test_model_functionality(self):
"""Test basic model functionality"""
print("Testing model functionality...")
try:
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
# Test loss functions
infonce = InfoNCELoss()
query_emb = torch.randn(2, 128)
pos_emb = torch.randn(2, 128)
neg_emb = torch.randn(2, 5, 128)
loss = infonce(query_emb, pos_emb, neg_emb)
assert loss.item() > 0
print("✓ InfoNCE loss working")
# Test listwise loss
listwise = ListwiseCrossEntropyLoss()
scores = torch.randn(2, 5)
labels = torch.zeros(2, 5)
labels[:, 0] = 1 # First is positive
loss = listwise(scores, labels, [5, 5])
assert loss.item() > 0
print("✓ Listwise loss working")
return True
except Exception as e:
print(f"✗ Model functionality error: {e}")
traceback.print_exc()
return False
def test_hard_negative_mining(self):
"""Test hard negative mining"""
print("Testing hard negative mining...")
try:
from data.hard_negative_mining import ANCEHardNegativeMiner
# Create miner
miner = ANCEHardNegativeMiner(
embedding_dim=128,
num_hard_negatives=5,
use_gpu=False # Use CPU for testing
)
# Add some dummy embeddings
embeddings = np.random.rand(10, 128).astype(np.float32)
passages = [f"passage_{i}" for i in range(10)]
miner._add_to_index(embeddings, passages)
print("✓ Index creation working")
# Test mining
queries = ["query1", "query2"]
query_embeddings = np.random.rand(2, 128).astype(np.float32)
positive_passages = [["passage_0"], ["passage_5"]]
hard_negs = miner.mine_hard_negatives(
queries, query_embeddings, positive_passages
)
assert len(hard_negs) == 2
print("✓ Hard negative mining working")
return True
except Exception as e:
print(f"✗ Hard negative mining error: {e}")
traceback.print_exc()
return False
def test_evaluation_metrics(self):
"""Test evaluation metrics"""
print("Testing evaluation metrics...")
try:
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
# Create dummy data
query_embs = np.random.rand(5, 128)
passage_embs = np.random.rand(20, 128)
labels = np.zeros((5, 20))
labels[0, 0] = 1 # First query, first passage is relevant
labels[1, 5] = 1 # Second query, sixth passage is relevant
metrics = compute_retrieval_metrics(query_embs, passage_embs, labels)
assert 'recall@1' in metrics
assert 'precision@1' in metrics
print("✓ Retrieval metrics working")
# Test reranker metrics
scores = np.random.rand(10)
labels = np.random.randint(0, 2, 10)
reranker_metrics = compute_reranker_metrics(scores, labels)
assert 'accuracy' in reranker_metrics or 'mrr' in reranker_metrics
print("✓ Reranker metrics working")
return True
except Exception as e:
print(f"✗ Evaluation metrics error: {e}")
traceback.print_exc()
return False
def test_data_formats(self):
"""Test multiple data format support"""
print("Testing data format support...")
try:
from data.preprocessing import DataPreprocessor
preprocessor = DataPreprocessor()
# Test JSONL format
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write('{"query": "test", "pos": ["pos1"], "neg": ["neg1"]}\n')
f.write('{"query": "test2", "pos": ["pos2"], "neg": ["neg2"]}\n')
temp_file = f.name
data = preprocessor._load_jsonl(temp_file)
assert len(data) == 2
os.unlink(temp_file)
print("✓ JSONL format working")
# Test CSV format
import pandas as pd
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
df = pd.DataFrame({
'query': ['test1', 'test2'],
'positive': ['pos1', 'pos2'],
'negative': ['neg1', 'neg2']
})
df.to_csv(f.name, index=False)
temp_file = f.name
data = preprocessor._load_csv(temp_file)
assert len(data) >= 1 # At least one valid record
os.unlink(temp_file)
print("✓ CSV format working")
return True
except Exception as e:
print(f"✗ Data format support error: {e}")
traceback.print_exc()
return False
def test_device_compatibility(self):
"""Test device compatibility (CUDA/Ascend/CPU)"""
print("Testing device compatibility...")
try:
from config.base_config import BaseConfig
# Test device detection
config = BaseConfig()
print(f"✓ Device detected: {config.device}")
print(f"✓ Number of devices: {config.n_gpu}")
# Test tensor operations
if config.device == "cuda":
x = torch.randn(10, 10).cuda()
y = torch.matmul(x, x.t())
assert y.device.type == "cuda"
print("✓ CUDA tensor operations working")
elif config.device == "npu":
print("✓ Ascend NPU detected")
# Note: Would need actual NPU to test operations
else:
x = torch.randn(10, 10)
y = torch.matmul(x, x.t())
print("✓ CPU tensor operations working")
return True
except Exception as e:
print(f"✗ Device compatibility error: {e}")
traceback.print_exc()
return False
def test_checkpoint_functionality(self):
"""Test checkpointing functionality"""
print("Testing checkpoint functionality...")
try:
from utils.checkpoint import CheckpointManager
with tempfile.TemporaryDirectory() as temp_dir:
manager = CheckpointManager(temp_dir)
# Test saving
state_dict = {
'model_state_dict': torch.randn(10, 10),
'optimizer_state_dict': {'lr': 0.001},
'step': 100
}
checkpoint_path = manager.save(state_dict, step=100)
assert os.path.exists(checkpoint_path)
print("✓ Checkpoint saving working")
# Test loading
loaded_state = manager.load_checkpoint(checkpoint_path)
assert 'model_state_dict' in loaded_state
print("✓ Checkpoint loading working")
return True
except Exception as e:
print(f"✗ Checkpoint functionality error: {e}")
traceback.print_exc()
return False
def test_logging_system(self):
"""Test logging system"""
print("Testing logging system...")
try:
from utils.logging import setup_logging, get_logger, MetricsLogger
with tempfile.TemporaryDirectory() as temp_dir:
# Test setup
setup_logging(temp_dir, log_to_file=True)
logger = get_logger("test")
logger.info("Test log message")
print("✓ Basic logging working")
# Test metrics logger
metrics_logger = MetricsLogger(temp_dir)
metrics_logger.log(step=1, metrics={'loss': 0.5}, phase='train')
print("✓ Metrics logging working")
return True
except Exception as e:
print(f"✗ Logging system error: {e}")
traceback.print_exc()
return False
def test_training_script_config_integration(self):
"""Test training script configuration integration"""
print("Testing training script configuration integration...")
try:
# Import the main training script functions
from scripts.train_m3 import create_config_from_args
# Create mock arguments
class MockArgs:
def __init__(self):
self.config_path = None
self.model_name_or_path = None
self.cache_dir = None
self.train_data = "dummy"
self.eval_data = None
self.max_seq_length = None
self.query_max_length = None
self.passage_max_length = None
self.max_length = 8192
self.output_dir = "/tmp/test"
self.num_train_epochs = None
self.per_device_train_batch_size = None
self.per_device_eval_batch_size = None
self.gradient_accumulation_steps = None
self.learning_rate = None
self.warmup_ratio = None
self.weight_decay = None
self.train_group_size = None
self.temperature = None
self.use_hard_negatives = False
self.num_hard_negatives = None
self.use_self_distill = False
self.use_query_instruction = False
self.query_instruction = None
self.use_dense = True
self.use_sparse = False
self.use_colbert = False
self.unified_finetuning = False
self.augment_queries = False
self.augmentation_methods = None
self.create_hard_negatives = False
self.fp16 = False
self.bf16 = False
self.gradient_checkpointing = False
self.dataloader_num_workers = None
self.save_steps = None
self.eval_steps = None
self.logging_steps = None
self.save_total_limit = None
self.resume_from_checkpoint = None
self.local_rank = -1
self.seed = None
self.overwrite_output_dir = True
args = MockArgs()
config = create_config_from_args(args)
assert config.learning_rate > 0
assert config.per_device_train_batch_size > 0
print("✓ Training script config integration working")
return True
except Exception as e:
print(f"✗ Training script config integration error: {e}")
traceback.print_exc()
return False
def run_all_tests(self):
"""Run all tests and generate report"""
print("\n" + "="*80)
print("BGE FINE-TUNING COMPREHENSIVE TEST SUITE")
print("="*80)
tests = [
("Core Dependencies Import", self.test_imports),
("Project Components Import", self.test_project_imports),
("Configuration System", self.test_configuration_system),
("Configuration Integration", self.test_configuration_integration),
("Data Processing", self.test_data_processing),
("Data Augmentation", self.test_data_augmentation),
("Model Functionality", self.test_model_functionality),
("Hard Negative Mining", self.test_hard_negative_mining),
("Evaluation Metrics", self.test_evaluation_metrics),
("Data Format Support", self.test_data_formats),
("Device Compatibility", self.test_device_compatibility),
("Checkpoint Functionality", self.test_checkpoint_functionality),
("Logging System", self.test_logging_system),
("Training Script Integration", self.test_training_script_config_integration),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
if self.run_test(test_name, test_func):
passed += 1
# Generate report
print("\n" + "="*80)
print("TEST RESULTS SUMMARY")
print("="*80)
print(f"Tests passed: {passed}/{total}")
print(f"Success rate: {passed/total*100:.1f}%")
if passed == total:
print("\n🎉 ALL TESTS PASSED! Project is ready for use.")
print("\nNext steps:")
print("1. Configure your API keys in config.toml if needed")
print("2. Prepare your training data in JSONL format")
print("3. Run training with: python scripts/train_m3.py --train_data your_data.jsonl")
print("4. Evaluate with: python scripts/evaluate.py --retriever_model output/model")
else:
print(f"\n{total - passed} tests failed. Please review the errors above.")
print("\nFailed tests:")
for test_name in self.failed_tests:
print(f" - {test_name}")
print("\nDetailed test results:")
for test_name, status, error in self.test_results:
status_emoji = "" if status == "PASSED" else ""
print(f" {status_emoji} {test_name}: {status}")
if error:
print(f" Error: {error}")
return passed == total
def main():
"""Run comprehensive test suite"""
test_suite = ComprehensiveTestSuite()
success = test_suite.run_all_tests()
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,299 @@
#!/usr/bin/env python
"""Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines"""
import torch
import json
import tempfile
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch
from training.m3_trainer import BGEM3Trainer
from training.reranker_trainer import BGERerankerTrainer
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
def create_embedding_test_data():
"""Create test data for embedding model (triplets format)"""
data = [
{
"query": "What is machine learning?",
"pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."],
"neg": ["The weather today is sunny and warm."]
},
{
"query": "How does deep learning work?",
"pos": ["Deep learning uses neural networks with multiple layers to learn patterns."],
"neg": ["Pizza is a popular Italian dish."]
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def create_reranker_test_data():
"""Create test data for reranker model (pairs format)"""
data = [
{
"query": "What is machine learning?",
"passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
"label": 1,
"score": 0.95,
"qid": "q1",
"pid": "p1"
},
{
"query": "What is machine learning?",
"passage": "The weather today is sunny and warm.",
"label": 0,
"score": 0.15,
"qid": "q1",
"pid": "p2"
},
{
"query": "How does deep learning work?",
"passage": "Deep learning uses neural networks with multiple layers to learn patterns.",
"label": 1,
"score": 0.92,
"qid": "q2",
"pid": "p3"
},
{
"query": "How does deep learning work?",
"passage": "Pizza is a popular Italian dish.",
"label": 0,
"score": 0.12,
"qid": "q2",
"pid": "p4"
}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for item in data:
f.write(json.dumps(item) + '\n')
return f.name
def test_bge_m3():
"""Test BGE-M3 embedding model"""
print("\n🎯 Testing BGE-M3 Embedding Model...")
# Create test data
train_file = create_embedding_test_data()
print(f"✅ Created M3 test data: {train_file}")
# Initialize config
config = BGEM3Config(
model_name_or_path="./models/bge-m3",
train_data_path=train_file,
output_dir="output/test-m3",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=1e-5,
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
use_self_distill=False, # Disable for testing
use_hard_negatives=False # Disable for testing
)
try:
# Load model and tokenizer
print("📦 Loading M3 model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGEM3Model.from_pretrained(config.model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGEM3Dataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ M3 Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=config.per_device_train_batch_size,
shuffle=True,
collate_fn=collate_embedding_batch
)
# Initialize trainer
trainer = BGEM3Trainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test one batch
for batch in dataloader:
print(f"M3 Batch keys: {list(batch.keys())}")
print(f"Query shape: {batch['query_input_ids'].shape}")
print(f"Passage shape: {batch['passage_input_ids'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
# Move batch to device
batch = trainer._prepare_batch(batch)
# Compute loss
with torch.no_grad():
loss = trainer.compute_loss(batch)
print(f"✅ M3 Loss computed successfully: {loss.item():.4f}")
break
print("✅ BGE-M3 test passed!")
return True
except Exception as e:
print(f"❌ BGE-M3 test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
def test_bge_reranker():
"""Test BGE-Reranker model"""
print("\n🎯 Testing BGE-Reranker Model...")
# Create test data
train_file = create_reranker_test_data()
print(f"✅ Created Reranker test data: {train_file}")
# Initialize config
config = BGERerankerConfig(
model_name_or_path="./models/bge-reranker-base",
train_data_path=train_file,
output_dir="output/test-reranker",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=1e-5,
logging_steps=1,
save_steps=10,
fp16=False,
use_amp=False,
dataloader_num_workers=0,
overwrite_output_dir=True,
loss_type="binary_cross_entropy"
)
try:
# Load model and tokenizer
print("📦 Loading Reranker model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
model = BGERerankerModel(
model_name_or_path=config.model_name_or_path,
num_labels=1,
use_fp16=config.fp16,
gradient_checkpointing=config.gradient_checkpointing
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"🖥️ Using device: {device}")
# Create dataset
dataset = BGERerankerDataset(
data_path=train_file,
tokenizer=tokenizer,
query_max_length=config.query_max_length,
passage_max_length=config.passage_max_length,
is_train=True
)
print(f"✅ Reranker Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=config.per_device_train_batch_size,
shuffle=True,
collate_fn=collate_reranker_batch
)
# Initialize trainer
trainer = BGERerankerTrainer(
config=config,
model=model,
tokenizer=tokenizer,
device=device
)
# Test one batch
for batch in dataloader:
print(f"Reranker Batch keys: {list(batch.keys())}")
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
print(f"Labels values: {batch['labels']}")
# Move batch to device
batch = trainer._prepare_batch(batch)
# Compute loss
with torch.no_grad():
loss = trainer.compute_loss(batch)
print(f"✅ Reranker Loss computed successfully: {loss.item():.4f}")
break
print("✅ BGE-Reranker test passed!")
return True
except Exception as e:
print(f"❌ BGE-Reranker test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Clean up
Path(train_file).unlink()
def main():
"""Run all tests"""
print("🚀 Starting comprehensive BGE training pipeline tests...")
m3_success = test_bge_m3()
reranker_success = test_bge_reranker()
print("\n📊 Test Results Summary:")
print(f" BGE-M3: {'✅ PASS' if m3_success else '❌ FAIL'}")
print(f" BGE-Reranker: {'✅ PASS' if reranker_success else '❌ FAIL'}")
if m3_success and reranker_success:
print("\n🎉 All tests passed! Both models are working correctly.")
return 0
else:
print("\n⚠️ Some tests failed. Please check the error messages above.")
return 1
if __name__ == "__main__":
sys.exit(main())

40
training/__init__.py Normal file
View File

@ -0,0 +1,40 @@
"""
Training modules for BGE models
"""
def get_trainer(trainer_type: str):
"""Lazy loading of trainers to avoid importing all at once"""
if trainer_type == 'base':
from .trainer import BaseTrainer
return BaseTrainer
elif trainer_type == 'm3':
from .m3_trainer import BGEM3Trainer
return BGEM3Trainer
elif trainer_type == 'reranker':
from .reranker_trainer import BGERerankerTrainer
return BGERerankerTrainer
elif trainer_type == 'joint':
from .joint_trainer import JointTrainer
return JointTrainer
else:
raise ValueError(f"Unknown trainer type: {trainer_type}")
# For backward compatibility, import the commonly used ones
from .trainer import BaseTrainer
from .m3_trainer import BGEM3Trainer
# Only import reranker and joint trainers when explicitly requested
def __getattr__(name):
if name == 'BGERerankerTrainer':
from .reranker_trainer import BGERerankerTrainer
return BGERerankerTrainer
elif name == 'JointTrainer':
from .joint_trainer import JointTrainer
return JointTrainer
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
__all__ = [
'BaseTrainer',
'BGEM3Trainer',
'get_trainer'
]

419
training/joint_trainer.py Normal file
View File

@ -0,0 +1,419 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from typing import Dict, Optional, List, Tuple
import logging
from tqdm import tqdm
import numpy as np
import os
from .trainer import BaseTrainer
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import (
InfoNCELoss, ListwiseCrossEntropyLoss,
DynamicListwiseDistillationLoss, MultiTaskLoss
)
from data.dataset import JointDataset
from config.base_config import BaseConfig
from utils.distributed import is_main_process, get_world_size, reduce_dict
from utils.logging import get_logger
# Import autocast and GradScaler with fallback
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
# Scheduler import
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = get_logger(__name__)
class JointTrainer(BaseTrainer):
"""
Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach
Notes:
- All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BaseConfig,
retriever: BGEM3Model,
reranker: BGERerankerModel,
retriever_optimizer: Optional[torch.optim.Optimizer] = None,
reranker_optimizer: Optional[torch.optim.Optimizer] = None,
device: Optional[torch.device] = None
):
# Create default optimizer if none provided
if retriever_optimizer is None:
retriever_optimizer = AdamW(
retriever.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# We'll use the retriever as the main model for the base class
super().__init__(config, retriever, retriever_optimizer, device=device)
self.retriever = self.model # Alias for clarity
self.reranker = reranker
self.reranker.to(self.device)
# Setup reranker optimizer if not provided
if reranker_optimizer is None:
self.reranker_optimizer = AdamW(
self.reranker.parameters(),
lr=config.learning_rate * 2, # Higher LR for reranker
weight_decay=config.weight_decay
)
else:
self.reranker_optimizer = reranker_optimizer
# Initialize losses
self.retriever_loss_fn = InfoNCELoss(
temperature=config.temperature if hasattr(config, 'temperature') else 0.02,
use_in_batch_negatives=True # type: ignore[call-arg]
)
self.reranker_loss_fn = ListwiseCrossEntropyLoss(
temperature=1.0,
label_smoothing=0.0 # type: ignore[call-arg]
)
self.distillation_loss_fn = DynamicListwiseDistillationLoss(
temperature=1.0,
alpha=0.5 # type: ignore[call-arg]
)
# Joint training specific parameters
self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100)
self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3)
self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True)
# Tracking metrics
self.retriever_metrics = []
self.reranker_metrics = []
def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor:
"""
Compute joint loss for both retriever and reranker
Args:
batch: Tuple of (retriever_batch, reranker_batch)
"""
retriever_batch, reranker_batch = batch
# 1. Compute retriever loss
retriever_loss = self._compute_retriever_loss(retriever_batch)
# 2. Compute reranker loss
reranker_loss = self._compute_reranker_loss(reranker_batch)
# 3. Compute distillation losses if enabled
if self.use_dynamic_distillation:
# Get scores from both models
with torch.no_grad():
retriever_scores = self._get_retriever_scores(retriever_batch)
reranker_scores = self._get_reranker_scores(reranker_batch)
# Compute mutual distillation
retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn(
retriever_scores,
reranker_scores,
labels=reranker_batch.get('labels')
)
# Add distillation to main losses
retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss
reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss
# Combine losses (we'll handle separate backward passes)
self.current_retriever_loss = retriever_loss
self.current_reranker_loss = reranker_loss
# Return combined loss for logging
return retriever_loss + reranker_loss
def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor:
"""Compute retriever (bi-encoder) loss"""
# Get embeddings
query_outputs = self.retriever(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
passage_outputs = self.retriever(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True
)
# Compute contrastive loss
loss = self.retriever_loss_fn(
query_outputs['dense'],
passage_outputs['dense'],
labels=batch.get('labels')
)
return loss
def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor:
"""Compute reranker (cross-encoder) loss"""
outputs = self.reranker(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
logits = outputs['logits']
labels = batch['labels']
# Reshape for listwise loss if needed
if logits.dim() == 1:
# Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
else:
# Listwise ranking
loss = self.reranker_loss_fn(logits, labels)
return loss
def _get_retriever_scores(self, batch: Dict) -> torch.Tensor:
"""Get retriever similarity scores"""
with torch.no_grad():
query_outputs = self.retriever(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
passage_outputs = self.retriever(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True
)
# Compute similarities
scores = torch.matmul(
query_outputs['dense'],
passage_outputs['dense'].transpose(-1, -2)
)
return scores
def _get_reranker_scores(self, batch: Dict) -> torch.Tensor:
"""Get reranker scores"""
outputs = self.reranker(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
return outputs['logits']
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train both models for one epoch"""
self.retriever.train()
self.reranker.train()
epoch_metrics = {
'retriever_loss': 0.0,
'reranker_loss': 0.0,
'total_loss': 0.0
}
num_batches = 0
# Set epoch for distributed sampler
if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'):
dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
progress_bar = tqdm(
dataloader,
desc=f"Joint Training Epoch {epoch + 1}",
disable=not is_main_process()
)
for step, batch in enumerate(progress_bar):
# Prepare batch
retriever_batch, reranker_batch = batch
retriever_batch = self._prepare_batch(retriever_batch)
reranker_batch = self._prepare_batch(reranker_batch)
batch = (retriever_batch, reranker_batch)
# Forward pass
total_loss = self.compute_loss(batch)
# Scale losses for gradient accumulation
retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps
reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps
# Backward passes (separate for each model)
if self.scaler:
# Retriever backward
self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined]
# Reranker backward
self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined]
else:
retriever_loss.backward(retain_graph=True)
reranker_loss.backward()
# Update weights
if (step + 1) % self.config.gradient_accumulation_steps == 0:
# Gradient clipping
if self.scaler:
self.scaler.unscale_(self.optimizer)
self.scaler.unscale_(self.reranker_optimizer)
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
self.retriever.parameters(),
self.config.max_grad_norm
)
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
self.reranker.parameters(),
self.config.max_grad_norm
)
# Optimizer steps
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.step(self.reranker_optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.reranker_optimizer.step()
# Scheduler steps
if self.scheduler:
self.scheduler.step()
# Note: Add reranker scheduler if needed
# Zero gradients
self.optimizer.zero_grad()
self.reranker_optimizer.zero_grad()
self.global_step += 1
# Update hard negatives periodically
if self.global_step % self.update_retriever_interval == 0:
self._update_hard_negatives()
# Logging
if self.global_step % self.config.logging_steps == 0:
step_metrics = {
'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps,
'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps,
'total_loss': (
retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps,
'retriever_lr': self.optimizer.param_groups[0]['lr'],
'reranker_lr': self.reranker_optimizer.param_groups[0]['lr']
}
if self.is_distributed:
step_metrics = reduce_dict(step_metrics)
progress_bar.set_postfix({
'ret_loss': f"{step_metrics['retriever_loss']:.4f}",
'rank_loss': f"{step_metrics['reranker_loss']:.4f}"
})
if is_main_process():
for key, value in step_metrics.items():
self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step)
# Flush logs periodically
if self.global_step % (self.config.logging_steps * 10) == 0:
self.tb_logger.flush()
# Accumulate metrics
epoch_metrics['retriever_loss'] += retriever_loss.item()
epoch_metrics['reranker_loss'] += reranker_loss.item()
epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item())
num_batches += 1
# Average metrics
for key in epoch_metrics:
epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps
if self.is_distributed:
epoch_metrics = reduce_dict(epoch_metrics)
return epoch_metrics
def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]:
"""Evaluate both models on a batch"""
retriever_batch, reranker_batch = batch
metrics = {}
# Evaluate retriever
with torch.no_grad():
retriever_loss = self._compute_retriever_loss(retriever_batch)
metrics['retriever_loss'] = retriever_loss.item()
# Compute retrieval metrics
retriever_scores = self._get_retriever_scores(retriever_batch)
labels = retriever_batch['labels']
# Simple accuracy: is the positive passage ranked first?
top_indices = retriever_scores.argmax(dim=-1)
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
metrics['retriever_accuracy'] = correct.item()
# Evaluate reranker
with torch.no_grad():
reranker_loss = self._compute_reranker_loss(reranker_batch)
metrics['reranker_loss'] = reranker_loss.item()
# Compute reranking metrics
reranker_scores = self._get_reranker_scores(reranker_batch)
labels = reranker_batch['labels']
if reranker_scores.dim() == 1:
# Binary accuracy
predictions = (reranker_scores > 0).float()
correct = (predictions == labels).float().mean()
else:
# Ranking accuracy
top_indices = reranker_scores.argmax(dim=-1)
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
metrics['reranker_accuracy'] = correct.item()
metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss']
return metrics
def _update_hard_negatives(self):
"""Update hard negatives using current models (simplified version)"""
# This would implement the dynamic hard negative mining
# For now, just log that we would update
if is_main_process():
logger.info(f"Would update hard negatives at step {self.global_step}")
def save_models(self, output_dir: str):
"""Save both models"""
if not is_main_process():
return
# Save retriever
retriever_path = os.path.join(output_dir, "retriever")
os.makedirs(retriever_path, exist_ok=True)
if hasattr(self.retriever, 'module'):
self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined]
else:
self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined]
# Save reranker
reranker_path = os.path.join(output_dir, "reranker")
os.makedirs(reranker_path, exist_ok=True)
if hasattr(self.reranker, 'module'):
self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined]
else:
self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined]
logger.info(f"Saved models to {output_dir}")

894
training/m3_trainer.py Normal file
View File

@ -0,0 +1,894 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer
import os
import json
import logging
from tqdm import tqdm
from typing import Dict, Optional, List, Tuple
import numpy as np
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from .trainer import BaseTrainer
from models.bge_m3 import BGEM3Model
from models.losses import InfoNCELoss, M3KnowledgeDistillationLoss, MultiTaskLoss
from data.dataset import BGEM3Dataset
# Optional hard negative mining
try:
from data.hard_negative_mining import ANCEHardNegativeMiner
ANCE_AVAILABLE = True
except ImportError:
ANCE_AVAILABLE = False
ANCEHardNegativeMiner = None
from utils.checkpoint import CheckpointManager
from utils.logging import get_logger
from evaluation.metrics import compute_retrieval_metrics
from config.m3_config import BGEM3Config
# Import scheduler with fallback
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = get_logger(__name__)
class BGEM3Trainer(BaseTrainer):
"""
Trainer for BGE-M3 embedding model with all SOTA techniques
Notes:
- All parameters (config, model, tokenizer, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BGEM3Config,
model: Optional[BGEM3Model] = None,
tokenizer: Optional[AutoTokenizer] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None
):
# Store config first
self.config = config
# Initialize tokenizer first
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_name_or_path,
cache_dir=config.cache_dir,
trust_remote_code=True
)
else:
self.tokenizer = tokenizer
# Initialize model
if model is None:
self.model = BGEM3Model(
model_name_or_path=config.model_name_or_path,
use_dense=config.use_dense,
use_sparse=config.use_sparse,
use_colbert=config.use_colbert,
pooling_method=config.sentence_pooling_method,
normalize_embeddings=config.normalize_embeddings,
colbert_dim=config.colbert_dim,
sparse_top_k=config.sparse_top_k,
cache_dir=config.cache_dir
)
else:
self.model = model
# Initialize optimizer if not provided
if optimizer is None:
optimizer = self._create_optimizer()
# Initialize scheduler if not provided and get_linear_schedule_with_warmup is available
if scheduler is None and get_linear_schedule_with_warmup is not None:
# We'll create it later when we know the total training steps
pass
# Initialize parent class
super().__init__(config, self.model, optimizer, scheduler, device)
# Initialize losses
self.infonce_loss = InfoNCELoss(
temperature=config.temperature,
reduction='mean'
)
self.kd_loss = M3KnowledgeDistillationLoss(
temperature=config.self_distill_temperature,
alpha=config.self_distill_alpha,
distill_types=['dense', 'sparse', 'colbert']
)
self.multi_task_loss = MultiTaskLoss(
loss_weights={
'dense': config.dense_weight,
'sparse': config.sparse_weight,
'colbert': config.colbert_weight,
'distillation': config.distillation_weight
}
)
# Initialize hard negative miner
if config.use_hard_negatives and ANCE_AVAILABLE:
self.hard_negative_miner = ANCEHardNegativeMiner(
embedding_dim=self.model.config.hidden_size, # type: ignore[attr-defined]
num_hard_negatives=config.num_hard_negatives,
negative_range=config.ance_negative_range,
margin=config.ance_margin,
index_refresh_interval=config.hard_negative_mining_steps,
use_gpu=str(self.device) == "cuda"
)
self.hard_negative_miner.start_async_updates()
else:
if config.use_hard_negatives and not ANCE_AVAILABLE:
logger.warning("Hard negative mining requested but ANCEHardNegativeMiner not available")
self.hard_negative_miner = None
logger.info(f"Initialized BGE-M3 trainer with config: {config.to_dict()}")
def _create_optimizer(self) -> torch.optim.Optimizer:
"""Create optimizer with parameter groups"""
# Separate parameters for different learning rates if needed
param_groups = [
{
'params': [p for n, p in self.model.named_parameters()
if 'colbert_linear' not in n and 'sparse_linear' not in n],
'lr': self.config.learning_rate
}
]
# Different LR for task-specific layers
if self.config.use_colbert and hasattr(self.model, 'colbert_linear'):
param_groups.append({
'params': self.model.colbert_linear.parameters(),
'lr': self.config.learning_rate * 2 # Higher LR for new layers
})
if self.config.use_sparse and hasattr(self.model, 'sparse_linear'):
param_groups.append({
'params': self.model.sparse_linear.parameters(),
'lr': self.config.learning_rate * 2
})
optimizer = AdamW(
param_groups,
eps=self.config.adam_epsilon,
weight_decay=self.config.weight_decay,
betas=(self.config.adam_beta1, self.config.adam_beta2)
)
return optimizer
def create_scheduler(self, num_training_steps: int):
"""Create learning rate scheduler"""
if get_linear_schedule_with_warmup is None:
logger.warning("transformers not available, using constant learning rate")
return None
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return self.scheduler
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute M3 loss with multiple objectives"""
# Update hard negative mining if hook is available
if hasattr(self, '_mining_hook'):
self._mining_hook()
# Detect batch format and handle accordingly
if 'pos_input_ids' in batch and 'neg_input_ids' in batch:
# This is triplet format from FastM3Dataset
return self._compute_triplet_loss(batch)
elif 'passage_input_ids' in batch:
# This is standard embedding format
return self._compute_embedding_loss(batch)
else:
raise ValueError(f"Unknown batch format. Keys: {list(batch.keys())}")
def _compute_triplet_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for triplet format data (FastM3Dataset)"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get positive embeddings
pos_outputs = self.model.forward(
input_ids=batch['pos_input_ids'],
attention_mask=batch['pos_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get negative embeddings
neg_outputs = self.model.forward(
input_ids=batch['neg_input_ids'],
attention_mask=batch['neg_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
losses = {}
# Dense loss (InfoNCE with explicit pos/neg)
if 'dense' in query_outputs:
query_dense = query_outputs['dense'].float() # type: ignore[index]
pos_dense = pos_outputs['dense'].float() # type: ignore[index]
neg_dense = neg_outputs['dense'].float() # type: ignore[index]
# For triplet format, reshape negative embeddings to expected 3D format
# InfoNCE expects negative_embeddings: [batch_size, num_negatives, embedding_dim]
# But we have single negatives: [batch_size, embedding_dim]
# Reshape to [batch_size, 1, embedding_dim]
neg_dense_3d = neg_dense.unsqueeze(1) # [batch_size, 1, embedding_dim]
dense_loss = self.infonce_loss(
query_dense,
pos_dense,
neg_dense_3d # Now properly shaped 3D tensor
)
losses['dense'] = dense_loss
# Sparse loss
if self.config.use_sparse and 'sparse' in query_outputs:
from models.losses import SparseLoss
sparse_loss_fn = SparseLoss(
flops_loss_weight=1e-3,
sparse_reg_weight=1e-4
)
# Combine pos and neg for sparse loss
combined_sparse = torch.cat([pos_outputs['sparse'], neg_outputs['sparse']], dim=0) # type: ignore[index]
query_sparse_repeated = query_outputs['sparse'].repeat(2, 1) # type: ignore[index]
labels = torch.cat([torch.ones(len(pos_outputs['sparse'])), # type: ignore[index]
torch.zeros(len(neg_outputs['sparse']))], dim=0).to(query_sparse_repeated.device) # type: ignore[index]
sparse_loss, _ = sparse_loss_fn(
query_sparse_repeated,
combined_sparse,
labels
)
losses['sparse'] = sparse_loss
# ColBERT loss
if self.config.use_colbert and 'colbert' in query_outputs:
from models.losses import ColBERTLoss
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
# Combine pos and neg for ColBERT loss
combined_colbert = torch.cat([pos_outputs['colbert'], neg_outputs['colbert']], dim=0) # type: ignore[index]
query_colbert_repeated = query_outputs['colbert'].repeat(2, 1, 1) # type: ignore[index]
labels = torch.cat([torch.ones(len(pos_outputs['colbert'])), # type: ignore[index]
torch.zeros(len(neg_outputs['colbert']))], dim=0).to(query_colbert_repeated.device) # type: ignore[index]
colbert_loss = colbert_loss_fn(
query_colbert_repeated,
combined_colbert,
labels
)
losses['colbert'] = colbert_loss
# Combine losses
total_loss, _ = self.multi_task_loss(losses)
return total_loss
def _compute_embedding_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for standard embedding format data"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get passage embeddings - handle both single and multiple passages
if batch['passage_input_ids'].dim() == 3:
# Multiple passages per query [batch, num_passages, seq_len]
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
passage_input_ids = batch['passage_input_ids'].view(batch_size * num_passages, seq_len)
passage_attention_mask = batch['passage_attention_mask'].view(batch_size * num_passages, seq_len)
multiple_passages = True
else:
# Single passage per query [batch, seq_len]
passage_input_ids = batch['passage_input_ids']
passage_attention_mask = batch['passage_attention_mask']
batch_size = passage_input_ids.shape[0]
num_passages = 1
multiple_passages = False
passage_outputs = self.model.forward(
input_ids=passage_input_ids,
attention_mask=passage_attention_mask,
return_dense=True,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Compute losses for different components
losses = {}
# Dense loss (InfoNCE)
if 'dense' in query_outputs and 'dense' in passage_outputs:
# Get embeddings and ensure they're on the same device
query_dense = query_outputs['dense'].float() # type: ignore[index] # [batch_size, embedding_dim]
passage_dense = passage_outputs['dense'].float() # type: ignore[index] # [batch_size * num_passages, embedding_dim]
# Handle multiple passages per query differently for contrastive learning
if multiple_passages:
# Reshape passage embeddings back to [batch_size, num_passages, embedding_dim]
passage_dense = passage_dense.view(batch_size, num_passages, -1)
# For contrastive learning, we need to identify positive passages
# Use the labels to find the first positive passage for each query
labels = batch.get('labels') # [batch_size, num_passages]
if labels is not None:
# Find the first positive passage for each query
positive_indices = []
negative_passages = []
for i in range(batch_size):
query_labels = labels[i] # [num_passages]
positive_mask = query_labels > 0
if positive_mask.any():
# Get first positive passage
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
positive_indices.append(pos_idx)
# Collect negative passages
neg_mask = ~positive_mask
if neg_mask.any():
neg_embeddings = passage_dense[i][neg_mask] # [num_negatives, embedding_dim]
negative_passages.append(neg_embeddings)
else:
negative_passages.append(None)
else:
# No positive passage, use first passage as positive (fallback)
positive_indices.append(0)
negative_passages.append(passage_dense[i][1:])
# Extract positive embeddings
positive_embeddings = torch.stack([
passage_dense[i, positive_indices[i]] for i in range(batch_size)
]) # [batch_size, embedding_dim]
# Stack negative embeddings (if available)
max_neg = max(neg.shape[0] if neg is not None else 0 for neg in negative_passages)
if max_neg > 0:
# Pad negative embeddings to same size
padded_negatives = []
for neg in negative_passages:
if neg is not None and neg.shape[0] > 0:
if neg.shape[0] < max_neg:
# Pad with zeros
padding = torch.zeros(max_neg - neg.shape[0], neg.shape[1], device=neg.device)
neg = torch.cat([neg, padding], dim=0)
padded_negatives.append(neg)
else:
# No negatives, create zero tensor
padded_negatives.append(torch.zeros(max_neg, passage_dense.shape[-1], device=passage_dense.device))
negative_embeddings = torch.stack(padded_negatives) # [batch_size, max_neg, embedding_dim]
else:
negative_embeddings = None
# Compute InfoNCE loss with explicit positives and negatives
dense_loss = self.infonce_loss(
query_dense,
positive_embeddings,
negative_embeddings
)
else:
# No labels provided, use first passage as positive, rest as negatives
positive_embeddings = passage_dense[:, 0, :] # [batch_size, embedding_dim]
if num_passages > 1:
negative_embeddings = passage_dense[:, 1:, :] # [batch_size, num_passages-1, embedding_dim]
else:
negative_embeddings = None
dense_loss = self.infonce_loss(
query_dense,
positive_embeddings,
negative_embeddings
)
else:
# Single passage per query - use standard in-batch negatives
# InfoNCE will create contrastive pairs using in-batch negatives
dense_loss = self.infonce_loss(
query_dense,
passage_dense,
None # No explicit negatives, will use in-batch
)
losses['dense'] = dense_loss
# Sparse loss (if using sparse)
if self.config.use_sparse and 'sparse' in query_outputs and 'sparse' in passage_outputs:
from models.losses import SparseLoss
sparse_loss_fn = SparseLoss(
flops_loss_weight=1e-3,
sparse_reg_weight=1e-4
)
query_sparse = query_outputs['sparse'] # type: ignore[index] # [batch_size, query_seq_len]
passage_sparse = passage_outputs['sparse'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len]
# For sparse loss, we need to handle the contrastive learning differently
# Since sparse embeddings have different sequence lengths, we'll use the same approach as dense
if multiple_passages:
# Use only the first positive passage for each query
labels = batch.get('labels')
if labels is not None:
positive_indices = []
for i in range(batch_size):
query_labels = labels[i]
positive_mask = query_labels > 0
if positive_mask.any():
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
else:
pos_idx = 0
positive_indices.append(pos_idx + i * num_passages)
# Extract positive sparse embeddings
positive_sparse = passage_sparse[positive_indices] # [batch_size, doc_seq_len]
else:
# Use first passage as positive
positive_sparse = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
# Create contrastive labels for sparse loss
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
sparse_loss, _ = sparse_loss_fn(
query_sparse,
positive_sparse,
contrastive_labels
)
else:
# Single passage per query
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
sparse_loss, _ = sparse_loss_fn(
query_sparse,
passage_sparse,
contrastive_labels
)
losses['sparse'] = sparse_loss
# ColBERT loss (if using ColBERT)
if self.config.use_colbert and 'colbert' in query_outputs and 'colbert' in passage_outputs:
from models.losses import ColBERTLoss
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
query_colbert = query_outputs['colbert'] # type: ignore[index] # [batch_size, query_seq_len, dim]
passage_colbert = passage_outputs['colbert'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len, dim]
# For ColBERT, we need pairwise interactions
if multiple_passages:
# Use only positive passages
labels = batch.get('labels')
if labels is not None:
positive_indices = []
for i in range(batch_size):
query_labels = labels[i]
positive_mask = query_labels > 0
if positive_mask.any():
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
else:
pos_idx = 0
positive_indices.append(pos_idx + i * num_passages)
positive_colbert = passage_colbert[positive_indices]
else:
positive_colbert = passage_colbert[::num_passages]
# Binary labels for ColBERT (1 for positive pairs)
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
else:
positive_colbert = passage_colbert
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
colbert_loss = colbert_loss_fn(
query_colbert,
positive_colbert,
colbert_labels
)
losses['colbert'] = colbert_loss
# Knowledge distillation loss (only if explicitly enabled)
if (getattr(self.config, 'use_self_distill', False) and
self.global_step >= getattr(self.config, 'self_distill_start_step', 0)):
try:
# Compute teacher scores (weighted combination)
teacher_scores = self._compute_teacher_scores(query_outputs, passage_outputs) # type: ignore[arg-type]
student_scores = self.model.compute_similarity(
query_outputs, passage_outputs, mode='dense'
)
kd_loss, _ = self.kd_loss(
{'dense': student_scores},
{'dense': teacher_scores}
)
losses['kd'] = kd_loss
except Exception as e:
logger.warning(f"Knowledge distillation failed: {e}")
# Combine losses
total_loss, _ = self.multi_task_loss(losses)
return total_loss
def _compute_teacher_scores(
self,
query_outputs: Dict[str, torch.Tensor],
passage_outputs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Compute teacher scores from multiple signals"""
scores = []
weights = []
# Get original batch size for reshaping
query_batch_size = query_outputs['dense'].shape[0] if 'dense' in query_outputs else 1
passage_batch_size = passage_outputs['dense'].shape[0] if 'dense' in passage_outputs else 1
# Check if we have multiple passages per query
multiple_passages = passage_batch_size > query_batch_size
if multiple_passages:
# For multiple passages, we need to compute scores only for positive pairs
# Use only the first passage for each query as a simplified teacher score
num_passages = passage_batch_size // query_batch_size
# Dense scores (use only first passage per query)
if 'dense' in query_outputs and 'dense' in passage_outputs:
query_dense = query_outputs['dense'] # [batch_size, dim]
passage_dense = passage_outputs['dense'] # [batch_size * num_passages, dim]
# Extract first passage for each query
positive_passages = passage_dense[::num_passages] # [batch_size, dim]
try:
dense_scores = self.model.compute_similarity(
{'dense': query_dense},
{'dense': positive_passages},
mode='dense'
)
scores.append(dense_scores)
weights.append(self.config.dense_weight)
except Exception as e:
logger.warning(f"Dense similarity computation failed: {e}")
# Sparse scores (use only first passage per query)
if 'sparse' in query_outputs and 'sparse' in passage_outputs:
query_sparse = query_outputs['sparse'] # [batch_size, query_seq_len]
passage_sparse = passage_outputs['sparse'] # [batch_size * num_passages, doc_seq_len]
# Extract first passage for each query
positive_passages = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
try:
sparse_scores = self.model.compute_similarity(
{'sparse': query_sparse},
{'sparse': positive_passages},
mode='sparse'
)
scores.append(sparse_scores)
weights.append(self.config.sparse_weight)
except Exception as e:
logger.warning(f"Sparse similarity computation failed: {e}")
# ColBERT scores (use only first passage per query)
if 'colbert' in query_outputs and 'colbert' in passage_outputs:
query_colbert = query_outputs['colbert'] # [batch_size, query_seq_len, dim]
passage_colbert = passage_outputs['colbert'] # [batch_size * num_passages, doc_seq_len, dim]
# Extract first passage for each query
positive_passages = passage_colbert[::num_passages] # [batch_size, doc_seq_len, dim]
try:
colbert_scores = self.model.compute_similarity(
{'colbert': query_colbert},
{'colbert': positive_passages},
mode='colbert'
)
scores.append(colbert_scores)
weights.append(self.config.colbert_weight)
except Exception as e:
logger.warning(f"ColBERT similarity computation failed: {e}")
else:
# Single passage per query - original logic
# Dense scores
if 'dense' in query_outputs:
try:
dense_scores = self.model.compute_similarity(
{'dense': query_outputs['dense']},
{'dense': passage_outputs['dense']},
mode='dense'
)
scores.append(dense_scores)
weights.append(self.config.dense_weight)
except Exception as e:
logger.warning(f"Dense similarity computation failed: {e}")
# Sparse scores
if 'sparse' in query_outputs:
try:
sparse_scores = self.model.compute_similarity(
{'sparse': query_outputs['sparse']},
{'sparse': passage_outputs['sparse']},
mode='sparse'
)
scores.append(sparse_scores)
weights.append(self.config.sparse_weight)
except Exception as e:
logger.warning(f"Sparse similarity computation failed: {e}")
# ColBERT scores
if 'colbert' in query_outputs:
try:
colbert_scores = self.model.compute_similarity(
{'colbert': query_outputs['colbert']},
{'colbert': passage_outputs['colbert']},
mode='colbert'
)
scores.append(colbert_scores)
weights.append(self.config.colbert_weight)
except Exception as e:
logger.warning(f"ColBERT similarity computation failed: {e}")
# Weighted combination
if scores:
weights = torch.tensor(weights, device=scores[0].device)
weights = weights / weights.sum()
teacher_scores = torch.stack([w * s for w, s in zip(weights, scores)]).sum(dim=0)
return teacher_scores
# Fallback to zero scores if no valid similarity computation
fallback_size = query_batch_size if not multiple_passages else query_batch_size
return torch.zeros(fallback_size, device=query_outputs[list(query_outputs.keys())[0]].device)
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True,
return_dict=True
)
# Get passage embeddings
passage_outputs = self.model.forward(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True,
return_dict=True
)
# Compute loss
loss = self.compute_loss(batch)
# Compute similarity scores
scores = self.model.compute_similarity(
query_outputs, passage_outputs, mode='dense'
)
# Basic metrics
labels = batch.get('labels', torch.zeros(scores.shape[0]))
if labels.dim() > 1:
labels = labels.view(-1)
predictions = (scores > 0.5).float()
accuracy = (predictions == labels.float()).float().mean()
return {
'loss': loss.item(),
'accuracy': accuracy.item(),
'mean_score': scores.mean().item()
}
def train(
self,
train_dataset: BGEM3Dataset,
eval_dataset: Optional[BGEM3Dataset] = None,
resume_from_checkpoint: Optional[str] = None
):
"""
Main training loop for M3Trainer with robust error handling, config-driven parameters, and progress tracking.
"""
try:
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Calculate total training steps
num_update_steps_per_epoch = max(1, len(train_dataloader) // self.config.gradient_accumulation_steps) # type: ignore[arg-type]
total_training_steps = num_update_steps_per_epoch * self.config.num_train_epochs
# Create scheduler with total steps
self.create_scheduler(total_training_steps)
if resume_from_checkpoint:
self.load_checkpoint(resume_from_checkpoint)
logger.info(f"Starting training for {self.config.num_train_epochs} epochs")
logger.info(f"Total training steps: {total_training_steps}")
# Use parent's train method
super().train(train_dataloader, eval_dataloader, self.config.num_train_epochs) # type: ignore[arg-type]
except Exception as e:
logger.error(f"M3 training failed: {e}")
raise
finally:
# Cleanup
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
self.hard_negative_miner.stop_async_updates()
def _create_dataloader(
self,
dataset: Optional[BGEM3Dataset],
is_train: bool = True
) -> Optional[DataLoader]:
"""Create data loader"""
if dataset is None:
return None
# Import both collate functions
from data.dataset import collate_embedding_batch
# Choose appropriate collate function based on dataset type
collate_fn = collate_embedding_batch
num_workers = self.config.dataloader_num_workers
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
# Setup distributed sampler if needed
sampler = None
if self.is_distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=is_train)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train and sampler is None,
sampler=sampler,
num_workers=num_workers,
pin_memory=self.config.dataloader_pin_memory,
drop_last=is_train and len(dataset) > batch_size, # Only drop last if we have multiple batches
collate_fn=collate_fn
)
return dataloader
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train one epoch - use parent's method with hard negative mining integration"""
# Add hard negative mining updates before calling parent's method
if self.hard_negative_miner:
# Set up a hook to update hard negatives during training
def mining_hook():
if self.hard_negative_miner:
self.hard_negative_miner.update_index_if_needed(self.global_step, self.model)
# Store the hook for use in compute_loss
self._mining_hook = mining_hook
# Use parent's robust training loop
return super().train_epoch(dataloader, epoch)
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""Evaluate the model"""
self.model.eval()
all_query_embeddings = []
all_passage_embeddings = []
all_labels = []
with torch.no_grad():
with tqdm(dataloader, desc="Evaluating") as tbar:
for batch in tbar:
try:
batch = self._prepare_batch(batch)
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True,
return_dict=True
)
passage_outputs = self.model.forward(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True,
return_dict=True
)
all_query_embeddings.append(query_outputs['dense'].cpu()) # type: ignore[index]
all_passage_embeddings.append(passage_outputs['dense'].cpu()) # type: ignore[index]
all_labels.append(batch['labels'].cpu())
except Exception as e:
logger.error(f"Error in M3 evaluation step: {e}")
continue
if not all_query_embeddings:
logger.warning("No valid evaluation batches processed")
return {'loss': float('inf'), 'accuracy': 0.0}
# Concatenate all embeddings
all_query_embeddings = torch.cat(all_query_embeddings, dim=0)
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Compute metrics
try:
metrics = compute_retrieval_metrics(
all_query_embeddings.numpy(),
all_passage_embeddings.numpy(),
all_labels.numpy(),
k_values=[1, 3, 5, 10]
)
except Exception as e:
logger.error(f"Error computing retrieval metrics: {e}")
metrics = {'loss': float('inf'), 'accuracy': 0.0}
return metrics
def save_model(self, save_path: str):
"""Save the trained model"""
try:
os.makedirs(save_path, exist_ok=True)
# Save model
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(save_path) # type: ignore[attr-defined]
# Save tokenizer
self.tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
# Save config
self.config.save(os.path.join(save_path, 'training_config.json'))
logger.info(f"Model saved to {save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
def __del__(self):
"""Cleanup when trainer is destroyed"""
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
self.hard_negative_miner.stop_async_updates()

View File

@ -0,0 +1,489 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from typing import Dict, Optional, List, Tuple, Any
import logging
import numpy as np
from torch.nn import functional as F
from tqdm import tqdm
import os
from .trainer import BaseTrainer
from models.bge_reranker import BGERerankerModel
from models.losses import ListwiseCrossEntropyLoss, PairwiseMarginLoss
from data.dataset import BGERerankerDataset
from config.reranker_config import BGERerankerConfig
from evaluation.metrics import compute_reranker_metrics
from utils.distributed import is_main_process
# AMP imports
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
# Scheduler import
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = logging.getLogger(__name__)
class BGERerankerTrainer(BaseTrainer):
"""
Trainer for BGE-Reranker cross-encoder model
Notes:
- All parameters (config, model, tokenizer, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BGERerankerConfig,
model: Optional[BGERerankerModel] = None,
tokenizer=None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None
):
# Initialize model if not provided
if model is None:
model = BGERerankerModel(
model_name_or_path=config.model_name_or_path,
num_labels=config.num_labels,
cache_dir=config.cache_dir,
use_fp16=config.fp16,
gradient_checkpointing=config.gradient_checkpointing
)
# Initialize optimizer if not provided
if optimizer is None:
optimizer = self._create_optimizer(model, config)
# Initialize parent class
super().__init__(config, model, optimizer, scheduler, device)
self.tokenizer = tokenizer
self.config = config
# Initialize losses based on config
if config.loss_type == "listwise_ce":
self.loss_fn = ListwiseCrossEntropyLoss(
temperature=getattr(config, 'temperature', 1.0)
)
elif config.loss_type == "pairwise_margin":
self.loss_fn = PairwiseMarginLoss(margin=config.margin)
elif config.loss_type == "combined":
self.listwise_loss = ListwiseCrossEntropyLoss(
temperature=getattr(config, 'temperature', 1.0)
)
self.pairwise_loss = PairwiseMarginLoss(margin=config.margin)
else:
self.loss_fn = nn.BCEWithLogitsLoss()
# Knowledge distillation setup
self.use_kd = config.use_knowledge_distillation
if self.use_kd:
self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
self.kd_temperature = config.distillation_temperature
self.kd_loss_weight = config.distillation_alpha
# Hard negative configuration
self.hard_negative_ratio = config.hard_negative_ratio
self.use_denoising = config.use_denoising
self.denoise_threshold = getattr(config, 'denoise_confidence_threshold', 0.5)
def _create_optimizer(self, model: nn.Module, config: BGERerankerConfig) -> torch.optim.Optimizer:
"""Create optimizer with proper parameter groups"""
# Separate parameters
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{
'params': [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)],
'weight_decay': config.weight_decay,
},
{
'params': [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
}
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=config.learning_rate,
eps=config.adam_epsilon,
betas=(config.adam_beta1, config.adam_beta2)
)
return optimizer
def create_scheduler(self, num_training_steps: int) -> Optional[object]:
"""Create learning rate scheduler"""
if get_linear_schedule_with_warmup is None:
logger.warning("transformers not available, using constant learning rate")
return None
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return self.scheduler
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute reranker loss"""
# Get model outputs
outputs = self.model.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_dict=True
)
if isinstance(outputs, dict):
logits = outputs['logits']
else:
logits = outputs.scores
labels = batch['labels']
# Handle padding labels
if labels.dim() > 1:
# Remove padding (-1 labels)
valid_mask = labels != -1
if valid_mask.any():
logits = logits[valid_mask]
labels = labels[valid_mask]
else:
# If all labels are padding, return zero loss
return torch.tensor(0.0, device=logits.device, requires_grad=True)
# Handle different loss types
if self.config.loss_type == "listwise_ce":
# Listwise loss expects grouped format
try:
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
else:
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Listwise CE loss failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
elif self.config.loss_type == "pairwise_margin":
# Extract positive and negative scores
pos_mask = labels == 1
neg_mask = labels == 0
if pos_mask.any() and neg_mask.any():
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
try:
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Pairwise margin loss failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
else:
# Fallback to BCE if no proper pairs
logger.warning("No positive-negative pairs found for pairwise loss. Using BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
elif self.config.loss_type == "combined":
# Combine listwise and pairwise losses
try:
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
listwise_loss = self.listwise_loss(logits.squeeze(-1), labels.float(), groups)
pairwise_loss = self.pairwise_loss(logits.squeeze(-1), labels.float(), groups)
loss = self.config.listwise_weight * listwise_loss + self.config.pairwise_weight * pairwise_loss
else:
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Combined loss computation failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
else:
# Default binary cross-entropy
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
# Add knowledge distillation loss if enabled
if self.use_kd and 'teacher_scores' in batch:
teacher_scores = batch['teacher_scores']
student_probs = F.log_softmax(logits / self.kd_temperature, dim=-1)
teacher_probs = F.softmax(teacher_scores / self.kd_temperature, dim=-1)
kd_loss = self.kd_loss_fn(student_probs, teacher_probs) * (self.kd_temperature ** 2)
loss = (1 - self.kd_loss_weight) * loss + self.kd_loss_weight * kd_loss
# Apply denoising if enabled
if self.use_denoising and 'teacher_scores' in batch:
# Denoise by confidence thresholding
confidence_scores = torch.sigmoid(batch['teacher_scores'])
confidence_mask = confidence_scores > self.denoise_threshold
if confidence_mask.any():
loss = loss * confidence_mask.float().mean()
return loss
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch"""
with torch.no_grad():
# Get model outputs
outputs = self.model.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_dict=True
)
if isinstance(outputs, dict):
logits = outputs['logits']
else:
logits = outputs.scores
labels = batch['labels']
# Handle padding labels
if labels.dim() > 1:
valid_mask = labels != -1
if valid_mask.any():
logits = logits[valid_mask]
labels = labels[valid_mask]
else:
return {'loss': 0.0, 'accuracy': 0.0}
# Compute loss
loss = self.compute_loss(batch)
# Compute metrics based on group size
if hasattr(self.config, 'train_group_size') and self.config.train_group_size > 1:
# Listwise evaluation
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
try:
metrics = compute_reranker_metrics(
logits.squeeze(-1).cpu().numpy(),
labels.cpu().numpy(),
groups=groups,
k_values=[1, 3, 5, 10]
)
metrics['loss'] = loss.item()
except Exception as e:
logger.warning(f"Error computing reranker metrics: {e}")
metrics = {'loss': loss.item(), 'accuracy': 0.0}
else:
metrics = {'loss': loss.item(), 'accuracy': 0.0}
else:
# Binary evaluation
predictions = (torch.sigmoid(logits.squeeze(-1)) > 0.5).float()
accuracy = (predictions == labels.float()).float().mean()
metrics = {
'accuracy': accuracy.item(),
'loss': loss.item()
}
# Add AUC if we have both positive and negative examples
if len(torch.unique(labels)) > 1:
try:
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(labels.cpu().numpy(), torch.sigmoid(logits.squeeze(-1)).cpu().numpy())
metrics['auc'] = float(auc)
except Exception as e:
logger.warning(f"Error computing AUC: {e}")
return metrics
def train_with_hard_negative_mining(
self,
train_dataset: BGERerankerDataset,
eval_dataset: Optional[BGERerankerDataset] = None,
retriever_model: Optional[nn.Module] = None
):
"""
Train with periodic hard negative mining from retriever
"""
if retriever_model is None:
logger.warning("No retriever model provided for hard negative mining. Using standard training.")
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Only proceed if we have a valid training dataloader
if train_dataloader is not None:
return super().train(train_dataloader, eval_dataloader)
else:
logger.error("No valid training dataloader created")
return
# Training loop with hard negative updates
for epoch in range(self.config.num_train_epochs):
# Mine hard negatives at the beginning of each epoch
if epoch > 0:
logger.info(f"Mining hard negatives for epoch {epoch + 1}")
self._mine_hard_negatives(train_dataset, retriever_model)
# Create data loader with new negatives
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Only proceed if we have a valid training dataloader
if train_dataloader is not None:
# Train one epoch
super().train(train_dataloader, eval_dataloader, num_epochs=1)
else:
logger.error(f"No valid training dataloader for epoch {epoch + 1}")
continue
# Update current epoch
self.current_epoch = epoch
def _create_dataloader(self, dataset, is_train: bool = True) -> Optional[DataLoader]:
"""Create data loader"""
if dataset is None:
return None
from data.dataset import collate_reranker_batch
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
# Setup distributed sampler if needed
sampler = None
if self.is_distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=is_train)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train and sampler is None,
sampler=sampler,
num_workers=self.config.dataloader_num_workers,
pin_memory=self.config.dataloader_pin_memory,
drop_last=is_train,
collate_fn=collate_reranker_batch
)
def _mine_hard_negatives(
self,
dataset: BGERerankerDataset,
retriever_model: Any
):
"""Mine hard negatives using the retriever model and update the dataset's negatives."""
logger.info("Starting hard negative mining using retriever model...")
num_hard_negatives = getattr(self.config, 'num_hard_negatives', 10)
retriever_batch_size = getattr(self.config, 'retriever_batch_size', 32)
mined_count = 0
# Collect all unique candidate passages from the dataset
all_passages = set()
for item in dataset.data:
all_passages.update(item.get('pos', []))
all_passages.update(item.get('neg', []))
all_passages = list(all_passages)
if not all_passages:
logger.warning("No candidate passages found in dataset for hard negative mining.")
return
# Encode all candidate passages
try:
passage_embeddings = retriever_model.encode(
all_passages,
batch_size=retriever_batch_size,
convert_to_numpy=True
)
except Exception as e:
logger.error(f"Error encoding passages for hard negative mining: {e}")
return
# For each query, mine hard negatives
for idx in range(len(dataset)):
item = dataset.data[idx]
query = item['query']
positives = set(item.get('pos', []))
existing_negs = set(item.get('neg', []))
try:
query_emb = retriever_model.encode([query], batch_size=1, convert_to_numpy=True)[0]
# Compute similarity
scores = np.dot(passage_embeddings, query_emb)
# Rank passages by similarity
ranked_indices = np.argsort(scores)[::-1]
new_negs = []
for i in ranked_indices:
passage = all_passages[int(i)] # Convert to int for indexing
if passage not in positives and passage not in existing_negs:
new_negs.append(passage)
if len(new_negs) >= num_hard_negatives:
break
if new_negs:
item['neg'] = list(existing_negs) + new_negs
mined_count += 1
logger.debug(f"Query idx {idx}: Added {len(new_negs)} hard negatives.")
except Exception as e:
logger.error(f"Error mining hard negatives for query idx {idx}: {e}")
logger.info(f"Hard negative mining complete. Updated negatives for {mined_count} queries.")
def save_model(self, save_path: str):
"""Save the trained model"""
try:
os.makedirs(save_path, exist_ok=True)
# Save model
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
if hasattr(model_to_save, 'save_pretrained'):
model_to_save.save_pretrained(save_path) # type: ignore
else:
# Fallback to torch.save
torch.save(model_to_save.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) # type: ignore
# Save tokenizer if available
if self.tokenizer and hasattr(self.tokenizer, 'save_pretrained'):
self.tokenizer.save_pretrained(save_path)
# Save config
if hasattr(self.config, 'save'):
self.config.save(os.path.join(save_path, 'training_config.json'))
logger.info(f"Model saved to {save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
def export_for_inference(self, output_path: str):
"""Export model for efficient inference"""
if is_main_process():
logger.info(f"Exporting model for inference to {output_path}")
# Save model in inference mode
self.model.eval()
# Save model state
torch.save({
'model_state_dict': self.model.state_dict(),
'config': self.config.to_dict(),
'model_type': 'bge_reranker'
}, output_path)
# Also save in HuggingFace format if possible
try:
hf_path = output_path.replace('.pt', '_hf')
self.save_model(hf_path)
except Exception as e:
logger.warning(f"Failed to save in HuggingFace format: {e}")
logger.info(f"Model exported successfully")

534
training/trainer.py Normal file
View File

@ -0,0 +1,534 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing import Dict, Optional, List, Tuple, Any, Union, cast
import os
import json
import logging
from tqdm import tqdm
from abc import ABC, abstractmethod
import time
from datetime import datetime
import numpy as np
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
from torch.nn.utils.clip_grad import clip_grad_norm_
from collections import defaultdict
from utils.checkpoint import CheckpointManager
from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger
from utils.distributed import (
setup_distributed, cleanup_distributed, is_main_process,
reduce_dict, synchronize, get_rank, get_world_size
)
from config.base_config import BaseConfig
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = logging.getLogger(__name__)
class BaseTrainer(ABC):
"""
Abstract base trainer class for BGE models
Notes:
- All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BaseConfig,
model: nn.Module,
optimizer: Optimizer,
scheduler: Optional[_LRScheduler] = None,
device: Optional[torch.device] = None
):
self.config = config
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
# Setup device
if device is None:
self.device = torch.device(config.device)
else:
self.device = device
# Move model to device
self.model.to(self.device)
# Setup distributed training if needed
self.is_distributed = config.local_rank != -1
if self.is_distributed:
setup_distributed()
self.model = self._setup_distributed_model()
# Setup mixed precision training
self.use_amp = config.fp16 or config.bf16
if self.use_amp and config.bf16:
# Use bfloat16 for better numerical stability on RTX5090/A100
self.amp_dtype = torch.bfloat16
else:
self.amp_dtype = torch.float16
self.scaler = GradScaler() if self.use_amp and not config.bf16 else None
# Memory optimization settings
self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100)
self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
# Setup logging
self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None
self.tb_logger = TensorBoardLogger(
os.path.join(config.logging_dir, 'tensorboard'),
enabled=is_main_process()
)
self.progress_logger = None # Will be initialized in train()
# Setup checkpointing
self.checkpoint_manager = CheckpointManager(
output_dir=config.output_dir,
save_total_limit=config.save_total_limit
) if is_main_process() else None
# Training state
self.global_step = 0
self.current_epoch = 0
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
self.training_history = []
def _setup_distributed_model(self) -> nn.Module:
"""Setup model for distributed training"""
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
self.model,
device_ids=[self.config.local_rank],
output_device=self.config.local_rank,
find_unused_parameters=True
)
return model
@abstractmethod
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for a batch - to be implemented by subclasses"""
pass
@abstractmethod
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch - to be implemented by subclasses"""
pass
def train(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
num_epochs: Optional[int] = None
) -> Dict[str, Any]:
"""
Main training loop with robust error handling, config-driven parameters, and progress tracking.
"""
try:
num_epochs = num_epochs or self.config.num_train_epochs
total_steps = len(train_dataloader) * num_epochs
self.progress_logger = ProgressLogger(
total_steps=total_steps,
log_interval=self.config.logging_steps
)
if is_main_process():
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined]
logger.info(f" Num epochs = {num_epochs}")
logger.info(f" Batch size = {self.config.per_device_train_batch_size}")
logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {total_steps}")
logger.info(f" Device = {self.device}")
if self.is_distributed:
logger.info(f" Distributed training with {get_world_size()} GPUs")
for epoch in range(self.current_epoch, num_epochs):
self.current_epoch = epoch
epoch_start_time = time.time()
train_metrics = self.train_epoch(train_dataloader, epoch)
# Add metrics to training history
epoch_metrics = {
'epoch': epoch + 1,
'train_metrics': train_metrics,
'global_step': self.global_step,
'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0,
}
if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0:
eval_metrics = self.evaluate(eval_dataloader)
epoch_metrics['eval_metrics'] = eval_metrics
current_metric = eval_metrics.get(self.config.metric_for_best_model, 0)
is_best = self._is_better_metric(current_metric)
if is_best:
self.best_metric = current_metric
if is_main_process() and self.checkpoint_manager:
self._save_checkpoint(is_best=is_best, metrics=eval_metrics)
if is_main_process():
logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}")
# Add to training history
self.training_history.append(epoch_metrics)
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
self._save_checkpoint(metrics=train_metrics)
epoch_time = time.time() - epoch_start_time
if is_main_process():
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds")
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
if is_main_process():
self._save_checkpoint(metrics=train_metrics)
self._save_training_summary()
# Close TensorBoard logger
self.tb_logger.close()
if self.is_distributed:
cleanup_distributed()
return self.training_history # type: ignore[return-value]
except Exception as e:
logger.error(f"Training failed: {e}")
raise
def train_epoch(
self,
dataloader: DataLoader,
epoch: int
) -> Dict[str, float]:
"""
Train for one epoch with comprehensive error handling and recovery
Args:
dataloader: Training data loader
epoch: Current epoch number
Returns:
Dictionary of training metrics for the epoch
Raises:
RuntimeError: If training fails and cannot recover
"""
self.model.train()
epoch_loss = 0.0
epoch_metrics = defaultdict(float)
num_batches = len(dataloader)
# Track failed batches for recovery
failed_batches = []
max_failures = 3
# Training loop with error recovery
progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch + 1}",
disable=not is_main_process()
)
for step, batch in enumerate(progress_bar):
try:
# Move batch to device with error handling
try:
batch = self._prepare_batch(batch)
except Exception as e:
logger.warning(f"Failed to prepare batch {step}: {e}")
failed_batches.append((step, 'prepare', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}")
continue
# Compute loss with error recovery
loss: torch.Tensor
try:
if self.use_amp:
with autocast(enabled=True, dtype=self.amp_dtype):
loss = self.compute_loss(batch)
loss = cast(torch.Tensor, loss)
else:
loss = self.compute_loss(batch)
loss = cast(torch.Tensor, loss)
except Exception as e:
logger.warning(f"Failed to compute loss for batch {step}: {e}")
failed_batches.append((step, 'forward', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}")
continue
# Scale loss for gradient accumulation
loss = loss / self.config.gradient_accumulation_steps
# Backward pass with error handling
try:
if self.scaler:
self.scaler.scale(loss).backward() # type: ignore[attr-defined]
else:
loss.backward() # type: ignore[attr-defined]
except Exception as e:
logger.warning(f"Backward pass failed for batch {step}: {e}")
# Clear gradients and continue
self.optimizer.zero_grad()
failed_batches.append((step, 'backward', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}")
continue
# Update weights with error handling
if (step + 1) % self.config.gradient_accumulation_steps == 0:
try:
# Gradient clipping
if self.scaler:
self.scaler.unscale_(self.optimizer)
clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# Optimizer step
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
# Scheduler step
if self.scheduler:
self.scheduler.step()
# Update global step
self.global_step += 1
# Clear cache periodically for memory efficiency
if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0:
self._clear_device_cache()
# Zero gradients
self.optimizer.zero_grad()
except Exception as e:
logger.warning(f"Optimizer update failed at step {step}: {e}")
# Reset optimizer state
self.optimizer.zero_grad()
if self.scaler:
self.scaler.update()
failed_batches.append((step, 'optimizer', str(e)))
continue
# Logging with error handling
if self.global_step % self.config.logging_steps == 0:
try:
current_loss = loss.item() * self.config.gradient_accumulation_steps
epoch_loss += current_loss
# Update progress bar
progress_bar.set_postfix({
'loss': f'{current_loss:.4f}',
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A',
'failures': len(failed_batches)
})
# Log to TensorBoard
if self.tb_logger:
self.tb_logger.log_scalar('train/loss', current_loss, self.global_step)
self.tb_logger.log_scalar('train/learning_rate',
self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'],
self.global_step)
except Exception as e:
logger.debug(f"Logging failed: {e}")
# Checkpoint saving with error handling
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
try:
if self.checkpoint_manager:
state_dict = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'global_step': self.global_step,
'current_epoch': epoch,
'config': vars(self.config)
}
self.checkpoint_manager.save(state_dict, self.global_step)
except Exception as e:
logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}")
# Continue training even if checkpoint fails
except Exception as e:
# Catch-all for any unexpected errors
logger.error(f"Unexpected error in training step {step}: {e}")
failed_batches.append((step, 'unknown', str(e)))
if len(failed_batches) > max_failures * 2:
raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}")
continue
# Log epoch summary
if failed_batches:
logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches")
for step, stage, error in failed_batches[:5]: # Show first 5 failures
logger.debug(f" Batch {step} failed at {stage}: {error}")
# Compute epoch metrics
epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1)
epoch_metrics['failed_batches'] = len(failed_batches)
epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0
return epoch_metrics
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""Evaluate the model"""
self.model.eval()
total_loss = 0
num_batches = 0
all_metrics = {}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()):
batch = self._prepare_batch(batch)
# Compute metrics
metrics = self.evaluate_step(batch)
# Accumulate metrics
for key, value in metrics.items():
if key not in all_metrics:
all_metrics[key] = 0
all_metrics[key] += value
num_batches += 1
# Average metrics
for key in all_metrics:
all_metrics[key] /= num_batches
# Gather metrics across devices
if self.is_distributed:
all_metrics = reduce_dict(all_metrics)
return all_metrics
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""Move tensor values in batch dict to the target device, leave non-tensors untouched."""
moved = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
moved[k] = v.to(self.device, non_blocking=True)
else:
moved[k] = v # keep lists / strings / ints as-is
return moved
def _clear_device_cache(self):
"""Clear device memory cache with support for CUDA and Ascend NPU"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# Try Ascend support if available
try:
from utils.ascend_support import clear_ascend_cache
clear_ascend_cache()
except ImportError:
# Fallback to manual NPU check
try:
npu = getattr(torch, 'npu', None)
if npu is not None and hasattr(npu, 'is_available') and npu.is_available():
npu.empty_cache()
except Exception:
pass
except Exception as e:
logger.debug(f"Failed to clear device cache: {e}")
def _is_better_metric(self, current: float) -> bool:
"""Check if current metric is better than best"""
if self.config.greater_is_better:
return current > self.best_metric
else:
return current < self.best_metric
def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None):
"""Save model checkpoint"""
if not is_main_process() or not self.checkpoint_manager:
return
checkpoint = {
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined]
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'config': self.config.to_dict(),
'global_step': self.global_step,
'current_epoch': self.current_epoch,
'best_metric': self.best_metric
}
if self.scaler:
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
self.checkpoint_manager.save(
checkpoint,
step=self.global_step,
is_best=is_best,
metrics=metrics
)
def _save_training_summary(self):
"""Save training summary"""
if not is_main_process():
return
summary = {
'config': self.config.to_dict(),
'final_metrics': self.training_history[-1] if self.training_history else {},
'best_metric': self.best_metric,
'total_steps': self.global_step,
'total_epochs': self.current_epoch,
'training_time': datetime.now().isoformat()
}
summary_path = os.path.join(self.config.output_dir, 'training_summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
def load_checkpoint(self, checkpoint_path: str):
"""Load checkpoint"""
logger.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
if self.is_distributed:
self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
else:
self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load scheduler state
if self.scheduler and checkpoint.get('scheduler_state_dict'):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Load scaler state
if self.scaler and checkpoint.get('scaler_state_dict'):
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
# Load training state
self.global_step = checkpoint.get('global_step', 0)
self.current_epoch = checkpoint.get('current_epoch', 0)
self.best_metric = checkpoint.get('best_metric', float('-inf'))
logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")

127
utils/__init__.py Normal file
View File

@ -0,0 +1,127 @@
"""
Utility modules for logging, checkpointing, and distributed training
"""
from .logging import (
setup_logging,
get_logger,
MetricsLogger,
TensorBoardLogger,
ProgressLogger,
log_model_info,
log_training_summary
)
from .checkpoint import CheckpointManager, save_model_for_inference
from .distributed import (
setup_distributed,
cleanup_distributed,
is_main_process,
get_rank,
get_world_size,
get_local_rank,
synchronize,
all_reduce,
all_gather,
broadcast,
reduce_dict,
setup_model_for_distributed,
create_distributed_dataloader,
set_random_seed,
print_distributed_info,
DistributedMetricAggregator,
save_on_master,
log_on_master
)
import os
import sys
import logging
logger = logging.getLogger(__name__)
def optional_import(module_name, feature_desc=None, warn_only=True):
"""Try to import an optional module. Log a warning if not found. Returns the module or None."""
import importlib
try:
return importlib.import_module(module_name)
except ImportError:
msg = f"Optional dependency '{module_name}' not found."
if feature_desc:
msg += f" {feature_desc} will be unavailable."
if warn_only:
logger.warning(msg)
else:
logger.error(msg)
return None
# Import guard for optional dependencies
def _check_optional_dependencies():
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
# Check for torch_npu if running on Ascend
torch = optional_import('torch', warn_only=True)
if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'):
optional_import('torch_npu', feature_desc='Ascend NPU support')
# Check for wandb if Weights & Biases logging is enabled
try:
from utils.config_loader import get_config
config = get_config()
use_wandb = False
try:
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
except Exception:
pass
if use_wandb:
optional_import('wandb', feature_desc='Weights & Biases logging')
except Exception:
pass
# Check for tensorboard if tensorboard logging is likely to be used
tensorboard_dir = None
try:
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
except Exception:
pass
if tensorboard_dir:
optional_import('tensorboard', feature_desc='TensorBoard logging')
# Check for pytest if running tests
import sys
if any('pytest' in arg for arg in sys.argv):
optional_import('pytest', feature_desc='Some tests may not run')
__all__ = [
# Logging
'setup_logging',
'get_logger',
'MetricsLogger',
'TensorBoardLogger',
'ProgressLogger',
'log_model_info',
'log_training_summary',
# Checkpointing
'CheckpointManager',
'save_model_for_inference',
# Distributed training
'setup_distributed',
'cleanup_distributed',
'is_main_process',
'get_rank',
'get_world_size',
'get_local_rank',
'synchronize',
'all_reduce',
'all_gather',
'broadcast',
'reduce_dict',
'setup_model_for_distributed',
'create_distributed_dataloader',
'set_random_seed',
'print_distributed_info',
'DistributedMetricAggregator',
'save_on_master',
'log_on_master'
]

271
utils/ascend_support.py Normal file
View File

@ -0,0 +1,271 @@
"""
Huawei Ascend NPU support utilities for BGE fine-tuning
This module provides compatibility layer for running BGE models on Ascend NPUs.
Requires torch_npu to be installed: pip install torch_npu
"""
import os
import logging
from typing import Optional, Dict, Any, Tuple
import torch
logger = logging.getLogger(__name__)
# Global flag to track Ascend availability
_ASCEND_AVAILABLE = False
_ASCEND_CHECK_DONE = False
def check_ascend_availability() -> bool:
"""
Check if Ascend NPU is available
Returns:
bool: True if Ascend NPU is available and properly configured
"""
global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE
if _ASCEND_CHECK_DONE:
return _ASCEND_AVAILABLE
_ASCEND_CHECK_DONE = True
try:
import torch_npu
# Check if NPU is available
if hasattr(torch, 'npu') and torch.npu.is_available():
_ASCEND_AVAILABLE = True
npu_count = torch.npu.device_count()
logger.info(f"Ascend NPU available with {npu_count} devices")
# Log NPU information
for i in range(npu_count):
props = torch.npu.get_device_properties(i)
logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB")
else:
logger.info("Ascend NPU not available")
except ImportError:
logger.debug("torch_npu not installed. Ascend NPU support disabled.")
except Exception as e:
logger.warning(f"Error checking Ascend availability: {e}")
return _ASCEND_AVAILABLE
def get_ascend_backend() -> str:
"""
Get the appropriate backend for Ascend NPU
Returns:
str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU
"""
if check_ascend_availability():
return 'hccl'
elif torch.cuda.is_available():
return 'nccl'
else:
return 'gloo'
def setup_ascend_device(local_rank: int = 0) -> torch.device:
"""
Setup Ascend NPU device
Args:
local_rank: Local rank for distributed training
Returns:
torch.device: Configured device
"""
if not check_ascend_availability():
raise RuntimeError("Ascend NPU not available")
try:
# Set NPU device
torch.npu.set_device(local_rank)
device = torch.device(f'npu:{local_rank}')
logger.info(f"Using Ascend NPU device: {device}")
return device
except Exception as e:
logger.error(f"Failed to setup Ascend device: {e}")
raise
def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module:
"""
Convert model for Ascend NPU compatibility
Args:
model: PyTorch model
Returns:
Model configured for Ascend
"""
if not check_ascend_availability():
return model
try:
import torch_npu
# Apply Ascend-specific optimizations
model = torch_npu.contrib.transfer_to_npu(model)
# Enable Ascend graph mode if available
if hasattr(torch_npu, 'enable_graph_mode'):
torch_npu.enable_graph_mode()
logger.info("Enabled Ascend graph mode")
return model
except Exception as e:
logger.warning(f"Failed to apply Ascend optimizations: {e}")
return model
def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]:
"""
Optimize configuration for Ascend NPU
Args:
config: Training configuration
Returns:
Optimized configuration
"""
if not check_ascend_availability():
return config
# Create a copy to avoid modifying original
config = config.copy()
# Ascend-specific optimizations
ascend_config = config.get('ascend', {})
# Set backend to hccl for distributed training
if ascend_config.get('backend') == 'hccl':
config['distributed'] = config.get('distributed', {})
config['distributed']['backend'] = 'hccl'
# Enable mixed precision if specified
if ascend_config.get('mixed_precision', False):
config['optimization'] = config.get('optimization', {})
config['optimization']['fp16'] = True
logger.info("Enabled mixed precision for Ascend")
# Set visible devices
if 'visible_devices' in ascend_config:
os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices']
logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}")
return config
def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]:
"""
Get Ascend NPU memory information
Args:
device_id: NPU device ID
Returns:
Tuple of (used_memory_gb, total_memory_gb)
"""
if not check_ascend_availability():
return 0.0, 0.0
try:
import torch_npu
# Get memory info
used = torch_npu.memory_allocated(device_id) / 1024**3
total = torch_npu.get_device_properties(device_id).total_memory / 1024**3
return used, total
except Exception as e:
logger.warning(f"Failed to get Ascend memory info: {e}")
return 0.0, 0.0
def clear_ascend_cache():
"""Clear Ascend NPU memory cache"""
if not check_ascend_availability():
return
try:
torch.npu.empty_cache()
logger.debug("Cleared Ascend NPU cache")
except Exception as e:
logger.warning(f"Failed to clear Ascend cache: {e}")
def is_ascend_available() -> bool:
"""
Simple check for Ascend availability
Returns:
bool: True if Ascend NPU is available
"""
return check_ascend_availability()
class AscendDataLoader:
"""
Wrapper for DataLoader with Ascend-specific optimizations
"""
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
def __iter__(self):
for batch in self.dataloader:
# Move batch to NPU device
if isinstance(batch, dict):
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
elif isinstance(batch, (list, tuple)):
batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v
for v in batch]
else:
batch = batch.to(self.device)
yield batch
def __len__(self):
return len(self.dataloader)
# Compatibility functions
def npu_synchronize():
"""Synchronize NPU device (compatibility wrapper)"""
if check_ascend_availability():
try:
torch.npu.synchronize()
except Exception as e:
logger.warning(f"NPU synchronize failed: {e}")
elif torch.cuda.is_available():
torch.cuda.synchronize()
def get_device_name(device: torch.device) -> str:
"""
Get device name with Ascend support
Args:
device: PyTorch device
Returns:
str: Device name
"""
if device.type == 'npu':
if check_ascend_availability():
try:
props = torch.npu.get_device_properties(device.index or 0)
return f"Ascend {props.name}"
except:
return "Ascend NPU"
return "NPU (unavailable)"
elif device.type == 'cuda':
return torch.cuda.get_device_name(device.index or 0)
else:
return str(device)

465
utils/checkpoint.py Normal file
View File

@ -0,0 +1,465 @@
import os
import torch
import json
import shutil
from typing import Dict, Any, Optional, List
import logging
from datetime import datetime
import glob
logger = logging.getLogger(__name__)
class CheckpointManager:
"""
Manages model checkpoints with automatic cleanup and best model tracking
"""
def __init__(
self,
output_dir: str,
save_total_limit: int = 3,
best_model_dir: str = "best_model",
save_optimizer: bool = True,
save_scheduler: bool = True
):
self.output_dir = output_dir
self.save_total_limit = save_total_limit
self.best_model_dir = os.path.join(output_dir, best_model_dir)
self.save_optimizer = save_optimizer
self.save_scheduler = save_scheduler
# Create directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(self.best_model_dir, exist_ok=True)
# Track checkpoints
self.checkpoints = []
self._load_existing_checkpoints()
def _load_existing_checkpoints(self):
"""Load list of existing checkpoints"""
checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
# Sort by step number
self.checkpoints = []
for cp_dir in checkpoint_dirs:
try:
step = int(cp_dir.split("-")[-1])
self.checkpoints.append((step, cp_dir))
except:
pass
self.checkpoints.sort(key=lambda x: x[0])
def save(
self,
state_dict: Dict[str, Any],
step: int,
is_best: bool = False,
metrics: Optional[Dict[str, float]] = None
) -> str:
"""
Save checkpoint with atomic writes and error handling
Args:
state_dict: State dictionary containing model, optimizer, etc.
step: Current training step
is_best: Whether this is the best model so far
metrics: Optional metrics to save with checkpoint
Returns:
Path to saved checkpoint
"""
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp")
try:
# Create temporary directory for atomic write
os.makedirs(temp_dir, exist_ok=True)
# Save individual components
checkpoint_files = {}
# Save model with error handling
if 'model_state_dict' in state_dict:
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
self._safe_torch_save(state_dict['model_state_dict'], model_path)
checkpoint_files['model'] = 'pytorch_model.bin'
# Save optimizer
if self.save_optimizer and 'optimizer_state_dict' in state_dict:
optimizer_path = os.path.join(temp_dir, 'optimizer.pt')
self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path)
checkpoint_files['optimizer'] = 'optimizer.pt'
# Save scheduler
if self.save_scheduler and 'scheduler_state_dict' in state_dict:
scheduler_path = os.path.join(temp_dir, 'scheduler.pt')
self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path)
checkpoint_files['scheduler'] = 'scheduler.pt'
# Save training state
training_state = {
'step': step,
'checkpoint_files': checkpoint_files,
'timestamp': datetime.now().isoformat()
}
# Add additional state info
for key in ['global_step', 'current_epoch', 'best_metric']:
if key in state_dict:
training_state[key] = state_dict[key]
if metrics:
training_state['metrics'] = metrics
# Save config if provided
if 'config' in state_dict:
config_path = os.path.join(temp_dir, 'config.json')
self._safe_json_save(state_dict['config'], config_path)
# Save training state
state_path = os.path.join(temp_dir, 'training_state.json')
self._safe_json_save(training_state, state_path)
# Atomic rename from temp to final directory
if os.path.exists(checkpoint_dir):
# Backup existing checkpoint
backup_dir = f"{checkpoint_dir}.backup"
if os.path.exists(backup_dir):
shutil.rmtree(backup_dir)
shutil.move(checkpoint_dir, backup_dir)
# Move temp directory to final location
shutil.move(temp_dir, checkpoint_dir)
# Remove backup if successful
if os.path.exists(f"{checkpoint_dir}.backup"):
shutil.rmtree(f"{checkpoint_dir}.backup")
# Add to checkpoint list
self.checkpoints.append((step, checkpoint_dir))
# Save best model
if is_best:
self._save_best_model(checkpoint_dir, step, metrics)
# Cleanup old checkpoints
self._cleanup_checkpoints()
logger.info(f"Saved checkpoint to {checkpoint_dir}")
return checkpoint_dir
except Exception as e:
logger.error(f"Failed to save checkpoint: {e}")
# Cleanup temp directory
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Restore backup if exists
backup_dir = f"{checkpoint_dir}.backup"
if os.path.exists(backup_dir):
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
shutil.move(backup_dir, checkpoint_dir)
raise RuntimeError(f"Checkpoint save failed: {e}")
def _save_best_model(
self,
checkpoint_dir: str,
step: int,
metrics: Optional[Dict[str, float]] = None
):
"""Save the best model"""
# Copy checkpoint to best model directory
for filename in os.listdir(checkpoint_dir):
src = os.path.join(checkpoint_dir, filename)
dst = os.path.join(self.best_model_dir, filename)
if os.path.isfile(src):
shutil.copy2(src, dst)
# Save best model info
best_info = {
'step': step,
'source_checkpoint': checkpoint_dir,
'timestamp': datetime.now().isoformat()
}
if metrics:
best_info['metrics'] = metrics
best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json')
with open(best_info_path, 'w') as f:
json.dump(best_info, f, indent=2)
logger.info(f"Saved best model from step {step}")
def _cleanup_checkpoints(self):
"""Remove old checkpoints to maintain save_total_limit"""
if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit:
# Remove oldest checkpoints
checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
for step, checkpoint_dir in checkpoints_to_remove:
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
logger.info(f"Removed old checkpoint: {checkpoint_dir}")
# Update checkpoint list
self.checkpoints = self.checkpoints[-self.save_total_limit:]
def load_checkpoint(
self,
checkpoint_path: Optional[str] = None,
load_best: bool = False,
map_location: str = 'cpu',
strict: bool = True
) -> Dict[str, Any]:
"""
Load checkpoint with robust error handling
Args:
checkpoint_path: Path to specific checkpoint
load_best: Whether to load best model
map_location: Device to map tensors to
strict: Whether to strictly enforce state dict matching
Returns:
Loaded state dictionary
Raises:
RuntimeError: If checkpoint loading fails
"""
try:
if load_best:
checkpoint_path = self.best_model_dir
elif checkpoint_path is None:
# Load latest checkpoint
if self.checkpoints:
checkpoint_path = self.checkpoints[-1][1]
else:
raise ValueError("No checkpoints found")
# Ensure checkpoint_path is a string
if checkpoint_path is None:
raise ValueError("Checkpoint path is None")
checkpoint_path = str(checkpoint_path) # Ensure it's a string
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from {checkpoint_path}")
# Load training state with error handling
state_dict = {}
state_path = os.path.join(checkpoint_path, 'training_state.json')
if os.path.exists(state_path):
try:
with open(state_path, 'r') as f:
training_state = json.load(f)
state_dict.update(training_state)
except Exception as e:
logger.warning(f"Failed to load training state: {e}")
training_state = {}
else:
training_state = {}
# Load model state
model_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
if os.path.exists(model_path):
try:
state_dict['model_state_dict'] = torch.load(
model_path,
map_location=map_location,
weights_only=True # Security: prevent arbitrary code execution
)
except Exception as e:
# Try legacy loading without weights_only for backward compatibility
try:
state_dict['model_state_dict'] = torch.load(
model_path,
map_location=map_location
)
logger.warning("Loaded checkpoint using legacy mode (weights_only=False)")
except Exception as legacy_e:
raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}")
else:
logger.warning(f"Model state not found at {model_path}")
# Load optimizer state
optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
if os.path.exists(optimizer_path):
try:
state_dict['optimizer_state_dict'] = torch.load(
optimizer_path,
map_location=map_location,
weights_only=True
)
except Exception as e:
# Try legacy loading
try:
state_dict['optimizer_state_dict'] = torch.load(
optimizer_path,
map_location=map_location
)
except Exception as legacy_e:
logger.warning(f"Failed to load optimizer state: {e}")
# Load scheduler state
scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
if os.path.exists(scheduler_path):
try:
state_dict['scheduler_state_dict'] = torch.load(
scheduler_path,
map_location=map_location,
weights_only=True
)
except Exception as e:
# Try legacy loading
try:
state_dict['scheduler_state_dict'] = torch.load(
scheduler_path,
map_location=map_location
)
except Exception as legacy_e:
logger.warning(f"Failed to load scheduler state: {e}")
# Load config if available
config_path = os.path.join(checkpoint_path, 'config.json')
if os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
state_dict['config'] = json.load(f)
except Exception as e:
logger.warning(f"Failed to load config: {e}")
logger.info(f"Successfully loaded checkpoint from {checkpoint_path}")
return state_dict
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise RuntimeError(f"Checkpoint loading failed: {e}")
def get_latest_checkpoint(self) -> Optional[str]:
"""Get path to latest checkpoint"""
if self.checkpoints:
return self.checkpoints[-1][1]
return None
def get_best_checkpoint(self) -> str:
"""Get path to best model checkpoint"""
return self.best_model_dir
def checkpoint_exists(self, step: int) -> bool:
"""Check if checkpoint exists for given step"""
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
return os.path.exists(checkpoint_dir)
def _safe_torch_save(self, obj: Any, path: str):
"""Save torch object with error handling"""
temp_path = f"{path}.tmp"
try:
torch.save(obj, temp_path)
# Atomic rename
os.replace(temp_path, path)
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
raise RuntimeError(f"Failed to save {path}: {e}")
def _safe_json_save(self, obj: Any, path: str):
"""Save JSON with atomic write"""
temp_path = f"{path}.tmp"
try:
with open(temp_path, 'w') as f:
json.dump(obj, f, indent=2)
# Atomic rename
os.replace(temp_path, path)
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
raise RuntimeError(f"Failed to save {path}: {e}")
def save_model_for_inference(
model: torch.nn.Module,
tokenizer: Any,
output_dir: str,
model_name: str = "model",
save_config: bool = True,
additional_files: Optional[Dict[str, str]] = None
):
"""
Save model in a format ready for inference
Args:
model: The model to save
tokenizer: Tokenizer to save
output_dir: Output directory
model_name: Name for the model
save_config: Whether to save model config
additional_files: Additional files to include
"""
os.makedirs(output_dir, exist_ok=True)
# Save model
model_path = os.path.join(output_dir, f"{model_name}.pt")
torch.save(model.state_dict(), model_path)
logger.info(f"Saved model to {model_path}")
# Save tokenizer
if hasattr(tokenizer, 'save_pretrained'):
tokenizer.save_pretrained(output_dir)
logger.info(f"Saved tokenizer to {output_dir}")
# Save config if available
if save_config and hasattr(model, 'config'):
config_path = os.path.join(output_dir, 'config.json')
model.config.save_pretrained(output_dir) # type: ignore[attr-defined]
# Save additional files
if additional_files:
for filename, content in additional_files.items():
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w') as f:
if isinstance(content, dict):
json.dump(content, f, indent=2)
else:
f.write(str(content))
# Create README
readme_content = f"""# {model_name} Model
This directory contains a fine-tuned model ready for inference.
## Files:
- `{model_name}.pt`: Model weights
- `tokenizer_config.json`: Tokenizer configuration
- `special_tokens_map.json`: Special tokens mapping
- `vocab.txt` or `tokenizer.json`: Vocabulary file
## Loading the model:
```python
import torch
from transformers import AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('{output_dir}')
# Load model
model = YourModelClass() # Initialize your model class
model.load_state_dict(torch.load('{model_name}.pt'))
model.eval()
```
## Model Information:
- Model type: {model.__class__.__name__}
- Saved on: {datetime.now().isoformat()}
"""
readme_path = os.path.join(output_dir, 'README.md')
with open(readme_path, 'w') as f:
f.write(readme_content)
logger.info(f"Model saved for inference in {output_dir}")

518
utils/config_loader.py Normal file
View File

@ -0,0 +1,518 @@
"""
Centralized configuration loader for BGE fine-tuning project
"""
import os
import toml
import logging
from typing import Dict, Any, Optional, Tuple
from pathlib import Path
logger = logging.getLogger(__name__)
def resolve_model_path(
model_key: str,
config_instance: 'ConfigLoader',
model_name_or_path: Optional[str] = None,
cache_dir: Optional[str] = None
) -> Tuple[str, str]:
"""
Resolve model path with automatic local/remote fallback
Args:
model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker')
config_instance: ConfigLoader instance
model_name_or_path: Optional override from CLI
cache_dir: Optional cache directory
Returns:
Tuple of (resolved_model_path, resolved_cache_dir)
"""
# Default remote model names for each model type
remote_models = {
'bge_m3': 'BAAI/bge-m3',
'bge_reranker': 'BAAI/bge-reranker-base',
}
# Get model source (huggingface or modelscope)
model_source = config_instance.get('model_paths', 'source', 'huggingface')
# If ModelScope is used, adjust model names
if model_source == 'modelscope':
remote_models = {
'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base',
'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base',
}
# 1. Check CLI override first (only if explicitly provided)
if model_name_or_path is not None:
resolved_model_path = model_name_or_path
logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}")
else:
# 2. Check local path from config
local_path_raw = config_instance.get('model_paths', model_key, None)
# Handle list/string conversion
if isinstance(local_path_raw, list):
local_path = local_path_raw[0] if local_path_raw else None
else:
local_path = local_path_raw
if local_path and isinstance(local_path, str) and os.path.exists(local_path):
# Check if it's a valid model directory
if _is_valid_model_directory(local_path):
resolved_model_path = local_path
logger.info(f"[LOCAL] Using local model: {resolved_model_path}")
else:
logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote")
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
else:
# 3. Fall back to remote model
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
logger.info(f"[REMOTE] Using remote model: {resolved_model_path}")
# If local path was configured but doesn't exist, we'll download to that location
if local_path and isinstance(local_path, str):
logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}")
# Create directory if it doesn't exist
os.makedirs(local_path, exist_ok=True)
# Resolve cache directory
if cache_dir is not None:
resolved_cache_dir = cache_dir
else:
cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data')
if isinstance(cache_dir_raw, list):
resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data'
else:
resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data'
# Ensure cache directory exists
if isinstance(resolved_cache_dir, str):
os.makedirs(resolved_cache_dir, exist_ok=True)
return resolved_model_path, resolved_cache_dir
def _is_valid_model_directory(path: str) -> bool:
"""Check if a directory contains a valid model"""
required_files = ['config.json']
optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json']
model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json']
path_obj = Path(path)
if not path_obj.exists() or not path_obj.is_dir():
return False
# Check for required files
for file in required_files:
if not (path_obj / file).exists():
return False
# Check for at least one model file
has_model_file = any((path_obj / file).exists() for file in model_files)
if not has_model_file:
return False
# Check for at least one tokenizer file (for embedding models)
has_tokenizer_file = any((path_obj / file).exists() for file in optional_files)
if not has_tokenizer_file:
logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway")
return True
def download_model_if_needed(
model_name_or_path: str,
local_path: Optional[str] = None,
cache_dir: Optional[str] = None,
model_source: str = 'huggingface'
) -> str:
"""
Download model if it's a remote identifier and local path is specified
Args:
model_name_or_path: Model identifier or path
local_path: Target local path for download
cache_dir: Cache directory
model_source: 'huggingface' or 'modelscope'
Returns:
Final model path (local if downloaded, original if already local)
"""
# If it's already a local path, return as is
if os.path.exists(model_name_or_path):
return model_name_or_path
# If no local path specified, return original (will be handled by transformers)
if not local_path:
return model_name_or_path
# Check if local path already exists and is valid
if os.path.exists(local_path) and _is_valid_model_directory(local_path):
logger.info(f"Model already exists locally: {local_path}")
return local_path
# Download the model
try:
logger.info(f"Downloading model {model_name_or_path} to {local_path}")
if model_source == 'huggingface':
from transformers import AutoModel, AutoTokenizer
# Download model and tokenizer
model = AutoModel.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
# Save to local path
os.makedirs(local_path, exist_ok=True)
model.save_pretrained(local_path)
tokenizer.save_pretrained(local_path)
logger.info(f"Model downloaded successfully to {local_path}")
return local_path
elif model_source == 'modelscope':
try:
from modelscope import snapshot_download # type: ignore
# Download using ModelScope
downloaded_path = snapshot_download(
model_name_or_path,
cache_dir=cache_dir,
local_dir=local_path
)
logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}")
return downloaded_path
except ImportError:
logger.warning("ModelScope not available, falling back to HuggingFace")
return download_model_if_needed(
model_name_or_path, local_path, cache_dir, 'huggingface'
)
except Exception as e:
logger.error(f"Failed to download model {model_name_or_path}: {e}")
logger.info("Returning original model path - transformers will handle caching")
# Fallback return
return model_name_or_path
class ConfigLoader:
"""Centralized configuration management for BGE fine-tuning
Parameter Precedence (highest to lowest):
1. Command-line arguments (CLI)
2. Model-specific config ([m3], [reranker])
3. [hardware] section in config.toml
4. [training] section in config.toml
5. Hardcoded defaults in code
"""
_instance = None
_config = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ConfigLoader, cls).__new__(cls)
return cls._instance
def __init__(self):
if self._config is None:
self._load_config()
def _find_config_file(self) -> str:
"""Find config.toml file in project hierarchy"""
# Check environment variable FIRST (highest priority)
if 'BGE_CONFIG_PATH' in os.environ:
config_path = Path(os.environ['BGE_CONFIG_PATH'])
if config_path.exists():
return str(config_path)
else:
logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default")
current_dir = Path(__file__).parent
# Search up the directory tree for config.toml
for _ in range(5): # Limit search depth
config_path = current_dir / "config.toml"
if config_path.exists():
return str(config_path)
current_dir = current_dir.parent
# Default location
return os.path.join(os.getcwd(), "config.toml")
def _load_config(self):
"""Load configuration from TOML file"""
config_path = self._find_config_file()
try:
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
loaded = toml.load(f)
if loaded is None:
loaded = {}
self._config = dict(loaded)
logger.info(f"Loaded configuration from {config_path}")
else:
logger.warning(f"Config file not found at {config_path}, using defaults")
self._config = self._get_default_config()
self._save_default_config(config_path)
logger.warning("Default config created. Please check your configuration file.")
except Exception as e:
logger.error(f"Error loading config: {e}")
self._config = self._get_default_config()
# Validation for unset/placeholder values
self._validate_config()
def _validate_config(self):
"""Validate config for unset/placeholder values"""
api_key = self.get('api', 'api_key', None)
if not api_key or api_key == 'your-api-key-here':
logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.")
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
'api': {
'base_url': 'https://api.deepseek.com/v1',
'api_key': 'your-api-key-here',
'model': 'deepseek-chat',
'max_retries': 3,
'timeout': 30,
'temperature': 0.7
},
'training': {
'default_batch_size': 4,
'default_learning_rate': 1e-5,
'default_num_epochs': 3,
'default_warmup_ratio': 0.1,
'default_weight_decay': 0.01,
'gradient_accumulation_steps': 4
},
'model_paths': {
'bge_m3': './models/bge-m3',
'bge_reranker': './models/bge-reranker-base',
'cache_dir': './cache/data'
},
'data': {
'max_query_length': 64,
'max_passage_length': 512,
'max_seq_length': 512,
'validation_split': 0.1,
'random_seed': 42
},
'augmentation': {
'enable_query_augmentation': False,
'enable_passage_augmentation': False,
'augmentation_methods': ['paraphrase', 'back_translation'],
'num_augmentations_per_sample': 2,
'augmentation_temperature': 0.8
},
'hard_negative_mining': {
'enable_mining': True,
'num_hard_negatives': 15,
'negative_range_min': 10,
'negative_range_max': 100,
'margin': 0.1,
'index_refresh_interval': 1000,
'index_type': 'IVF',
'nlist': 100
},
'evaluation': {
'eval_batch_size': 32,
'k_values': [1, 3, 5, 10, 20, 50, 100],
'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'],
'save_predictions': True,
'predictions_dir': './output/predictions'
},
'logging': {
'log_level': 'INFO',
'log_to_file': True,
'log_dir': './logs',
'tensorboard_dir': './logs/tensorboard',
'wandb_project': 'bge-finetuning',
'wandb_entity': '',
'use_wandb': False
},
'distributed': {
'backend': 'nccl',
'init_method': 'env://',
'find_unused_parameters': True
},
'optimization': {
'gradient_checkpointing': True,
'fp16': True,
'bf16': False,
'fp16_opt_level': 'O1',
'max_grad_norm': 1.0,
'adam_epsilon': 1e-8,
'adam_beta1': 0.9,
'adam_beta2': 0.999
},
'hardware': {
'cuda_visible_devices': '0,1,2,3',
'dataloader_num_workers': 4,
'dataloader_pin_memory': True,
'per_device_train_batch_size': 4,
'per_device_eval_batch_size': 8
},
'checkpoint': {
'save_strategy': 'steps',
'save_steps': 1000,
'save_total_limit': 3,
'load_best_model_at_end': True,
'metric_for_best_model': 'eval_recall@10',
'greater_is_better': True,
'resume_from_checkpoint': ''
},
'export': {
'export_format': ['pytorch', 'onnx'],
'optimize_for_inference': True,
'quantize': False,
'quantization_bits': 8
}
}
def _save_default_config(self, config_path: str):
"""Save default configuration to file"""
try:
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, 'w', encoding='utf-8') as f:
config_to_save = self._config if isinstance(self._config, dict) else {}
toml.dump(config_to_save, f)
logger.info(f"Created default config at {config_path}")
except Exception as e:
logger.error(f"Error saving default config: {e}")
def get(self, section: str, key: str, default=None):
"""Get a config value, logging its source (CLI, config, default)."""
# CLI overrides are handled outside this loader; here we check config and default
value = default
source = 'default'
if isinstance(self._config, dict) and section in self._config and key in self._config[section]:
value = self._config[section][key]
source = 'config.toml'
# Robust list/tuple conversion
if isinstance(value, str) and (',' in value or ';' in value):
try:
value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()]
source += ' (parsed as list)'
except Exception as e:
logging.warning(f"Failed to parse list for {section}.{key}: {e}")
log_func = logging.info if source != 'default' else logging.warning
log_func(f"Config param {section}.{key} = {value} (source: {source})")
return value
def get_section(self, section: str) -> Dict[str, Any]:
"""Get entire configuration section"""
if isinstance(self._config, dict):
return self._config.get(section, {})
return {}
def update(self, key: str, value: Any):
"""Update configuration value. Logs a warning if overriding an existing value."""
if self._config is None:
self._config = {}
keys = key.split('.')
config = self._config
# Navigate to the parent of the final key
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
# Log if overriding
old_value = config.get(keys[-1], None)
if old_value is not None and old_value != value:
logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)")
config[keys[-1]] = value
def reload(self):
"""Reload configuration from file"""
self._config = None
self._load_config()
# Global instance
config = ConfigLoader()
def get_config() -> ConfigLoader:
"""Get global configuration instance"""
return config
def load_config_for_training(
config_path: Optional[str] = None,
override_params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Load configuration specifically for training scripts
Args:
config_path: Optional path to config file
override_params: Parameters to override from command line
Returns:
Dictionary with all configuration parameters
"""
global config
# Reload with specific path if provided
if config_path:
old_path = os.environ.get('BGE_CONFIG_PATH')
os.environ['BGE_CONFIG_PATH'] = config_path
config.reload()
if old_path:
os.environ['BGE_CONFIG_PATH'] = old_path
else:
del os.environ['BGE_CONFIG_PATH']
# Get all configuration
training_config = {
'api': config.get_section('api'),
'training': config.get_section('training'),
'model_paths': config.get_section('model_paths'),
'data': config.get_section('data'),
'augmentation': config.get_section('augmentation'),
'hard_negative_mining': config.get_section('hard_negative_mining'),
'evaluation': config.get_section('evaluation'),
'logging': config.get_section('logging'),
'distributed': config.get_section('distributed'),
'optimization': config.get_section('optimization'),
'hardware': config.get_section('hardware'),
'checkpoint': config.get_section('checkpoint'),
'export': config.get_section('export'),
'm3': config.get_section('m3'),
'reranker': config.get_section('reranker'),
}
# Apply overrides if provided
if override_params:
for key, value in override_params.items():
config.update(key, value)
# Also update the returned dict
keys = key.split('.')
current = training_config
for k in keys[:-1]:
if k not in current:
current[k] = {}
current = current[k]
current[keys[-1]] = value
return training_config

425
utils/distributed.py Normal file
View File

@ -0,0 +1,425 @@
import os
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import logging
from typing import Optional, Any, List
import random
import numpy as np
from .config_loader import get_config
# Import Ascend support utilities
try:
from .ascend_support import (
check_ascend_availability,
get_ascend_backend,
setup_ascend_device,
clear_ascend_cache
)
ASCEND_SUPPORT_AVAILABLE = True
except ImportError:
ASCEND_SUPPORT_AVAILABLE = False
logger = logging.getLogger(__name__)
def setup_distributed(config=None):
"""Setup distributed training environment with proper error handling."""
try:
# Check if we're in a distributed environment
world_size = int(os.environ.get('WORLD_SIZE', 1))
rank = int(os.environ.get('RANK', 0))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
if world_size == 1:
logger.info("Single GPU/CPU training mode")
return
# Determine backend based on hardware availability
# First check for Ascend NPU support
if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability():
backend = 'hccl'
try:
setup_ascend_device(local_rank)
logger.info("Using Ascend NPU with HCCL backend")
except Exception as e:
logger.warning(f"Failed to setup Ascend device: {e}, falling back")
backend = 'gloo'
elif torch.cuda.is_available():
backend = 'nccl'
# Set CUDA device
torch.cuda.set_device(local_rank)
else:
backend = 'gloo'
# Override backend from config if provided
if config is not None:
backend = config.get('distributed.backend', backend)
# Initialize process group
init_method = os.environ.get('MASTER_ADDR', None)
if init_method:
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}"
else:
init_method = 'env://'
dist.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank
)
logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}")
except Exception as e:
logger.warning(f"Failed to initialize distributed training: {e}")
logger.info("Falling back to single GPU/CPU training")
# Set environment variables to indicate single GPU mode
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
def cleanup_distributed():
"""Cleanup distributed training environment with error handling."""
try:
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Distributed process group destroyed.")
except Exception as e:
logger.warning(f"Error during distributed cleanup: {e}")
# Continue execution even if cleanup fails
def is_main_process() -> bool:
"""Check if this is the main process"""
if not dist.is_initialized():
return True
return dist.get_rank() == 0
def get_rank() -> int:
"""Get current process rank"""
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_world_size() -> int:
"""Get world size"""
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_local_rank() -> int:
"""Get local rank (GPU index on current node)"""
if not dist.is_initialized():
return 0
return int(os.environ.get('LOCAL_RANK', 0))
def synchronize():
"""Synchronize all processes"""
if dist.is_initialized():
dist.barrier()
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
"""
All-reduce a tensor across all processes
Args:
tensor: Tensor to reduce
op: Reduction operation
Returns:
Reduced tensor
"""
if not dist.is_initialized():
return tensor
dist.all_reduce(tensor, op=op)
return tensor
def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
"""
All-gather tensors from all processes
Args:
tensor: Tensor to gather
Returns:
List of tensors from all processes
"""
if not dist.is_initialized():
return [tensor]
world_size = get_world_size()
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensors, tensor)
return tensors
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""
Broadcast tensor from source to all processes
Args:
tensor: Tensor to broadcast
src: Source rank
Returns:
Broadcasted tensor
"""
if not dist.is_initialized():
return tensor
dist.broadcast(tensor, src=src)
return tensor
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
"""
Reduce a dictionary of values across all processes
Args:
input_dict: Dictionary to reduce
average: Whether to average the values
Returns:
Reduced dictionary
"""
if not dist.is_initialized():
return input_dict
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v.item() for k, v in zip(names, values)}
return reduced_dict
def setup_model_for_distributed(
model: torch.nn.Module,
device_ids: Optional[List[int]] = None,
output_device: Optional[int] = None,
find_unused_parameters: bool = True,
broadcast_buffers: bool = True
) -> torch.nn.Module:
"""
Wrap model for distributed training
Args:
model: Model to wrap
device_ids: GPU indices to use
output_device: Device where outputs are gathered
find_unused_parameters: Whether to find unused parameters
broadcast_buffers: Whether to broadcast buffers
Returns:
Distributed model
"""
if not dist.is_initialized():
return model
if device_ids is None:
device_ids = [get_local_rank()]
if output_device is None:
output_device = device_ids[0]
model = DDP(
model,
device_ids=device_ids,
output_device=output_device,
find_unused_parameters=find_unused_parameters,
broadcast_buffers=broadcast_buffers
)
logger.info(f"Wrapped model for distributed training on devices {device_ids}")
return model
def create_distributed_dataloader(
dataset,
batch_size: int,
shuffle: bool = True,
num_workers: int = 0,
pin_memory: bool = True,
drop_last: bool = False,
seed: int = 42
) -> DataLoader:
"""
Create a distributed dataloader
Args:
dataset: Dataset to load
batch_size: Batch size per GPU
shuffle: Whether to shuffle
num_workers: Number of data loading workers
pin_memory: Whether to pin memory
drop_last: Whether to drop last incomplete batch
seed: Random seed for shuffling
Returns:
Distributed dataloader
"""
if dist.is_initialized():
sampler = DistributedSampler(
dataset,
num_replicas=get_world_size(),
rank=get_rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last
)
shuffle = False # Sampler handles shuffling
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last
)
return dataloader
def set_random_seed(seed: int, deterministic: bool = False):
"""
Set random seed for reproducibility
Args:
seed: Random seed
deterministic: Whether to use deterministic algorithms
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic and torch.cuda.is_available():
# Configure cuDNN determinism settings
cudnn.deterministic = True
cudnn.benchmark = False
logger.info(f"Set random seed to {seed} (deterministic={deterministic})")
def print_distributed_info():
"""Print distributed training information"""
if not dist.is_initialized():
logger.info("Distributed training not initialized")
return
logger.info(f"Distributed Info:")
logger.info(f" Backend: {dist.get_backend()}")
logger.info(f" World size: {get_world_size()}")
logger.info(f" Rank: {get_rank()}")
logger.info(f" Local rank: {get_local_rank()}")
logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}")
logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}")
class DistributedMetricAggregator:
"""Helper class to aggregate metrics across distributed processes"""
def __init__(self):
self.metrics = {}
def update(self, **kwargs):
"""Update metrics"""
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.detach()
else:
v = torch.tensor(v, dtype=torch.float32)
if torch.cuda.is_available():
v = v.cuda()
self.metrics[k] = v
def aggregate(self, average: bool = True) -> dict:
"""Aggregate metrics across all processes"""
if not dist.is_initialized():
return {k: v.item() if isinstance(v, torch.Tensor) else v
for k, v in self.metrics.items()}
aggregated = {}
world_size = get_world_size()
for k, v in self.metrics.items():
# All-reduce
all_reduce(v)
# Average if requested
if average:
v = v / world_size
aggregated[k] = v.item()
return aggregated
def reset(self):
"""Reset metrics"""
self.metrics = {}
def save_on_master(*args, **kwargs):
"""Decorator to only save on master process"""
def decorator(func):
def wrapper(*args, **kwargs):
if is_main_process():
return func(*args, **kwargs)
else:
# Return None on non-master processes
return None
return wrapper
return decorator
def log_on_master(logger: logging.Logger):
"""Get a logger that only logs on master process"""
class MasterOnlyLogger:
def __init__(self, logger):
self.logger = logger
def __getattr__(self, name):
if is_main_process():
return getattr(self.logger, name)
else:
return lambda *args, **kwargs: None
return MasterOnlyLogger(logger)

344
utils/logging.py Normal file
View File

@ -0,0 +1,344 @@
import logging
import os
import sys
from datetime import datetime
from typing import Optional, Dict, Any
import json
import torch
from tqdm import tqdm
import numpy as np
def create_progress_bar(
iterable,
desc: str = "",
total: Optional[int] = None,
disable: bool = False,
position: int = 0,
leave: bool = True,
ncols: Optional[int] = None,
unit: str = "it",
dynamic_ncols: bool = True,
):
"""
Create a tqdm progress bar with consistent settings across the project.
Args:
iterable: Iterable to wrap with progress bar
desc: Description for the progress bar
total: Total number of iterations (if None, will try to get from iterable)
disable: Whether to disable the progress bar
position: Position of the progress bar (for multiple bars)
leave: Whether to leave the progress bar after completion
ncols: Width of the progress bar
unit: Unit of iteration (e.g., "batch", "step")
dynamic_ncols: Automatically adjust width based on terminal size
Returns:
tqdm progress bar object
"""
# Check if we're in distributed training
try:
import torch.distributed as dist
if dist.is_initialized() and dist.get_rank() != 0:
# Disable progress bar for non-main processes
disable = True
except:
pass
# Create progress bar with consistent settings
return tqdm(
iterable,
desc=desc,
total=total,
disable=disable,
position=position,
leave=leave,
ncols=ncols,
unit=unit,
dynamic_ncols=dynamic_ncols,
# Use custom write method to avoid conflicts with logging
file=sys.stdout,
# Ensure smooth updates
smoothing=0.1,
# Format with consistent style
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
)
class TqdmLoggingHandler(logging.Handler):
"""Custom logging handler that plays nicely with tqdm progress bars"""
def emit(self, record):
try:
msg = self.format(record)
# Use tqdm.write to avoid breaking progress bars
tqdm.write(msg, file=sys.stdout)
self.flush()
except Exception:
self.handleError(record)
def setup_logging(
output_dir: Optional[str] = None,
log_to_file: bool = False,
log_level: int = logging.INFO,
log_to_console: bool = True,
use_tqdm_handler: bool = True,
rank: int = 0,
world_size: int = 1,
):
"""
Setup logging with optional file and console handlers.
Args:
output_dir: Directory for log files
log_to_file: Whether to log to file
log_level: Logging level
log_to_console: Whether to log to console
use_tqdm_handler: Use tqdm-compatible handler for console logging
rank: Process rank for distributed training
world_size: World size for distributed training
"""
handlers: list[logging.Handler] = []
log_path = None
# Only log to console from main process in distributed training
if log_to_console and (world_size == 1 or rank == 0):
if use_tqdm_handler:
# Use tqdm-compatible handler
console_handler = TqdmLoggingHandler()
else:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
handlers.append(console_handler)
if log_to_file:
if output_dir is None:
output_dir = "./logs"
os.makedirs(output_dir, exist_ok=True)
# Add rank to filename for distributed training
if world_size > 1:
log_filename = f"train_rank{rank}.log"
else:
log_filename = "train.log"
log_path = os.path.join(output_dir, log_filename)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(log_level)
handlers.append(file_handler)
# Create formatter
if world_size > 1:
# Include rank in log format for distributed training
formatter = logging.Formatter(
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
else:
formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Apply formatter to all handlers
for handler in handlers:
handler.setFormatter(formatter)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# Remove existing handlers to avoid duplicates
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Add new handlers
for handler in handlers:
root_logger.addHandler(handler)
# Log initialization message
if rank == 0:
logging.info(
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
f"Log file: {log_path if log_to_file else 'none'} | "
f"World size: {world_size}"
)
def get_logger(name: str) -> logging.Logger:
"""Get a logger with the specified name"""
return logging.getLogger(name)
class MetricsLogger:
"""Logger specifically for training metrics"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
os.makedirs(output_dir, exist_ok=True)
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
"""Log metrics to file"""
entry = {
'step': step,
'phase': phase,
'timestamp': datetime.now().isoformat(),
'metrics': metrics
}
with open(self.metrics_file, 'a') as f:
f.write(json.dumps(entry) + '\n')
def log_config(self, config: Dict[str, Any]):
"""Log configuration"""
config_file = os.path.join(self.output_dir, 'config.json')
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
def log_summary(self, summary: Dict[str, Any]):
"""Log training summary"""
summary_file = os.path.join(self.output_dir, 'summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
class TensorBoardLogger:
"""TensorBoard logging wrapper"""
def __init__(self, log_dir: str, enabled: bool = True):
self.enabled = enabled
if self.enabled:
try:
from torch.utils.tensorboard.writer import SummaryWriter
self.writer = SummaryWriter(log_dir)
except ImportError:
get_logger(__name__).warning(
"TensorBoard not available. Install with: pip install tensorboard"
)
self.enabled = False
def log_scalar(self, tag: str, value: float, step: int):
"""Log a scalar value"""
if self.enabled:
self.writer.add_scalar(tag, value, step)
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
"""Log multiple scalars"""
if self.enabled:
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
def log_histogram(self, tag: str, values: np.ndarray, step: int):
"""Log a histogram"""
if self.enabled:
self.writer.add_histogram(tag, values, step)
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
"""Log embeddings for visualization"""
if self.enabled:
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
def flush(self):
"""Flush the writer"""
if self.enabled:
self.writer.flush()
def close(self):
"""Close the writer"""
if self.enabled:
self.writer.close()
class ProgressLogger:
"""Helper for logging training progress"""
def __init__(self, total_steps: int, log_interval: int = 100):
self.total_steps = total_steps
self.log_interval = log_interval
self.logger = get_logger(__name__)
self.start_time = datetime.now()
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
"""Log progress for a training step"""
if step % self.log_interval == 0:
elapsed = (datetime.now() - self.start_time).total_seconds()
steps_per_sec = step / elapsed if elapsed > 0 else 0
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
if metrics:
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
msg += f" | {metric_str}"
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
self.logger.info(msg)
def _format_time(self, seconds: float) -> str:
"""Format seconds to human readable time"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes}m {seconds}s"
elif minutes > 0:
return f"{minutes}m {seconds}s"
else:
return f"{seconds}s"
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
"""Log model information"""
if logger is None:
logger = get_logger(__name__)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model: {model.__class__.__name__}")
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
# Log parameter shapes
logger.debug("Parameter shapes:")
for name, param in model.named_parameters():
logger.debug(f" {name}: {list(param.shape)}")
def log_training_summary(
config: Dict[str, Any],
metrics: Dict[str, float],
elapsed_time: float,
output_dir: str
):
"""Log comprehensive training summary"""
logger = get_logger(__name__)
summary = {
'config': config,
'final_metrics': metrics,
'training_time': elapsed_time,
'timestamp': datetime.now().isoformat()
}
# Save to file
summary_file = os.path.join(output_dir, 'training_summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
# Log to console
logger.info("=" * 50)
logger.info("TRAINING SUMMARY")
logger.info("=" * 50)
logger.info(f"Training time: {elapsed_time:.2f} seconds")
logger.info("Final metrics:")
for metric, value in metrics.items():
logger.info(f" {metric}: {value:.4f}")
logger.info(f"Summary saved to: {summary_file}")
logger.info("=" * 50)