init commit
This commit is contained in:
commit
36003b83e2
166
.gitignore
vendored
Normal file
166
.gitignore
vendored
Normal file
@ -0,0 +1,166 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# Cache
|
||||
cache/
|
||||
.cursor/
|
||||
|
||||
# Models
|
||||
models/bge-m3/
|
||||
models/bge-reranker-base/
|
||||
105
__init__.py
Normal file
105
__init__.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""
|
||||
BGE Fine-tuning Package
|
||||
A comprehensive implementation for fine-tuning BAAI BGE models
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "ldy"
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mandatory dependency check
|
||||
def _check_dependencies():
|
||||
"""Ensure core dependencies are available."""
|
||||
try:
|
||||
_ = torch.__version__
|
||||
except Exception as exc: # pragma: no cover - should not happen in tests
|
||||
raise ImportError("PyTorch is required for bge_finetune") from exc
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
except Exception as exc: # pragma: no cover - should not happen in tests
|
||||
raise ImportError("transformers is required for bge_finetune") from exc
|
||||
|
||||
# Import guard for optional dependencies
|
||||
|
||||
def _check_optional_dependencies():
|
||||
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
|
||||
# Check for torch_npu if running on Ascend
|
||||
try:
|
||||
npu_available = hasattr(torch, 'npu') and torch.npu.is_available()
|
||||
except Exception:
|
||||
npu_available = False
|
||||
if npu_available or os.environ.get('DEVICE', '').lower() == 'npu':
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'torch_npu' not found. Ascend NPU support will be unavailable.")
|
||||
|
||||
# Check for wandb if Weights & Biases logging is enabled
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
use_wandb = False
|
||||
try:
|
||||
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if use_wandb:
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'wandb' not found. Weights & Biases logging will be unavailable.")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for tensorboard if tensorboard logging is likely to be used
|
||||
tensorboard_dir = None
|
||||
try:
|
||||
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if tensorboard_dir:
|
||||
try:
|
||||
import tensorboard
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'tensorboard' not found. TensorBoard logging will be unavailable.")
|
||||
|
||||
# Check for pytest if running tests
|
||||
if any('pytest' in arg for arg in sys.argv):
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'pytest' not found. Some tests may not run.")
|
||||
|
||||
|
||||
# Check dependencies on import
|
||||
_check_dependencies()
|
||||
_check_optional_dependencies()
|
||||
|
||||
# Make key components available at package level
|
||||
try:
|
||||
from .config import BaseConfig, BGEM3Config, BGERerankerConfig
|
||||
from .models import BGEM3Model, BGERerankerModel
|
||||
from .utils import setup_logging, get_logger
|
||||
|
||||
__all__ = [
|
||||
'BaseConfig',
|
||||
'BGEM3Config',
|
||||
'BGERerankerConfig',
|
||||
'BGEM3Model',
|
||||
'BGERerankerModel',
|
||||
'setup_logging',
|
||||
'get_logger'
|
||||
]
|
||||
except ImportError as e:
|
||||
# Allow partial imports during development
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Some components not available: {e}")
|
||||
__all__ = []
|
||||
292
config.toml
Normal file
292
config.toml
Normal file
@ -0,0 +1,292 @@
|
||||
[api]
|
||||
# API settings for data augmentation (used in data/augmentation.py)
|
||||
# NOTE: You can override api_key by setting the environment variable BGE_API_KEY.
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "sk-1bf2a963751_deletethis_b4e5e96d319810f8926e5" # Replace with your actual API key or set BGE_API_KEY env var
|
||||
model = "deepseek-chat"
|
||||
max_retries = 3 # Number of times to retry API calls on failure
|
||||
timeout = 30 # Timeout for API requests (seconds)
|
||||
temperature = 0.7 # Sampling temperature for augmentation
|
||||
|
||||
[training]
|
||||
# Default training parameters (overridden by [hardware] or model-specific sections if present)
|
||||
# If a parameter is present in both [training] and [hardware], [hardware] takes precedence.
|
||||
default_batch_size = 4 # Default batch size per device for training
|
||||
default_learning_rate = 1e-5 # Default learning rate
|
||||
default_num_epochs = 3 # Default number of epochs
|
||||
default_warmup_ratio = 0.1 # Warmup steps as a ratio of total steps
|
||||
default_weight_decay = 0.01 # Weight decay for optimizer
|
||||
gradient_accumulation_steps = 2 # Number of steps to accumulate gradients - reduced for small datasets
|
||||
|
||||
[model_paths]
|
||||
# Model source: 'huggingface' or 'modelscope'. Determines which hub to use for downloading/loading models.
|
||||
source = "huggingface" # Options: 'huggingface', 'modelscope'
|
||||
bge_m3 = "./models/bge-m3"
|
||||
bge_reranker = "./models/bge-reranker-base"
|
||||
# Cache directory for downloaded models (tokenizers, config files, etc.)
|
||||
cache_dir = "./cache/models"
|
||||
# To use ModelScope, set source = "modelscope" and provide the correct model id/path.
|
||||
|
||||
[data]
|
||||
# Enhanced data processing settings with unified schema support
|
||||
max_query_length = 64 # Max length for query tokens
|
||||
max_passage_length = 512 # Max length for passage tokens
|
||||
max_seq_length = 512 # Max total input sequence length
|
||||
# For testing only, in production, set to 0.3
|
||||
validation_split = 0.0 # Fraction of data for validation
|
||||
random_seed = 42 # Random seed for reproducibility
|
||||
|
||||
# Cache directory for processed/optimized datasets (.pkl files)
|
||||
cache_dir = "./cache/data"
|
||||
# Default output directory for optimization script
|
||||
optimization_output_dir = "./cache/data"
|
||||
|
||||
# New: Enhanced dataset format settings
|
||||
auto_detect_format = true # Enable automatic format detection
|
||||
legacy_support = true # Support legacy nested reranker format
|
||||
migrate_legacy_on_load = false # Automatically migrate legacy data during loading
|
||||
embedding_schema_version = "v3" # Use enhanced embedding schema
|
||||
reranker_schema_version = "v3" # Use enhanced reranker schema
|
||||
|
||||
# New: Enhanced sampling settings
|
||||
use_score_weighted_sampling = true # Enable score-weighted positive sampling
|
||||
multiple_positives_per_batch = true # Use multiple positives for richer contrastive signals
|
||||
hard_negative_ratio = 0.7 # Ratio of hard to random negatives
|
||||
difficulty_threshold = 0.1 # Minimum difficulty threshold for hard negatives
|
||||
|
||||
[augmentation]
|
||||
# Data augmentation settings (used in data/augmentation.py)
|
||||
enable_query_augmentation = false # Enable query augmentation
|
||||
enable_passage_augmentation = false # Enable passage augmentation
|
||||
augmentation_methods = ["paraphrase", "back_translation"] # List of augmentation methods
|
||||
api_first = true # Prefer LLM API for all augmentation if enabled; fallback to rule-based if API unavailable
|
||||
# Augmentation is now language-aware: prompts and templates are selected based on query language (Chinese/English)
|
||||
num_augmentations_per_sample = 2 # Number of augmentations per sample
|
||||
augmentation_temperature = 0.8 # Sampling temperature for augmentation
|
||||
|
||||
[hard_negative_mining]
|
||||
# ANCE hard negative mining settings (used in data/hard_negative_mining.py)
|
||||
enable_mining = true # Enable hard negative mining
|
||||
num_hard_negatives = 15 # Number of hard negatives to mine
|
||||
negative_range_min = 10 # Min index for negative mining
|
||||
negative_range_max = 100 # Max index for negative mining
|
||||
margin = 0.1 # Margin for ANCE mining
|
||||
index_refresh_interval = 1000 # Steps between index refreshes
|
||||
index_type = "IVF" # FAISS index type
|
||||
nlist = 100 # Number of clusters for IVF
|
||||
|
||||
[evaluation]
|
||||
# Evaluation settings (used in evaluation/evaluator.py)
|
||||
eval_batch_size = 32 # Batch size for evaluation
|
||||
k_values = [1, 3, 5, 10, 20, 50, 100] # k values for metrics
|
||||
metrics = ["recall", "precision", "map", "mrr", "ndcg"] # Metrics to compute
|
||||
save_predictions = true # Save predictions to file
|
||||
predictions_dir = "./output/predictions" # Directory for predictions
|
||||
|
||||
[logging]
|
||||
# Logging configuration (used in utils/logging.py)
|
||||
log_level = "INFO" # Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_to_file = true # Log to file in addition to stdout
|
||||
log_dir = "./logs" # Directory for log files
|
||||
tensorboard_dir = "./logs/tensorboard" # Directory for TensorBoard logs
|
||||
wandb_project = "bge-finetuning" # Weights & Biases project name
|
||||
wandb_entity = "" # Your wandb entity
|
||||
use_wandb = false # Enable Weights & Biases logging
|
||||
|
||||
[distributed]
|
||||
# Distributed training settings (used in utils/distributed.py)
|
||||
# backend: 'nccl' for CUDA GPUs, 'hccl' for Huawei Ascend NPUs, 'gloo' for CPU
|
||||
backend = "nccl" # Backend for distributed training (nccl for CUDA, hccl for Ascend, gloo for CPU)
|
||||
init_method = "env://" # Initialization method for distributed training
|
||||
find_unused_parameters = true # Find unused parameters in DDP
|
||||
|
||||
[ascend]
|
||||
# Ascend NPU-specific settings (used for Huawei hardware)
|
||||
# Set device_type to 'npu' to enable Ascend support
|
||||
# backend: 'hccl' is required for distributed training on Ascend
|
||||
# mixed_precision: true enables float16/bfloat16 training if supported
|
||||
# Example usage:
|
||||
# device_type = "npu"
|
||||
# backend = "hccl"
|
||||
# mixed_precision = true
|
||||
# visible_devices = "0,1,2,3" # Comma-separated list of NPU device IDs
|
||||
# dataloader_num_workers = 4
|
||||
# per_device_train_batch_size = 4
|
||||
# per_device_eval_batch_size = 8
|
||||
# Uncomment and set as needed:
|
||||
# device_type = "npu"
|
||||
# backend = "hccl"
|
||||
# mixed_precision = true
|
||||
# visible_devices = "0,1,2,3"
|
||||
# dataloader_num_workers = 4
|
||||
# per_device_train_batch_size = 4
|
||||
# per_device_eval_batch_size = 8
|
||||
|
||||
[optimization]
|
||||
# Advanced optimization settings (used in models/losses.py, training/*_trainer.py)
|
||||
gradient_checkpointing = true # Enable gradient checkpointing for memory savings
|
||||
fp16 = true # Use mixed precision (float16) training
|
||||
bf16 = false # Set to true for A100/H100 GPUs (bfloat16)
|
||||
fp16_opt_level = "O1" # Apex AMP optimization level
|
||||
max_grad_norm = 1.0 # Max gradient norm for clipping
|
||||
adam_epsilon = 1e-8 # Epsilon for Adam optimizer
|
||||
adam_beta1 = 0.9 # Beta1 for Adam optimizer
|
||||
adam_beta2 = 0.999 # Beta2 for Adam optimizer
|
||||
|
||||
[hardware]
|
||||
# Hardware-specific settings (overrides [training] for batch sizes, etc.)
|
||||
cuda_visible_devices = "0,1,2,3" # Comma-separated list of visible CUDA devices
|
||||
dataloader_num_workers = 4 # Number of workers for DataLoader
|
||||
dataloader_pin_memory = true # Use pinned memory in DataLoader
|
||||
per_device_train_batch_size = 4 # Batch size per device for training (overrides [training])
|
||||
per_device_eval_batch_size = 8 # Batch size per device for evaluation
|
||||
|
||||
[checkpoint]
|
||||
# Checkpointing settings (used in training/*_trainer.py)
|
||||
save_strategy = "steps" # Save checkpoint by steps or epoch
|
||||
save_steps = 1000 # Steps between checkpoints
|
||||
save_total_limit = 3 # Max number of checkpoints to keep
|
||||
load_best_model_at_end = true # Load best model at end of training
|
||||
metric_for_best_model = "eval_recall@10" # Metric to select best model
|
||||
greater_is_better = true # Whether higher metric is better
|
||||
resume_from_checkpoint = "" # Path to checkpoint to resume from
|
||||
|
||||
[export]
|
||||
# Model export settings (used in scripts/export.py)
|
||||
export_format = ["pytorch", "onnx"] # Export formats
|
||||
optimize_for_inference = true # Optimize model for inference
|
||||
quantize = false # Enable quantization
|
||||
quantization_bits = 8 # Number of bits for quantization
|
||||
|
||||
[m3]
|
||||
# Enhanced BGE-M3 model-specific parameters with new features
|
||||
# If a parameter is present in both [m3] and [training]/[hardware], [m3] takes precedence for BGE-M3.
|
||||
use_dense = true # Enable dense embedding head
|
||||
use_sparse = true # Enable sparse (SPLADE-style) embedding head
|
||||
use_colbert = true # Enable ColBERT multi-vector head
|
||||
unified_finetuning = true # Unified fine-tuning for all heads
|
||||
sentence_pooling_method = "cls" # Pooling method for sentence embeddings
|
||||
normalize_embeddings = true # Normalize output embeddings
|
||||
colbert_dim = -1 # Output dimension for ColBERT head (-1 uses model hidden size)
|
||||
sparse_top_k = 100 # Top-k for sparse representations
|
||||
|
||||
# Training
|
||||
train_group_size = 8 # Number of passages per query in training
|
||||
temperature = 0.1 # InfoNCE loss temperature - increased from 0.02 for better stability
|
||||
negatives_cross_device = true # Share negatives across devices/GPUs
|
||||
|
||||
# Enhanced hard negative mining
|
||||
use_hard_negatives = true # Enable hard negative mining
|
||||
num_hard_negatives = 15 # Number of hard negatives to mine
|
||||
hard_negative_mining_steps = 1000 # Steps between hard negative mining refreshes
|
||||
ance_negative_range = [10, 100] # Range for ANCE negative mining
|
||||
ance_margin = 0.1 # Margin for ANCE mining
|
||||
hard_negative_ratio = 0.7 # Ratio of hard to random negatives
|
||||
difficulty_threshold = 0.1 # Minimum difficulty for hard negatives
|
||||
max_triplets_per_query = 8 # Maximum triplets per query for training
|
||||
|
||||
# Enhanced self-distillation
|
||||
use_self_distill = false # Enable self-knowledge distillation - changed default to false
|
||||
self_distill_temperature = 4.0 # Temperature for self-distillation
|
||||
self_distill_alpha = 0.5 # Alpha (weight) for self-distillation loss
|
||||
|
||||
# Enhanced query instruction
|
||||
use_query_instruction = false # Add instruction to queries
|
||||
query_instruction_for_retrieval = "Represent this sentence for searching relevant passages:"
|
||||
query_instruction_for_reranking = ""
|
||||
|
||||
# Enhanced data augmentation
|
||||
augment_queries = false # Enable query augmentation
|
||||
augmentation_methods = ["paraphrase", "back_translation"] # List of augmentation methods
|
||||
create_hard_negatives = false # Create hard negatives via augmentation
|
||||
|
||||
# Enhanced sampling features
|
||||
use_score_weighted_sampling = true # Enable score-weighted positive sampling
|
||||
multiple_positives = true # Use multiple positives per batch for richer contrastive signals
|
||||
enhanced_contrastive_learning = true # Enable enhanced contrastive learning features
|
||||
|
||||
# Loss weights
|
||||
# For multi-task learning, these weights control the contribution of each head/loss
|
||||
dense_weight = 1.0
|
||||
sparse_weight = 0.3
|
||||
colbert_weight = 0.5
|
||||
distillation_weight = 1.0
|
||||
|
||||
# Evaluation
|
||||
metric_for_best_model = "eval_recall@10"
|
||||
greater_is_better = true
|
||||
|
||||
# Memory optimization
|
||||
use_gradient_checkpointing = true # Enable gradient checkpointing
|
||||
max_passages_per_query = 32 # Max passages per query in memory
|
||||
|
||||
# Multi-granularity
|
||||
max_query_length = 64 # Max query length
|
||||
max_passage_length = 512 # Max passage length
|
||||
max_length = 8192 # Max total input length (multi-granularity)
|
||||
|
||||
# New: Enhanced dataset integration
|
||||
use_new_dataset_api = true # Use integrated dataset API with format detection
|
||||
auto_migrate_legacy = false # Automatically migrate legacy formats during training
|
||||
|
||||
[reranker]
|
||||
# Enhanced BGE-Reranker model-specific parameters with flat pairs support
|
||||
# If a parameter is present in both [reranker] and [training]/[hardware], [reranker] takes precedence for BGE-Reranker.
|
||||
add_pooling_layer = false # Not used for cross-encoder
|
||||
use_fp16 = true # Use mixed precision (float16) training
|
||||
train_group_size = 16 # Number of passages per query for listwise training
|
||||
loss_type = "listwise_ce" # Loss function type ('listwise_ce', 'pairwise_margin', 'combined', 'cross_entropy', 'binary_cross_entropy')
|
||||
margin = 1.0 # Margin for pairwise loss
|
||||
temperature = 1.0 # Temperature for listwise CE loss
|
||||
|
||||
# Enhanced knowledge distillation
|
||||
use_knowledge_distillation = false # Enable knowledge distillation from teacher model
|
||||
teacher_model_name_or_path = "" # Path to teacher model for distillation
|
||||
distillation_temperature = 3.0 # Temperature for distillation
|
||||
distillation_alpha = 0.7 # Alpha (weight) for distillation loss
|
||||
|
||||
# Enhanced hard negatives
|
||||
use_hard_negatives = true # Enable hard negative mining
|
||||
hard_negative_ratio = 0.8 # Ratio of hard to random negatives
|
||||
|
||||
# Enhanced denoising
|
||||
use_denoising = false # Enable denoising strategy
|
||||
noise_probability = 0.1 # Probability of noise for denoising
|
||||
|
||||
# Performance optimizations
|
||||
max_pairs_per_device = 128 # Max query-passage pairs per device
|
||||
accumulate_gradients_from_all_devices = true # Accumulate gradients from all devices
|
||||
max_seq_length = 512 # Max combined query+passage length
|
||||
query_max_length = 64 # Max query length (if separate encoding)
|
||||
passage_max_length = 512 # Max passage length (if separate encoding)
|
||||
|
||||
# Enhanced training features
|
||||
shuffle_passages = true # Shuffle passage order during training
|
||||
position_bias_cutoff = 10 # Top-k passages to consider for position bias
|
||||
use_gradient_caching = true # Enable gradient caching for efficiency
|
||||
gradient_cache_batch_size = 32 # Batch size for gradient caching
|
||||
eval_group_size = 100 # Number of candidates during evaluation
|
||||
metric_for_best_model = "eval_mrr@10" # Metric to select best model
|
||||
eval_normalize_scores = true # Normalize scores during evaluation
|
||||
|
||||
# Loss weights for combined training
|
||||
listwise_weight = 0.7 # Loss weight for listwise loss
|
||||
pairwise_weight = 0.3 # Loss weight for pairwise loss
|
||||
|
||||
# Data quality controls
|
||||
filter_empty_passages = true # Filter out empty passages
|
||||
min_passage_length = 10 # Minimum passage length to keep
|
||||
max_passage_length_ratio = 50.0 # Max ratio between query and passage length
|
||||
|
||||
# Advanced features
|
||||
gradient_checkpointing = true # Enable gradient checkpointing
|
||||
recompute_scores = false # Recompute scores instead of storing all
|
||||
use_dynamic_teacher = false # Enable dynamic teacher for RocketQAv2-style training
|
||||
teacher_update_steps = 1000 # Steps between teacher updates
|
||||
teacher_momentum = 0.999 # Momentum for teacher updates
|
||||
|
||||
# New: Enhanced format support
|
||||
support_flat_pairs = true # Support flat pairs format (query, passage, label)
|
||||
support_nested_format = true # Support legacy nested format for backward compatibility
|
||||
auto_convert_nested_to_flat = false # Automatically convert nested to flat during training
|
||||
use_new_dataset_api = true # Use integrated dataset API with format detection
|
||||
13
config/__init__.py
Normal file
13
config/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""
|
||||
Configuration module for BGE fine-tuning
|
||||
"""
|
||||
from .base_config import BaseConfig
|
||||
from .m3_config import BGEM3Config
|
||||
from .reranker_config import BGERerankerConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BaseConfig',
|
||||
'BGEM3Config',
|
||||
'BGERerankerConfig'
|
||||
]
|
||||
560
config/base_config.py
Normal file
560
config/base_config.py
Normal file
@ -0,0 +1,560 @@
|
||||
import os
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
"""
|
||||
Base configuration for all trainers and models.
|
||||
|
||||
Configuration Precedence Order:
|
||||
1. Command-line arguments (highest priority)
|
||||
2. Model-specific config sections ([m3], [reranker])
|
||||
3. [hardware] section (for hardware-specific overrides)
|
||||
4. [training] section (general training defaults)
|
||||
5. Dataclass field defaults (lowest priority)
|
||||
|
||||
Example:
|
||||
If batch_size is defined in multiple places:
|
||||
- CLI arg --batch_size 32 (wins)
|
||||
- [m3] batch_size = 16
|
||||
- [hardware] per_device_train_batch_size = 8
|
||||
- [training] default_batch_size = 4
|
||||
- Dataclass default = 2
|
||||
|
||||
Note: Model-specific configs ([m3], [reranker]) take precedence over
|
||||
[hardware] for their respective models to allow fine-grained control.
|
||||
"""
|
||||
temperature: float = 1.0 # Softmax temperature for loss functions
|
||||
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
|
||||
negatives_cross_device: bool = False # Use negatives across devices (DDP)
|
||||
label_smoothing: float = 0.0 # Label smoothing for classification losses
|
||||
alpha: float = 1.0 # Weight for auxiliary losses
|
||||
gradient_accumulation_steps: int = 1 # Steps to accumulate gradients
|
||||
num_train_epochs: int = 3 # Number of training epochs
|
||||
eval_steps: int = 1 # Evaluate every N epochs
|
||||
save_steps: int = 1 # Save checkpoint every N epochs
|
||||
warmup_steps: int = 0 # Number of warmup steps for scheduler
|
||||
warmup_ratio: float = 0.0 # Warmup ratio for scheduler
|
||||
metric_for_best_model: str = "accuracy" # Metric to select best model
|
||||
checkpoint_dir: str = "./checkpoints" # Directory for saving checkpoints
|
||||
resume_from_checkpoint: str = "" # Path to resume checkpoint
|
||||
max_grad_norm: float = 1.0 # Max gradient norm for clipping
|
||||
use_amp: bool = False # Use automatic mixed precision
|
||||
|
||||
# Model configuration
|
||||
model_name_or_path: str = "BAAI/bge-m3"
|
||||
cache_dir: Optional[str] = None
|
||||
|
||||
# Data configuration
|
||||
train_data_path: str = "./data/train.jsonl"
|
||||
eval_data_path: Optional[str] = "./data/eval.jsonl"
|
||||
max_seq_length: int = 512
|
||||
query_max_length: int = 64
|
||||
passage_max_length: int = 512
|
||||
|
||||
# Training configuration
|
||||
output_dir: str = "./output"
|
||||
learning_rate: float = 1e-5
|
||||
weight_decay: float = 0.01
|
||||
adam_epsilon: float = 1e-8
|
||||
adam_beta1: float = 0.9
|
||||
adam_beta2: float = 0.999
|
||||
|
||||
# Advanced training configuration
|
||||
fp16: bool = True
|
||||
bf16: bool = False
|
||||
gradient_checkpointing: bool = True
|
||||
deepspeed: Optional[str] = None
|
||||
local_rank: int = -1
|
||||
dataloader_num_workers: int = 4
|
||||
dataloader_pin_memory: bool = True
|
||||
|
||||
# Batch sizes
|
||||
per_device_train_batch_size: int = 4
|
||||
per_device_eval_batch_size: int = 8
|
||||
|
||||
# Logging and saving
|
||||
logging_dir: str = "./logs"
|
||||
# For testing only, in production, set to 100
|
||||
logging_steps: int = 100
|
||||
save_total_limit: int = 3
|
||||
load_best_model_at_end: bool = True
|
||||
greater_is_better: bool = False
|
||||
evaluation_strategy: str = "steps"
|
||||
|
||||
# Checkpoint and resume
|
||||
overwrite_output_dir: bool = False
|
||||
|
||||
# Seed
|
||||
seed: int = 42
|
||||
|
||||
# Platform compatibility
|
||||
device: str = "auto"
|
||||
n_gpu: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization with robust config loading and device setup
|
||||
Notes:
|
||||
- [hardware] section overrides [training] for overlapping parameters.
|
||||
- Model-specific configs ([m3], [reranker]) override [training]/[hardware] for their models.
|
||||
"""
|
||||
# Load configuration from centralized config system
|
||||
config = self._get_config_safely()
|
||||
|
||||
# Apply configuration values with fallbacks
|
||||
self._apply_config_values(config)
|
||||
|
||||
# Setup device detection
|
||||
self._setup_device()
|
||||
|
||||
def _get_config_safely(self):
|
||||
"""Safely get configuration with multiple fallback strategies"""
|
||||
try:
|
||||
# Try absolute import
|
||||
from utils.config_loader import get_config
|
||||
return get_config()
|
||||
except ImportError:
|
||||
try:
|
||||
# Try absolute import (for scripts)
|
||||
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
||||
return get_config()
|
||||
except ImportError:
|
||||
try:
|
||||
# Try direct import (when run as script from root)
|
||||
from utils.config_loader import get_config
|
||||
return get_config()
|
||||
except ImportError:
|
||||
logger.error("Could not import config_loader, using fallback configuration. Please check your environment and config loader path.")
|
||||
return self._create_fallback_config()
|
||||
|
||||
def _create_fallback_config(self):
|
||||
"""Create a minimal fallback configuration object"""
|
||||
|
||||
class FallbackConfig:
|
||||
def __init__(self):
|
||||
self._config = {
|
||||
'training': {
|
||||
'default_batch_size': 4,
|
||||
'default_learning_rate': 1e-5,
|
||||
'default_num_epochs': 3,
|
||||
'default_warmup_ratio': 0.1,
|
||||
'default_weight_decay': 0.01,
|
||||
'gradient_accumulation_steps': 4
|
||||
},
|
||||
'model_paths': {
|
||||
'cache_dir': './cache/data'
|
||||
},
|
||||
'data': {
|
||||
'max_seq_length': 512,
|
||||
'max_query_length': 64,
|
||||
'max_passage_length': 512,
|
||||
'random_seed': 42
|
||||
},
|
||||
'hardware': {
|
||||
'dataloader_num_workers': 4,
|
||||
'dataloader_pin_memory': True,
|
||||
'per_device_train_batch_size': 4,
|
||||
'per_device_eval_batch_size': 8
|
||||
},
|
||||
'optimization': {
|
||||
'fp16': True,
|
||||
'bf16': False,
|
||||
'gradient_checkpointing': True,
|
||||
'max_grad_norm': 1.0,
|
||||
'adam_epsilon': 1e-8,
|
||||
'adam_beta1': 0.9,
|
||||
'adam_beta2': 0.999
|
||||
},
|
||||
'checkpoint': {
|
||||
'save_steps': 1000,
|
||||
'save_total_limit': 3,
|
||||
'load_best_model_at_end': True,
|
||||
'metric_for_best_model': 'eval_loss',
|
||||
'greater_is_better': False
|
||||
},
|
||||
'logging': {
|
||||
'log_dir': './logs'
|
||||
}
|
||||
}
|
||||
|
||||
def get(self, key, default=None):
|
||||
keys = key.split('.')
|
||||
value = self._config
|
||||
try:
|
||||
for k in keys:
|
||||
value = value[k]
|
||||
return value
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def get_section(self, section):
|
||||
return self._config.get(section, {})
|
||||
|
||||
return FallbackConfig()
|
||||
|
||||
def _apply_config_values(self, config):
|
||||
"""
|
||||
Apply configuration values with proper precedence order.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. Dataclass field defaults (already set)
|
||||
2. [training] section values
|
||||
3. [hardware] section values (override training)
|
||||
4. Model-specific sections (handled in subclasses)
|
||||
5. Command-line arguments (handled by training scripts)
|
||||
|
||||
This method handles steps 2-3. Model-specific configs are applied
|
||||
in subclass __post_init__ methods AFTER this base method runs.
|
||||
"""
|
||||
try:
|
||||
# === STEP 2: Apply [training] section (general defaults) ===
|
||||
training_config = config.get_section('training')
|
||||
if training_config:
|
||||
# Only update if value exists in config
|
||||
if 'default_num_epochs' in training_config:
|
||||
old_value = self.num_train_epochs
|
||||
self.num_train_epochs = training_config['default_num_epochs']
|
||||
if old_value != self.num_train_epochs:
|
||||
logger.debug(f"[training] num_train_epochs: {old_value} -> {self.num_train_epochs}")
|
||||
|
||||
if 'default_batch_size' in training_config:
|
||||
self.per_device_train_batch_size = training_config['default_batch_size']
|
||||
logger.debug(f"[training] batch_size set to {self.per_device_train_batch_size}")
|
||||
|
||||
if 'default_learning_rate' in training_config:
|
||||
self.learning_rate = training_config['default_learning_rate']
|
||||
|
||||
if 'default_warmup_ratio' in training_config:
|
||||
self.warmup_ratio = training_config['default_warmup_ratio']
|
||||
|
||||
if 'gradient_accumulation_steps' in training_config:
|
||||
self.gradient_accumulation_steps = training_config['gradient_accumulation_steps']
|
||||
|
||||
if 'default_weight_decay' in training_config:
|
||||
self.weight_decay = training_config['default_weight_decay']
|
||||
|
||||
# === STEP 3: Apply [hardware] section (overrides training) ===
|
||||
hardware_config = config.get_section('hardware')
|
||||
if hardware_config:
|
||||
# Hardware settings take precedence over training defaults
|
||||
if 'dataloader_num_workers' in hardware_config:
|
||||
self.dataloader_num_workers = hardware_config['dataloader_num_workers']
|
||||
|
||||
if 'dataloader_pin_memory' in hardware_config:
|
||||
self.dataloader_pin_memory = hardware_config['dataloader_pin_memory']
|
||||
|
||||
# Batch sizes from hardware override training defaults
|
||||
if 'per_device_train_batch_size' in hardware_config:
|
||||
old_value = self.per_device_train_batch_size
|
||||
self.per_device_train_batch_size = hardware_config['per_device_train_batch_size']
|
||||
if old_value != self.per_device_train_batch_size:
|
||||
logger.info(f"[hardware] overrides batch_size: {old_value} -> {self.per_device_train_batch_size}")
|
||||
|
||||
if 'per_device_eval_batch_size' in hardware_config:
|
||||
self.per_device_eval_batch_size = hardware_config['per_device_eval_batch_size']
|
||||
|
||||
# === Apply other configuration sections ===
|
||||
|
||||
# Model paths
|
||||
if self.cache_dir is None:
|
||||
cache_dir = config.get('model_paths.cache_dir')
|
||||
if cache_dir is not None:
|
||||
self.cache_dir = cache_dir
|
||||
logger.debug(f"cache_dir set from config: {self.cache_dir}")
|
||||
|
||||
# Data configuration
|
||||
data_config = config.get_section('data')
|
||||
if data_config:
|
||||
if 'max_seq_length' in data_config:
|
||||
self.max_seq_length = data_config['max_seq_length']
|
||||
if 'max_query_length' in data_config:
|
||||
self.query_max_length = data_config['max_query_length']
|
||||
if 'max_passage_length' in data_config:
|
||||
self.passage_max_length = data_config['max_passage_length']
|
||||
if 'random_seed' in data_config:
|
||||
self.seed = data_config['random_seed']
|
||||
|
||||
# Optimization configuration
|
||||
opt_config = config.get_section('optimization')
|
||||
if opt_config:
|
||||
if 'fp16' in opt_config:
|
||||
self.fp16 = opt_config['fp16']
|
||||
if 'bf16' in opt_config:
|
||||
self.bf16 = opt_config['bf16']
|
||||
if 'gradient_checkpointing' in opt_config:
|
||||
self.gradient_checkpointing = opt_config['gradient_checkpointing']
|
||||
if 'max_grad_norm' in opt_config:
|
||||
self.max_grad_norm = opt_config['max_grad_norm']
|
||||
if 'adam_epsilon' in opt_config:
|
||||
self.adam_epsilon = opt_config['adam_epsilon']
|
||||
if 'adam_beta1' in opt_config:
|
||||
self.adam_beta1 = opt_config['adam_beta1']
|
||||
if 'adam_beta2' in opt_config:
|
||||
self.adam_beta2 = opt_config['adam_beta2']
|
||||
|
||||
# Checkpoint configuration
|
||||
checkpoint_config = config.get_section('checkpoint')
|
||||
if checkpoint_config:
|
||||
if 'save_steps' in checkpoint_config:
|
||||
self.save_steps = checkpoint_config['save_steps']
|
||||
if 'save_total_limit' in checkpoint_config:
|
||||
self.save_total_limit = checkpoint_config['save_total_limit']
|
||||
if 'load_best_model_at_end' in checkpoint_config:
|
||||
self.load_best_model_at_end = checkpoint_config['load_best_model_at_end']
|
||||
if 'metric_for_best_model' in checkpoint_config:
|
||||
self.metric_for_best_model = checkpoint_config['metric_for_best_model']
|
||||
if 'greater_is_better' in checkpoint_config:
|
||||
self.greater_is_better = checkpoint_config['greater_is_better']
|
||||
|
||||
# Logging configuration
|
||||
logging_config = config.get_section('logging')
|
||||
if logging_config:
|
||||
if 'log_dir' in logging_config:
|
||||
self.logging_dir = logging_config['log_dir']
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying configuration values: {e}")
|
||||
|
||||
def _setup_device(self):
|
||||
"""Enhanced device detection with Ascend NPU support"""
|
||||
if self.device == "auto":
|
||||
device_info = self._detect_available_devices()
|
||||
self.device = device_info['device']
|
||||
self.n_gpu = device_info['count']
|
||||
logger.info(f"Auto-detected device: {self.device} (count: {self.n_gpu})")
|
||||
else:
|
||||
# Manual device specification
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
elif self.device == "npu":
|
||||
try:
|
||||
import torch_npu # type: ignore[import]
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and npu.is_available():
|
||||
self.n_gpu = npu.device_count()
|
||||
else:
|
||||
logger.warning("NPU requested but not available, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
self.n_gpu = 0
|
||||
except ImportError:
|
||||
logger.warning("NPU requested but torch_npu not installed, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
self.n_gpu = 0
|
||||
else:
|
||||
self.n_gpu = 0
|
||||
|
||||
def _detect_available_devices(self) -> Dict[str, Any]:
|
||||
"""Comprehensive device detection with priority ordering"""
|
||||
device_priority = ['npu', 'cuda', 'mps', 'cpu']
|
||||
available_devices = []
|
||||
|
||||
# Check Ascend NPU
|
||||
try:
|
||||
import torch_npu # type: ignore[import]
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and npu.is_available():
|
||||
count = npu.device_count()
|
||||
available_devices.append(('npu', count))
|
||||
logger.debug(f"Ascend NPU available: {count} devices")
|
||||
except ImportError:
|
||||
logger.debug("Ascend NPU library not available")
|
||||
|
||||
# Check CUDA
|
||||
if torch.cuda.is_available():
|
||||
count = torch.cuda.device_count()
|
||||
available_devices.append(('cuda', count))
|
||||
logger.debug(f"CUDA available: {count} devices")
|
||||
|
||||
# Check MPS (Apple Silicon)
|
||||
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and hasattr(torch.backends.mps, 'is_available') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
||||
available_devices.append(('mps', 1))
|
||||
logger.debug("Apple MPS available")
|
||||
|
||||
# CPU always available
|
||||
available_devices.append(('cpu', 1))
|
||||
|
||||
# Select best available device based on priority
|
||||
for device_type in device_priority:
|
||||
for available_type, count in available_devices:
|
||||
if available_type == device_type:
|
||||
return {'device': device_type, 'count': count}
|
||||
|
||||
return {'device': 'cpu', 'count': 1} # Fallback
|
||||
|
||||
def validate(self):
|
||||
"""Comprehensive configuration validation"""
|
||||
errors = []
|
||||
# Check required paths
|
||||
if self.train_data_path and not os.path.exists(self.train_data_path):
|
||||
errors.append(f"Training data not found: {self.train_data_path}")
|
||||
if self.eval_data_path and not os.path.exists(self.eval_data_path):
|
||||
errors.append(f"Evaluation data not found: {self.eval_data_path}")
|
||||
# Validate numeric parameters
|
||||
if self.learning_rate is None or self.learning_rate <= 0:
|
||||
errors.append("Learning rate must be positive and not None")
|
||||
if self.num_train_epochs is None or self.num_train_epochs <= 0:
|
||||
errors.append("Number of epochs must be positive and not None")
|
||||
if self.per_device_train_batch_size is None or self.per_device_train_batch_size <= 0:
|
||||
errors.append("Training batch size must be positive and not None")
|
||||
if self.gradient_accumulation_steps is None or self.gradient_accumulation_steps <= 0:
|
||||
errors.append("Gradient accumulation steps must be positive and not None")
|
||||
if self.warmup_ratio is None or self.warmup_ratio < 0 or self.warmup_ratio > 1:
|
||||
errors.append("Warmup ratio must be between 0 and 1 and not None")
|
||||
# Validate device
|
||||
if self.device not in ['auto', 'cpu', 'cuda', 'npu', 'mps']:
|
||||
errors.append(f"Invalid device: {self.device}")
|
||||
# Check sequence lengths
|
||||
if self.max_seq_length is None or self.max_seq_length <= 0:
|
||||
errors.append("Maximum sequence length must be positive and not None")
|
||||
if self.query_max_length is None or self.query_max_length <= 0:
|
||||
errors.append("Maximum query length must be positive and not None")
|
||||
if self.passage_max_length is None or self.passage_max_length <= 0:
|
||||
errors.append("Maximum passage length must be positive and not None")
|
||||
if errors:
|
||||
raise ValueError(f"Configuration validation failed:\n" + "\n".join(f" - {error}" for error in errors))
|
||||
logger.info("Configuration validation passed")
|
||||
return True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert config to dictionary with type handling"""
|
||||
config_dict = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config_dict[key] = value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
config_dict[key] = list(value)
|
||||
elif isinstance(value, dict):
|
||||
config_dict[key] = dict(value)
|
||||
else:
|
||||
config_dict[key] = str(value)
|
||||
return config_dict
|
||||
|
||||
def save(self, save_path: str):
|
||||
"""Save configuration to file with error handling"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
import json
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Configuration saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save configuration: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path: str):
|
||||
"""Load configuration from file with validation"""
|
||||
try:
|
||||
import json
|
||||
with open(load_path, 'r', encoding='utf-8') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Filter only valid field names
|
||||
import inspect
|
||||
sig = inspect.signature(cls)
|
||||
valid_keys = set(sig.parameters.keys())
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
||||
|
||||
config = cls(**filtered_dict)
|
||||
config.validate() # Validate after loading
|
||||
|
||||
logger.info(f"Configuration loaded from {load_path}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load configuration from {load_path}: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]):
|
||||
"""Create config from dictionary with validation"""
|
||||
try:
|
||||
# Filter only valid field names
|
||||
import inspect
|
||||
sig = inspect.signature(cls)
|
||||
valid_keys = set(sig.parameters.keys())
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
||||
|
||||
config = cls(**filtered_dict)
|
||||
config.validate() # Validate after creation
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create configuration from dictionary: {e}")
|
||||
raise
|
||||
|
||||
def update_from_args(self, args):
|
||||
"""Update config from command line arguments with validation"""
|
||||
try:
|
||||
updated_fields = []
|
||||
for key, value in vars(args).items():
|
||||
if hasattr(self, key) and value is not None:
|
||||
old_value = getattr(self, key)
|
||||
setattr(self, key, value)
|
||||
updated_fields.append(f"{key}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
logger.info("Updated configuration from command line arguments:")
|
||||
for field in updated_fields:
|
||||
logger.info(f" {field}")
|
||||
|
||||
# Validate after updates
|
||||
self.validate()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update configuration from arguments: {e}")
|
||||
raise
|
||||
|
||||
def get_effective_batch_size(self) -> int:
|
||||
"""Calculate effective batch size considering gradient accumulation"""
|
||||
return self.per_device_train_batch_size * self.gradient_accumulation_steps * max(1, self.n_gpu)
|
||||
|
||||
def get_training_steps(self, dataset_size: int) -> int:
|
||||
"""Calculate total training steps"""
|
||||
steps_per_epoch = dataset_size // self.get_effective_batch_size()
|
||||
return steps_per_epoch * self.num_train_epochs
|
||||
|
||||
def print_effective_config(self):
|
||||
"""Print the effective value and its source for each parameter."""
|
||||
print("Parameter | Value | Source")
|
||||
print("----------|-------|-------")
|
||||
for field in self.__dataclass_fields__:
|
||||
value = getattr(self, field)
|
||||
# For now, just print the value; source tracking can be added if needed
|
||||
print(f"{field} | {value} | (see docstring for precedence)")
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of configuration"""
|
||||
return f"BaseConfig(device={self.device}, batch_size={self.per_device_train_batch_size}, lr={self.learning_rate})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def get_device(self):
|
||||
"""Return the device string for training (cuda, npu, mps, cpu)"""
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
return "npu"
|
||||
backends = getattr(torch, "backends", None)
|
||||
if backends is not None and hasattr(backends, "mps") and backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
def is_npu(self):
|
||||
"""Return True if NPU is available and selected as device."""
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
device = self.get_device() if hasattr(self, 'get_device') else "cpu"
|
||||
return device == "npu"
|
||||
return False
|
||||
254
config/m3_config.py
Normal file
254
config/m3_config.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""
|
||||
BGE-M3 specific configuration
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict
|
||||
from .base_config import BaseConfig
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BGEM3Config(BaseConfig):
|
||||
"""
|
||||
Configuration for BGE-M3 model and trainer.
|
||||
"""
|
||||
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
|
||||
teacher_temperature: float = 1.0 # Temperature for teacher distillation
|
||||
kd_loss_weight: float = 1.0 # Weight for KD loss
|
||||
kd_loss_type: str = "mse" # Type of KD loss (e.g., mse, kl)
|
||||
hard_negative_margin: float = 0.0 # Margin for hard negative mining
|
||||
"""Configuration for BGE-M3 embedding model fine-tuning
|
||||
|
||||
Attributes:
|
||||
model_name_or_path: Path or identifier for the pretrained BGE-M3 model.
|
||||
use_dense: Enable dense embedding head.
|
||||
use_sparse: Enable sparse (SPLADE-style) embedding head.
|
||||
use_colbert: Enable ColBERT multi-vector head.
|
||||
unified_finetuning: Unified fine-tuning for all functionalities.
|
||||
sentence_pooling_method: Pooling method for sentence embeddings (e.g., 'cls').
|
||||
normalize_embeddings: Whether to normalize output embeddings.
|
||||
colbert_dim: Output dimension for ColBERT head (-1 uses model hidden size).
|
||||
sparse_top_k: Top-k for sparse representations.
|
||||
temperature: InfoNCE loss temperature.
|
||||
train_group_size: Number of passages per query in training.
|
||||
negatives_cross_device: Share negatives across devices/GPUs.
|
||||
use_hard_negatives: Enable hard negative mining.
|
||||
num_hard_negatives: Number of hard negatives to mine.
|
||||
hard_negative_mining_steps: Steps between hard negative mining refreshes.
|
||||
ance_negative_range: Range for ANCE negative mining.
|
||||
ance_margin: Margin for ANCE mining.
|
||||
use_self_distill: Enable self-knowledge distillation.
|
||||
self_distill_temperature: Temperature for self-distillation.
|
||||
self_distill_alpha: Alpha (weight) for self-distillation loss.
|
||||
use_query_instruction: Add instruction to queries.
|
||||
query_instruction_for_retrieval: Instruction for retrieval queries.
|
||||
query_instruction_for_reranking: Instruction for reranking queries.
|
||||
augment_queries: Enable query augmentation.
|
||||
augmentation_methods: List of augmentation methods to use.
|
||||
create_hard_negatives: Create hard negatives via augmentation.
|
||||
dense_weight: Loss weight for dense head.
|
||||
sparse_weight: Loss weight for sparse head.
|
||||
colbert_weight: Loss weight for ColBERT head.
|
||||
distillation_weight: Loss weight for distillation.
|
||||
eval_strategy: Evaluation strategy ('epoch', 'steps').
|
||||
metric_for_best_model: Metric to select best model.
|
||||
greater_is_better: Whether higher metric is better.
|
||||
use_gradient_checkpointing: Enable gradient checkpointing.
|
||||
max_passages_per_query: Max passages per query in memory.
|
||||
max_query_length: Max query length.
|
||||
max_passage_length: Max passage length.
|
||||
max_length: Max total input length (multi-granularity).
|
||||
"""
|
||||
|
||||
# Model specific
|
||||
model_name_or_path: str = "BAAI/bge-m3"
|
||||
|
||||
# Multi-functionality settings
|
||||
use_dense: bool = True
|
||||
use_sparse: bool = True
|
||||
use_colbert: bool = True
|
||||
unified_finetuning: bool = True
|
||||
|
||||
# Architecture settings
|
||||
sentence_pooling_method: str = "cls" # BGE-M3 uses CLS pooling
|
||||
normalize_embeddings: bool = True
|
||||
colbert_dim: int = -1 # -1 means use model hidden size
|
||||
sparse_top_k: int = 100 # Top-k for sparse representations
|
||||
|
||||
# Training specific
|
||||
temperature: float = 0.02 # InfoNCE temperature (from BGE-M3 paper)
|
||||
train_group_size: int = 8 # Number of passages per query
|
||||
negatives_cross_device: bool = True # Share negatives across GPUs
|
||||
|
||||
# Hard negative mining
|
||||
use_hard_negatives: bool = True
|
||||
num_hard_negatives: int = 15
|
||||
hard_negative_mining_steps: int = 1000
|
||||
ance_negative_range: tuple = (10, 100)
|
||||
ance_margin: float = 0.1
|
||||
|
||||
# Self-distillation
|
||||
use_self_distill: bool = False # Changed default to False - make KD opt-in
|
||||
self_distill_temperature: float = 4.0
|
||||
self_distill_alpha: float = 0.5
|
||||
|
||||
# Query instruction
|
||||
use_query_instruction: bool = False
|
||||
query_instruction_for_retrieval: str = "Represent this sentence for searching relevant passages:"
|
||||
query_instruction_for_reranking: str = ""
|
||||
|
||||
# Data augmentation
|
||||
augment_queries: bool = False
|
||||
augmentation_methods: List[str] = field(default_factory=lambda: ["paraphrase", "back_translation"])
|
||||
create_hard_negatives: bool = False
|
||||
|
||||
# Loss weights for multi-task learning
|
||||
dense_weight: float = 1.0
|
||||
sparse_weight: float = 0.3
|
||||
colbert_weight: float = 0.5
|
||||
distillation_weight: float = 1.0
|
||||
|
||||
# Evaluation
|
||||
eval_strategy: str = "epoch"
|
||||
metric_for_best_model: str = "eval_recall@10"
|
||||
greater_is_better: bool = True
|
||||
|
||||
# Memory optimization
|
||||
use_gradient_checkpointing: bool = True
|
||||
max_passages_per_query: int = 32
|
||||
|
||||
# Specific to M3's multi-granularity
|
||||
max_query_length: int = 64
|
||||
max_passage_length: int = 512
|
||||
max_length: int = 8192 # BGE-M3 supports up to 8192 tokens
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load model-specific parameters from [m3] section in config.toml"""
|
||||
super().__post_init__()
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
except ImportError:
|
||||
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
||||
config = get_config()
|
||||
self._apply_m3_config(config)
|
||||
|
||||
def _apply_m3_config(self, config):
|
||||
"""
|
||||
Apply BGE-M3 specific configuration from [m3] section in config.toml.
|
||||
"""
|
||||
m3_config = config.get_section('m3')
|
||||
if not m3_config:
|
||||
logger.warning("No [m3] section found in config.toml")
|
||||
return
|
||||
|
||||
# Helper to convert tuple from config to list if needed
|
||||
def get_val(key, default):
|
||||
if key not in m3_config:
|
||||
return default
|
||||
v = m3_config[key]
|
||||
if isinstance(default, list) and isinstance(v, tuple):
|
||||
return list(v)
|
||||
return v
|
||||
|
||||
# Only update if key exists in config
|
||||
if 'use_dense' in m3_config:
|
||||
self.use_dense = m3_config['use_dense']
|
||||
if 'use_sparse' in m3_config:
|
||||
self.use_sparse = m3_config['use_sparse']
|
||||
if 'use_colbert' in m3_config:
|
||||
self.use_colbert = m3_config['use_colbert']
|
||||
if 'unified_finetuning' in m3_config:
|
||||
self.unified_finetuning = m3_config['unified_finetuning']
|
||||
if 'sentence_pooling_method' in m3_config:
|
||||
self.sentence_pooling_method = m3_config['sentence_pooling_method']
|
||||
if 'normalize_embeddings' in m3_config:
|
||||
self.normalize_embeddings = m3_config['normalize_embeddings']
|
||||
if 'colbert_dim' in m3_config:
|
||||
self.colbert_dim = m3_config['colbert_dim']
|
||||
if 'sparse_top_k' in m3_config:
|
||||
self.sparse_top_k = m3_config['sparse_top_k']
|
||||
if 'temperature' in m3_config:
|
||||
self.temperature = m3_config['temperature']
|
||||
if 'train_group_size' in m3_config:
|
||||
self.train_group_size = m3_config['train_group_size']
|
||||
if 'negatives_cross_device' in m3_config:
|
||||
self.negatives_cross_device = m3_config['negatives_cross_device']
|
||||
if 'use_hard_negatives' in m3_config:
|
||||
self.use_hard_negatives = m3_config['use_hard_negatives']
|
||||
if 'num_hard_negatives' in m3_config:
|
||||
self.num_hard_negatives = m3_config['num_hard_negatives']
|
||||
if 'hard_negative_mining_steps' in m3_config:
|
||||
self.hard_negative_mining_steps = m3_config['hard_negative_mining_steps']
|
||||
if 'ance_negative_range' in m3_config:
|
||||
self.ance_negative_range = tuple(get_val('ance_negative_range', self.ance_negative_range))
|
||||
if 'augmentation_methods' in m3_config:
|
||||
self.augmentation_methods = list(get_val('augmentation_methods', self.augmentation_methods))
|
||||
if 'ance_margin' in m3_config:
|
||||
self.ance_margin = m3_config['ance_margin']
|
||||
if 'use_self_distill' in m3_config:
|
||||
self.use_self_distill = m3_config['use_self_distill']
|
||||
if 'self_distill_temperature' in m3_config:
|
||||
self.self_distill_temperature = m3_config['self_distill_temperature']
|
||||
if 'self_distill_alpha' in m3_config:
|
||||
self.self_distill_alpha = m3_config['self_distill_alpha']
|
||||
if 'use_query_instruction' in m3_config:
|
||||
self.use_query_instruction = m3_config['use_query_instruction']
|
||||
if 'query_instruction_for_retrieval' in m3_config:
|
||||
self.query_instruction_for_retrieval = m3_config['query_instruction_for_retrieval']
|
||||
if 'query_instruction_for_reranking' in m3_config:
|
||||
self.query_instruction_for_reranking = m3_config['query_instruction_for_reranking']
|
||||
if 'augment_queries' in m3_config:
|
||||
self.augment_queries = m3_config['augment_queries']
|
||||
if 'create_hard_negatives' in m3_config:
|
||||
self.create_hard_negatives = m3_config['create_hard_negatives']
|
||||
if 'dense_weight' in m3_config:
|
||||
self.dense_weight = m3_config['dense_weight']
|
||||
if 'sparse_weight' in m3_config:
|
||||
self.sparse_weight = m3_config['sparse_weight']
|
||||
if 'colbert_weight' in m3_config:
|
||||
self.colbert_weight = m3_config['colbert_weight']
|
||||
if 'distillation_weight' in m3_config:
|
||||
self.distillation_weight = m3_config['distillation_weight']
|
||||
if 'metric_for_best_model' in m3_config:
|
||||
self.metric_for_best_model = m3_config['metric_for_best_model']
|
||||
if 'greater_is_better' in m3_config:
|
||||
self.greater_is_better = m3_config['greater_is_better']
|
||||
if 'use_gradient_checkpointing' in m3_config:
|
||||
self.use_gradient_checkpointing = m3_config['use_gradient_checkpointing']
|
||||
if 'max_passages_per_query' in m3_config:
|
||||
self.max_passages_per_query = m3_config['max_passages_per_query']
|
||||
if 'max_query_length' in m3_config:
|
||||
self.max_query_length = m3_config['max_query_length']
|
||||
if 'max_passage_length' in m3_config:
|
||||
self.max_passage_length = m3_config['max_passage_length']
|
||||
if 'max_length' in m3_config:
|
||||
self.max_length = m3_config['max_length']
|
||||
|
||||
# Log parameter overrides
|
||||
logger.info("Loaded BGE-M3 config from [m3] section in config.toml")
|
||||
|
||||
# Check for required parameters
|
||||
if self.temperature is None or self.temperature <= 0:
|
||||
raise ValueError("Temperature must be positive and not None")
|
||||
if self.train_group_size is None or self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2 and not None")
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration settings"""
|
||||
if not any([self.use_dense, self.use_sparse, self.use_colbert]):
|
||||
raise ValueError("At least one of use_dense, use_sparse, or use_colbert must be True")
|
||||
|
||||
if self.unified_finetuning and not all([self.use_dense, self.use_sparse, self.use_colbert]):
|
||||
logger.warning("unified_finetuning is True but not all functionalities are enabled")
|
||||
|
||||
if self.temperature <= 0:
|
||||
raise ValueError("Temperature must be positive")
|
||||
|
||||
if self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2")
|
||||
|
||||
if self.max_length > 8192:
|
||||
logger.warning("max_length > 8192 may not be supported by BGE-M3")
|
||||
|
||||
return True
|
||||
264
config/reranker_config.py
Normal file
264
config/reranker_config.py
Normal file
@ -0,0 +1,264 @@
|
||||
"""
|
||||
BGE-Reranker specific configuration
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict
|
||||
from .base_config import BaseConfig
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BGERerankerConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for BGE-Reranker model and trainer.
|
||||
"""
|
||||
num_labels: int = 1 # Number of output labels
|
||||
listwise_temperature: float = 1.0 # Temperature for listwise loss
|
||||
kd_temperature: float = 1.0 # Temperature for KD loss
|
||||
kd_loss_weight: float = 1.0 # Weight for KD loss
|
||||
denoise_confidence_threshold: float = 0.0 # Threshold for denoising
|
||||
"""Configuration for BGE-Reranker cross-encoder fine-tuning
|
||||
|
||||
Attributes:
|
||||
model_name_or_path: Path or identifier for the pretrained BGE-Reranker model.
|
||||
add_pooling_layer: Whether to add a pooling layer (not used for cross-encoder).
|
||||
use_fp16: Use mixed precision (float16) training.
|
||||
train_group_size: Number of passages per query for listwise training.
|
||||
loss_type: Loss function type ('listwise_ce', 'pairwise_margin', 'combined').
|
||||
margin: Margin for pairwise loss.
|
||||
temperature: Temperature for listwise CE loss.
|
||||
use_knowledge_distillation: Enable knowledge distillation from teacher model.
|
||||
teacher_model_name_or_path: Path to teacher model for distillation.
|
||||
distillation_temperature: Temperature for distillation.
|
||||
distillation_alpha: Alpha (weight) for distillation loss.
|
||||
use_hard_negatives: Enable hard negative mining.
|
||||
hard_negative_ratio: Ratio of hard to random negatives.
|
||||
use_denoising: Enable denoising strategy.
|
||||
noise_probability: Probability of noise for denoising.
|
||||
max_pairs_per_device: Max query-passage pairs per device.
|
||||
accumulate_gradients_from_all_devices: Accumulate gradients from all devices.
|
||||
max_seq_length: Max combined query+passage length.
|
||||
query_max_length: Max query length (if separate encoding).
|
||||
passage_max_length: Max passage length (if separate encoding).
|
||||
shuffle_passages: Shuffle passage order during training.
|
||||
position_bias_cutoff: Top-k passages to consider for position bias.
|
||||
use_gradient_caching: Enable gradient caching for efficiency.
|
||||
gradient_cache_batch_size: Batch size for gradient caching.
|
||||
eval_group_size: Number of candidates during evaluation.
|
||||
metric_for_best_model: Metric to select best model.
|
||||
eval_normalize_scores: Whether to normalize scores during evaluation.
|
||||
listwise_weight: Loss weight for listwise loss.
|
||||
pairwise_weight: Loss weight for pairwise loss.
|
||||
filter_empty_passages: Filter out empty passages.
|
||||
min_passage_length: Minimum passage length to keep.
|
||||
max_passage_length_ratio: Max ratio between query and passage length.
|
||||
gradient_checkpointing: Enable gradient checkpointing.
|
||||
recompute_scores: Recompute scores instead of storing all.
|
||||
use_dynamic_teacher: Enable dynamic teacher for RocketQAv2-style training.
|
||||
teacher_update_steps: Steps between teacher updates.
|
||||
teacher_momentum: Momentum for teacher updates.
|
||||
"""
|
||||
|
||||
# Model specific
|
||||
model_name_or_path: str = "BAAI/bge-reranker-base"
|
||||
|
||||
# Architecture settings
|
||||
add_pooling_layer: bool = False # Cross-encoder doesn't use pooling
|
||||
use_fp16: bool = True # Use mixed precision by default
|
||||
|
||||
# Training specific
|
||||
train_group_size: int = 16 # Number of passages per query for listwise training
|
||||
loss_type: str = "listwise_ce" # "listwise_ce", "pairwise_margin", "combined"
|
||||
margin: float = 1.0 # For pairwise margin loss
|
||||
temperature: float = 1.0 # For listwise CE loss
|
||||
|
||||
# Knowledge distillation
|
||||
use_knowledge_distillation: bool = False
|
||||
teacher_model_name_or_path: Optional[str] = None
|
||||
distillation_temperature: float = 3.0
|
||||
distillation_alpha: float = 0.7
|
||||
|
||||
# Hard negative strategies
|
||||
use_hard_negatives: bool = True
|
||||
hard_negative_ratio: float = 0.8 # Ratio of hard to random negatives
|
||||
use_denoising: bool = False # Denoising strategy from paper
|
||||
noise_probability: float = 0.1
|
||||
|
||||
# Listwise training specific
|
||||
max_pairs_per_device: int = 128 # Maximum query-passage pairs per device
|
||||
accumulate_gradients_from_all_devices: bool = True
|
||||
|
||||
# Cross-encoder specific settings
|
||||
max_seq_length: int = 512 # Maximum combined query+passage length
|
||||
query_max_length: int = 64 # For separate encoding (if needed)
|
||||
passage_max_length: int = 512 # For separate encoding (if needed)
|
||||
|
||||
# Position bias mitigation
|
||||
shuffle_passages: bool = True # Shuffle passage order during training
|
||||
position_bias_cutoff: int = 10 # Consider position bias for top-k passages
|
||||
|
||||
# Advanced training strategies
|
||||
use_gradient_caching: bool = True # Cache gradients for efficient training
|
||||
gradient_cache_batch_size: int = 32
|
||||
|
||||
# Evaluation specific
|
||||
eval_group_size: int = 100 # Number of candidates during evaluation
|
||||
metric_for_best_model: str = "eval_mrr@10"
|
||||
eval_normalize_scores: bool = True
|
||||
|
||||
# Loss combination weights (for combined loss)
|
||||
listwise_weight: float = 0.7
|
||||
pairwise_weight: float = 0.3
|
||||
|
||||
# Data filtering
|
||||
filter_empty_passages: bool = True
|
||||
min_passage_length: int = 10
|
||||
max_passage_length_ratio: float = 50.0 # Max ratio between query and passage length
|
||||
|
||||
# Memory optimization for large group sizes
|
||||
gradient_checkpointing: bool = True
|
||||
recompute_scores: bool = False # Recompute instead of storing all scores
|
||||
|
||||
# Dynamic teacher (for RocketQAv2-style training)
|
||||
use_dynamic_teacher: bool = False
|
||||
teacher_update_steps: int = 1000
|
||||
teacher_momentum: float = 0.999
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load model-specific parameters from [reranker] section in config.toml"""
|
||||
super().__post_init__()
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
except ImportError:
|
||||
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
||||
config = get_config()
|
||||
self._apply_reranker_config(config)
|
||||
|
||||
def _apply_reranker_config(self, config):
|
||||
"""
|
||||
Apply BGE-Reranker specific configuration from [reranker] section in config.toml.
|
||||
"""
|
||||
reranker_config = config.get_section('reranker')
|
||||
if not reranker_config:
|
||||
logger.warning("No [reranker] section found in config.toml")
|
||||
return
|
||||
|
||||
# Helper to convert tuple from config to list if needed
|
||||
def get_val(key, default):
|
||||
if key not in reranker_config:
|
||||
return default
|
||||
v = reranker_config[key]
|
||||
if isinstance(default, tuple) and isinstance(v, list):
|
||||
return tuple(v)
|
||||
if isinstance(default, list) and isinstance(v, tuple):
|
||||
return list(v)
|
||||
return v
|
||||
|
||||
# Only update if key exists in config
|
||||
if 'add_pooling_layer' in reranker_config:
|
||||
self.add_pooling_layer = reranker_config['add_pooling_layer']
|
||||
if 'use_fp16' in reranker_config:
|
||||
self.use_fp16 = reranker_config['use_fp16']
|
||||
if 'train_group_size' in reranker_config:
|
||||
self.train_group_size = reranker_config['train_group_size']
|
||||
if 'loss_type' in reranker_config:
|
||||
self.loss_type = reranker_config['loss_type']
|
||||
if 'margin' in reranker_config:
|
||||
self.margin = reranker_config['margin']
|
||||
if 'temperature' in reranker_config:
|
||||
self.temperature = reranker_config['temperature']
|
||||
if 'use_knowledge_distillation' in reranker_config:
|
||||
self.use_knowledge_distillation = reranker_config['use_knowledge_distillation']
|
||||
if 'teacher_model_name_or_path' in reranker_config:
|
||||
self.teacher_model_name_or_path = reranker_config['teacher_model_name_or_path']
|
||||
if 'distillation_temperature' in reranker_config:
|
||||
self.distillation_temperature = reranker_config['distillation_temperature']
|
||||
if 'distillation_alpha' in reranker_config:
|
||||
self.distillation_alpha = reranker_config['distillation_alpha']
|
||||
if 'use_hard_negatives' in reranker_config:
|
||||
self.use_hard_negatives = reranker_config['use_hard_negatives']
|
||||
if 'hard_negative_ratio' in reranker_config:
|
||||
self.hard_negative_ratio = reranker_config['hard_negative_ratio']
|
||||
if 'use_denoising' in reranker_config:
|
||||
self.use_denoising = reranker_config['use_denoising']
|
||||
if 'noise_probability' in reranker_config:
|
||||
self.noise_probability = reranker_config['noise_probability']
|
||||
if 'max_pairs_per_device' in reranker_config:
|
||||
self.max_pairs_per_device = reranker_config['max_pairs_per_device']
|
||||
if 'accumulate_gradients_from_all_devices' in reranker_config:
|
||||
self.accumulate_gradients_from_all_devices = reranker_config['accumulate_gradients_from_all_devices']
|
||||
if 'max_seq_length' in reranker_config:
|
||||
self.max_seq_length = reranker_config['max_seq_length']
|
||||
if 'query_max_length' in reranker_config:
|
||||
self.query_max_length = reranker_config['query_max_length']
|
||||
if 'passage_max_length' in reranker_config:
|
||||
self.passage_max_length = reranker_config['passage_max_length']
|
||||
if 'shuffle_passages' in reranker_config:
|
||||
self.shuffle_passages = reranker_config['shuffle_passages']
|
||||
if 'position_bias_cutoff' in reranker_config:
|
||||
self.position_bias_cutoff = reranker_config['position_bias_cutoff']
|
||||
if 'use_gradient_caching' in reranker_config:
|
||||
self.use_gradient_caching = reranker_config['use_gradient_caching']
|
||||
if 'gradient_cache_batch_size' in reranker_config:
|
||||
self.gradient_cache_batch_size = reranker_config['gradient_cache_batch_size']
|
||||
if 'eval_group_size' in reranker_config:
|
||||
self.eval_group_size = reranker_config['eval_group_size']
|
||||
if 'metric_for_best_model' in reranker_config:
|
||||
self.metric_for_best_model = reranker_config['metric_for_best_model']
|
||||
if 'eval_normalize_scores' in reranker_config:
|
||||
self.eval_normalize_scores = reranker_config['eval_normalize_scores']
|
||||
if 'listwise_weight' in reranker_config:
|
||||
self.listwise_weight = reranker_config['listwise_weight']
|
||||
if 'pairwise_weight' in reranker_config:
|
||||
self.pairwise_weight = reranker_config['pairwise_weight']
|
||||
if 'filter_empty_passages' in reranker_config:
|
||||
self.filter_empty_passages = reranker_config['filter_empty_passages']
|
||||
if 'min_passage_length' in reranker_config:
|
||||
self.min_passage_length = reranker_config['min_passage_length']
|
||||
if 'max_passage_length_ratio' in reranker_config:
|
||||
self.max_passage_length_ratio = reranker_config['max_passage_length_ratio']
|
||||
if 'gradient_checkpointing' in reranker_config:
|
||||
self.gradient_checkpointing = reranker_config['gradient_checkpointing']
|
||||
if 'recompute_scores' in reranker_config:
|
||||
self.recompute_scores = reranker_config['recompute_scores']
|
||||
if 'use_dynamic_teacher' in reranker_config:
|
||||
self.use_dynamic_teacher = reranker_config['use_dynamic_teacher']
|
||||
if 'teacher_update_steps' in reranker_config:
|
||||
self.teacher_update_steps = reranker_config['teacher_update_steps']
|
||||
if 'teacher_momentum' in reranker_config:
|
||||
self.teacher_momentum = reranker_config['teacher_momentum']
|
||||
|
||||
# Log parameter overrides
|
||||
logger.info("Loaded BGE-Reranker config from [reranker] section in config.toml")
|
||||
|
||||
# Check for required parameters
|
||||
if self.loss_type not in ["listwise_ce", "pairwise_margin", "combined", "cross_entropy", "binary_cross_entropy"]:
|
||||
raise ValueError("loss_type must be one of ['listwise_ce', 'pairwise_margin', 'combined', 'cross_entropy', 'binary_cross_entropy']")
|
||||
if self.train_group_size is None or self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2 and not None for reranking")
|
||||
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
|
||||
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration settings"""
|
||||
valid_loss_types = ["listwise_ce", "pairwise_margin", "combined"]
|
||||
if self.loss_type not in valid_loss_types:
|
||||
raise ValueError(f"loss_type must be one of {valid_loss_types}")
|
||||
|
||||
if self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2 for reranking")
|
||||
|
||||
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
|
||||
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
|
||||
|
||||
if self.hard_negative_ratio < 0 or self.hard_negative_ratio > 1:
|
||||
raise ValueError("hard_negative_ratio must be between 0 and 1")
|
||||
|
||||
if self.max_seq_length > 512 and "base" in self.model_name_or_path:
|
||||
logger.warning("max_seq_length > 512 may not be optimal for base models")
|
||||
|
||||
if self.gradient_cache_batch_size > self.per_device_train_batch_size:
|
||||
logger.warning("gradient_cache_batch_size > per_device_train_batch_size may not improve efficiency")
|
||||
|
||||
return True
|
||||
85
data/__init__.py
Normal file
85
data/__init__.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Integrated data processing and augmentation module for BGE fine-tuning
|
||||
Now with unified dataset pipeline supporting both embedding and reranker formats
|
||||
"""
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy loading to avoid import errors for missing dependencies"""
|
||||
|
||||
# Core dataset classes - from integrated dataset.py
|
||||
if name == 'BGEDataset':
|
||||
from .dataset import BGEDataset
|
||||
return BGEDataset
|
||||
elif name == 'BGEM3Dataset':
|
||||
from .dataset import BGEM3Dataset
|
||||
return BGEM3Dataset
|
||||
elif name == 'BGERerankerDataset':
|
||||
from .dataset import BGERerankerDataset
|
||||
return BGERerankerDataset
|
||||
elif name == 'JointDataset':
|
||||
from .dataset import JointDataset
|
||||
return JointDataset
|
||||
|
||||
# Factory functions for dataset creation
|
||||
elif name == 'create_embedding_dataset':
|
||||
from .dataset import create_embedding_dataset
|
||||
return create_embedding_dataset
|
||||
elif name == 'create_reranker_dataset':
|
||||
from .dataset import create_reranker_dataset
|
||||
return create_reranker_dataset
|
||||
|
||||
# Data migration utilities
|
||||
elif name == 'migrate_nested_to_flat':
|
||||
from .dataset import migrate_nested_to_flat
|
||||
return migrate_nested_to_flat
|
||||
|
||||
# Collate functions
|
||||
elif name == 'collate_embedding_batch':
|
||||
from .dataset import collate_embedding_batch
|
||||
return collate_embedding_batch
|
||||
elif name == 'collate_reranker_batch':
|
||||
from .dataset import collate_reranker_batch
|
||||
return collate_reranker_batch
|
||||
|
||||
# Data preprocessing and augmentation
|
||||
elif name == 'DataPreprocessor':
|
||||
from .preprocessing import DataPreprocessor
|
||||
return DataPreprocessor
|
||||
elif name == 'DataAugmenter':
|
||||
from .preprocessing import DataAugmenter
|
||||
return DataAugmenter
|
||||
elif name == 'QueryAugmenter':
|
||||
from .augmentation import QueryAugmenter
|
||||
return QueryAugmenter
|
||||
elif name == 'PassageAugmenter':
|
||||
from .augmentation import PassageAugmenter
|
||||
return PassageAugmenter
|
||||
elif name == 'HardNegativeAugmenter':
|
||||
from .augmentation import HardNegativeAugmenter
|
||||
return HardNegativeAugmenter
|
||||
elif name == 'CrossLingualAugmenter':
|
||||
from .augmentation import CrossLingualAugmenter
|
||||
return CrossLingualAugmenter
|
||||
|
||||
# Hard negative mining
|
||||
elif name == 'ANCEHardNegativeMiner':
|
||||
from .hard_negative_mining import ANCEHardNegativeMiner
|
||||
return ANCEHardNegativeMiner
|
||||
elif name == 'HardNegativeDataAugmenter':
|
||||
from .hard_negative_mining import HardNegativeDataAugmenter
|
||||
return HardNegativeDataAugmenter
|
||||
|
||||
# Backward compatibility - remove after migration
|
||||
elif name == 'M3Dataset':
|
||||
from .dataset import BGEM3Dataset
|
||||
return BGEM3Dataset
|
||||
elif name == 'RerankerDataset':
|
||||
from .dataset import BGERerankerDataset
|
||||
return BGERerankerDataset
|
||||
elif name == 'NestedDataset':
|
||||
# Deprecated - use BGEDataset with legacy_support=True
|
||||
from .dataset import BGEDataset
|
||||
return BGEDataset
|
||||
|
||||
else:
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
1307
data/augmentation.py
Normal file
1307
data/augmentation.py
Normal file
File diff suppressed because it is too large
Load Diff
857
data/dataset.py
Normal file
857
data/dataset.py
Normal file
@ -0,0 +1,857 @@
|
||||
"""
|
||||
Unified BGE dataset pipeline with comprehensive format support
|
||||
- Triplets: Query-triplets format (query, pos, neg) for embedding training
|
||||
- Pairs: Flat pairs format (query, passage, label) for reranker training
|
||||
- Candidates: Unified format (query, candidates) with threshold-based label assignment
|
||||
- Legacy: Supports nested formats with automatic conversion to flat pairs
|
||||
- Enhanced: Score-weighted sampling, hard negative mining, multiple positives
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGEDataset(Dataset):
|
||||
"""
|
||||
Unified BGE dataset with automatic format detection and enhanced features
|
||||
Supports: triplets, pairs, candidates, and legacy nested formats
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 512,
|
||||
is_train: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
legacy_support: bool = True,
|
||||
format_type: str = "auto", # "triplets", "pairs", "candidates", "legacy_nested", "auto"
|
||||
candidates_threshold: float = 0.5, # Threshold for candidates format label assignment
|
||||
cache_in_memory: bool = False, # Cache tokenized data in memory for small datasets
|
||||
max_cache_size: int = 100000 # Maximum number of samples to cache in memory
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.query_max_length = query_max_length
|
||||
self.passage_max_length = passage_max_length
|
||||
self.is_train = is_train
|
||||
self.legacy_support = legacy_support
|
||||
self.format_type = format_type
|
||||
self.candidates_threshold = candidates_threshold
|
||||
self.cache_in_memory = cache_in_memory
|
||||
self.max_cache_size = max_cache_size
|
||||
self._tokenized_cache = {}
|
||||
|
||||
logger.info(f"Loading data from {data_path}")
|
||||
self.data, self.detected_format = self._load_and_detect_format(data_path)
|
||||
logger.info(f"Loaded {len(self.data)} examples in {self.detected_format} format")
|
||||
|
||||
# Pre-tokenize data if caching is enabled and dataset is small enough
|
||||
if self.cache_in_memory and len(self.data) <= self.max_cache_size:
|
||||
logger.info(f"Pre-tokenizing {len(self.data)} samples for in-memory cache...")
|
||||
self._pretokenize_all()
|
||||
logger.info("Pre-tokenization complete")
|
||||
|
||||
def _load_and_detect_format(self, data_path: str) -> Tuple[List[Dict], str]:
|
||||
"""Load data and automatically detect format with improved detection logic"""
|
||||
data = []
|
||||
format_type = "unknown"
|
||||
total_lines = 0
|
||||
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading and detecting format"), 1):
|
||||
if line.strip():
|
||||
total_lines += 1
|
||||
try:
|
||||
item = json.loads(line)
|
||||
|
||||
# Detect format on first valid item
|
||||
if format_type == "unknown":
|
||||
format_type = self._detect_schema_type(item)
|
||||
logger.info(f"Detected schema: {format_type}")
|
||||
|
||||
# Validate and process item
|
||||
if self._validate_item(item, format_type):
|
||||
processed_items = self._process_item(item, format_type)
|
||||
# Process_item may return multiple items (for candidates format)
|
||||
if isinstance(processed_items, list):
|
||||
data.extend(processed_items)
|
||||
else:
|
||||
data.append(processed_items)
|
||||
else:
|
||||
logger.warning(f"Invalid item at line {line_num}: {item}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON at line {line_num}: {e}")
|
||||
continue
|
||||
|
||||
if not data:
|
||||
raise ValueError(f"No valid data found in {data_path}")
|
||||
|
||||
# Validation check
|
||||
if total_lines > 10 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(f"Few valid data lines loaded from {data_path}: {len(data)} out of {total_lines}")
|
||||
|
||||
return data, format_type
|
||||
|
||||
def _detect_schema_type(self, item: Dict) -> str:
|
||||
"""
|
||||
Enhanced format detection based on key presence and structure:
|
||||
- triplets: has 'pos' and 'neg' keys (for embedding training)
|
||||
- pairs: has 'passage' and 'label' keys (for reranker training)
|
||||
- candidates: has 'candidates' key with list of candidate objects
|
||||
- legacy_nested: has 'pos' but no 'label' (backward compatibility)
|
||||
"""
|
||||
try:
|
||||
# Use explicitly set format if provided
|
||||
if self.format_type != "auto":
|
||||
return self.format_type
|
||||
|
||||
# Ensure item is a dictionary
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(f"Expected dict but got {type(item)}")
|
||||
|
||||
# Extract available keys for debugging
|
||||
available_keys = set(item.keys())
|
||||
|
||||
# Check for candidates format first (most specific)
|
||||
if 'candidates' in item:
|
||||
# Validate candidates structure
|
||||
candidates = item.get('candidates', [])
|
||||
if isinstance(candidates, list) and len(candidates) > 0:
|
||||
# Check if first candidate has expected structure
|
||||
first_candidate = candidates[0]
|
||||
if isinstance(first_candidate, dict) and 'text' in first_candidate:
|
||||
return "candidates"
|
||||
# Also support simple string list as candidates
|
||||
elif isinstance(first_candidate, str):
|
||||
logger.info("Detected candidates format with string list - will convert to dict format")
|
||||
return "candidates"
|
||||
|
||||
# Check for pairs format (reranker)
|
||||
if 'passage' in item and 'label' in item:
|
||||
# Additional check for query field
|
||||
if 'query' in item:
|
||||
return "pairs"
|
||||
else:
|
||||
logger.warning(f"Found 'passage' and 'label' but missing 'query'. Available keys: {available_keys}")
|
||||
|
||||
# Check for triplets format (embedding)
|
||||
if 'pos' in item:
|
||||
# Standard triplets format has both pos and neg
|
||||
if 'neg' in item:
|
||||
return "triplets"
|
||||
# Might be triplets with only positives
|
||||
elif 'query' in item and not ('label' in item or 'passage' in item):
|
||||
logger.info("Detected triplets format with only positive samples")
|
||||
return "triplets"
|
||||
# Legacy nested format
|
||||
elif self.legacy_support and 'query' in item:
|
||||
logger.warning("Detected legacy nested format. Converting to flat pairs for reranker training.")
|
||||
return "legacy_nested"
|
||||
|
||||
# Try to provide helpful error message
|
||||
common_keys = available_keys & {'query', 'pos', 'neg', 'passage', 'label', 'candidates'}
|
||||
error_msg = f"Unknown data format. Available keys: {available_keys}"
|
||||
if common_keys:
|
||||
error_msg += f", Common keys found: {common_keys}"
|
||||
else:
|
||||
error_msg += ". No common keys found. Expected 'query' with 'pos'/'neg' (triplets), 'passage'/'label' (pairs), or 'candidates' (candidates format)"
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Format detection failed: {e}")
|
||||
raise
|
||||
|
||||
def _validate_item(self, item: Dict, format_type: str) -> bool:
|
||||
"""
|
||||
Validate item based on detected format with comprehensive checks
|
||||
|
||||
Returns:
|
||||
bool: True if item is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Basic type check
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
|
||||
# Format-specific validation
|
||||
if format_type == "triplets":
|
||||
# Check required fields
|
||||
if not all(field in item for field in ['query', 'pos']):
|
||||
return False
|
||||
# Validate query
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
# Validate positives
|
||||
if not isinstance(item.get('pos', []), list) or len(item['pos']) == 0:
|
||||
return False
|
||||
# Check each positive is a non-empty string
|
||||
for pos in item['pos']:
|
||||
if not isinstance(pos, str) or not pos.strip():
|
||||
return False
|
||||
# Validate negatives if present
|
||||
if 'neg' in item:
|
||||
if not isinstance(item['neg'], list):
|
||||
return False
|
||||
for neg in item['neg']:
|
||||
if not isinstance(neg, str) or not neg.strip():
|
||||
return False
|
||||
return True
|
||||
|
||||
elif format_type == "pairs":
|
||||
# Check required fields
|
||||
if not all(field in item for field in ['query', 'passage', 'label']):
|
||||
return False
|
||||
# Validate query and passage
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
if not isinstance(item['passage'], str) or not item['passage'].strip():
|
||||
return False
|
||||
# Validate label (should be numeric 0 or 1, or float between 0 and 1)
|
||||
label = item['label']
|
||||
if isinstance(label, (int, float)):
|
||||
if not (0 <= label <= 1):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
elif format_type == "candidates":
|
||||
# Check required fields
|
||||
if not ('query' in item and 'candidates' in item):
|
||||
return False
|
||||
# Validate query
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
# Validate candidates
|
||||
if not isinstance(item['candidates'], list) or len(item['candidates']) == 0:
|
||||
return False
|
||||
# Check each candidate
|
||||
for i, candidate in enumerate(item['candidates']):
|
||||
if not isinstance(candidate, dict):
|
||||
return False
|
||||
if 'text' not in candidate:
|
||||
return False
|
||||
if not isinstance(candidate['text'], str) or not candidate['text'].strip():
|
||||
return False
|
||||
# Validate score if present
|
||||
if 'score' in candidate:
|
||||
score = candidate['score']
|
||||
if not isinstance(score, (int, float)):
|
||||
return False
|
||||
return True
|
||||
|
||||
elif format_type == "legacy_nested":
|
||||
# Basic validation for legacy format
|
||||
if not ('query' in item and 'pos' in item):
|
||||
return False
|
||||
if not isinstance(item['query'], str) or not item['query'].strip():
|
||||
return False
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Validation error: {e}")
|
||||
return False
|
||||
|
||||
def _process_item(self, item: Dict, format_type: str) -> Union[Dict, List[Dict]]:
|
||||
"""Process item based on format type - may return multiple items for candidates"""
|
||||
if format_type == "triplets":
|
||||
return self._process_triplets_item(item)
|
||||
elif format_type == "pairs":
|
||||
return self._process_pairs_item(item)
|
||||
elif format_type == "candidates":
|
||||
return self._process_candidates_item(item)
|
||||
elif format_type == "legacy_nested":
|
||||
return self._process_legacy_nested_item(item)
|
||||
else:
|
||||
raise ValueError(f"Unknown format type: {format_type}")
|
||||
|
||||
def _process_triplets_item(self, item: Dict) -> Dict:
|
||||
"""Process triplets format item for embedding training"""
|
||||
processed = {
|
||||
'query': item['query'],
|
||||
'pos': item.get('pos', []),
|
||||
'neg': item.get('neg', []),
|
||||
'pos_scores': item.get('pos_scores', [1.0] * len(item.get('pos', []))),
|
||||
'neg_scores': item.get('neg_scores', [0.0] * len(item.get('neg', []))),
|
||||
'prompt': item.get('prompt', ''),
|
||||
'type': item.get('type', 'normal'),
|
||||
'format': 'triplets'
|
||||
}
|
||||
return processed
|
||||
|
||||
def _process_pairs_item(self, item: Dict) -> Dict:
|
||||
"""Process pairs format item for reranker training"""
|
||||
processed = {
|
||||
'query': item['query'],
|
||||
'passage': item['passage'],
|
||||
'label': item['label'],
|
||||
'score': item.get('score', 1.0 if item['label'] > 0 else 0.0),
|
||||
'qid': item.get('qid', ''),
|
||||
'pid': item.get('pid', ''),
|
||||
'format': 'pairs'
|
||||
}
|
||||
return processed
|
||||
|
||||
def _process_candidates_item(self, item: Dict) -> List[Dict]:
|
||||
"""
|
||||
Process candidates format item - explode into multiple pairs
|
||||
Each candidate becomes a separate (query, passage, label) pair
|
||||
Label is assigned based on score threshold
|
||||
"""
|
||||
query = item['query']
|
||||
candidates = item['candidates']
|
||||
qid = item.get('qid', '')
|
||||
|
||||
pairs = []
|
||||
for i, candidate in enumerate(candidates):
|
||||
if isinstance(candidate, dict):
|
||||
text = candidate.get('text', '')
|
||||
score = candidate.get('score', 0.0)
|
||||
pid = candidate.get('pid', f'cand_{i}')
|
||||
else:
|
||||
# Handle simple string candidates
|
||||
text = str(candidate)
|
||||
score = 1.0 # Default to positive if no score
|
||||
pid = f'cand_{i}'
|
||||
|
||||
# Assign label based on threshold
|
||||
label = 1 if score >= self.candidates_threshold else 0
|
||||
|
||||
pair = {
|
||||
'query': query,
|
||||
'passage': text,
|
||||
'label': label,
|
||||
'score': score,
|
||||
'qid': qid,
|
||||
'pid': pid,
|
||||
'format': 'pairs', # Convert to pairs format
|
||||
'source': 'candidates' # Track original source
|
||||
}
|
||||
pairs.append(pair)
|
||||
|
||||
return pairs
|
||||
|
||||
def _process_legacy_nested_item(self, item: Dict) -> List[Dict]:
|
||||
"""
|
||||
Process legacy nested format - convert to flat pairs for reranker training
|
||||
Each pos[i] becomes (query, passage, label=1)
|
||||
Each neg[j] becomes (query, passage, label=0)
|
||||
"""
|
||||
query = item['query']
|
||||
positives = item.get('pos', [])
|
||||
negatives = item.get('neg', [])
|
||||
pos_scores = item.get('pos_scores', [1.0] * len(positives))
|
||||
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
|
||||
qid = item.get('qid', '')
|
||||
|
||||
pairs = []
|
||||
|
||||
# Convert positives
|
||||
for i, passage in enumerate(positives):
|
||||
score = pos_scores[i] if i < len(pos_scores) else 1.0
|
||||
pair = {
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 1,
|
||||
'score': score,
|
||||
'qid': qid,
|
||||
'pid': f'pos_{i}',
|
||||
'format': 'pairs', # Convert to pairs format
|
||||
'source': 'legacy_nested' # Track original source
|
||||
}
|
||||
pairs.append(pair)
|
||||
|
||||
# Convert negatives
|
||||
for i, passage in enumerate(negatives):
|
||||
score = neg_scores[i] if i < len(neg_scores) else 0.0
|
||||
pair = {
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 0,
|
||||
'score': score,
|
||||
'qid': qid,
|
||||
'pid': f'neg_{i}',
|
||||
'format': 'pairs', # Convert to pairs format
|
||||
'source': 'legacy_nested' # Track original source
|
||||
}
|
||||
pairs.append(pair)
|
||||
|
||||
return pairs
|
||||
|
||||
def _pretokenize_all(self):
|
||||
"""Pre-tokenize all samples for in-memory caching"""
|
||||
from tqdm import tqdm
|
||||
for idx in tqdm(range(len(self.data)), desc="Pre-tokenizing"):
|
||||
# Tokenize based on format
|
||||
item = self.data[idx]
|
||||
if item['format'] == 'triplets':
|
||||
# Pre-tokenize query and passages for triplets
|
||||
query = item['query']
|
||||
query_encoding = self.tokenizer(
|
||||
query,
|
||||
max_length=self.query_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
self._tokenized_cache[f"{idx}_query"] = {
|
||||
'input_ids': query_encoding.input_ids.squeeze(0),
|
||||
'attention_mask': query_encoding.attention_mask.squeeze(0)
|
||||
}
|
||||
elif item['format'] == 'pairs':
|
||||
# Pre-tokenize query-passage pair
|
||||
query = item['query']
|
||||
passage = item['passage']
|
||||
pair_text = f"{query} [SEP] {passage}"
|
||||
encoding = self.tokenizer(
|
||||
pair_text,
|
||||
max_length=self.query_max_length + self.passage_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
self._tokenized_cache[idx] = {
|
||||
'input_ids': encoding.input_ids.squeeze(0),
|
||||
'attention_mask': encoding.attention_mask.squeeze(0)
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""Get item based on detected format"""
|
||||
item = self.data[idx]
|
||||
|
||||
if item['format'] == 'triplets':
|
||||
return self._get_triplets_item(item, idx)
|
||||
elif item['format'] == 'pairs':
|
||||
return self._get_pairs_item(item, idx)
|
||||
else:
|
||||
raise ValueError(f"Unknown format: {item['format']}")
|
||||
|
||||
def _get_triplets_item(self, item: Dict, idx: int) -> Dict:
|
||||
"""Get triplets training example for embedding model (contrastive learning)"""
|
||||
query = item['query']
|
||||
positives = item['pos']
|
||||
negatives = item['neg']
|
||||
pos_scores = item['pos_scores']
|
||||
neg_scores = item['neg_scores']
|
||||
prompt = item['prompt']
|
||||
|
||||
# Apply prompt if provided
|
||||
if prompt and self.is_train:
|
||||
query = f"{prompt}{query}"
|
||||
|
||||
# Check if we have cached tokenized query
|
||||
if f"{idx}_query" in self._tokenized_cache:
|
||||
query_encodings = self._tokenized_cache[f"{idx}_query"]
|
||||
else:
|
||||
# Tokenize query
|
||||
query_encodings = self.tokenizer(
|
||||
query,
|
||||
max_length=self.query_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
query_encodings = {
|
||||
'input_ids': query_encodings.input_ids.squeeze(0),
|
||||
'attention_mask': query_encodings.attention_mask.squeeze(0)
|
||||
}
|
||||
|
||||
# Enhanced sampling for contrastive learning
|
||||
if self.is_train:
|
||||
# Sample multiple positives for richer contrastive signals
|
||||
num_positives = min(2, len(positives))
|
||||
sampled_positives = random.sample(positives, num_positives) if positives else []
|
||||
|
||||
# Hard negative mining with score-based sampling
|
||||
num_negatives = min(6, len(negatives))
|
||||
if negatives and neg_scores:
|
||||
# Weight by scores for hard negative selection
|
||||
neg_weights = np.array(neg_scores)
|
||||
neg_weights = neg_weights / neg_weights.sum() if neg_weights.sum() > 0 else np.ones_like(neg_weights)
|
||||
sampled_neg_indices = np.random.choice(
|
||||
len(negatives), size=min(num_negatives, len(negatives)),
|
||||
replace=False, p=neg_weights
|
||||
)
|
||||
sampled_negatives = [negatives[i] for i in sampled_neg_indices]
|
||||
sampled_neg_scores = [neg_scores[i] for i in sampled_neg_indices]
|
||||
else:
|
||||
sampled_negatives = random.sample(negatives, num_negatives) if negatives else []
|
||||
sampled_neg_scores = [0.0] * len(sampled_negatives)
|
||||
|
||||
passages = sampled_positives + sampled_negatives
|
||||
labels = [1] * len(sampled_positives) + [0] * len(sampled_negatives)
|
||||
scores = pos_scores[:len(sampled_positives)] + sampled_neg_scores
|
||||
else:
|
||||
# Use all for evaluation
|
||||
passages = positives + negatives
|
||||
labels = [1] * len(positives) + [0] * len(negatives)
|
||||
scores = pos_scores + neg_scores
|
||||
|
||||
# Tokenize passages
|
||||
passage_encodings = []
|
||||
for passage in passages:
|
||||
encoding = self.tokenizer(
|
||||
passage,
|
||||
max_length=self.passage_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
passage_encodings.append({
|
||||
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
|
||||
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0)
|
||||
})
|
||||
|
||||
# Pad to consistent size (max 8 passages for triplets)
|
||||
max_passages = 8
|
||||
while len(passage_encodings) < max_passages:
|
||||
passage_encodings.append({
|
||||
'input_ids': torch.zeros(self.passage_max_length, dtype=torch.long),
|
||||
'attention_mask': torch.zeros(self.passage_max_length, dtype=torch.long)
|
||||
})
|
||||
labels.append(-1)
|
||||
scores.append(0.0)
|
||||
|
||||
# Stack tensors
|
||||
passage_input_ids = torch.stack([p['input_ids'] for p in passage_encodings])
|
||||
passage_attention_mask = torch.stack([p['attention_mask'] for p in passage_encodings])
|
||||
|
||||
return {
|
||||
'query_input_ids': query_encodings['input_ids'],
|
||||
'query_attention_mask': query_encodings['attention_mask'],
|
||||
'passage_input_ids': passage_input_ids,
|
||||
'passage_attention_mask': passage_attention_mask,
|
||||
'labels': torch.tensor(labels, dtype=torch.long),
|
||||
'scores': torch.tensor(scores, dtype=torch.float),
|
||||
'query_text': query,
|
||||
'passage_texts': passages + [''] * (max_passages - len(passages)),
|
||||
'item_idx': idx,
|
||||
'format': 'triplets',
|
||||
# Backward compatibility
|
||||
'positives': positives,
|
||||
'negatives': negatives
|
||||
}
|
||||
|
||||
def _get_pairs_item(self, item: Dict, idx: int) -> Dict:
|
||||
"""Get pairs training example for reranker (cross-encoder training)"""
|
||||
query = item['query']
|
||||
passage = item['passage']
|
||||
label = item['label']
|
||||
score = item['score']
|
||||
|
||||
# Check if we have cached tokenized data
|
||||
if idx in self._tokenized_cache:
|
||||
encoding = self._tokenized_cache[idx]
|
||||
return {
|
||||
'input_ids': encoding['input_ids'],
|
||||
'attention_mask': encoding['attention_mask'],
|
||||
'labels': torch.tensor(label, dtype=torch.long),
|
||||
'scores': torch.tensor(score, dtype=torch.float),
|
||||
'query_text': query,
|
||||
'passage_text': passage,
|
||||
'qid': item.get('qid', ''),
|
||||
'pid': item.get('pid', ''),
|
||||
'item_idx': idx,
|
||||
'format': 'pairs',
|
||||
'source': item.get('source', 'original')
|
||||
}
|
||||
|
||||
# Tokenize query-passage pair if not cached
|
||||
encoding = self.tokenizer(
|
||||
query,
|
||||
passage,
|
||||
max_length=self.passage_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
return {
|
||||
'input_ids': torch.as_tensor(encoding['input_ids']).squeeze(0),
|
||||
'attention_mask': torch.as_tensor(encoding['attention_mask']).squeeze(0),
|
||||
'labels': torch.tensor(label, dtype=torch.long),
|
||||
'scores': torch.tensor(score, dtype=torch.float),
|
||||
'query_text': query,
|
||||
'passage_text': passage,
|
||||
'qid': item['qid'],
|
||||
'pid': item['pid'],
|
||||
'item_idx': idx,
|
||||
'format': 'pairs',
|
||||
'source': item.get('source', 'original')
|
||||
}
|
||||
|
||||
|
||||
# Specialized dataset classes for backward compatibility
|
||||
class BGEM3Dataset(BGEDataset):
|
||||
"""BGE-M3 specialized dataset (embedding model) with enhanced contrastive learning"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 512,
|
||||
is_train: bool = True,
|
||||
train_group_size: int = 8,
|
||||
cache_dir: Optional[str] = None
|
||||
):
|
||||
# Force embedding format
|
||||
super().__init__(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
format_type="triplets"
|
||||
)
|
||||
self.train_group_size = train_group_size
|
||||
|
||||
# Warn if wrong format detected
|
||||
if self.detected_format not in ['triplets']:
|
||||
logger.warning(f"Expected triplets format for BGEM3Dataset, got {self.detected_format}")
|
||||
|
||||
|
||||
class BGERerankerDataset(BGEDataset):
|
||||
"""BGE-Reranker specialized dataset (cross-encoder model) with flat pairs support"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 448,
|
||||
is_train: bool = True,
|
||||
train_group_size: int = 16,
|
||||
cache_dir: Optional[str] = None,
|
||||
legacy_support: bool = True
|
||||
):
|
||||
# Force reranker format (or allow legacy)
|
||||
format_type = "auto" if legacy_support else "pairs"
|
||||
super().__init__(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
legacy_support=legacy_support,
|
||||
format_type=format_type
|
||||
)
|
||||
self.train_group_size = train_group_size
|
||||
|
||||
# Warn if wrong format detected
|
||||
if self.detected_format not in ['pairs', 'legacy_nested']:
|
||||
logger.warning(f"Expected pairs format for BGERerankerDataset, got {self.detected_format}")
|
||||
|
||||
|
||||
class JointDataset(Dataset):
|
||||
"""Dataset for joint training of retriever and reranker"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
retriever_config: Dict,
|
||||
reranker_config: Dict,
|
||||
is_train: bool = True,
|
||||
cache_dir: Optional[str] = None
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.is_train = is_train
|
||||
|
||||
# Create both datasets
|
||||
self.retriever_dataset = BGEM3Dataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
**retriever_config
|
||||
)
|
||||
|
||||
self.reranker_dataset = BGERerankerDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=is_train,
|
||||
cache_dir=cache_dir,
|
||||
**reranker_config
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.retriever_dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""Get example for both retriever and reranker"""
|
||||
retriever_item = self.retriever_dataset[idx]
|
||||
reranker_item = self.reranker_dataset[idx]
|
||||
|
||||
return {
|
||||
'retriever': retriever_item,
|
||||
'reranker': reranker_item,
|
||||
'item_idx': idx
|
||||
}
|
||||
|
||||
|
||||
# Factory functions for easy dataset creation
|
||||
def create_embedding_dataset(
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 512,
|
||||
is_train: bool = True
|
||||
) -> BGEDataset:
|
||||
"""Create dataset specifically for embedding model training"""
|
||||
dataset = BGEDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
format_type="triplets"
|
||||
)
|
||||
|
||||
if dataset.detected_format not in ['triplets']:
|
||||
logger.warning(f"Expected triplets format, got {dataset.detected_format}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def create_reranker_dataset(
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query_max_length: int = 64,
|
||||
passage_max_length: int = 448,
|
||||
is_train: bool = True,
|
||||
legacy_support: bool = True
|
||||
) -> BGEDataset:
|
||||
"""Create dataset specifically for reranker model training"""
|
||||
dataset = BGEDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=query_max_length,
|
||||
passage_max_length=passage_max_length,
|
||||
is_train=is_train,
|
||||
legacy_support=legacy_support,
|
||||
format_type="auto" if legacy_support else "pairs"
|
||||
)
|
||||
|
||||
if dataset.detected_format not in ['pairs', 'legacy_nested']:
|
||||
logger.warning(f"Expected pairs format, got {dataset.detected_format}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def migrate_nested_to_flat(input_file: str, output_file: str) -> str:
|
||||
"""
|
||||
Migrate legacy nested reranker JSONL to flat pairs format
|
||||
Explodes each (pos[i], label=1) and (neg[j], label=0) into separate lines
|
||||
"""
|
||||
logger.info(f"Migrating nested reranker format: {input_file} -> {output_file}")
|
||||
|
||||
flat_pairs = []
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line_no, line in enumerate(tqdm(f, desc="Converting to flat pairs"), 1):
|
||||
if line.strip():
|
||||
try:
|
||||
item = json.loads(line)
|
||||
query = item['query']
|
||||
positives = item.get('pos', [])
|
||||
negatives = item.get('neg', [])
|
||||
pos_scores = item.get('pos_scores', [1.0] * len(positives))
|
||||
neg_scores = item.get('neg_scores', [0.0] * len(negatives))
|
||||
|
||||
# Create flat pairs for positives
|
||||
for i, passage in enumerate(positives):
|
||||
flat_pairs.append({
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 1,
|
||||
'score': pos_scores[i] if i < len(pos_scores) else 1.0,
|
||||
'qid': f"Q{line_no}",
|
||||
'pid': f"P{line_no}_{i}"
|
||||
})
|
||||
|
||||
# Create flat pairs for negatives
|
||||
for i, passage in enumerate(negatives):
|
||||
flat_pairs.append({
|
||||
'query': query,
|
||||
'passage': passage,
|
||||
'label': 0,
|
||||
'score': neg_scores[i] if i < len(neg_scores) else 0.0,
|
||||
'qid': f"Q{line_no}",
|
||||
'pid': f"N{line_no}_{i}"
|
||||
})
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON on line {line_no}: {e}")
|
||||
continue
|
||||
|
||||
# Save flat pairs
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for pair in flat_pairs:
|
||||
f.write(json.dumps(pair, ensure_ascii=False) + '\n')
|
||||
|
||||
logger.info(f"✅ Migrated {len(flat_pairs)} flat pairs to {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
# Collate functions for different formats
|
||||
def collate_embedding_batch(batch: List[Dict]) -> Dict:
|
||||
"""Collate function for embedding model batches"""
|
||||
return {
|
||||
'query_input_ids': torch.stack([item['query_input_ids'] for item in batch]),
|
||||
'query_attention_mask': torch.stack([item['query_attention_mask'] for item in batch]),
|
||||
'passage_input_ids': torch.stack([item['passage_input_ids'] for item in batch]),
|
||||
'passage_attention_mask': torch.stack([item['passage_attention_mask'] for item in batch]),
|
||||
'labels': torch.stack([item['labels'] for item in batch]),
|
||||
'scores': torch.stack([item['scores'] for item in batch]),
|
||||
'query_text': [item['query_text'] for item in batch],
|
||||
'passage_texts': [item['passage_texts'] for item in batch],
|
||||
'item_idx': [item['item_idx'] for item in batch]
|
||||
}
|
||||
|
||||
|
||||
def collate_reranker_batch(batch: List[Dict]) -> Dict:
|
||||
"""Collate function for reranker model batches"""
|
||||
if batch[0]['format'] == 'pairs':
|
||||
# Flat pairs format
|
||||
return {
|
||||
'input_ids': torch.stack([item['input_ids'] for item in batch]),
|
||||
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
|
||||
'labels': torch.stack([item['labels'] for item in batch]),
|
||||
'scores': torch.stack([item['scores'] for item in batch]),
|
||||
'query_text': [item['query_text'] for item in batch],
|
||||
'passage_text': [item['passage_text'] for item in batch],
|
||||
'qid': [item['qid'] for item in batch],
|
||||
'pid': [item['pid'] for item in batch],
|
||||
'item_idx': [item['item_idx'] for item in batch]
|
||||
}
|
||||
else:
|
||||
# Legacy nested format
|
||||
return {
|
||||
'input_ids': torch.stack([item['input_ids'] for item in batch]),
|
||||
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
|
||||
'labels': torch.stack([item['labels'] for item in batch]),
|
||||
'query_text': [item['query_text'] for item in batch],
|
||||
'passage_texts': [item['passage_texts'] for item in batch],
|
||||
'item_idx': [item['item_idx'] for item in batch]
|
||||
}
|
||||
4316
data/datasets/Reranker_AFQMC/dev.json
Normal file
4316
data/datasets/Reranker_AFQMC/dev.json
Normal file
File diff suppressed because it is too large
Load Diff
3861
data/datasets/Reranker_AFQMC/test.json
Normal file
3861
data/datasets/Reranker_AFQMC/test.json
Normal file
File diff suppressed because it is too large
Load Diff
34334
data/datasets/Reranker_AFQMC/train.json
Normal file
34334
data/datasets/Reranker_AFQMC/train.json
Normal file
File diff suppressed because it is too large
Load Diff
400
data/datasets/Reranker_AFQMC/train_triplets.jsonl
Normal file
400
data/datasets/Reranker_AFQMC/train_triplets.jsonl
Normal file
@ -0,0 +1,400 @@
|
||||
{"query": "花呗账单金额和还款金额不一致", "pos": ["我的花呗账单是***,还款怎么是***", "我的花呗,月结出来说让我还***元,我自己算了一下详细名单我应该还***元"], "neg": ["蚂蚁借呗还款时,为什么会自动扣款", "为什么我现在花呗额度为负", "花呗什么情况不能使用", "申请借呗界面审核多久", "借呗分期利息怎样算", "蚂蚁借呗不按期还款什么处罚"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.12, 0.135, 0.076, 0.053, 0.062], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "支付宝找不到花呗入口", "pos": ["支付宝系统点我的里面没有花呗这一项", "我下载支付宝怎么没有花呗的"], "neg": ["蚂蚁借呗必须使用本人名下储蓄卡才能借款吗", "花呗,已经还款,怎么还有显示没还款", "我的花呗刷脸为什么不成功", "几个小时借呗能到账", "今天用了花呗是下个月才还款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.071, 0.061, 0.163, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗支持的商品", "pos": ["花呗标示的商品", "花呗识别的商品"], "neg": ["蚂蚁借呗这么没有了", "借呗还款单笔支付宝限额", "花呗自动扣款是在哪里扣", "花呗授权怎么设置", "有蚂蚁借呗怎么没额度", "蚂蚁借呗只能选择***个月的期限"], "pos_scores": [1.0, 0.95], "neg_scores": [0.103, 0.087, 0.198, 0.102, 0.2, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗超额消费或负额度的后果", "pos": ["花呗消费超过额度有什么影响吗", "花呗额度成负数有啥影响吗"], "neg": ["为什么我的蚂蚁花呗无法为游戏充值了", "在么花呗分期后可以一次性还么", "花呗可以通过哪些途径消费", "花呗分期还款利息如何算", "使用花呗分期后退货", "花呗和借呗有什么事区别", "花呗是否有手续费和利息", "蚂蚁借呗可以还***期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.112, 0.058, 0.09, 0.165, 0.06, 0.096, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还清后仍显示待还款", "pos": ["还款还清了,为什么花呗账单显示还要还款", "花呗全额还清怎么显示没有还款"], "neg": ["我的借呗关闭了是什么情况", "花呗交易记录被删怎么退款", "小蓝单车的押金用的花呗", "花呗 还好了", "花呗还款逾期***天要利息不", "花呗退款后还款额未变"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.079, 0.189, 0.087, 0.097, 0.148], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗评估周期", "pos": ["借呗一般是多久评估", "大概多久评估一次借呗条件"], "neg": ["如何取消借呗自动还款", "花呗可以在银行卡扣款不", "芝麻信用是不是借呗信用额度", "花呗如果逾期几天可以吗", "之前的支付宝花呗 我已经关闭了", "你能帮查下我的花呗吗?点不进去的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.076, 0.134, 0.188, 0.095, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗无法使用原因", "pos": ["么,这个是不是我现在不能不能用花呗", "我的花呗不能用了是吧"], "neg": ["借呗会影响证信吗", "借呗分期还款可以在第一个月还清吗", "花呗预留号码", "花呗被冻结是否能帮忙解开", "怎么改变花呗扣款日期", "我的花呗信用卡收钱码跟我支付宝收钱码一样吗", "我用花呗东西的话,忘记还了,他会自动从余额宝里面扣吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.099, 0.17, 0.133, 0.123, 0.081, 0.138, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗客服电话", "pos": ["把你们花呗的电话给我", "给花呗打电话"], "neg": ["花呗是使用多少还多少还是还额度", "花呗的账单我已还为什么上边显示还款的金额没有变", "借呗提前还款可以提升信用分吗", "花呗分期退款为什么少钱", "余额宝有钱为什么蚂蚁借呗到还款期不扣钱", "花呗算是借钱,这个月借下个月还吗", "花呗是商家开通还是卖家", "商家花呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.184, 0.08, 0.19, 0.071, 0.14, 0.095, 0.124, 0.188], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "不使用花呗是否需要还款", "pos": ["要是不用花呗 还需要还吗", "不用花呗里面的钱 是不是就可以不还款了"], "neg": ["小号不开花呗的情况下,客人能不能用花呗支付", "花呗额度*** 怎么在淘宝里使用", "我没取消授权,我的借贝怎么不见了", "花呗已还款,退款又打回花呗", "花呗什么时候邀请开通", "网商银行还借呗", "每次使用花呗后有一个抽奖的页面在哪里"], "pos_scores": [1.0, 0.95], "neg_scores": [0.183, 0.113, 0.084, 0.139, 0.089, 0.171, 0.131], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "交易不支持花呗付款原因", "pos": ["当前交易不支持花呗付款怎么回事", "该付款不支持花呗是什么意思"], "neg": ["花呗怎么删除账单", "花呗有逾期记录了,还能使用吗", "花呗***号还款逾期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.143, 0.143], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期开始时间", "pos": ["花呗分期是从什么时候开始", "花呗分期的钱什么时候扣除"], "neg": ["开通花呗不用有什么费用", "花呗上月只还最低,本月提前还款,利息怎么算", "***w额度免费提现是跟蚂蚁借呗一样吗", "花呗收款,别人怎么付不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.063, 0.085, 0.183, 0.192], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗开通额度条件", "pos": ["借呗信用到多少可以借款", "达到多少额度可以申请借呗"], "neg": ["我怎么能暂时关闭花呗", "花呗分期付款期间还可以涨额度吗", "怎么关 花呗", "花呗超出部分怎么还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.111, 0.093, 0.168], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗借款未到账", "pos": ["昨天晚上借呗借了***,怎么还没到账", "申请的蚂蚁借呗怎么还没到账了"], "neg": ["我现在不知道怎么用上花呗", "为何开通了花呗,支付宝无显示", "花呗分期提前还完后还可分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.082, 0.104, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "网商贷转借呗方法", "pos": ["网商贷换回借呗贷……晕", "网商贷要怎么换借呗"], "neg": ["本款衣服有标识花呗支付,为什么不能", "花呗欠款逾期,我现在花呗被关闭了,要怎么还款", "花呗***分***期每月给多少钱", "第一次认证花呗没有成功 要多久又才可以认证", "为什么说我 花呗操作失误", "我想知道花呗临时额度多少", "借呗必须还了债才能再借吗", "我的花呗已经全部还款了,什么时候可以恢复"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.173, 0.06, 0.163, 0.056, 0.141, 0.094, 0.163], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款后无法使用", "pos": ["账单都还清楚了!怎么花呗不能用了", "为什么我的花呗还款后还不能用"], "neg": ["注意主动申请蚂蚁借呗", "我打算在淘宝上买,能用花呗吗", "需求人工客服", "花呗能不能转账到另一个账号的", "花呗充值到支付宝账号", "花呗不支持充游戏吗", "花呗上的临时额度是什么", "为什么花呗里面提醒我打开允许读取联系人我打开了允许读取联系人以后还不行"], "pos_scores": [1.0, 0.95], "neg_scores": [0.093, 0.176, 0.063, 0.107, 0.067, 0.071, 0.125, 0.176], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗消费记录查询", "pos": ["我能查看用花呗买什么东西吗", "花呗都能购买那些东西"], "neg": ["我的有花呗额度挺多", "蚂蚁借呗怎么关闭借款功能了", "开通了花呗信用卡怎么还不能收款", "借呗日利率怎么制定的", "我这个花呗怎么退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.059, 0.105, 0.096, 0.084, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期提前还款规则", "pos": ["花呗分期提前还完不可以吗", "花呗分期能提前全部还清"], "neg": ["借呗还款日期可以根改吗", "子账户花呗怎么取消", "如何能够快速的使用借呗", "我花呗可以给我开通了吗", "为什么二维码收款用花呗的才几百元", "为什么我的花呗还了还有账单", "蚂蚁借呗当前操作可在风险"], "pos_scores": [1.0, 0.95], "neg_scores": [0.074, 0.059, 0.196, 0.143, 0.176, 0.059, 0.17], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗天猫券领取位置", "pos": ["花呗抽天猫券放哪了", "花呗抽了***元天猫抵用券到哪去了"], "neg": ["花呗的消费在哪里能查得到", "花呗额度怎么涨的", "邮政储蓄卡不能使用借呗", "花呗分期付款有利息吗", "找不到蚂蚁借呗还款记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.085, 0.115, 0.077, 0.188], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗自动扣款失败", "pos": ["借呗借呗,今天还没扣款", "借呗还款金额未扣"], "neg": ["蚂蚁花呗额度查询", "我的花呗不能用是怎么回事", "为什么我的蚂蚁借呗分期还款,现在换不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.087, 0.099, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗支付产生蚂蚁森林能量", "pos": ["用蚂蚁花呗缴费产生能量球吗", "使用花呗能收集蚂蚁深林能量不"], "neg": ["我这花呗为什么冻结了", "主申请蚂蚁借呗", "如果退花呗的压金", "我蚂蚁借呗什么时候还钱", "不是淘宝绑定的支付宝可以用花呗么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.145, 0.185, 0.122, 0.15], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗绑定手机号不一致", "pos": ["我点花呗,怎么显示之前绑定手机号", "花呗手机号和绑定手机号不一致"], "neg": ["商家收款,花呗有手续费吗", "花呗提前结清了,但是还有未还进去的,怎么把他还进去", "我想知道怎么还花呗的款", "借呗还款怎么还没扣", "花呗额度为负值还能继续使用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.156, 0.14, 0.158, 0.183, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗充值话费入口", "pos": ["花呗充话费在哪个界面", "在那里用花呗充值话费"], "neg": ["关闭花呗多久能在开通", "我的押金是花呗教的,退出来之后为什么还要还花呗的***", "买东西要用花呗分期付款,为什么提示额度不够用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.131, 0.095], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "加油卡能否用花呗支付", "pos": ["油卡充值可以使用花呗吗", "加油卡花呗可以不"], "neg": ["如何看我的蚂蚁借呗账单", "我想查花呗怎么查", "用蚂蚁借呗借了钱怎么还", "蚂蚁借呗我想推迟几天还款", "我的花呗现在自己用不了了", "借呗还款自动扣还是", "花呗可以微信还款不可以"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.106, 0.165, 0.057, 0.178, 0.132, 0.064], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "无法开通花呗解决方法", "pos": ["说暂时无法开通花呗", "我开通不了花呗"], "neg": ["蚂蚁花呗体验卡卷在哪", "该商户不支持花呗支付 怎么办", "花呗最多分期几次", "花呗账单钱对不上", "和花呗付款后蚂蚁森林怎么不产生能量球"], "pos_scores": [1.0, 0.95], "neg_scores": [0.145, 0.137, 0.161, 0.171, 0.158], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭花呗收款服务", "pos": ["花呗收钱钱码关闭", "如何关闭信用卡和花呗收款服务"], "neg": ["我的花呗什么时候才能再开通", "我是用花呗交的话费,是不是要等自己还了才能退", "我要花呗退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.095, 0.199], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗使用教程", "pos": ["花呗是怎么一回事吗", "花呗怎么玩吗"], "neg": ["蚂蚁借呗可以借支吗", "花呗额度支付不够", "我那花呗还了,什么又显示还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.2, 0.159, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗是否产生蚂蚁森林能量", "pos": ["用花呗 蚂蚁森林能产生能量吗", "花呗会产生能量吗"], "neg": ["不知道花呗有什么好处", "我的蚂蚁借呗怎么没有了", "开通了花呗不使用要扣钱吗", "过了***点怎么借呗还款", "花呗总共还有多少负债"], "pos_scores": [1.0, 0.95], "neg_scores": [0.144, 0.095, 0.185, 0.068, 0.085], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗自动还款说明", "pos": ["自动还款 花呗是什么", "这花呗自动还款怎么回事"], "neg": ["为什么暂时开通不了花呗", "两个账号只有一个能使用花呗", "借呗可以就是后面一点还吗拖点时间", "蚂蚁借呗提额要不要输密码"], "pos_scores": [1.0, 0.95], "neg_scores": [0.18, 0.086, 0.074, 0.122], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通余额宝代扣花呗", "pos": ["开通余额宝代扣功能还款蚂蚁花呗", "蚂蚁花呗余额宝代扣协议怎么开通"], "neg": ["为什么我的借呗不能提高额度", "芝麻分要多少才可以用花呗", "花呗的账单周期是每月几号到几号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.119, 0.055, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗误发欠款通知", "pos": ["我的花呗怎么会欠款泥", "我都没欠花呗钱,怎么老是发信息来"], "neg": ["一个营业执照可以开通几个支付宝信用卡花呗收款", "我花呗人脸识别一直失败", "花呗几号出帐单的", "我上次申请借呗的时间", "已经还款的花呗账单明细怎么查"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.126, 0.09, 0.094, 0.153], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通花呗信用分要求", "pos": ["开通蚂蚁花呗需要多少分", "花呗要多少信用额度才可以开通"], "neg": ["花呗退回来的钱去哪里了", "我想提升借呗都可以吗", "花呗立减十元", "为什么开通花呗要常"], "pos_scores": [1.0, 0.95], "neg_scores": [0.068, 0.177, 0.164, 0.124], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收货时间与还款关系", "pos": ["我用蚂蚁花呗确认收货的时间是不一样的,还款期限也是一样的嘛", "确认收货时间与花呗还款时间有关吗"], "neg": ["为什么花呗不能使用共享单车了", "为什么借呗提示网商贷", "我昨天晚上给151******11号交了***元花费怎么还没到账", "我的花呗双***的时候临时额度是***现在为什么就没有了", "我的收款码怎么用不了花呗", "花呗能从银行卡直接扣除还款吗", "我的花呗额度用完了 怎么还能用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.13, 0.159, 0.052, 0.168, 0.075, 0.056], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "提前还清花呗分期", "pos": ["怎么样提前还花呗的分期账单", "怎么样提前还清花呗分期"], "neg": ["借呗交易可以删除吗", "的是蚂蚁借呗为什么关了,我之前还是开着的", "花呗还完款了,怎么还让还款", "之前的支付宝不用了 花呗怎么办", "花呗没按期还", "在那里提升花呗临时额度", "用花呗买的东西。现在退货了。怎么退款的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.193, 0.132, 0.122, 0.1, 0.145, 0.196, 0.148], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗最低还款影响信用吗", "pos": ["花呗最低还款会不会扣信誉", "花呗最低还款影响信用吗"], "neg": ["滴滴打车能不能使支付宝花呗了", "我的花呗是否还请清了", "蚂蚁借呗可以提现到支付宝吗,怎样提现", "花呗提前还款手续费怎么算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.165, 0.098, 0.126, 0.13], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "提升借呗额度方法", "pos": ["如何增加?蚂蚁借呗额度", "蚂蚁借呗要怎么样才能提升额度"], "neg": ["为什么花呗会忽然空白看不到额度", "我要在本账户开通花呗 但是另一个账户已经开通了 怎么办", "如何解绑花呗", "花呗脸部识别", "花呗还了一部分 下月可以分期还吗", "蚂蚁借呗分期还一次之后,还可以继续用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.18, 0.154, 0.062, 0.112, 0.093], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查询花呗总还款额", "pos": ["花呗怎么的看借款总金额", "怎样看花呗总需还款额"], "neg": ["蚂蚁借呗自功还款系统升级我怎样还款", "刚刚我用花呗付款了,要怎么还款", "怎么在手机上使用花呗收款", "开通花呗收款后二维码怎么用", "我不想用花呗,如何退掉", "还花呗。为啥。额度满了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.144, 0.139, 0.144, 0.091, 0.14, 0.197], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "冻结花呗额度", "pos": ["可以帮我冻结花呗吗", "里冻结花呗额度"], "neg": ["蚂蚁借呗什么时候通过审核", "我花呗逾期了,现在已经还清了,什么时候可以用", "借呗分期逾期以后 下次还款需要一次性还清吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.17, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "个人收款码不能用花呗", "pos": ["扫个人收钱码,为什么不能用花呗", "收钱码为什么不能使用花呗"], "neg": ["我的花呗不小心分了***期", "花呗要怎昵解冻结", "怎么永久关闭借呗", "刚刚还蚂蚁花呗,银行卡扣了费,怎么花呗没有清楚", "如何把借呗删了", "为啥没有借呗信用额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.109, 0.158, 0.053, 0.059, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款是否收费", "pos": ["花呗退款还要收费吗", "花呗退的款有利息吗"], "neg": ["如何更改借呗里预留的手机号", "我注销这个账户以后,我另外一个账户就可以申请花呗了吗", "我想借蚂蚁借呗为什么收不到手机验证码", "我上个月把花呗不知道是冻结还是调了。现在怎么恢复之前额度", "为什么我滴奖励金不能翻倍?用了一个月了,就出现过两次翻倍机会,其它时候没有翻倍?我一直设置的付款方式是花呗,余额宝"], "pos_scores": [1.0, 0.95], "neg_scores": [0.102, 0.141, 0.179, 0.157, 0.123], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通花呗条件", "pos": ["我怎么还没花呗", "我的花呗还没有开户"], "neg": ["花呗不能支付共享单车押金我", "我的借呗还款日几号", "怎么没有借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.155, 0.09], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款日当天还款", "pos": ["蚂蚁花呗还钱,,***号之前还款,那九号当天还可以吗", "花呗还款是***号前,***号当天还可以吧"], "neg": ["借呗的开通和信用分数有关吗", "这两三个月可能借呗还不起款", "花呗额度会随着还款而变吗", "我从来没有延迟还过钱,为什么把我的蚂蚁借呗给停用了", "借呗还款了还可以借出来么", "花呗这个账号有,另一个账号没有"], "pos_scores": [1.0, 0.95], "neg_scores": [0.077, 0.179, 0.12, 0.077, 0.053, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款日逾期判定", "pos": ["我的花呗最后还款日为***号,我十号还算不算逾期", "花呗***号还款 十号还中不"], "neg": ["花呗不够用怎么申请", "花呗可以每天还款不", "花呗如何退款", "关闭了借呗怎么在开通", "花呗不能付单车押金", "怎么提前还款一个月的借呗", "花呗充话费!充到空号了!花呗会不会自动退回"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.079, 0.135, 0.145, 0.194, 0.132, 0.146], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查询花呗消费明细", "pos": ["我的花呗账单有***我不记得是买什么", "我花呗***块是买什么"], "neg": ["我的蚂蚁借呗额度申请提升了怎么", "为什么我的账户没有条件开通花呗收款", "已经开通花呗 商品也支持 为什么还不能付款", "我怎样可以使用蚂蚁花呗", "怎么授权芝麻花呗", "我已经用银行卡还了花呗的钱,可是支付宝显示没还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.109, 0.081, 0.053, 0.082, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "电脑端开通花呗", "pos": ["用电脑客户端可以开通花呗么", "电脑能开通花呗吗"], "neg": ["花呗付款,后来还款了,可是退款又退款到花呗,钱去哪了", "我的花呗额度变为***是怎么回事", "我问的是我怎么不能用花呗和借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.191, 0.067, 0.133], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗借款失败原因", "pos": ["为什么借呗都用不了", "我在借呗怎么借不了钱"], "neg": ["我的借呗了", "蚂蚁花呗还款了银行卡扣钱了为啥还显示没有还款", "哪些航空公司可以使用花呗", "花呗还款之后,之前用花呗买的火车票退款了,最后没有收到退款", "花呗的期限"], "pos_scores": [1.0, 0.95], "neg_scores": [0.199, 0.174, 0.177, 0.133, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗支付退款流程", "pos": ["怎么要求花呗退款", "花呗支付怎么退款"], "neg": ["支付宝付款摩拜单车用花呗支付", "我是为什么花呗降额", "余额能还借呗么", "之前,蚂蚁借呗不是可以分为***期借贷吗,为什么现在只能分为***期", "我的借呗开了之后怎么又关闭了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.109, 0.114, 0.16, 0.052], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款后金额未更新", "pos": ["我是还没发退款的,退款到了花呗还是金额没变动", "花呗用出去,但是退款了,还款还是木有变的"], "neg": ["花呗,我最多能用多少", "花呗分期后能不能提前部分结清", "我花呗都取消自动首选了", "查询蚂蚁借呗***年***月账单", "花呗分期后一次性付清会有手续费吗", "还不了解花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.185, 0.189, 0.087, 0.12, 0.116, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗利率差异原因", "pos": ["蚂蚁借呗怎么利息不同", "为什么借呗的利息每个人的都不一样"], "neg": ["我是用银行卡还花呗的,扣钱的短信收到了花呗哪里没有少钱", "怎么查花呗明细", "身份信息完善了,为什么还不可以借呗", "怎么查询蚂蚁借呗还款记录利息", "花呗已自动还款,为什么提示还款失败", "为什么借呗还了就不能用了", "如何知道花呗是否还清", "蚂蚁借呗上的提前还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.125, 0.138, 0.081, 0.078, 0.177, 0.138, 0.113], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "密付是否支持花呗", "pos": ["密付可以当使用花呗吗", "密付。能用花呗吗"], "neg": ["注册了网商银行还有没有借呗", "退款是在花呗里再转去银行卡的吗", "为什么不支持花呗付款了", "借唄可以分期嗎", "花呗逾期 以后要还能用吗", "我换手机了我的花呗登不上怎么办", "花呗逾期 借呗能不能继续使用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.08, 0.13, 0.135, 0.114, 0.186, 0.147], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期恢复使用", "pos": ["花呗逾期被停用怎么才能恢复使用", "蚂蚁花呗逾期被停止使用 如何再次使用"], "neg": ["花呗分期不能还全款吗", "从我银行卡提现到花呗一百,说提升额度,提现一百额度没动,那怎么算提升。还是我自己的钱", "花呗要多久才能充值话费", "花呗a帐户关闭,如何开通b帐户", "借呗利息多少钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.142, 0.11, 0.19, 0.114], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗利率不同原因", "pos": ["借呗利率为什么和别人的不一样", "为什么我的借呗收万分四"], "neg": ["等多久才可以开通花呗", "花呗邀请好友代还", "什么情况下平台会关闭个人借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.091, 0.183], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗入口消失", "pos": ["我的蚂蚁借呗怎么关闭了", "我的借呗已经有一万六的额度,今天登录支付宝一看怎么没有入口了"], "neg": ["信用卡逾期也影响花呗的使用吗", "花呗大额付款有限制嘛", "花呗分期不够额度怎么办", "花呗开通不使用会收费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.164, 0.185, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通信用卡花呗收款", "pos": ["怎样设置信用卡花呗收款", "怎么开通,信用卡和花呗收款"], "neg": ["花呗红包,专享红包都没法用", "花呗里的钱和余额宝的钱能同时用吗", "刚刚都可以开通借呗", "花呗,可以让朋友带还吗", "花呗***线下支付不了", "借呗是怎么样的还款方式", "花呗收款为什么会显示不满足升级条件"], "pos_scores": [1.0, 0.95], "neg_scores": [0.088, 0.1, 0.172, 0.145, 0.101, 0.186, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗最高额度限制", "pos": ["我想蚂蚁借呗最高金额有上线么", "借呗有最高额度限制吗"], "neg": ["有个有花呗,怎么有一个没有", "为什么我这个没有花呗的", "为什么支付宝花呗还不了款", "为什么我的花呗用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.145, 0.075, 0.111, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "余额宝自动还花呗", "pos": ["如果花呗,忘记还款,自己在余额宝里扣除", "花呗还钱,可以自动从余额宝扣除吗"], "neg": ["蚂蚁花呗不够不能分期", "我的支付宝添加的是我交通银行信用卡,怎么会用了花呗里的钱", "为什么我找人代付成功了花呗却要我还钱", "我用花呗买的东西 退货后 怎么退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.124, 0.061, 0.126, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "网商贷转借呗", "pos": ["网商贷怎么样变借呗", "能把网商贷换成借呗吗"], "neg": ["花呗还需还账总额", "为什么借呗要上传身份证件", "花呗最多支付***是什么原因", "上午从借呗上借款到现在没有发款到卡上是怎么回事", "为什么我的花呗不能提前还完", "花呗收款码额度提升"], "pos_scores": [1.0, 0.95], "neg_scores": [0.058, 0.098, 0.074, 0.063, 0.096, 0.087], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "降低借呗额度", "pos": ["怎么能降低蚂蚁借呗", "借呗可以降额么"], "neg": ["我的花呗怎么一直不长", "借呗还款有个先息后本和等额还款 有什么区别", "借呗首月还款", "关团花呗,别一个帐户什么时候能开通", "花呗的流量提取", "花呗如何解除子账户", "a账号关闭花呗b账号多久可以开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.194, 0.056, 0.123, 0.08, 0.089, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗放款时间", "pos": ["蚂蚁 借呗借款多长时间放款", "蚂蚁借呗,什么时候放款"], "neg": ["***不是临时额度吗", "花呗金额为什么这么底", "花呗可以用来订机票么", "花呗无法使用、是怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.111, 0.144, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款去向查询", "pos": ["我的退款到哪去了?我是提前还款花呗,但是收到货产生退货退款,卖家退回花呗了,那我的钱到哪去了", "我之前的花呗已经还清了,但是我在淘宝上买的东西退货退钱了又退到了花呗上面,那钱去哪了"], "neg": ["花呗点餐能点几次", "花呗说我额度涨了,可是却没有变,这是怎么回事", "我用银行卡多还了花呗怎么办", "暂停花呗额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.148, 0.113, 0.111, 0.097], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查询上月花呗账单", "pos": ["想查上个月的花呗账单", "把我花呗这个月的账单发给我一下"], "neg": ["花呗提前结清多久额度能提", "用花呗分期买的怎么还款", "蚂蚁借呗只能分***个月或者***个月还么", "花呗上的支付宝账户和我需要添加的银行卡不一致怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.131, 0.189, 0.067, 0.197], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗和借呗区别", "pos": ["蚂蚁借呗和蚂蚁花呗一样", "花呗和蚂蚁借呗是一回事吗"], "neg": ["什么时候提的借呗额度", "我的***元花呗还款咋回事", "蚂蚁借呗多久重新可以评估额度", "花呗分期影响", "逾期还清后再使用借呗为什么还显示有逾期", "中专生如何完善花呗的学历资料"], "pos_scores": [1.0, 0.95], "neg_scores": [0.118, 0.131, 0.129, 0.125, 0.174, 0.157], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗提前还款规则", "pos": ["蚂蚁借呗提前还款一个月这个月还需要再次还款吗", "提前把蚂蚁借呗分期的钱还了,账单日还扣款吗"], "neg": ["花呗的手续费是什么", "我的额度是多少花呗", "花呗,怎么老是要我绑定银行卡,,我都绑定银行卡了都不行", "我的花呗为什么只能在淘宝买东西,其他功能没有", "能帮我开开花呗", "花呗 能转给他人支付宝吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.191, 0.063, 0.082, 0.195, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查看借呗账单", "pos": ["怎么看借呗还款的哪个账单", "借呗账单怎么看"], "neg": ["为什么我的花呗,不是我的支付宝账号", "花呗也关了", "用花呗支付手续费怎么扣", "购物津贴花呗可以用吗", "第一次开通花呗有多少额度", "需要具备什么样的条件才能向蚂蚁借呗借钱", "花呗什么安全验证不通过是怎么回事", "借呗能修改还款方式吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.108, 0.103, 0.083, 0.068, 0.137, 0.156, 0.149, 0.076], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗被关闭原因", "pos": ["借呗为和又又把我关了", "我的蚂蚁借呗怎么还关了"], "neg": ["借呗日息如何再降低", "好友没有花呗怎么邀请", "因为个人原因,我花呗逾期了好久,我现在还清了,还可以使用吗", "而且网商贷没有额度", "我的花呗所有分期账单如何提前还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.181, 0.081, 0.105, 0.139, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "线下无法用花呗", "pos": ["为什么商家支持花呗,而我用不了", "我的花呗为什么好多线下支付不了"], "neg": ["花呗怎么读取手机联系人名单", "我的花呗额度是够的 店主也支持花呗 为什么付款的时候不能显示花呗支付", "借呗把钱还了会有吗", "花呗的还款日是几号", "余额里的钱还蚂蚁借呗的钱有限额吗", "花呗临时额度入口", "借呗没有自动还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.099, 0.072, 0.115, 0.163, 0.151, 0.158], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗开通时间", "pos": ["我什么时候可以开通支付宝花呗", "花呗何时能开通"], "neg": ["能自动提高借呗额度", "支付宝显示花呗没开通,天猫却用了花呗,怎么还款", "花呗支付没享受到优惠"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.141, 0.121], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭借呗功能", "pos": ["我的意思是,要像刚申请支付宝那种状态,没有借呗", "能不能把我借呗关闭成没开通状态"], "neg": ["借呗还款为什么没有自动扣款", "付尾款怎么使用花呗", "我的借呗怎么没有提升额度的", "花呗可以自动扣除绑定银行卡吗", "花呗临时额度到期后还能不能使用", "我的花呗为什么扫码不能付款呀"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.198, 0.064, 0.16, 0.149, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款方式", "pos": ["花呗 还款怎么还", "不知道怎样还花呗"], "neg": ["开通花呗分期条件", "花呗体验额度多少", "花呗 绑定的卡 怎么更换"], "pos_scores": [1.0, 0.95], "neg_scores": [0.153, 0.138, 0.157], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "已还花呗的退款处理", "pos": ["花呗钱己经还了,后面退了款的钱退到那里", "之前的花呗已还,但后来有笔花呗退款"], "neg": ["怎么不支持花呗", "咋要我还款,我没用花呗", "我能继续使用支付宝花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.193, 0.13, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "删除借呗记录", "pos": ["你可以帮我删掉借呗里的账单记录吗", "蚂蚁借呗还款记录可以清除吗"], "neg": ["花呗分期还款支持邮政储蓄银行绿卡通支付吗", "花呗款已还,怎么还显示下个月还可还十几元,说未出帐的,什么意思", "借呗能逾期多久不会影响信用", "花呗充值话费退款是什么原因", "花呗能让别人先还吗", "为什么没收到花呗的优惠券", "花呗绑定了一个支付宝另一个还能开通吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.074, 0.102, 0.125, 0.077, 0.164, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账单月划分规则", "pos": ["***号当天花呗消费是算当月的,还是次月的账单", "我一号过后用的花呗支付,是当月还还是下月还"], "neg": ["蚂蚁借呗分期的一次性还完", "蚂蚁借呗授权书生效是什么意思", "钱可以退到花呗里面吗", "把以前的帐号注销,现在这个能开通花呗吗", "我的花呗还欠有多少钱", "网商货和借呗可以同时有吗", "蚂蚁借呗借款超时没到账", "达到什么条件才能使用花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.155, 0.098, 0.128, 0.18, 0.068, 0.086, 0.19, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期后能否使用", "pos": ["我的花呗逾期两个月了现在还款还能用吗", "花呗逾期***填还掉后还能不能用"], "neg": ["我现在花呗不可以用吗", "可以有花呗充话费吗", "花呗黄金会员", "花呗支付尾款可以用购物津贴吗", "积分可以兑换花呗免费分期吗", "蚂蚁借呗可以每月还款吗", "蚂蚁花呗额度怎么提升"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.154, 0.17, 0.086, 0.062, 0.098, 0.192], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗到账时间", "pos": ["借呗什么时候会到账", "借呗蚂蚁几天到账"], "neg": ["借呗利润怎么降低", "我的花呗钱被扣了***我该怎么找回", "我想还了花呗能关闭功能不", "花呗有没有每个月必须花多少吗", "为什么我开不了花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.16, 0.072, 0.139, 0.081, 0.068], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期提前结清", "pos": ["花呗已经分期的可以一次还款吗", "花呗欠款***元,然后我分期付款,到下一个月***日前可以一次还清***元吗"], "neg": ["为什么我的收钱码开通了花呗支付别人付款却不能用花呗", "借呗按日借款怎么没了,只能按月了", "花呗分期提前还款收手续费么", "我花呗里的钱应该多出***", "花呗对应的保险"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.179, 0.18, 0.095, 0.184], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗系统繁忙", "pos": ["花呗点进去总是提示系统繁忙", "打开花呗显示系统繁忙,请稍后再试"], "neg": ["我的蚂蚁借呗还款了嘛", "我怎么没法还花呗", "蚂蚁借呗随时还随时借可以吗", "我都退了,怎么还在花呗里", "信用卡花呗是哪个收钱码"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.13, 0.132, 0.174, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗随时还款", "pos": ["花呗是不是可以随时还", "花呗可不可以随时还款"], "neg": ["是的 花呗账单分期", "在我的里面找不到花呗", "花呗分期提前还要利息吗", "为什么我的花呗上没看到,没有关闭", "借呗到账时间变慢了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.138, 0.16, 0.127, 0.199, 0.107], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期后可用额度", "pos": ["花呗分期后 还能用吗", "分期付款这个月还能使用花呗"], "neg": ["没有找到花呗的设置", "取消花呗、信用卡收款服务", "我用花呗支付了,为什么对方还没有收到钱", "淘宝花呗信息截图在哪里", "我的蚂蚁借呗可以借***千的额度,如果是提前还款的话,利息是多少钱", "花呗退款如何拿到这比退款", "为什么我开通了花呗,就是用不了", "我在商场买东西不能使用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.111, 0.158, 0.194, 0.108, 0.051, 0.096, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗不显示原因", "pos": ["我的蚂蚁借呗功能怎么没有", "为啥我的蚂蚁借呗不显示"], "neg": ["分期花呗提前还,有什么危害吗", "现在分期花呗,到时候想一次性还,可以一次性还吗", "蚂蚁借呗额度没用完,还可以继续借嘛", "本月花呗已还款,银行卡钱也扣了 ,为什么显示没还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.103, 0.143, 0.092, 0.111], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗停用原因", "pos": ["为什么花呗不能使用了", "花呗很久用不了了"], "neg": ["登陆花呗提示你的支付宝账号存在安全风险", "花呗已经还款,还显示有余额", "花呗退款后,还款额不变"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.198, 0.118], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期影响信用", "pos": ["花呗分期会对个人信用不好吗", "花呗分期对分数有影响吗"], "neg": ["为什么我花呗只能淘宝付款", "用了花呗没还款会怎样", "口碑商家开通花呗收款", "查询申请借呗进度", "冻结花呗什么时候能解开", "花呗还款怎么没到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.081, 0.176, 0.127, 0.183, 0.122, 0.066], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗只能分6期", "pos": ["为什么蚂蚁借呗里借多久的选项中只有六个月", "我得借呗为什么只能分六个月"], "neg": ["花呗不用咯。 他会不会扣钱的", "借呗额度能关闭不", "我的花呗额度被冻结了怎么办", "花呗请朋友代付,没有链接", "我有一个花呗账单被我删除了", "我是卖家,但是我没有开通蚂蚁花呗,现在有个买家买东西的时候说他不能使用蚂蚁花呗付款", "为什么我的花呗额度这么底", "为什么支付宝手机号跟花呗手机号不一样"], "pos_scores": [1.0, 0.95], "neg_scores": [0.103, 0.131, 0.115, 0.147, 0.144, 0.18, 0.193, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗好友代付", "pos": ["朋友代付朋友可以用花呗帮我付款吗", "可以用花呗帮好友代付么"], "neg": ["蚂蚁借呗开通最少能借多少", "花呗开通后,不用会不会扣取费用", "支付宝上面的花呗可以每个月分期吗", "要是光开通花呗,不用有费用吗", "彻底关闭借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.129, 0.066, 0.144, 0.133], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗二次分期", "pos": ["借呗借款分期后到期还可以再分期还款吗", "蚂蚁借呗能分期后能再次分期吗"], "neg": ["蚂蚁借呗提前还款有副", "附近哪些门店可以用蚂蚁花呗", "借呗***天免息券"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.14, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗临时额度申请", "pos": ["花呗怎么才能使用临时额度", "我需要花呗临时额度"], "neg": ["我后期有网商贷,为什么借呗不见了", "蚂蚁借呗还款有短信通知吗", "花呗如何点提前结清", "花呗冻结好几个月了,以后还能用么", "花呗能支付微信零钱吗", "短信提示借呗额度提升成功为什么我额度没变", "花呗额度有***为啥不能分期付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.14, 0.159, 0.188, 0.167, 0.163, 0.061, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗重复扣款", "pos": ["银行卡支付了什么花呗再重复扣款", "我用花呗付款,银行扣了一次款,还花呗银行又扣了一次款,重复扣款了!怎么回事"], "neg": ["花呗分期买手机在哪查看物流", "我开通不了花呗也开通不了体验额度", "我刚申请关闭借呗,是否借呗关闭了,花呗也不能用了", "借呗怎么不给借"], "pos_scores": [1.0, 0.95], "neg_scores": [0.116, 0.098, 0.139, 0.145], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗临时额度分期", "pos": ["花呗临时额度用了能分期还吗", "花呗临时额度用掉可以分期吗"], "neg": ["花呗临时额度也可以分期还款么", "花呗可以他人代还么", "借呗手机号码换了收不到验证码怎么办", "蚂蚁借呗怎么每天那么多人申请", "花呗临时额度可以改几次", "我刚刚花呗点错了分期,可以取消吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.157, 0.128, 0.193, 0.118, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "清除借呗记录", "pos": ["能不能把借呗记录给删除了", "蚂蚁借呗还完后记录怎样删除"], "neg": ["我的借呗开通过,无不良记录,为何停了", "蚂蚁借呗可推迟还款吗", "我的借呗为什么没有按天还款的选项", "花呗邦定的手机", "蚂蚁借呗和网商银行是一家吗", "我想要使用余额宝自动还款花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.067, 0.088, 0.164, 0.124, 0.131, 0.058], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查询花呗额度", "pos": ["花呗额度怎么看", "花呗额度查不了"], "neg": ["花呗只能在网上消费吗", "没有在花呗", "花呗支付笔笔抽奖,抽到奖品在哪查看", "为什么分期额度那么低", "我的花呗账单想先还一半可以吗", "怎么提升花呗额度", "蚂蚁花呗怎么设置余额宝不自动还款", "实体店买手机能用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.06, 0.068, 0.1, 0.119, 0.088, 0.089, 0.141], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "芝麻信用与借呗关系", "pos": ["芝麻信用跟借呗有关吗", "芝麻信用跟借呗开通有关嘛"], "neg": ["花呗开通了为什么还不能使用", "在淘宝可以使用花呗吗", "怎么还借呗", "有手续费吗借呗", "花呗为出账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.089, 0.196, 0.147, 0.156, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "网商贷和借呗利息对比", "pos": ["网商贷的利息跟借呗利息差多少", "网上贷和借呗哪个利息高"], "neg": ["蚂蚁借呗手续费多钱", "怎么样才可以用借呗", "花呗逾期还款之后什么时候可以恢复使用", "借呗怎么不能分***期还咯", "能不能提高蚂蚁借呗的额度", "我好端端的花呗 不能使用了", "蚂蚁借呗一般是几号提额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.099, 0.101, 0.157, 0.1, 0.061, 0.131, 0.092], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款期限调整", "pos": ["怎么现在借呗只能***个月还款", "蚂蚁借呗怎么不能借***个月了"], "neg": ["为什么我花呗不能使用", "蚂蚁借呗客服", "花呗的钱我不花,是不是不用还钱", "我的蚂蚁借呗是不是已经升级成网商贷了", "刚用花呗支付的可以退款吗", "借呗还款期最多只有六个月"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.118, 0.176, 0.185, 0.166, 0.108], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账号转移", "pos": ["如何把花呗功能转到主账号", "花呗怎么能从一个账号转到另一个账号"], "neg": ["为什么我的蚂蚁借呗还是不能使用", "借呗还款不显示还款", "这个花呗我都不知道怎么弄", "我已经开通了花呗收款和信用卡 可是客户扫码付款确是不支持花呗付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.157, 0.183, 0.059, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期后恢复", "pos": ["我逾期还款已还还能用花呗吗", "花呗逾期后还款了,用额已经冻结,以后是否可以正常使用花呗"], "neg": ["蚂蚁借呗还款日期还没有到,就扣钱", "花呗从哪里可以手机充值", "上传身份证后多久才能给我开通借呗", "账户有花呗", "花呗可以 透支吗", "刚刚还的花呗怎么没到账", "为什么不见了花呗的二维码了", "花呗还完了多久可以用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.156, 0.183, 0.119, 0.143, 0.086, 0.175, 0.123], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗分期还款", "pos": ["借呗可以分期还", "蚂蚁借呗分期还"], "neg": ["就帮我恢复借呗吧", "我开通花呗信用卡功能,是不是有二维码给我", "这个月的***号用花呗买 多久还", "我的借呗今天是还款日,刚才***点还上的款,算逾期吗", "怎样查看花呗总还额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.146, 0.195, 0.194, 0.06], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗重复还款提醒", "pos": ["我这个月还完花呗了,怎么还要还", "花呗钱我还了,怎么又出来了让我还"], "neg": ["蚂蚁借呗不用额度是变吗", "花呗多少分可以免押金", "信用积分***多可惜申请蚂蚁借呗吗", "我的花呗里,为什么少了***元钱", "开通花呗需要多少额度", "愈期还过要多长时间才能使用花呗", "借呗每月的还款日可以改么", "我的蚂蚁花呗在哪儿"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.128, 0.171, 0.194, 0.055, 0.174, 0.132, 0.177], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗结清证明", "pos": ["蚂蚁借呗还请证明", "个人借呗 结清证明"], "neg": ["已经还款的花呗账单明细怎么查", "花呗标识图案是什么样", "借呗当月还不上咋办", "借呗能邀请朋友开通吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.079, 0.091, 0.175], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收款码查看", "pos": ["查看花呗收款码", "我的花呗收钱码"], "neg": ["蚂蚁借呗什么能评估完", "可以预期还花呗么", "支付宝花呗只能有一个么", "支付宝借呗可以推后一天还款吗", "已申请退款退货成功了,退货后退的款是还退到花呗吗", "我要关闭蚂蚁庄园,不是关闭蚂蚁花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.07, 0.07, 0.18, 0.173, 0.056], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗审核未通过", "pos": ["蚂蚁借呗审核没通过,好何重新提交", "借呗申诉审核未通过"], "neg": ["你花呗是什么意思", "怎么来取消花呗", "银行卡还借呗有限额吗", "为什么借呗申请失败"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.11, 0.056, 0.051], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期提前还款手续费", "pos": ["花呗账单分期提起还款还有手续费么", "花呗分期以后,如果提前还款需要手续费吗"], "neg": ["花呗能提前还下一期吗", "花呗还款可以推迟麽", "用花呗怎么开发票", "花呗分期想到能提前还款么", "蚂蚁借呗支持邮政储蓄卡吗", "为什么我的码蚁借呗不能用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.168, 0.196, 0.097, 0.177, 0.055, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "没有花呗原因", "pos": ["为什么,我的没有花呗", "我为什么没有花呗呀"], "neg": ["些条件满足借呗", "我的借呗额度有点不够用,可以提额吗", "蚂蚁借呗可以用网商银行帐户还款", "花呗额度是高多久", "为啥不开通借呗", "花呗支付的时候会有手机短信通知吗", "才申请的口碑码怎么不支持花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.159, 0.072, 0.187, 0.153, 0.123, 0.094, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗逾期几天不影响芝麻信用分", "pos": ["借呗逾期***天,会影响多少信用分", "借呗逾期几天不影响芝麻信用分"], "neg": ["什么额度能用借呗", "花呗的单笔消费最多是多少", "打算花呗分期买个手机 拜托快些提升额度", "花呗上提示可以用,蚂蚁借呗可用***为什么不能借", "如何删除借呗借还款记录", "为什么我不能使用花呗付电费", "我想知道花呗为什么用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.18, 0.108, 0.136, 0.168, 0.125, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款去向 提前还款后", "pos": ["我用花呗买东西、买错了、退款回花呗我查看花呗没钱、因为我一用花呗买东西了就立马还花呗的钱了、那退回来的款在花呗也没有、那退去哪了", "我是用的花呗,然后在淘宝上面买了东西,我就用银行卡提前还款了,之后我要把淘宝的那样东西退款了,但是退款的钱去了哪里,没有收到银行卡,上面也花呗,上面也没有看到"], "neg": ["我订了份外卖用花呗付的款还用给卖家钱或者还钱吗", "开通花呗,芝麻分要多少", "我已经还款两个月为什么没有恢复借呗额度", "我这借呗都有一阵显示异常了 啥意思", "蚂蚁借呗还了,啥时候能把逾期关拉"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.092, 0.072, 0.16, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗负分影响", "pos": ["我的花呗是负的有影响吗", "花呗负分有影响吗"], "neg": ["我没有任何不良记录怎么花呗开通不了", "网商贷切回借呗操作", "花呗以还款为什么还有账单", "我满足借呗开通的要求嘛", "花呗临时额度可以分期嘛"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.129, 0.153, 0.081, 0.186], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "口碑不能用花呗原因", "pos": ["口碑支付是不是不能用花呗", "我的口碑为什么不能用花呗"], "neg": ["我现在开通花呗什么时候还款", "花呗最低额度设置", "钱怎么退到花呗了什么时候到卡上", "花呗换起安全核身验证失败", "蚂蚁借呗电脑怎么使用", "花呗怎么删除账单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.057, 0.178, 0.194, 0.149, 0.087], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "无银行卡如何还花呗", "pos": ["没有银行卡可以还花呗嘛", "支付宝没有绑定银行卡,能不能还花呗"], "neg": ["怎么不显示蚂蚁借呗", "花呗还款是直接从绑定银行卡上扣吗", "商家花呗收款有手续费么", "借呗取消授权", "花呗付款是不是都是货到付款", "跟借呗有什么不同", "我花呗是***我想先还***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.08, 0.198, 0.191, 0.177, 0.11, 0.098, 0.148], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款去向 已还款", "pos": ["我已经还完款了,可是这是淘宝退款自动退到花呗里的,可是在他退之前我已经先还完了这个款,他这个钱退去哪里了", "蚂蚁花呗退款的钱去哪里了"], "neg": ["花呗分期还款费用怎么计算", "买家用花呗分期 在我店里买了***元的", "蚂蚁借呗分期还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.095, 0.183], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗无法使用原因", "pos": ["我蚂蚁花呗怎么不能使用", "我的花呗怎么花不了了"], "neg": ["双***花呗***期免息", "借呗可以拖一天还款吗", "借呗可以***期还款吗", "花呗额度成负数有啥影响吗", "借呗分期最长只有***个月吗", "蚂蚁借呗***期还款怎么不能选了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.118, 0.064, 0.109, 0.179, 0.075, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期还款设置", "pos": ["如果将花呗用完是否可以分期还款", "花呗分期还款如何设置"], "neg": ["手机丢了花呗还能用吗", "这个已经退款了,为什么还没把钱返回花呗", "使用借呗对以后房贷有影响吗", "借呗长的太慢了", "为什么我***号借呗还款后,一直没有恢复额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.184, 0.177, 0.182, 0.163, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无法借款原因", "pos": ["怎么在借呗借不了钱了", "借呗里面还有钱\\怎么借不了"], "neg": ["蚂蚁借呗的分期手续费", "蚂蚁借呗输错三次密码", "借呗分期,提前还了是否会影响再次借款", "按时还款为什么借呗额度不恢复了", "蚂蚁借呗可以延长分期", "花呗分期后提前还款,要收取手续费吗", "花呗总额度是不是可以透支的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.052, 0.094, 0.068, 0.13, 0.075, 0.053, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "支付方式无花呗选项", "pos": ["付款的地方没有花呗支付", "每次支付方式里直接没有了花呗这一个"], "neg": ["我通过支付宝购物时每次都从银行卡扣过款项了,但是花呗还让还款", "花呗分期付款手机 利息是怎么算的", "我用花呗交摩拜单车押金,退押金后没有到账", "我怎么删除用花呗", "借呗还款日,没提醒吗", "使用蚂蚁花呗支付时会有短信通知吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.153, 0.161, 0.14, 0.115, 0.183, 0.053], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗余额宝自动还款设置", "pos": ["花呗自动还款余额宝代扣功能怎么开", "怎么设置花呗从余额宝中自动还款"], "neg": ["取消花呗支付", "花呗可以换最低吗", "花呗额度是多少那"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.152, 0.129], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款期修改", "pos": ["蚂蚁借呗换还款时间", "蚂蚁借呗分期还款,还款期可以改吗"], "neg": ["花呗关闭后,怎么另外一个还是不能开通", "花呗,如何为好友额度加分", "借呗还款日是", "关了借呗能再开通吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.055, 0.191, 0.093], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗评估时间", "pos": ["系统评估借呗需要几天", "系统多久评估借呗可以使用"], "neg": ["蚂蚁借呗可以延长时间还款吗", "原先蚂蚁借呗都没有手续费", "蚂蚁借呗提前还款先还一部分可以吗", "花呗分期可以多个么", "我的花呗***分,为什么没有借呗入口"], "pos_scores": [1.0, 0.95], "neg_scores": [0.132, 0.137, 0.158, 0.121, 0.143], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗额度突然降低原因", "pos": ["为什么我的蚂蚁借呗额度突然降了那么多", "为什么我的借呗降低额度了"], "neg": ["这个月经济紧张推迟***天还花呗行不行", "怎么提升花呗的额度", "花呗还的时候有利息么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.103, 0.057], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "支付宝人工客服联系", "pos": ["找客服人员", "请联系人工客服为我解答"], "neg": ["用花呗分期买的东西,结果不合适退款了,手续费怎么没退", "蚂蚁借呗如何邀请好友开通", "花呗更换手机号了怎么换新的手机号", "我花呗是以前的号码,我想换现在的,怎么换", "花呗可以在店里用吗", "我的借呗为什么只能选择六个月和十二个月"], "pos_scores": [1.0, 0.95], "neg_scores": [0.163, 0.07, 0.124, 0.108, 0.123, 0.117], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗开通方法", "pos": ["花呗怎么样开通", "我的账户如何开通花呗"], "neg": ["我的花呗怎么才能开通", "蚂蚁借呗属于贷款吗,对于以后买房首贷会有影响吗", "怎样才是蚂蚁借呗关闭", "能不能申请借呗最低还款", "怎么看是不是使用花呗支付", "我说的借呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.119, 0.134, 0.19, 0.11, 0.071, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗生活圈加入方式", "pos": ["蚂蚁借呗生活圈怎么关注", "如何加入借呗生活圈"], "neg": ["我想蚂蚁借呗提升额度怎么办", "我用花呗已经付款了。怎么我在淘宝上看了。怎么显示没有付款", "网商贷额度和借呗额度哪个更高"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.193, 0.09], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗提前还款后无额度", "pos": ["蚂蚁借呗为什么提前还款了之后没有额度", "为什么提前还款借呗,没有额度了"], "neg": ["我的花呗有什么微贷商品逾期,能不能查出来", "借呗现在还可以提现吗", "花呗额度为何会下降", "为什么我的借呗额度没有了", "花呗还了最低额度了会有影响吗", "花呗临时额度完了会回到原有额度吗", "花呗分期最多是多少钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.148, 0.185, 0.141, 0.112, 0.056, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "商家支持花呗支付吗", "pos": ["你们家支持花呗支付么", "花呗支付可以吗"], "neg": ["要用多久支付宝才能用花呗", "蚂蚁借呗会影响信誉吗", "还花呗为什么发验证码还发以前的号码", "我怎么在电脑上进行花呗的还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.165, 0.102, 0.121, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭花呗方法", "pos": ["花呗如关闭", "花呗额度不长、想关闭"], "neg": ["关于花呗最低还款的", "花呗自动还款红功能", "为什么我的借呗借款方式跟别人不一样", "花呗一次性还款怎么操作", "花呗分期中,剩下的额度还能用吗", "信用卡与花呗有什么不同"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.156, 0.151, 0.166, 0.085, 0.071], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗长期不用会降额吗", "pos": ["蚂蚁借呗不用额度是变吗", "蚂蚁借呗长期不使用是否会降低额度"], "neg": ["为什么我美团外卖不能花呗", "借呗也没有", "***号之前花呗会自动扣款吗", "代付可以用花呗吗", "换个手机借呗怎么没有了", "借呗如何能提升额度", "花呗怎么不能在美团上买东西了", "电费怎么不能用花呗交"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.136, 0.11, 0.123, 0.14, 0.055, 0.2, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收款功能异常", "pos": ["为什么我现在花呗收款用不了", "花呗收不了钱怎么了"], "neg": ["花呗的优惠券有吗", "借呗怎么解除等额还款", "还什么花呗还款", "蚂蚁借呗不可以分期还款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.152, 0.063, 0.179, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗延期还款规则", "pos": ["花呗可以迟一个星期还吗", "花呗可以延期几天还吗"], "neg": ["花呗临时额度也能分期账单吗", "花呗分期后,额度是多少", "买手机用花呗分期付款", "我为什么没有蚂蚁借呗信用额度", "我用花呗买的火车票,然后我还上了,改签火车票退回的钱又给我退到额度一部分"], "pos_scores": [1.0, 0.95], "neg_scores": [0.146, 0.195, 0.056, 0.075, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗开通条件不满足原因", "pos": ["我哪里不满足借呗条件了", "蚂蚁借呗不满足条件的原因"], "neg": ["帮我看一下,我这个花呗是什么时候用的吗", "怎么样限额借呗", "购物津贴可以用来花呗分期吗", "系统一般好久评估一下借呗", "花呗可以让别人帮你还吗", "花呗分期后能提前还清么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.086, 0.112, 0.097, 0.1, 0.14], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款去向 已还款", "pos": ["我的这笔款是已经还了的 为什么不退现金而是退给花呗", "我上次买袜子的钱自己还款了 但是退钱的钱为什么又到花呗了"], "neg": ["交易退款至花呗为什么需要还款的金额没变", "花呗开通不能在淘宝使用", "借呗还款日", "我已经花呗分期了,还款日是***月***号,那我这个月***号提前还,还要手续费吗", "如何采用花呗二维码收款", "网商跟花呗有什么关系", "我这个账号不是有欠借呗吗还是贷款", "我花呗在淘宝买的衣服,然后我提前还上了花呗,淘宝退回的款到哪里"], "pos_scores": [1.0, 0.95], "neg_scores": [0.148, 0.097, 0.12, 0.07, 0.193, 0.066, 0.188, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗提前还款方法", "pos": ["花呗钱可以提前还么", "花呗付款了,可以提前还款吗"], "neg": ["在淘宝上买东西可以不用花呗吗", "花呗瑜期怎么办", "商品可以用花呗支付", "我的借呗为什么突然没有额度", "如果退出花呗", "我得支付宝,花呗是什么意思"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.15, 0.121, 0.147, 0.111, 0.076], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还清后显示欠款", "pos": ["花呗的钱我全部还完了滴按,怎么显示还欠款", "我的花呗明明已经还清了。为什么显示有未还完的"], "neg": ["花呗款怎么办", "口碑怎么用花呗支付不了了", "我买东西用花呗付款,怎么还钱", "我的信用很高 为什么不能使用借呗", "花呗临时额度会到期", "蚂蚁借呗分期得能在分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.182, 0.137, 0.123, 0.077, 0.17], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗上传身份证要求", "pos": ["?借呗上传身份证", "开通蚂蚁借呗可以不用上传身份证吗"], "neg": ["花呗还款实际到账不符", "有花呗怎么才能尽快有借呗", "退款回复花呗额度怎么做", "花呗分期会不会影响信用问题", "我支付宝就是现在这个支付宝账号登入但是花呗显示不是我登入支付宝的手机号怎么办", "花呗最低还款后剩下的金额可以分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.185, 0.149, 0.147, 0.058, 0.135], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通花呗不用会收费吗", "pos": ["花呗如果我开通了不用会有费用吗", "开能花呗不用要收费的吗"], "neg": ["我的蚂蚁花呗无法正常使用,为什么是怎么回事", "花呗欠款可以用微信还吗", "怎样设置付款方式为花呗", "蚂蚁花呗逾期被停止使用 如何再次使用", "花呗可以在外面店里购物吗", "取消的订单花呗还要还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.166, 0.073, 0.175, 0.182, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度无法使用原因", "pos": ["我的花呗额度为什么不可以用了", "为啥我现在用不了蚂蚁花呗"], "neg": ["借呗是不是临时额度的", "查询我的花呗注册手机号是", "系统多久评估一次借呗权限", "花呗分期后还能不能继续分期", "分期提前还款花呗", "蚂蚁借呗降了***额度", "可以把花呗权限给自己的另一个支付宝账号嘛"], "pos_scores": [1.0, 0.95], "neg_scores": [0.069, 0.161, 0.129, 0.119, 0.144, 0.116, 0.068], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通花呗收费吗", "pos": ["开通花呗需不需要费用", "开通花呗会收取费用吗"], "neg": ["企业支付宝如何开通信用卡与花呗支付功能", "支付宝的号码换过来了,花呗的号码还是没有换过来", "最高借呗多少", "蚂蚁花呗怎么还了两次钱", "借呗钱能转给别人吗", "如何寻找带有花呗标识的商品", "在花呗里没有看到退回来的钱", "为什么无缘无故就开通花呗,我要取消怎样操作"], "pos_scores": [1.0, 0.95], "neg_scores": [0.072, 0.165, 0.186, 0.199, 0.117, 0.16, 0.097, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期能用购物津贴吗", "pos": ["购物津贴能用在花呗分期上吗", "使用花呗分期,能否享受购物津贴"], "neg": ["我的花呗什么时候才能用了", "借呗没有还款功能", "我开通花呗收款现在怎么不能用了", "我的花呗写的是***号可以***号还款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.087, 0.113, 0.088, 0.052], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款状态查询", "pos": ["花呗的钱是否还清", "我的花呗钱还完了吗"], "neg": ["我消费的花呗没有那么多钱,为什么我的还款还要多于这个钱", "可以用支付宝吗", "订单取消后花呗什么时候恢复", "花呗对芝麻信用有多大影响", "借呗只有***个月和***个月的期限", "支付宝花呗营业照主体和本账户名称不一致"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.091, 0.056, 0.086, 0.089, 0.149], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "淘宝支付宝手机号不一致能用花呗吗", "pos": ["淘宝手机号和支付宝手机号码不一致,不能用花呗吗", "支付宝支付宝绑定的手机号和淘宝绑定的手机号不一样可以用花呗吗"], "neg": ["怎么查用花呗买的物品", "我明明付钱了,为什么确认收货后花呗上还要多次付款", "共享单车退钱了可是支付宝花呗没有收到", "在花呗怎么绑定银行卡"], "pos_scores": [1.0, 0.95], "neg_scores": [0.109, 0.122, 0.179, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗利息计算", "pos": ["***元花呗利息多少", "我上个月扣了我多少花呗利息"], "neg": ["朋友帮还花呗有什么好处", "蚂蚁借呗最后还款日是什么时候", "花呗是分期的吗", "我要怎样才能用借呗", "蚂蚁借呗用钱什么时候到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.134, 0.168, 0.072, 0.148, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "尽情花呗含义", "pos": ["显示尽情花呗什么意思", "尽情花呗怎么回事"], "neg": ["蚂蚁借呗还款日设置", "花呗逾期不还会影响黑户吗", "能不能注销支付宝", "花呗返利活动", "为什么这期花呗不能分期", "密付不能用花呗么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.063, 0.066, 0.144, 0.145, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "密付支持花呗吗", "pos": ["买东西密付可以用花呗吗", "密付能使用花呗么"], "neg": ["花呗能在淘宝里用吗", "我的花呗还款日期能延长到***号吗", "花呗分期第一个月需要付钱吗", "我的借呗额度什么时候才能恢复", "借呗怎么到期了不扣款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.081, 0.181, 0.054, 0.115, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗利息差异原因", "pos": ["蚂蚁借呗利息怎么别人低我的高", "怎么蚂蚁借呗我的利息和别人的不同,万分之六,别人的万分之***"], "neg": ["比如我用花呗分期买了东西,然后我应该怎么还款", "花呗能不能提前还清", "我是在还花呗的时候扣了***次款", "我的花呗为什么还有账单?我从未贷过款", "加好友可以发花呗吗", "我不是网商", "支付宝里找不到花呗――我的账单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.11, 0.057, 0.12, 0.19, 0.124, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗提前还款利息", "pos": ["提前还款的话利息会少么", "花呗 提前还款要利息吗"], "neg": ["怎么看是否支持花呗", "怎么完善借呗的信息", "借呗信用额度不够了", "花呗借呗开通不了是怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.159, 0.178, 0.052, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "网商贷和借呗区别", "pos": ["网商贷好借呗有什么区别", "我想知道蚂蚁借呗和网商贷的区别"], "neg": ["花呗短信怎样用", "怎么提前还已分期花呗", "我的借呗未还成功", "蚂蚁借呗没有三个月的借款期限", "花呗可以用余额还吗", "为什么借呗无信用额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.17, 0.055, 0.175, 0.15, 0.171], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款周期", "pos": ["蚂蚁借呗多久还一次", "借呗最长多久要还款"], "neg": ["蚂蚁借呗还款账单", "花呗分期期间提升可用额度吗", "信用额度怎样才能蚂蚁借呗的条件"], "pos_scores": [1.0, 0.95], "neg_scores": [0.125, 0.128, 0.055], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗开通时间", "pos": ["我的什么时候才能用花呗", "我的蚂蚁花呗什么时候能给开通,我现在一直信誉良好"], "neg": ["怎么样查询自己有没有开通花呗收款功能", "上个月花呗买的东西已经扣款为什么这个月又扣一次", "花呗可以换红包不", "如何设置花呗绑定银行卡", "我花呗付款,退货的话钱是不是也退到花呗了", "我的花呗为什么不能用?说有逾期的,但是找不到"], "pos_scores": [1.0, 0.95], "neg_scores": [0.135, 0.16, 0.084, 0.164, 0.064, 0.152], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "支付宝无花呗选项", "pos": ["为什么我没有花呗可以用", "支付宝里怎么没有花呗"], "neg": ["为什么这个月借呗不能借", "借呗一次性还款后能不能再借钱", "我花呗分三期还款的", "花呗充话费怎么充不了啦"], "pos_scores": [1.0, 0.95], "neg_scores": [0.167, 0.154, 0.084, 0.169], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无额度原因", "pos": ["借吧怎么没有额度啦", "借呗怎么没有咯"], "neg": ["为什么我用花呗一直提示满***用花呗分期", "花呗冻结啥时候能开通", "花呗的最大额度是多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.128, 0.119, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗专享红包使用条件", "pos": ["开通了花呗,专享红包能用吗", "开通花呗是不是就可以申请专享红包门店"], "neg": ["总是显示我的另一个账户,怎么能把花呗转到我这个账户上", "花呗我为什么绑定不了银行卡", "为什么在淘宝购物不能用花呗", "我看了我的花呗没有这个", "中专生如何完善花呗的学历资料", "借呗有最低还款额,没有", "花呗贷款一百还多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.074, 0.12, 0.182, 0.182, 0.159, 0.104, 0.08], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗未还完能否再借", "pos": ["借呗是要全部还清才可以再次借款吗", "借呗分期未还完可以再借吗"], "neg": ["我啥时候能开通蚂蚁花呗", "借呗额度全部还清可以恢复吗", "在共享单车里的押金可以用花呗吗", "我想请你们把我花呗开通", "刚开通花呗有多少透支额", "我的花呗收款吗,不能收款了怎么办", "为什么我的借呗收回去了", "花呗还款 组合付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.069, 0.084, 0.113, 0.062, 0.158, 0.136, 0.099, 0.176], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗通讯录权限设置", "pos": ["花呗允许读取通讯录怎么搞", "花呗怎么设置通讯录"], "neg": ["花呗限制了虚拟支付是什么意思", "可以使用花呗里的余额吗", "商品价格大于花呗额度但能分期", "为什么退款了之后花呗还款还是没有减"], "pos_scores": [1.0, 0.95], "neg_scores": [0.175, 0.195, 0.094, 0.166], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗购买火车票限制", "pos": ["花呗,能买火车票吗", "支付宝花呗不能买火车票对吗"], "neg": ["买花呗保险", "借呗十二点以后怎么还款", "花呗临时额度啥时还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.171, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无法借款原因", "pos": ["为什么借呗不能借钱了", "为什么我的蚂蚁借呗不能用"], "neg": ["开通了花呗就成了蚂蚁会员吗", "蚂蚁花呗可以提前还款有手续费吗", "商家支持花呗付,但是付不了", "什么时间可以使用花呗", "***月***号用的花呗,是否可以到下个月***月***号还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.077, 0.063, 0.193, 0.087], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还清后无法使用", "pos": ["花呗我已经还了,为什么现在无法使用", "我的花呗还完了怎么不能用"], "neg": ["花呗完善学历资料怎么完成", "我蚂蚁借呗今天还款没有", "花呗免单活动怎么报名", "怎么填写申请蚂蚁借呗", "从淘宝买东西用花呗只用几十块钱可以吗", "蚂蚁借呗进不去", "花呗账单分期免息券在哪儿领"], "pos_scores": [1.0, 0.95], "neg_scores": [0.058, 0.124, 0.189, 0.131, 0.098, 0.053, 0.112], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗不支持电费缴费", "pos": ["交电费不能扣花呗", "支付宝里的花呗怎不支持交电费了"], "neg": ["花呗在还款日的几点自动还款", "为什么我改了手机号码还是不可以使用花呗", "刚才没有输入付款码花呗就付款了", "可不可以更改花呗还款日", "淘宝花呗支付宝花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.106, 0.071, 0.187, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退票未到账", "pos": ["用花呗买的汽车票退票为什么钱还没到账", "退款钱到花呗,查询没到账"], "neg": ["花呗有额度已选了花呗付款,为什么还是在我银行卡里扣", "借呗分期还可以分期吗", "花呗可以定酒店麽", "花呗开通了,不使用会产生费用么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.091, 0.151, 0.121], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账单查询", "pos": ["在哪看用了几次花呗", "查看花呗所有账单在哪"], "neg": ["收钱码在哪儿?我开通了花呗收款,但是找不到收钱吗", "我前面忘了花呗还款有几条 我的花呗被冻结了", "蚂蚁借呗可以借***年", "我九月这期蚂蚁借呗是否还款成功", "大概多久评估一次借呗条件", "借呗没有了成了网商贷", "我刚才在淘宝买东西了,是用花呗付的钱怎么办", "***元已经退了,花呗怎么没有减掉"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.058, 0.115, 0.171, 0.174, 0.082, 0.187, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期一天影响", "pos": ["花呗只逾期一天还款后可以用嘛", "我用的花呗我晚***天还可以吗"], "neg": ["什么交易能用花呗", "借呗和花呗不能同时开通吗", "花呗的还快款金额不对", "花呗的钱还不进"], "pos_scores": [1.0, 0.95], "neg_scores": [0.135, 0.067, 0.148, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗更换手机号", "pos": ["花呗能不能换手机号码呀", "我花呗如何更改手机号"], "neg": ["我没有逾期怎么花呗就不能用了", "另一个号关闭花呗了,这个号可以开了吗", "借呗怎么再用", "同一个身份证花呗在另一个账户上面,转到这一个上面怎么转", "花呗在哪里查询"], "pos_scores": [1.0, 0.95], "neg_scores": [0.099, 0.192, 0.111, 0.154, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗有额度借不出", "pos": ["我借呗还有额度,为什么借不出来了", "借呗上有额度,但是借不出钱"], "neg": ["花呗就能使用一个淘宝账号", "主动申请花呗审核时间", "花呗手机直接还款", "可以用蚂蚁花呗购买商品吗", "花呗还最低,影响个人信用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.089, 0.065, 0.085, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款记录查询", "pos": ["蚂蚁借呗今天的还款记录", "我是说借呗还款记录"], "neg": ["开通了花呗,不花钱是不是也不算", "为什么我的没有蚂蚁借呗功能", "花呗逾期未还有利息吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.152, 0.192], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款处理", "pos": ["退到花呗里的钱怎么办", "钱款退到我的花呗里"], "neg": ["花呗还款后才能恢复可用额度", "花呗扣钱是在哪里扣", "每个月都可以继续花呗分期吗", "为什么我的账号不能用花呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.186, 0.056, 0.155, 0.179], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗实体店使用范围", "pos": ["花呗可以在实体店使用么", "花呗可以用于超市吗"], "neg": ["花呗刷脸为什么老是失败", "为什么花呗每个月减***元", "为什么我在借呗哪里借的钱一天都没到帐", "信用卡花呗收款怎么开通不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.052, 0.079, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度查询", "pos": ["我有多少的花呗额度", "开通花呗我的额度是多少"], "neg": ["***.***显示两笔退款,但我的花呗账单***月和***月只有支出金额,没有这两笔收益", "还有一个开通了网商贷,开通不了借呗", "蚂蚁花呗被冻结了怎么解冻", "花呗能让别人先还吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.162, 0.173, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗线下支付限额", "pos": ["花呗线下支付能用多少额度", "线下商家花呗数额"], "neg": ["用花呗买东西不收取利息和手续费吗", "花呗***号自动还款扣利息吗", "蚂蚁借呗提前还款后可以在提现吗", "信贷产品要确定永久关闭关闭关闭", "我花呗有***额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.146, 0.14, 0.083, 0.095, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "注销花呗方法", "pos": ["怎么消花呗卡", "我要怎么注销蚂蚁花呗"], "neg": ["花呗临时额度消费了可以分期吗", "为什么我的花呗要利息", "借呗还款可以自动在支付宝余额扣钱吗", "花呗提升额度需要学历我没有学历怎么办", "花呗额度负数还能买吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.163, 0.192, 0.125, 0.173, 0.067], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通花呗不使用影响", "pos": ["开通花呗不使用行吗", "开通花呗都不用会怎么样"], "neg": ["为什么借呗只能最高借六个月了", "为什么我花呗刷脸刷不过去", "花呗还上次分期的怎么还", "花呗分期付款的苹果手机"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.057, 0.096, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗关联账号更换", "pos": ["我能把我另一个关联账号花呗转到现在用的账号吗", "换花呗关联账号"], "neg": ["花呗支付加油卡充值", "身份证过期了咋办借呗", "已经形成的借呗账单可以分期吗", "借呗放款了没到银行卡", "花呗点进去总是提示系统繁忙", "我想退出花呗,不想再使用", "花呗短信提醒免费吗", "花呗可以使用的途径"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.083, 0.13, 0.054, 0.139, 0.196, 0.071, 0.184], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗缴费功能异常", "pos": ["以前能用花呗缴费着嘛", "前几天还能用花呗充值"], "neg": ["怎么查看花呗分期还的开销", "差了多少钱,我给她还了,把她的蚂蚁花呗给关了吧", "花呗是不是全部还好就能使用了", "花呗分期的费率是多少", "网商贷和蚂蚁借呗额度共享吗", "二维码申请花呗怎么申请"], "pos_scores": [1.0, 0.95], "neg_scores": [0.068, 0.16, 0.09, 0.162, 0.058, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗突然无法使用", "pos": ["我现在为什么用不了蚂蚁花呗", "怎么不能用花呗了"], "neg": ["借呗还款日最晚当天什么时间", "这个花呗里面借的钱怎么还呀", "也就是花呗账单多于实际消费,而且我是有明细的", "借呗提额为什么要输入密码", "借呗为什么不借了", "花呗分期后提前还款,要收取手续费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.148, 0.166, 0.112, 0.149, 0.156, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还清后能否再借", "pos": ["借呗一次还完本金下次还能借吗", "借呗还完后还能继续借款不"], "neg": ["花呗手续费怎么算", "花呗能在闲鱼上使用ok", "我都不用花呗买东西。怎么会有费用叫还款", "为什么我的花呗账户冻结了", "花呗钱能不还吗", "为什么我已经退款了 还款还是多一笔花呗? 退款在那里去了", "花呗冻结了。怎么开通", "我这个花呗里面的钱怎么扫不出去是怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.116, 0.094, 0.123, 0.159, 0.161, 0.117, 0.174, 0.074], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗最新活动", "pos": ["花呗近期活动", "花呗有新活动吗"], "neg": ["淘宝退款还到花呗我什么没看到钱", "怎么不能用花呗充话费了", "花呗借款,还款日期是固定的么", "花呗分期付款的没有积分", "我借呗是每月等额的,现在可以改成先息后款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.085, 0.149, 0.141, 0.104, 0.15], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "收款码不支持花呗", "pos": ["扫手机上的收钱码怎么不支持花呗收钱", "为什么别人扫我的收款码不支持花呗付款"], "neg": ["花呗pass授权", "花呗显示上期账单已还清,但还有短信提示欠费,怎么回事", "我要还花呗什么还", "借呗最高额度能提到多少钱", "借呗分期还款能先还利息么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.15, 0.163, 0.067, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗提现到银行卡", "pos": ["我要从花呗里提到卡里面", "我要花呗滴钱转到银行卡"], "neg": ["我现在应该怎么还款,我花呗是花着,我号欠款的,我都不用了,怎么还款", "借呗说已经到银行卡了,可是银行卡没收到", "可以用密付支付花呗账单吗", "借呗每个月几号还", "怎样修改花呗密码", "花呗自己充的钱怎么提出", "支付宝里如何切换成花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.141, 0.153, 0.086, 0.152, 0.095, 0.061], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗重新开通时间", "pos": ["我的花呗什么时候可以再开通回来", "我什么时候可以重新开启花呗"], "neg": ["蚂蚁借呗的借款记录怎么删", "系统平的花呗一万我可以定五千吗", "我的怎么找不到花呗在哪", "花呗逾期不能超过今天", "为什么我点花呗,进去显示系统繁忙"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.083, 0.139, 0.192, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款未到账", "pos": ["说是退回了花呗但没看到钱进了哪儿了", "退到花呗,没看到进帐"], "neg": ["花呗拖一天还款", "淘宝花呗如何使用", "借呗还了怎么还有负面记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.125, 0.195], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期还款后无法使用", "pos": ["花呗逾期还款了为啥还不能用", "花呗逾期还款以后怎么还不能用"], "neg": ["蚂蚁借呗额度提高", "花呗可不可以用美团", "要开通花呗线下收款怎样开通", "花呗最多能透支多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.19, 0.107, 0.196, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还清后关闭方法", "pos": ["花呗全部还了 能不能关闭", "还款以后能不能关闭花呗"], "neg": ["花呗什么时候能分期", "花呗那有条信息说,当前剩余款***应还,是什么意思", "什么时候花呗能提前还款", "花呗用了就要还吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.077, 0.161, 0.165, 0.056], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗可用额度查询", "pos": ["花呗我能接多少钱", "我花呗现在多少钱"], "neg": ["没开通,怎么用了花呗的钱", "为什么我的花呗页面没有账单分期", "花呗还款可以不分期么", "商家蚂蚁花呗怎么收手续费", "用花呗按时还款的话要不要利息"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.172, 0.099, 0.157, 0.191], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗强制开通方法", "pos": ["可以帮我开启花呗没", "帮我打开花呗好嘛"], "neg": ["杂样申通蚂蚁借呗", "支付宝已经绑定银行卡,为什么花呗还要绑定银行卡", "花呗还款日到了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.121, 0.111, 0.066], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "微信还花呗", "pos": ["花呗还款可以在微信还吗", "我想知道微信可以还支付宝上的花呗欠款吗"], "neg": ["这是退了的订单,花呗还有帐单要还", "花呗逾期,利息怎么计算", "花呗分期付款的利息多少", "花呗能买飞机票", "我不想用花呗了,想还款", "借呗银行卡体现放款多久到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.183, 0.183, 0.188, 0.164, 0.134], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款无短信验证码", "pos": ["蚂蚁借呗还钱要手机验证码,我手机号码没有再用,所以收不到信息,应该怎么还", "借呗自助还款收不到短信验证码怎么办"], "neg": ["美团花呗用不了", "这个月,花呗还能不能用", "花呗为什么不自动缴电费了", "为什么查不到蚂蚁借呗的还款记录", "花呗能否预订酒店", "问,我的借呗当期还款怎么提前还", "集分宝可以还花呗借款吗", "花呗里的钱只能买东西吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.105, 0.174, 0.13, 0.074, 0.2, 0.145, 0.093, 0.108], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收钱开通", "pos": ["花呗,我给我开通,我现在急需要收钱", "我开通了花呗收钱了"], "neg": ["原来的邮寄的二维码可以用花呗吗现在", "为什么我现在还是无法开通花呗", "我如果全部还款完成了还可以用花呗分期购物吗", "花呗申请分期后,一次还清也需要还全部的手续费", "帮我花呗打开", "为什么在淘宝付款用了花呗,而在支付宝上显示未开通花呗", "芝麻信用够了,身份证绑定了,银行卡绑定了,为什么还不能申请蚂蚁借呗", "借呗会自动扣除余额宝欠款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.169, 0.165, 0.056, 0.154, 0.098, 0.092, 0.17, 0.151], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期影响芝麻分", "pos": ["花呗分期影响,芝麻分吗", "花呗分期付款 影响信用额度那"], "neg": ["我的借呗超过了***笔就不能结款了吗", "花呗可以支付滴滴打车的费用吗", "打开花呗总显示登录超时怎么回事", "花呗付款后退款会收手续费吗", "退款回到花呗的钱还需要还吗", "在淘宝和天猫上什么东西都可以用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.175, 0.194, 0.106, 0.083, 0.169, 0.123], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗手机号不一致", "pos": ["花呗和现在是一个手机号,不符,怎么办", "花呗和绑定支付宝的手机号不一致"], "neg": ["花呗付款了,可以提前还款吗", "花呗还款提醒怎么取消", "花呗还款是优先还最低的吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.119, 0.086], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗需旧手机号登录", "pos": ["为何花呗还需使用旧号码登陆", "为什么我登录的是新手机号,花呗怎么让我登录旧手机号码使用花呗"], "neg": ["前几天花***花呗积分兑换的优酷会员在哪里可以领取", "退回花呗我还会不会有影响", "付的钱是花呗的,该怎么退款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.157, 0.137], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗绑定抽奖", "pos": ["蚂蚁花呗抽奖机会", "绑定花呗,抽奖"], "neg": ["我的花呗额度是", "借呗借第二次用不用验证码", "没有人工客服吗", "双十二可以摇花呗额度吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.139, 0.104, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "密付花呗支付", "pos": ["密付能用花呗支付嘛", "密付用花呗"], "neg": ["蚂蚁借呗有保证金吗", "花呗可以借钱", "可以给个借呗免息卷", "参加花呗的活动什么时候到账", "花呗降额度是什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.138, 0.082, 0.098, 0.102], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗放款未到账", "pos": ["借呗提示已放款,但是银行卡没有收到钱", "借呗提到银行卡,也什么一天了都还没到"], "neg": ["更改开通花呗的支付宝", "借呗能提前还清吗", "花呗分期免息卷怎么积分兑换", "为什么淘宝买的东西付了款,支付宝花呗还要扣除一次"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.091, 0.144, 0.195], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收款码贴纸更换", "pos": ["我以前有申请过的贴纸。现在申请了花呗收款需要重新再申请贴纸吗", "收到申请的收款码贴纸后再申请花呗收款用更换贴纸吗"], "neg": ["花呗在我另一个支付宝账号上我可以改在我这个账号", "花呗付款三次 ,显示二个订单。怎么回事", "如何撤销花呗这个软件", "为什么我的收钱码不能用花呗了", "我的借呗怎么开通了,现在又没有了,怎么开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.066, 0.175, 0.084, 0.175, 0.141], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗无法买车票", "pos": ["为什么花呗无法买车票", "花呗能买火车票ma"], "neg": ["花呗可以提红包吗", "花呗额度转账号", "花呗最低还款以后每日利息多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.124, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款后未恢复额度", "pos": ["蚂蚁借呗钱已经还上了怎么还没有恢复额度", "我的借呗怎么还款了还没有恢复额度"], "neg": ["我的蚂蚁借呗几号还", "支付宝与淘宝上的花呗额不一样", "怎么提高花呗的信用额度", "更改花呗最低消费金额"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.149, 0.063, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款日修改", "pos": ["借呗还款日能换日期么", "借呗还款日期可以改不"], "neg": ["花呗自动抵扣红包", "为什么我的花呗额度不能提高", "我的花呗收钱怎解收不了钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.158, 0.116, 0.17], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗审核放款时间", "pos": ["申请借呗,多久能审核额度", "借呗多久审核完,能放款"], "neg": ["借呗必须还款日还款吗", "花呗信息补全", "借呗手机号码换了收不到验证码怎么办", "花呗分期取消后重新分期", "怎么花呗借款不能分***期了", "怎么样才能继续使用借呗", "我之前用了花呗分期,今天提前还款了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.186, 0.18, 0.193, 0.099, 0.053, 0.171, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗开通失败原因", "pos": ["借呗开通不了", "为什么我开通不了借呗"], "neg": ["我用花呗付款成功,商家未收到?是什么情况", "另一个账户开通了花呗,怎么转到这个账户", "花呗的钱还完了。那退款怎么办", "怎么没蚂蚁借呗", "为什么给我关了借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.064, 0.055, 0.077, 0.119, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "朋友代还花呗", "pos": ["可以让朋友替还花呗吗", "花呗可不可以让朋友帮忙还款"], "neg": ["刚一笔还款,我银行卡有信息扣取成功,为何花呗没还成功", "更改花呗还款日期", "蚂蚁借呗可以恢复那", "问什么关闭我的借呗", "我的蚂蚁借呗刚才领了一个什么礼包。额度就涨到***了", "花呗开通后一定要使用吗", "花呗分期,如果提前还完,会扣后边的手续费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.112, 0.093, 0.137, 0.084, 0.149, 0.107], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗降额原因", "pos": ["花呗额度为什么会降下", "为什么我的花呗降额了"], "neg": ["花呗怎么是旧手机号", "蚂蚁花呗要什么分期付款不能一下还完", "我上个月花呗的账单是什么", "借呗逾期三天已经还了,为什么还显示有逾期", "怎么关闭首选花呗支付方式", "花呗用了最低还款算逾期吗", "你的,为什么我的花呗不能用,求解", "花呗最高分多少期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.094, 0.169, 0.14, 0.152, 0.068, 0.127, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款短信通知", "pos": ["借呗还款有短信通知吗", "借呗借钱短信通知"], "neg": ["花呗年货购物津贴有什么用", "蚂蚁借呗怎么解呀", "花呗没绑定银行卡能用吗", "借呗想分期还款", "花呗分期提前还款扣手续费吗", "我的支付宝账户被提示司法冻结!我的借呗还款怎么还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.101, 0.191, 0.081, 0.107, 0.131], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无法确认借款", "pos": ["借呗输入借款金额无法点击确认", "借呗怎么无法确认借款了"], "neg": ["花呗逾期怎么办,会影响信用吗", "蚂蚁借呗怎么绑定网商银行", "手机号被停了 然后支付宝号登不上了 花呗逾期未还", "为什么我用了花呗分期后就一堆贷款的打电话加微信", "为什么每次点开花呗都叫我绑定银行卡", "蚂蚁借呗借***万的手续费是多少", "蚂蚁借呗服务密码是什么", "花呗额度还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.086, 0.158, 0.102, 0.175, 0.15, 0.161, 0.057, 0.061], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗额度恢复时间", "pos": ["我的借呗额度什么时候恢复,我希望快点", "我问借呗提前完款,什么时候恢复额度"], "neg": ["我的蚂蚁花呗怎么可用余额变少了", "把网商贷变成借呗", "蚂蚁借呗能不能***个月后再还款", "花呗消费了在哪里抽奖", "借呗总额度多少", "我想启用花呗,但是我的银行卡绑定的不是我本人的怎么办", "为什么花呗申请退款还没到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.066, 0.069, 0.147, 0.124, 0.149, 0.069, 0.051], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗无法付款", "pos": ["为什么我有花呗没办法付款", "为什么我这个花呗不能用"], "neg": ["为什么我的好花呗重复扣款", "怎么开通不了花呗了", "为什么蚂蚁借呗身份信息无法完善", "花呗返现是怎么", "花呗用多久能用借呗", "花呗一次分期后下个月还能分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.156, 0.191, 0.099, 0.196, 0.083, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗开通不用收费", "pos": ["开通花呗不使用需要缴费吗", "如果我开通花呗,确一直没有使用,会扣费用吗"], "neg": ["我的这个月花呗可以分期吗", "我花呗***元,十一月一号已经还了怎么还显示没还款", "蚂蚁借呗怎么解呀", "我不想用花呗,更不想体验"], "pos_scores": [1.0, 0.95], "neg_scores": [0.163, 0.072, 0.053, 0.181], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗停用原因", "pos": ["怎么花呗使用不了", "花呗为什么停用"], "neg": ["我花呗怎么还款了,不开通", "为什么为花呗主页看不到我的账单", "信用卡没有限额,而花呗限额***", "有过逾期记录会影响借呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.117, 0.072, 0.173, 0.172], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭花呗服务", "pos": ["我不想用花呗,终止合同怎么操作", "花呗我不想使用"], "neg": ["为什么y每个人的花呗可用额度不一样", "问用了花呗付款,下个月是自动扣款吗", "花呗额度掉了怎么回事", "怎样弄,借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.066, 0.114, 0.077], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗需要支付利息吗", "pos": ["花呗要利息口马", "花呗他需要付利息吗"], "neg": ["我不想用了怎样卸载花呗", "花呗交了话费怎么还没到账", "蚂蚁花呗分期如果提前还还需要手续费吗", "花呗能转到另一个支付宝吗", "这个账号花呗已经关闭,为什么另一个账号不能开通花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.125, 0.165, 0.132, 0.113, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗可以申请临时额度吗", "pos": ["花呗可以提临额吗", "我的花呗额度可以申请调高一点吗"], "neg": ["要开通信用卡与花呗收款功能,但提示说我的账户不符合开通条件", "花呗额度怎么评估", "花呗帮还对以后提额有影响吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.054, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款后为什么没有恢复额度", "pos": ["蚂蚁借呗按时还款了。但没有余额", "为什么借呗还款了没有可用余额"], "neg": ["企业支付宝可以使用花呗吗", "我的借呗什么时候降的额度", "现在这个账号的花呗怎么开通", "借呗到帐了,怎么钱还不到了", "蚂蚁借呗如果借***一个月的利息多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.125, 0.095, 0.183, 0.071, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗和支付宝账号是否通用", "pos": ["花呗号和支付宝账号不一样", "支付宝花呗和淘宝花呗不是一个账号"], "neg": ["开通了另一个支付宝怎么我的花呗开通不了", "再次解冻花呗", "我想取消花呗怎样才能取消", "我的花呗之前绑定了一个邮箱密码和账号都忘记了怎么办", "花呗每月最低还款额度怎么看", "花呗的金额可以冻结吗", "顾客花呗付款卖家收不到钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.142, 0.165, 0.097, 0.118, 0.101, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款日期可以修改吗", "pos": ["蚂蚁花呗是可以自己设置还款时间吗", "蚂蚁花呗的还款日期可以调整么"], "neg": ["之前用完花呗有一次抽奖的机会现在还有吗", "借呗提前全额还款了,怎么没有恢复额度", "花呗每天只能用一次吗", "交电费是不是欠费了才可以用花呗", "可以提前还完借呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.117, 0.172, 0.071, 0.058, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期还款会影响信用吗", "pos": ["花呗晚两天还款,会对信用有影响吗", "花呗还了最低后 剩下的不及时还 会影响个人信用吗"], "neg": ["卖家说花呗分期付款,是按每月付款吗", "为何我开不了蚂蚁借呗", "花呗负额度分期还款后还能用吗", "能不能提前还花呗分期", "***号退款金额退到花呗里,可是已经还完款了", "借呗怎么不自动扣款的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.056, 0.143, 0.142, 0.154, 0.198, 0.074], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "商家支持花呗但我无法使用", "pos": ["为什么商家开通了花呗服务 我却无法用花呗付款", "为什么商家开启了花呗支付,到我这里不能用花呗"], "neg": ["花呗还有余额,为什么刷不出来", "花呗分期可以分别结清吗", "淘宝蚂蚁花呗怎么开通", "蚂蚁借呗额度消失两年", "我指的是蚂蚁借呗怎么还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.117, 0.134, 0.123, 0.161, 0.163], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款后仍收到催款通知", "pos": ["借呗还款了还在催我还款是怎么回事", "借呗昨天还过了,怎么还显示还款"], "neg": ["没钱还给花呗", "网商银行也不能用", "余额宝里有钱直接还蚂蚁借呗可以吗", "蚂蚁花呗分期后如果提前换该期的欠款,那还会扣利息吗", "为什么不能申请借呗", "怎么查看对方有没有花呗开通", "如果不使用手机需要关闭花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.198, 0.108, 0.092, 0.086, 0.197, 0.099, 0.095], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账单和还款金额不一致", "pos": ["我用花呗,账单与实际还款不一样", "我已还款的数额与花呗显示的数额不一致"], "neg": ["花呗可以微信发红包吗", "点解在我支付宝扣处了,花呗又扣我的支付宝", "美团订单花呗扣款了 但是美团显示未付款", "我之前用了花呗买手机,按时还完了,为什么这一年都没提升", "花呗邮箱如何登录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.175, 0.194, 0.086, 0.067], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期提前还款是否收费", "pos": ["假如现在手头有点紧,***号还不了花呗,如果选择分期这个月底前一次还清还收取额外费用吗", "上个月花呗分期,这个月全部还清应该不收取任何费用吧"], "neg": ["我花呗多还了两次", "为什么我开通不了蚂蚁借呗,花呗也是***提不了额度", "不是花呗是收钱码", "现在支付宝没有借呗了吗", "***借呗一个月忘还推迟了二天利息要多少", "为什么我花呗退回了金额,怎么未出账的金额没有变", "为什么我的芝麻借呗不开", "花呗每月充值不了话费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.153, 0.127, 0.094, 0.147, 0.123, 0.059, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗突然被关闭原因", "pos": ["为什么我的蚂蚁借呗突然封了", "为啥借呗会被封了"], "neg": ["花呗,买车险", "蚂蚁借呗自动扣款默认开通在哪里", "我的花呗无法使用了怎么办", "蚂蚁花呗额度冻结", "我花呗开通了吗", "天猫怎么不能用花呗", "想开通商家版花呗", "蚂蚁借呗怎么知道是否还清了最低"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.166, 0.169, 0.199, 0.083, 0.114, 0.199, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "如何开通借呗", "pos": ["借呗开通在哪里开通", "借呗能不能给我开通"], "neg": ["花呗充错号码了", "花呗逾期了还完了不能用了", "花呗的补充学历填错了怎么办", "借呗审核未通过怎么解决", "我花呗还清款", "蚂蚁借呗可以取现金吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.136, 0.121, 0.101, 0.117, 0.14, 0.118], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗无故扣款原因", "pos": ["未买东西怎么自动从花呗里扣款", "我没用花呗买东西咋扣费"], "neg": ["借呗 银行卡错了怎么办", "这个没有付款,但花呗叫我还款", "密付可以使用花呗支付吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.146, 0.111], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款收不到验证码怎么办", "pos": ["还款花呗,手机停机收不到短信", "花呗还款要手机验证码,我手机号码没用,收不到怎么办"], "neg": ["花呗怎么在淘宝使用", "假如说我的花呗逾期一天,会怎样,有没有负面影响", "借呗还了之后能再借吗", "到期蚂蚁花呗会自动扣款吗", "借呗可以越期吗", "蚂蚁借呗逾期***月有多大影响", "为什么我的花呗还了,然后我我买东西要申请退款,钱原来退回花呗,有什么看不到***块多的,那哪里去了", "花呗没有临时额度为什么可以为负"], "pos_scores": [1.0, 0.95], "neg_scores": [0.059, 0.054, 0.116, 0.069, 0.148, 0.074, 0.077, 0.189], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收款码失效", "pos": ["蚂蚁花呗收款码不能收钱了", "怎么现在不能花呗收款"], "neg": ["花呗冻结了怎么办", "借呗说已经到账。银行卡并没有收到短信", "***月*** 有没有临时花呗额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.096, 0.072, 0.071], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期利息计算方式", "pos": ["花呗逾期一两天会影响什么?利息是多少", "花呗逾期***天利息怎么算"], "neg": ["花呗多还钱了,怎么搞的", "借呗可以推迟一期还款吗", "进入花呗总是显示服务繁忙怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.173, 0.058, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗借款期限选项", "pos": ["借呗借款期限只有***个月和***个月", "借呗怎么只有六个月十二个月的周期选项"], "neg": ["为什么确认收货了 花呗没账单", "蚂蚁借呗逾期几天罚息", "我的是花呗,是极速退款的", "如何改蚂蚁花呗的还款日期", "借呗完善不了信息", "小黄单车押金用花呗支付 退款到哪里"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.073, 0.101, 0.167, 0.093, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗借呗消失原因", "pos": ["花呗借呗都不见了", "我的蚂蚁花呗和借呗怎么不见"], "neg": ["商家收款 支付方用花呗支付 现在没收到款", "我用了花呗付款怎么换钱", "我为什么没有花呗呀", "蚂蚁借呗分期了之后还能分期吗", "借呗借钱了,影响以后在银行贷款吗", "成年当天开通花呗可以吗", "借呗礼包在那里看", "我想把花呗里面的钱,转到支付宝里面能吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.099, 0.088, 0.19, 0.196, 0.063, 0.168, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭花呗后能否重新开通", "pos": ["花呗关了还能开吗", "花呗关闭之后。还可以再次开通吗"], "neg": ["为什么我的借呗借不借了", "怎么样申请花呗的临时额度", "花呗还款金额错误怎么办", "这一次借呗的红包有效期", "一句话,不耽误你时间,我取消这次交易,永远关闭花呗", "我的花呗逾期了,我现在还款了,花呗还可以使用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.084, 0.125, 0.118, 0.178, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无法开通原因", "pos": ["蚂蚁借唄怎么补补不能开通", "蚂蚁借呗 我为什么开通不了"], "neg": ["是不是我这种手机号不是自己的不能开通花呗", "带有花唄标示的是什么样的", "花呗分期有哪几种", "借呗全款还", "花呗是怎么付款的,是不是直接从支付宝付钱", "蚂蚁借呗开启和信用分数有关系吗", "花呗十号还款怎么十号没过就显示我逾期了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.092, 0.058, 0.116, 0.131, 0.099, 0.072, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账单周期说明", "pos": ["花呗记账周期", "花呗账单是***号到***号为一周期吗"], "neg": ["芝麻信用和和花呗怎么重新授权", "花呗收款吗开通之后什么时候能支持花呗信用卡收款", "借呗手动还款不是还分期的", "十二点以过为什么借呗还没扣款,银行卡己开通网银"], "pos_scores": [1.0, 0.95], "neg_scores": [0.183, 0.177, 0.071, 0.111], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查询花呗开通状态", "pos": ["我的账号花呗开通了吗", "现在我的账户没有开通花呗吗"], "neg": ["蚂蚁借呗超过还款,日中午***点,有影响吗", "花呗上个月有分期,这个月要取消分期付款,还全款可以吗", "现在花呗提示我不够权限,下个月就又可以开通了,这些反复的这怎么样", "身份证过期对借呗开通有影响吗", "app怎么退到花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.067, 0.17, 0.061, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗能否用于外卖支付", "pos": ["花呗可以在支付宝里点外卖嘛", "外卖商家可以使用花呗嘛"], "neg": ["花呗临时额度的有效期在哪", "花呗可用数目是怎么算的", "借呗利息分期是算一共借多少吗", "为什么我的借呗不能循环使用", "蚂蚁借呗没通过啥意思"], "pos_scores": [1.0, 0.95], "neg_scores": [0.156, 0.192, 0.089, 0.058, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "网商贷切换回借呗方法", "pos": ["不想要网商贷怎么变回借呗", "怎样把网商贷换回借呗"], "neg": ["借呗还款日都没有短信通知", "怎么改借呗借款限额", "花呗没有花过***多这么多钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.108, 0.135, 0.126], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗延期还款政策", "pos": ["蚂蚁借呗能延长还款时间吗", "借呗可以延后还款吗"], "neg": ["用花呗买的东西没到没账单怎么还款", "为什么花呗还不能用", "花呗***号可以还款", "花呗每月按时还款,收息吗", "借呗可以申请吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.191, 0.059, 0.131, 0.148, 0.124], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款日次日还款影响", "pos": ["我的蚂蚁借呗可以明天还吗 今天是还款日", "借呗还款日,今天到期,我明天还,会有影响吗"], "neg": ["我的花呗现在不可以用了,那么以后还可以用吗", "蚂蚁借呗首月免息吗", "我原来的花呗怎么登不上去", "我蚂蚁花呗的借呗都不可以用", "花呗刷脸通不过"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.135, 0.162, 0.186, 0.077], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗消费还款日期查询", "pos": ["一号花的花呗什么时候还款", "本月一号使用花呗,次月几号还款"], "neg": ["怎么更换蚂蚁借呗银行卡", "为借呗申诉", "花呗临时怎么提额", "我的借呗什么时候可以恢复", "我借呗借了五万可以分期付款吗", "昨天在借呗借了一百元!今天怎么还没到卡里", "添加店员可以花呗收款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.106, 0.178, 0.08, 0.189, 0.173, 0.08, 0.052], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "美团无法使用花呗支付", "pos": ["花呗用不了美团是怎么回事", "我美团点好餐,花唄怎么支付不了"], "neg": ["不记得是不是花呗付的款", "我是商家想查询花呗服务是否开通", "为什么我的花呗额度是***,但是刷不了", "经常使用支付宝,可是为什么开通不了花呗", "蚂蚁借呗还款日期能改吗", "借呗还款没到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.123, 0.133, 0.099, 0.17, 0.134], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查看花呗还款明细", "pos": ["如何查看花呗还款的具体", "怎么查看还清花呗"], "neg": ["花呗可以先下使用吗", "为什么我的借呗利息是万五呀,我朋友的就是万三", "我想开通下花呗给申请下呗", "借呗什么时候恢复评估", "花呗分期还款可以提前还", "合同违约花呗", "蚂蚁借呗的钱能买房"], "pos_scores": [1.0, 0.95], "neg_scores": [0.139, 0.106, 0.182, 0.179, 0.052, 0.179, 0.132], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "隐藏借呗入口方法", "pos": ["可以把借呗这个标栏隐藏起来,别人看不见", "借呗工具栏能隐藏起来吗"], "neg": ["提前还款还是算上分期付款那些账单的利息吗", "为什么我的花呗额度能为负", "花呗还款分期了可以改吗", "为什么花呗交电费了马上就退款", "蚂蚁借呗还款日期是怎么计算的", "借呗今日借款哪天还"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.153, 0.094, 0.082, 0.139, 0.077], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗分期还款金额计算", "pos": ["蚂蚁借呗借***一个月还多少", "蚂蚁借呗借***,***个月还。一个月还多少"], "neg": ["为什么我的花呗当时借了我就还了 怎么现在还要我还", "蚂蚁借呗提示当前存在逾期", "为什么我的花呗可用度只有***", "今天借呗还款,怎么余额宝里没有扣款", "开通花呗显示手机验证未通过", "如何删除花呗里面以出账的记录", "支付宝小黄车押金用不了花呗付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.102, 0.188, 0.084, 0.119, 0.178, 0.125, 0.115], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "永久关闭借呗流程", "pos": ["我想永久关闭蚂蚁借呗", "永久关闭借呗"], "neg": ["完善蚂蚁借呗信息", "二唯码可以收款花呗吗", "借呗最多能上升到多少", "借呗可以借几期", "怎么可以再开蚂蚁借呗", "为什么商家支持花呗分期!我确不能分期", "花呗显示为负数,还可以使用么", "花呗红包用手机支付宝付款时为什么没有自动抵扣"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.069, 0.151, 0.179, 0.077, 0.135, 0.135, 0.053], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款操作指南", "pos": ["要还花呗的钱,怎么操作,昨晚用的", "我要花呗还款,怎么操作"], "neg": ["为什么我的花呗额度被调低了那么多", "借呗无法一次还清还能分期吗", "能查查我欠不欠花呗里面的钱吗", "花呗可以透支分期吗", "借呗每月几号评估一次额度", "怎么查看我的花呗账单", "花呗还款日期每个人都一样吗", "如何查询还款借呗帐单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.087, 0.067, 0.181, 0.063, 0.115, 0.126, 0.059, 0.095], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗关闭后退款处理", "pos": ["花呗里有退款,但是后来我关闭了,退款怎么办", "花呗已经关闭,如果有退款怎么办"], "neg": ["为啥我信用那么高,我的花呗和借呗那么点", "花呗怎么转移手机号", "之前不需要钱用就有借呗", "为什么花呗未还显示为***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.157, 0.16, 0.064, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗免息券使用规则", "pos": ["我有一个***天免息的借呗券", "借呗***天免息券"], "neg": ["我蚂蚁借呗按期还完之后额度怎么没有回复", "怎么找回另外一个有花呗的账号", "我的钱怎么去了花呗?怎么弄的", "我想知道花呗的未出账怎么查详单", "为什么在支付宝上设置了花呗在淘宝上用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.14, 0.159, 0.192, 0.154], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗缴纳电费限制", "pos": ["缴费电费不能用花呗", "花呗不可以付电费了吗"], "neg": ["花呗为什么没有体验额", "花呗账单金额是什么", "我的蚂蚁借呗利息去哪里看", "我的花呗分期付款的,,为什么不管来借呗", "我用花呗分期付了", "戒备怎么关闭"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.181, 0.087, 0.081, 0.122, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "更改花呗绑定支付宝账号", "pos": ["花呗绑定的支付宝能改吗", "花呗换绑定的支付宝"], "neg": ["显示,钱已经退到花呗里了", "花呗消费的账单被删掉了,怎么找回来", "借呗要时逾期一天怎么办", "借呗额度还清整笔之后恢复吗", "花呗不能使用美团了么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.098, 0.119, 0.076, 0.068, 0.085], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗充值加油卡支持", "pos": ["花呗充加油卡吗", "可以使用蚂蚁花呗充值加油卡吗"], "neg": ["为什么花呗还款日还没到 就显示我逾期了", "怎么找不到花呗签约", "电脑端怎么设置花呗还款顺序"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.082, 0.184], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款去向查询", "pos": ["拿花呗支付的退款后钱是退到哪儿", "花呗买的东西,退款就退到哪了"], "neg": ["为什么我点击花呗应还款这里没有反应,点击可用额度有反应", "为什么显示我没有花呗", "为什么我不可以办花呗,什么原因", "我的蚂蚁借呗刚才领了一个什么礼包。额度就涨到***了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.059, 0.067, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "芝麻信用分不提升原因", "pos": ["芝麻信用分怎么提升不了", "为啥不加芝麻分了"], "neg": ["换新手机号码怎么授权花呗", "花呗开通不用有额外费用吗", "我在淘宝上买东西用花呗钱了然后我申请退款是不是退花呗里面", "我的花呗还能开通了么", "信用卡花呗收取服务费是什么", "没有人工客服吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.139, 0.05, 0.195, 0.098, 0.131, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度与芝麻分关系", "pos": ["花呗额度与芝麻分有关系吗", "信用分跟花呗有关吗"], "neg": ["花呗如何一次还清分期付款", "我要开通花呗 买手机", "我欠款已经还清了,淘宝上退货的款直接退到花呗,这个款在哪看到"], "pos_scores": [1.0, 0.95], "neg_scores": [0.111, 0.178, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗不显示解决方法", "pos": ["借呗也不显示了", "我为什么不显示蚂蚁借呗"], "neg": ["花呗自主还款最晚几点", "点了分***期 钱从花呗一次性扣的", "我从借呗借的钱,光显示放款中,什么时候能到账", "花呗分期能修改呗", "帮我查询一下 花呗什么时候可以用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.194, 0.124, 0.087, 0.131], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无法提现原因", "pos": ["嗨为什么不借呗提现了", "为什么我支付宝借呗不能取款了"], "neg": ["花呗申请后一次性还清", "缴费电费不能用花呗", "还有花呗不涨", "怎么看以前的借呗记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.198, 0.181, 0.15, 0.145], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗转账到其他账户", "pos": ["我想把花呗转到这个支付宝账号上,怎么操作", "花呗怎么转到另个账户"], "neg": ["另一个不常用的支付宝以前开通了花呗,我刚刚关闭了,我想开通在现在用的多的支付宝号上", "花呗未出账单不可以分期", "花呗关闭后,再开启里面的余额还是一样吗", "我用花呗可以用多少钱", "为什么我买的时候在余额里扣钱了,现在花呗还让我还钱", "如何主动关闭蚂蚁借呗", "如何清除所有花呗交易记录", "星力超市可以用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.097, 0.114, 0.053, 0.148, 0.089, 0.09, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "退款退回花呗原因", "pos": ["为什么把钱退到花呗", "给我退到花呗上了"], "neg": ["花呗额度不是最低***吗", "我强,,蚂蚁借呗和蚂蚁微贷是不是一个产品", "花呗手机号多少", "临时额度用完后会不会对以后使用花呗有影响", "多少额度才能开蚂蚁借呗", "借呗额度提升流程", "有借呗入口", "花呗分期购提前还款收不收手续费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.162, 0.148, 0.147, 0.071, 0.068, 0.068, 0.19, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "登录指定支付宝使用花呗", "pos": ["您可登录支付宝号使用花呗", "您可登录某支付宝使用花呗"], "neg": ["我的花呗到期后会从余额宝自动扣款吗", "怎么样才能得到花呗临时额度", "比如我今天月末最后一天用的花呗下个月月初才发货那我的花呗是下个月还还是下下个月还", "我花呗还清款", "花呗还款可以选择用银行卡吗", "花呗买单限额"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.095, 0.191, 0.182, 0.169, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "快速提升花呗额度方法", "pos": ["怎么快速提高花呗额度", "蚂蚁花呗可用额度怎么样才能提高"], "neg": ["收款花呗有没有限制", "我的借呗还款日是***月***日", "我的花唄額度", "借呗分期还可以更改时间么", "花呗钱已还清,退回花呗的钱为什么不还给我", "花呗授权在哪"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.158, 0.119, 0.125, 0.064, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭网商贷流程", "pos": ["我把网商贷取消", "我把网商关掉了"], "neg": ["花呗临时额度为什么用", "为什么我现在的支付宝用花呗要用以前的号码", "显示,钱已经退到花呗里了", "借呗借来的钱可以还花呗吗", "为什么同一个手机号登陆的花呗还显示要手机号登陆", "开通了花呗,收款收不了款是什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.15, 0.058, 0.14, 0.052, 0.165, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗最新活动查询", "pos": ["今天借呗有活动", "借呗里面的活动"], "neg": ["我***点左右在借呗借的钱怎么现在还没到账", "账单都还清楚了!怎么花呗不能用了", "花呗抽了一个天猫券 不知道从哪里再找出来", "商店或超市购物可以用花呗支付吗", "我现在怎么开通花呗", "我的花呗是昨天还的,今天可以使用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.094, 0.182, 0.172, 0.18, 0.139, 0.117], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "无法登录花呗账户解决", "pos": ["您说的是您登录不上您开通花呗的 支付宝账户", "您说的是您登录不上您开通花呗的"], "neg": ["手机丢失导致借呗逾期了。现在怎么办", "借呗最低还款还了吗", "借呗里面还有三百", "花呗分期免息卷在哪里兑换", "为什么花呗不能滴滴打车和美团了", "蚂蚁借呗进不去怎么还", "借呗这么变成网商货了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.172, 0.188, 0.111, 0.068, 0.163, 0.16, 0.085], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗提前还款操作", "pos": ["蚂蚁借呗怎么每期提前还", "借呗能提前还清吗"], "neg": ["所以请把花呗关掉", "我这个是花呗红包", "蚂蚁借呗三个月了不提额", "为什么我绑定银行卡了 不能使用蚂蚁花呗", "花呗有额度但是淘宝显示花呗分期暂不可用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.082, 0.103, 0.095, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗利息计算公式", "pos": ["蚂蚁借呗利率怎么计算", "详细说算下蚂蚁借呗利息"], "neg": ["我的蚂蚁借呗还没通过", "我为什么不可以蚂蚁借呗", "蚂蚁花呗有问题,我明明已经还款了,为什么还显示我没有还", "为什么总让我用另一个支付宝用花呗", "我的借呗可以通过吗", "为什么我花呗只能扫***", "花呗提示我可以提额是怎么回事", "为什么我的花呗的额度半年都没有升过了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.124, 0.067, 0.151, 0.089, 0.11, 0.184, 0.179, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "提高借呗额度方法", "pos": ["我怎么才能提高借呗金额", "借呗额度怎样提升"], "neg": ["花呗分期和不分期不一样吗", "我是用建行卡付的款,他却退到了花呗上", "刚才一家店里吃了拉面 我用花呗付过款了 为什么人家没收到钱", "怎么再次用借呗", "花呗最低还款的利息咋算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.062, 0.123, 0.092, 0.18, 0.099], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭新开通的花呗", "pos": ["怎么关闭刚开的花呗", "花呗怎么关掉我不用花呗了"], "neg": ["我的蚂蚁借呗最迟到几号还款", "银行留的手机号已经改了,但是为什么支付宝借呗的验证码还是原来的手机号", "我的花呗还款日期可以改为***号吗", "我用我的微信付款了咋就成了还花呗的钱那我的钱去那里了", "我的花呗之前逾期了,现在什么时候能用", "为什么我蚂蚁借呗给我取消啦", "我购物刷的都是银行卡,为什么花呗还有账单?我购物刷的都是银行卡,为什么花呗还有账单?我购物刷的都是银行卡,为什么花呗还有账单?我购物刷的都是银行卡,为什么花呗还有账单", "调花呗快捷支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.198, 0.194, 0.149, 0.068, 0.107, 0.054, 0.071], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "重新开通花呗流程", "pos": ["开通下花呗", "从新开通蚂蚁花呗"], "neg": ["现在蚂蚁借呗还可以开通吗", "刚刚提示了关于借呗提额了", "花呗是到期自动从银行卡或余额宝扣款吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.072, 0.079], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗开通条件不满足", "pos": ["为什么说不满足花呗开通条件", "开花呗,显示不符合条件"], "neg": ["花呗分期低一点", "请把蚂蚁借呗发过来", "借呗能分期还嘛", "蚂蚁借呗可以申请", "花呗能帮我开一下吗", "花呗已经激活"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.056, 0.155, 0.161, 0.078, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "符合条件但无法开通借呗", "pos": ["借呗开通的条件都满足了 为啥开不了", "为什么我不能开通借呗呀"], "neg": ["网贷商和蚂蚁借呗", "***元花呗分***期还款每个月还多少钱", "借呗还款日能否更改", "我在蚂蚁借呗上借的钱,然后转到银行卡里了,但是我的卡挂失了,我的钱去哪了", "以前能用花呗交电费,现在不能了", "蚂蚁借呗多久放款", "除了蚂蚁借呗,还有可以借钱", "为什么我的花呗显示暂时无法使用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.086, 0.159, 0.094, 0.148, 0.065, 0.081, 0.114], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗和花呗区别", "pos": ["借呗和花呗有冲突吗", "借呗和花呗都一样的吗"], "neg": ["积分要怎样才可以兑换蚂蚁花呗额度", "额度怎么会降定借呗", "修改花呗手机怎么吗", "花呗如何返现花呗如何返现花呗如何返现", "借呗什么时候可以分期还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.19, 0.19, 0.106, 0.129], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "网商贷与借呗关系", "pos": ["网商贷和借呗是一个平台吗", "网商银行和借呗是一起的吗"], "neg": ["为啥不能花呗账单分期", "我可以用微信还花呗钱嘛", "花呗消费会有什么提醒", "能不能不要网商贷要借呗", "花呗分期后一次还清需要手续费么", "花呗分期可以提前还款完吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.086, 0.083, 0.174, 0.107, 0.098, 0.113], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗支付抽奖活动", "pos": ["花呗支付成功,怎么抽奖", "花呗还能抽奖"], "neg": ["商家支持花呗我不能用花呗", "然后还款的时候我忘了是用花呗退款的", "花呗临时额度怎么是自动领取的", "为什么我花呗用不了了", "蚂蚁借呗***月份我还了吗", "为什么我的借呗申请人太多了,叫明天在来"], "pos_scores": [1.0, 0.95], "neg_scores": [0.176, 0.17, 0.185, 0.197, 0.087, 0.084], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "滴滴支持花呗支付吗", "pos": ["花呗能否滴滴", "滴滴花呗支付"], "neg": ["借呗中可以分十二期还吗", "淘宝退款退回花呗钱怎么看不到", "收钱码开通信用卡花呗付款", "但我付款时,却让我开通花呗,否则付不了款", "借呗的借款记录怎么删除"], "pos_scores": [1.0, 0.95], "neg_scores": [0.191, 0.102, 0.142, 0.088, 0.064], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "退款转入花呗原因", "pos": ["为什么我的退款***元会转入花呗", "小蜜,退款怎么会转花呗里"], "neg": ["我没有逾期还款,也一直用花呗,信用也很好,为什么额度那么少", "我是还蚂蚁借呗", "怎么我还不上借呗呀", "收款码收花呗", "我再办一张银行卡可以开通花呗吗", "借呗逾期***天有事情吗", "借呗提现到卡怎么查"], "pos_scores": [1.0, 0.95], "neg_scores": [0.127, 0.097, 0.058, 0.126, 0.146, 0.067, 0.1], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "提升花呗额度方法", "pos": ["我的花呗额度能调高吗", "花呗额度怎么提升"], "neg": ["昨天晚上还花呗 银行卡已经扣款 但是花呗没有显示成功", "花呗可不可以还款", "我的蚂蚁花呗已经还最低还款 为什么还自动扣款", "花呗可以调临时额度吗", "为什么花呗不能滴滴打车和美团了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.165, 0.094, 0.127, 0.132], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查询借呗账单", "pos": ["查一下蚂蚁借呗的账单", "我的借呗账单在哪里找"], "neg": ["为什么我的收钱码不能用花呗", "其他付款方式怎么转到花呗", "预期的钱还完了还能使用借呗吗", "花呗分期如何取消", "开通花呗后不使用会产生费用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.138, 0.138, 0.132, 0.119, 0.186], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "个人花呗收款资格", "pos": ["我现在可以向客户花呗收款吗", "个体可以花呗收款吗"], "neg": ["***月一号用的花呗,什么时候还", "如何完善借呗额度申请资料", "意思说支付宝里的花呗,我不可以用现在的支付宝还款吗", "花呗消费返还是什么", "使用购物津贴,还能用花呗吗", "全款还了蚂蚁借呗能恢复吗", "有了网商贷还能有借呗么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.114, 0.199, 0.188, 0.187, 0.087, 0.183], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "恢复借呗额度方法", "pos": ["我的借呗什么时候恢复额度", "怎么样才能恢复借呗额度"], "neg": ["花呗冻结什么时候才能使用", "蚂蚁借呗还款日主动还款怎么还", "不是本人身份证可以开通花呗嘛", "我,修改借呗还款日", "花呗能设置还款提醒吗", "我能不能撤消开通花呗", "广西农村信用卡可以开通花呗吗", "花呗可以***号还钱吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.132, 0.131, 0.197, 0.157, 0.127, 0.095, 0.162], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗短期借款选项", "pos": ["借呗没出现短期借款", "借呗短期在哪选"], "neg": ["小蓝单车以前用花呗交的押金,花呗还了,退款退回哪", "花呗,如果到期无法还款可以延期还款吗", "花呗返现***元"], "pos_scores": [1.0, 0.95], "neg_scores": [0.109, 0.08, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "有贷款资格但无法开通借呗", "pos": ["为什么我有现金贷款资格 借呗开通不了", "为什么我开不了借呗呀"], "neg": ["我以开通花呗收款,为什么客户还是用花呗付不了款", "借呗上月***日借的,什么时候还款", "借呗为什么不能借一年了", "已经关闭了花呗怎么另外一个账户还让我到这个账户来看", "花呗可以接到哪些银行卡机"], "pos_scores": [1.0, 0.95], "neg_scores": [0.128, 0.107, 0.152, 0.148, 0.156], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账单不能分期原因", "pos": ["花呗账单为什么不能分期还款", "花呗账单不可分期了吗"], "neg": ["登录了支付宝,还是登录不了花呗", "我的花呗为什么没有最低还款", "花呗最低还款剩余部分可以提前结清吗", "我花呗里的钱充话费充错了能退回吗", "怎么开了借呗", "实名认证过期了蚂蚁借呗就不能用了吗", "如何申请能用花呗的收钱码"], "pos_scores": [1.0, 0.95], "neg_scores": [0.174, 0.148, 0.05, 0.089, 0.175, 0.076, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无法借款原因", "pos": ["借呗***为什么不可以借", "天天用支付宝 为什么借呗用不了"], "neg": ["什么时候给我恢复花呗嘛,我真的很需要这个功能", "我明明付钱了,为什么确认收货后花呗上还要多次付款", "花呗,余额可以用吗", "我的借呗怎么最低分期***个月", "借呗怎么改自动还款账号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.163, 0.16, 0.169, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗额度评估周期", "pos": ["蚂蚁借呗多久重新可以评估额度", "蚂蚁借呗能提额多久能评估"], "neg": ["也就是我不用花呗了", "支付宝交电费怎么不支持花呗了", "花呗添加银行卡"], "pos_scores": [1.0, 0.95], "neg_scores": [0.055, 0.077, 0.143], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款日说明", "pos": ["花呗还款日为***号或***号,有可能是***号,请以页面提示为准", "花呗还款日 一般是每月的***号或者***号哈 具体"], "neg": ["花呗换前结清", "花呗用不了,多久能使用", "为什么我已经推过了蚂蚁花呗上面还是有金额产生", "蚂蚁借呗没有去还款", "花呗密码和支付密码能不一样吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.064, 0.188, 0.19, 0.082, 0.191], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款后仍需还款", "pos": ["都已经退款了,为什么花呗上还是有这个要钱要还", "已退款到花呗,怎么每个月还要还"], "neg": ["花呗支付快,还是余额宝支付快", "花呗的钱怎么还不上去", "为什么我的支付宝账号没法开通花呗收钱", "蚂蚁花呗不是免费体验的吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.109, 0.115, 0.136, 0.185], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账单金额不符", "pos": ["为什么我花呗数对不上", "为什么我的花呗三月有***要还"], "neg": ["如何查花呗账单看还款情况", "我可以使用支付宝花呗吗", "使用花呗分期付款,退款后依然需要付手续费", "借呗最最低还多少", "如果蚂蚁花呗脱期了分期还有用吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.131, 0.063, 0.117, 0.165, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期额度不足", "pos": ["为什么我的花呗分期额度显示余额不足", "花呗分期总是提示额度不足是什么意思"], "neg": ["为什么付款了之后花呗还要收我的钱", "花呗分期购物的花呗额度需要大于物品金额么", "已经付款到花呗,为什么还有", "用花呗充值苹果商店", "商家花呗额度怎么提高", "花呗做了分期可以提前还吗", "怎么显示不了借呗金额了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.072, 0.16, 0.06, 0.121, 0.094, 0.101, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗放款时间", "pos": ["借呗的钱什么时候到", "为什么在借呗借钱要那么久还不到账"], "neg": ["花呗帮还,我能有积分吗", "如果上次花呗一次性还完,还可以在分期吗", "借呗要满多少分可以借"], "pos_scores": [1.0, 0.95], "neg_scores": [0.108, 0.148, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款提醒设置", "pos": ["借呗有还款信息提醒吗", "借呗还款不提醒吗"], "neg": ["花呗交电费付不了款", "花呗的钱可以还款", "退款了花呗还款不变", "怎么突然花呗少了这么多"], "pos_scores": [1.0, 0.95], "neg_scores": [0.194, 0.08, 0.2, 0.15], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗借款期限变更", "pos": ["借呗的期限怎么变短了", "借呗可以借***期,现在怎么只可以借***期了"], "neg": ["花呗逾期一天咋办", "花呗有支付宝有分期付款的车吗", "花呗每日限额使用么", "我是商家能能用花呗收款吗", "花呗怎么能调额", "怎样用支付宝花呗买淘宝东西", "借呗还款 我可以一起还吗", "卖东西花呗图标是什么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.084, 0.155, 0.179, 0.1, 0.147, 0.111, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "降低借呗利息方法", "pos": ["蚂蚁借呗怎么把利息降低", "借呗怎么降低利息"], "neg": ["几号自动扣除花呗", "蚂蚁借呗也能提前还款", "我的花呗怎么降低额度了", "为什么我的花呗不能用了?还有借呗", "花呗用户立减***元", "也弄删除花呗账单", "我的蚂蚁借呗现在能不能提额", "蚂蚁花呗可以找朋友代付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.123, 0.103, 0.129, 0.172, 0.119, 0.063, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗分期选项消失", "pos": ["为什么借呗不能十二个月分期还款了", "蚂蚁借呗***期还款怎么不能选了"], "neg": ["从借呗借款影响从银行贷款吗", "我刚刚是不是用花呗消费了", "我的花呗提醒还了,但又没扣银行卡钱", "使用借呗会影响银行贷款吗", "蚂蚁借呗的记录了", "我申请了退款怎么花呗额度没有恢复", "我的花呗怎么显示的以前的手机号", "我的花呗到期了没还进去怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.085, 0.092, 0.146, 0.18, 0.093, 0.151, 0.09, 0.157], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "解冻花呗账户方法", "pos": ["花呗冻结了怎么把花呗恢复", "花呗冻结了 给我开开"], "neg": ["蚂蚁花呗账户怎么解冻", "电费怎么不可以用花呗了", "而且现在淘宝买东西也不可以用花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.079, 0.178], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "删除借呗还款记录", "pos": ["如何删除支付宝借呗的还款记录", "我想你给我蚂蚁借呗记录删除"], "neg": ["花呗分期可以一次还请吗", "花呗已还款。怎么再次开通", "我的花呗提示还款失败,怎么回事", "我是我是花呗,我是我,是花呗花呗分期分期付款付款,商品", "商家开通花呗啦,我就买啦一件", "我咨询下 为什么我的退款金额已经退回到花呗了 花呗的还款金额还是有", "开通花呗后未使用会产生费用吗", "实|店可以用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.07, 0.121, 0.173, 0.092, 0.106, 0.187, 0.089, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗被关闭原因", "pos": ["我不明白为什么停了我的借呗", "我的借呗给我关了干嘛"], "neg": ["蚂蚁借呗蚂蚁花呗可以一块吗", "蚂蚁借呗还款算法", "花呗的账户出错,怎么改回来", "花呗已还款,仍提示我需要还款", "蚂蚁借呗能提***吗", "到时候我把蚂蚁借呗给关了", "花呗未出账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.073, 0.111, 0.148, 0.097, 0.107, 0.098, 0.145], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗功能异常", "pos": ["为什么我的借呗用不了", "为什么我的借呗不能接了"], "neg": ["借呗关闭了!还可以再开通吗", "蚂蚁花呗的二维码", "花呗逾期还了过了快一年还有影响吗", "借呗临时提额", "为什么我花呗和借呗那么少", "借呗最后最后还款日"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.093, 0.151, 0.145, 0.12, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "查询花呗开通状态", "pos": ["我的花呗开没开", "我这个支付宝还没开通花呗吗"], "neg": ["借呗审核不通过用什么原因", "花呗可以付款腾讯会员吗", "花呗提额为什么不能提了", "借呗上个月***号借的款为什么到下个月***号就得还款了,问题是还没有到一个月"], "pos_scores": [1.0, 0.95], "neg_scores": [0.118, 0.08, 0.19, 0.19], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗自动扣款失败", "pos": ["借呗 没有自动扣款", "为什么借呗到现在还没有扣"], "neg": ["借呗还款不对", "花呗还款,一定要设置银行卡吗", "我花呗预期还款都还完了怎么不能用了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.132, 0.092, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "新账号开通花呗限制", "pos": ["怎么可以再另一个账户开通花呗", "我的这个账号花呗关掉了,我得新账号怎么不能开通花呗"], "neg": ["花呗退了款 下个月还用还吗", "退回花呗的钱我有收到信息,但是退款金额里没有看到", "你能看出来我花呗有没有分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.177, 0.19, 0.108], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款日期查询", "pos": ["花呗还款详情时间", "花呗每月几时还款"], "neg": ["花呗还款找怎么分期", "借呗还款会通知么", "花呗的可用额度什么意思", "花呗dut额度怎么提高", "花呗分期,优惠卷", "花呗账户怎么切换手机号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.101, 0.073, 0.079, 0.108, 0.167, 0.116], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "商家花呗收款失败", "pos": ["安安,我申请的花呗收款怎么不能用", "我的商家收款二维码为什么花呗支付不了"], "neg": ["为什么借呗之前可以借***个月的,现在最多只有***个月的", "怎样申请花呗临时额度", "花呗分期有手续费没", "花呗支付怎么用", "花呗还款能不能不从余额里面扣"], "pos_scores": [1.0, 0.95], "neg_scores": [0.151, 0.149, 0.19, 0.117, 0.117], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗额度消失原因", "pos": ["我的蚂蚁借呗怎么沒有了,我要借钱", "我的借呗以前有两万多的,现在为什么没有了"], "neg": ["花呗付款的限制", "什么的商品可以用花呗", "花呗分期如何提前"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.15, 0.19], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期服务费发票", "pos": ["花呗分期服务开票", "花呗分期服务费开票类型"], "neg": ["花呗购买有短信吗", "商家花呗服务怎么开通", "提高芝麻借呗的额度", "为啥我用不了蚂蚁借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.072, 0.18, 0.127, 0.102], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗当月账单分期", "pos": ["借呗本月的款可以分期吗", "借呗分期可以"], "neg": ["借呗额度停用和身份证过期有关吗", "我看我的蚂蚁借呗可以转两万元到我卡上是真的吗", "花呗啥时候还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.069, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗支付电费限制", "pos": ["不可以用花呗支付电费了么", "用花呗可以叫电费嘛"], "neg": ["花呗到了时间就自动扣款吗", "我这个为什么开不成花呗", "为什么我美团都使用不到花呗", "我的花呗***月账单上个月就还了", "花呗逾期后还了款,无法使用,该如何才能恢复"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.101, 0.19, 0.121, 0.119], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗入口消失", "pos": ["找不到我的花呗了", "为什么我的蚂蚁花呗找不到了"], "neg": ["花呗怎么转到另一个账号", "为什么我花呗没有临时额度了", "使用花呗有哪些好处", "花呗还款用微信红包可以吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.177, 0.102, 0.136, 0.134], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "网商贷与借呗额度对比", "pos": ["蚂蚁借呗升级为网商贷后额度是一样的吗", "网商贷是否跟原有的借呗一样的额度吗"], "neg": ["连锁超市可以用花呗吗", "花呗账单怎样查看", "没有营业执照怎么办花呗二维码收款", "为什么我又有花呗账单", "为什么我的收钱码不能用花呗", "花呗怎么向支付宝付款", "我把钱放到了余额宝咋样才能转到蚂蚁借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.13, 0.095, 0.068, 0.09, 0.134, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期提前结清", "pos": ["花呗已经办理分期还款,现在怎么样才可以提前一次性能还清", "花呗分期提前还清 在哪里"], "neg": ["可以用支付宝花呗还信用卡吗", "我的蚂蚁借呗怎么才可以继续使用", "借呗余额减少了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.198, 0.133], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "淘宝花呗与支付宝花呗区别", "pos": ["我没有支付宝花呗,只有淘宝花呗", "淘宝上有花呗为什么支付宝上没有"], "neg": ["花呗的余付款怎么能还上", "支付宝花呗昨天交天然气费交易成功没到账", "卖家开通了花呗付款,买家为什么不能使用花呗付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.194, 0.083, 0.169], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "美团支持花呗支付吗", "pos": ["我用花呗在美团订餐饮", "花呗可以在美团用嘛"], "neg": ["我想取消花呗怎么取消的", "什么是花呗标志", "请帮我关闭***的花呗", "怎么取消借呗的自动还款", "dc花呗分期修改", "我用了花呗的钱怎么还上", "花呗用积分兑换的免费提现额度是什么意思", "我要停花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.185, 0.16, 0.097, 0.172, 0.13, 0.155, 0.085, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "取消借呗身份证验证", "pos": ["我刚才注册了蚂蚁借呗我想取消我扫的身份证", "蚂蚁借呗怎么取消验证身份证"], "neg": ["花呗分期,提前还清下个月,没找到提前还款", "借呗转账可以取消吗", "我的花呗账单已还清,为何还自动扣我钱", "我想还四月的花呗账单", "借呗现在的还款日是几号", "什么时候能设置花呗分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.116, 0.194, 0.144, 0.122, 0.072, 0.062], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗火车票退款差额", "pos": ["火车票花呗退款金额差额很大", "支付宝火车票退款与花呗不符"], "neg": ["花呗逾期就不能申请分期还款了吗", "邀请朋友开借呗", "借呗还款逾期一天怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.102, 0.145, 0.193], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗自动还款失败怎么办", "pos": ["【花呗】你的支付宝******@qq.com花呗***月账单自动还款失败,***元未还,请登录“支付宝-花呗”核对并还款,已还忽略", "【花呗】你的支付宝******@qq.com花呗***月账单自动还款失败,***元未还,请登录“支付宝-花呗”核对并还款,已还忽略"], "neg": ["怎么样能把借呗关闭", "花呗分期期间提升可用额度吗", "我没有银行卡怎么还花呗款", "我今天在淘宝下了个单。交易关闭但花呗又扣款,怎么办", "我用花呗买东西,但是我的钱被扣了,为什么花呗还要还钱", "花呗付款的押金退回后钱在哪"], "pos_scores": [1.0, 0.95], "neg_scores": [0.181, 0.181, 0.189, 0.061, 0.089, 0.107], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗刷脸验证不成功解决方法", "pos": ["花呗,刷脸没成功", "多次刷脸花呗不成功,咋办"], "neg": ["可以重新給我開通花唄嗎", "蚂蚁借呗始终是放款中", "借呗还能借吗", "信用卡和花呗收不了款了", "花呗集盒子活动在哪里报名", "花呗如何绑定到别外的手机号上"], "pos_scores": [1.0, 0.95], "neg_scores": [0.084, 0.146, 0.194, 0.173, 0.106, 0.113], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗登录问题解决方法", "pos": ["花呗登不进去了", "蚂蚁花呗,怎么登不进去"], "neg": ["为什么我的功能里没有花呗", "现在不能借***个月了吗蚂蚁借呗", "申请借呗是上传身份证,正反面吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.195, 0.167, 0.11], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度提升方法", "pos": ["花呗怎么提升到几千", "花呗怎么提额"], "neg": ["花呗开通条件有什么为什么说我不符合", "蚂蚁借呗提前还利息i", "蚂蚁借呗只能选择***个月的期限", "花呗分期可以一次多还嘛", "花呗已经还款 为什么发信息说没有还", "花呗临时的额度也能分期吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.07, 0.113, 0.095, 0.198, 0.115, 0.138], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "商家开通花呗的条件", "pos": ["我是商家,店铺目前无法开通花呗", "我的店铺刚开一个月 想开通花呗 但是好像不能开 对店铺有什么具体要求 才能开通花呗"], "neg": ["我就是蚂蚁花呗,这个需要利息不", "蚂蚁借呗分期了,提前还款一期到期了还会扣钱吗", "花呗按揭退款还会扣手续费嘛", "申请蚂蚁借呗的钱,显示已通过申请,为什么没有看到钱,在哪里看", "蚂蚁借呗降额度了", "借呗最长能分几期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.135, 0.124, 0.186, 0.119, 0.143, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "忘记花呗账号怎么找回", "pos": ["我之前的蚂蚁花呗忘记账号了", "花呗使用的账户忘记是什么号码"], "neg": ["我的退款到哪去了?我是提前还款花呗,但是收到货产生退货退款,卖家退回花呗了,那我的钱到哪去了", "我蚂蚁花呗还款晚了一个小时有影响吗", "借呗没有进度查询", "为什么花呗不能充值华费", "尾款支付使用花呗", "借呗会影响我房贷吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.147, 0.087, 0.175, 0.137, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "如何快速开通蚂蚁借呗", "pos": ["怎样可以快速使用蚂蚁借呗", "我想使用蚂蚁借吧"], "neg": ["花呗分期付款利息那么高", "我的借呗怎么还用了吗", "我借呗和花呗都不能用了", "不可以花呗什么情况"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.083, 0.056, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "未开通花呗但误用如何还款", "pos": ["我还没有开通花呗但是淘宝付款使用了要怎么还款", "我并没有开通花呗,付款按错用了***次,怎么样还款"], "neg": ["蚂蚁花呗双十二有临时额度吗", "我记得蚂蚁借呗可以分***个月还款的", "我在淘宝消费 花呗的上限会涨吗", "花呗扣款的具体时间是什么时候"], "pos_scores": [1.0, 0.95], "neg_scores": [0.178, 0.113, 0.17, 0.099], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "如何关闭花呗", "pos": ["我得花呗怎么停用了", "怎么样把花呗关了"], "neg": ["花呗总欠款怎么看", "花呗为啥不能开", "花呗分期付款可以一起还吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.16, 0.175, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期后能否一次性还清", "pos": ["开启花呗账单分期后可以一次还清吗", "花呗分期了以后还可以一次性还上么"], "neg": ["充话费为什么现在不能用花呗了", "借呗还款余额中有钱为什么从卡里扣金额", "为什么付款时用不了花呗", "怎么我的花呗不可以充话费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.189, 0.144, 0.055, 0.139], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期涉及的小贷公司", "pos": ["还有重庆阿里小额小贷、重庆蚂蚁小额小贷 这两个公司也收取了花呗分期", "***重庆阿里小额小贷,***重庆蚂蚁小额小贷 这两个也是花呗分期"], "neg": ["花呗分期免息卷在哪里兑换", "商家开通了花呗支付,为什么不能付款", "我从借呗借的钱转到银行卡里去了现在用不了", "花呗收款已经开通了但是不能用", "蚂蚁借呗还款只能本人还款吗", "关了花呗,什么时候可以重新开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.052, 0.166, 0.198, 0.061, 0.157, 0.121], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "如何查询借呗还款是否成功", "pos": ["我得借呗还款成功了吗", "我九月这期蚂蚁借呗是否还款成功"], "neg": ["该怎么弄!我现在着急用,额度不够淘宝不能花呗分期", "我为什么开不起花呗", "花呗开通后一定要使用吗", "我为什么花呗人脸识别不通过", "借呗能可以把六个月变成***个月还款吗", "花呗,可以提额"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.102, 0.083, 0.103, 0.059, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "如何退出花呗", "pos": ["我还清了花呗怎样退出", "我怎样才能退出花呗"], "neg": ["借呗还款后额度没有恢复还显示借款", "后面几个月的花呗账单怎么查询", "花呗额度也降低了,咋回事", "借呗到期会不会自动还款吗", "花呗与信用卡收钱码", "我别的支付宝账号花呗额度能转移到这个支付宝上吗", "蚂蚁借呗提示暂时不能使用刷脸为什么"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.198, 0.065, 0.182, 0.16, 0.198, 0.098], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度恢复方法", "pos": ["怎么样恢复调整前的花呗额度", "花呗额度怎么恢复原来的花呗钱数"], "neg": ["提前还花呗咋样操作", "花呗余额为什么是负***", "我不是花呗分期吗!我需要全款支付吗", "花呗能分期不", "借呗短期借款不是免息的吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.092, 0.118, 0.058, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗开通资格查询", "pos": ["借呗怎么没资格申请", "借呗怎么还没开通"], "neg": ["花呗分期付款后可以一次付清", "***花呗试用", "花呗付款申请售后怎么退款", "打开花呗的时候显示的系统繁忙是咋回事", "推迟借呗还款", "申请借呗学历要怎么填写", "怎么解冻结的花呗", "在支付宝支付多少次才可以开通花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.137, 0.095, 0.132, 0.195, 0.196, 0.101, 0.11, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "如何使用借呗", "pos": ["我想用借呗怎么做", "怎么样才可以用借呗"], "neg": ["借呗 预留号码变了 怎么办", "借呗***个月还上还是每个月都还点", "蚂蚁借呗额度关闭了怎么能快速恢复"], "pos_scores": [1.0, 0.95], "neg_scores": [0.128, 0.141, 0.136], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗付款抽奖活动", "pos": ["之前用完花呗有一次抽奖的机会现在还有吗", "之前花呗付款有抽奖机会的"], "neg": ["花呗返现金额", "这个花呗是以前的号码", "选择了花呗分期的预售商品,尾款可以用红包支付吗", "使用花呗需要额外收服务费吗", "可以取消花呗支付,用别的支付方式吗", "有蚂蚁借呗怎么还要申请", "为什么邦定银行卡,还使用不了花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.067, 0.165, 0.194, 0.162, 0.185, 0.162], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗使用限制解除方法", "pos": ["我的花呗什么时候才可以y", "怎么才可以用回花呗"], "neg": ["求花呗额度要买电话", "花呗怎么算还款的", "花呗付了最低款,剩余的怎么算利息", "关联账户能使用花呗或借呗吗", "为什么我没有借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.162, 0.134, 0.181, 0.097, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "重新开通花呗的步骤", "pos": ["蚂蚁花呗重新怎么开通", "我请求再次开通蚂蚁花呗"], "neg": ["我花呗分了几期?每期多少钱", "从新开通蚂蚁花呗", "借呗怎么取消身份证号", "我的花呗收款码怎么制作"], "pos_scores": [1.0, 0.95], "neg_scores": [0.104, 0.116, 0.129, 0.16], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期后恢复使用条件", "pos": ["花呗逾期款返清后花呗可以使用吗", "逾期一天还可以用花呗不"], "neg": ["花呗可以扫别人的支付宝二维码吗", "花呗逾期会利滚利算利息么", "怎样退花呗,不用"], "pos_scores": [1.0, 0.95], "neg_scores": [0.179, 0.115, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期不可用原因", "pos": ["花呗分期暂不可用", "到提交订单的时候显示花呗分期不可用"], "neg": ["怎么才能开通花呗", "我要暂时花呗的可用额可以吗", "如何临时提升蚂蚁花呗额度", "蚂蚁借呗自动降低还会提升吗", "一分钱花呗体验活动", "借呗和网商贷不能共存吗", "如何取消花呗的零时额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.114, 0.196, 0.141, 0.094, 0.184, 0.052, 0.088], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通花呗的芝麻信用分要求", "pos": ["要多少芝麻信用分,花呗才可以用", "花呗要多少分可以开通"], "neg": ["为什么我用花呗帮朋友代付不行", "蚂蚁借呗,怎么完善信息", "花呗额度在哪边调", "信用住支持花呗付款吗", "已经有的借呗为什么突然变成网商贷了", "支付宝怎么更改花呗登录", "借呗关闭了怎么才能在开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.193, 0.122, 0.082, 0.104, 0.185, 0.173, 0.065], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗最低还款设置方法", "pos": ["花呗借呗设计最低还款哪里设计", "花呗怎么自动按最低额还款"], "neg": ["花呗扣款是从哪里扣,是支付宝绑定", "为什么我的没有蚂蚁借呗功能", "为什么我的花呗有***分也用不了借呗", "花呗为啥不能充游戏点券", "我还不上花呗", "境外商家申请花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.105, 0.065, 0.085, 0.07, 0.197, 0.103], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "绑定银行卡后花呗无法使用原因", "pos": ["绑定了银行卡为什么还用不了花呗", "为什么绑定银行卡了,还是使用不了花呗支付"], "neg": ["借呗转给我的款,都超过***个小时了,还是没有收到", "网商贷如何降回借呗", "怎么绑定不到花呗的", "蚂蚁借呗不见了,怎么还钱", "花呗 退款", "蚂蚁借呗怎么分***个月", "蚂蚁花呗忘记了以前的手机号账号用户名怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.173, 0.189, 0.121, 0.178, 0.137, 0.106], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "如何重新授权花呗", "pos": ["怎么样可以授权花呗", "重新授权蚂蚁花呗"], "neg": ["花呗人脉关系授权", "怎么显示不了借呗金额了", "帮我查询一下 花呗什么时候可以用", "花呗还款怎么查询详单"], "pos_scores": [1.0, 0.95], "neg_scores": [0.164, 0.156, 0.106, 0.13], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期后恢复使用步骤", "pos": ["花呗预期,如何可以恢复使用", "我的花呗有逾期,怎样才能恢复正常使用"], "neg": ["如果停用借呗", "借呗可以当天还款,当天借出来,再当天还款吗", "花呗分期后可以第三个月把其他的钱还完吗", "我支付宝帐户是手机号,咋叫我登入花呗时显示支付宝", "蚂蚁花呗还款延期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.184, 0.133, 0.162, 0.145, 0.094], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度更新时间", "pos": ["花呗额度用完什么时候可以更新", "花呗的零时额度,什么时候到期"], "neg": ["花呗分期已提前还清,为什么花呗还不能用", "我凌晨用借呗借的钱怎么到现在还没到账", "蚂蚁花呗多久提一次额度", "怎么能重新开通花呗", "我没有用花呗一分钱为什么让我还款", "花呗未收获之前怎么分期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.113, 0.086, 0.168, 0.123, 0.1, 0.185], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款使用余额宝设置", "pos": ["花呗还款如何选择用余额宝", "从余额宝怎么设置代扣到花呗还款"], "neg": ["借呗的还款日可不可以修改", "借呗自动关闭是什么情况", "电脑端如花使用花呗支付", "借呗能延伸还款时间吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.169, 0.121, 0.076, 0.064], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗支持加油充值吗", "pos": ["花呗是否支持加油充值", "可以用花呗充加油卡"], "neg": ["网商银行跟蚂蚁借呗", "手机上没有支付宝了", "网商袋怎么换借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.08, 0.172, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款方法", "pos": ["花呗有余额,但不知道怎么还款", "我就花呗,怎么还款"], "neg": ["我刚还清蚂蚁借呗,什么时候能再借呗", "开通花呗,不使用会收费用吗", "怎麼关掉借呗", "如何递增花呗金额", "借呗要不按期还,推迟两天还会怎么样"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.103, 0.061, 0.184, 0.18], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款金额不符解决方法", "pos": ["花呗的还快款金额不对", "还的钱跟花呗对不上"], "neg": ["借呗提前还款后,额度为何不恢复", "为什么借呗申请的时候说账户有风险", "花呗有消费限制"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.147, 0.105], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "双十一花呗付款支持情况", "pos": ["双十一可以用蚂蚁花呗付款吗", "双十一抢购的东西 可以用花呗付款吗"], "neg": ["借呗是券怎么领", "借呗提示提现额度", "退款后花呗分期暂不可用", "但是我分了期 分了三个月", "花呗在哪 我这没有", "花呗的运费险退回哪", "申请退款后花呗支付金额什么时间到账"], "pos_scores": [1.0, 0.95], "neg_scores": [0.077, 0.108, 0.178, 0.112, 0.086, 0.137, 0.2], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度下降原因", "pos": ["花呗额度下降了", "花呗降额,有病"], "neg": ["我要花呗退款", "从花呗还借款", "怎么样才能把花呗的钱转到我账户上", "我想知道我的花呗有多少可用额度", "香港花呗购物返现", "花呗有逾期 怎么办", "借呗***期怎么还款", "我想注销的帐户开通了花呗,另一个开通不了,那我想使用的花呗该怎样开通"], "pos_scores": [1.0, 0.95], "neg_scores": [0.138, 0.182, 0.165, 0.098, 0.096, 0.091, 0.104, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗无法使用解决方法", "pos": ["我的花贝怎么用不了", "我的花呗怎么没法使用"], "neg": ["支付宝花呗只能一个号码使用吗", "密付可以扣花呗的吗", "花呗收钱码开通不了了,怎么开通", "花呗***次月还是什么意思"], "pos_scores": [1.0, 0.95], "neg_scores": [0.121, 0.052, 0.179, 0.172], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款后账单问题", "pos": ["花呗退款了为什么在还花呗账单里面还有这个账单", "这单淘宝已显示退款,这里也显示退回花呗,为什么帐单还款里还要再还"], "neg": ["加油可用花呗付款吗", "我借呗还有额度,为什么借不出来了", "可以消除借呗还款记录么", "如何开通余额宝自动还花呗", "花呗怎么返现"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.06, 0.192, 0.138, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "关闭花呗后能否再次开通", "pos": ["主动关闭花呗,是否以后都不能再开启", "花呗关闭后还能再次开通嘛"], "neg": ["我的蚂蚁借呗一年了怎么还没提额", "能把网商贷转为借呗吗", "我的花呗已经花完了,怎么还让我还款", "蚂蚁借呗快到还款日有短信提醒不", "花呗消费限制额度吗", "我想咨询一下我的花呗逾期一两天可以吗", "怎么把花呗什么的设置到我的页面"], "pos_scores": [1.0, 0.95], "neg_scores": [0.161, 0.106, 0.094, 0.191, 0.076, 0.164, 0.162], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗用余额宝还款设置", "pos": ["余额宝设置还花呗", "如何设置,花呗用余额宝还款"], "neg": ["今天借呗没有自动还款", "我的卡怎么绑定花呗绑定不了", "花呗怎么不支持", "花呗分期付款的事", "我花呗逾期了,可以晚两天还么", "花呗退款,问什么还款额没变,余额也没有收到", "花呗退款手续费怎么收取", "不要在页面昱显示花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.162, 0.116, 0.071, 0.14, 0.172, 0.063, 0.136, 0.101], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗额度冻结原因", "pos": ["花呗也被冻结了", "蚂蚁花呗额度冻结"], "neg": ["怎么取消花呗的授权", "我想花呗自动在余额宝里扣款,怎样操作", "我的花呗双***之前有临时额度", "蚂蚁花呗收款时在哪里上传营业执照", "关闭花呗自动交话费", "查询花呗未还款款余", "借呗金额全部还清可以再借么", "我现在用花呗买电瓶车,额度不够怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.062, 0.19, 0.135, 0.127, 0.143, 0.132, 0.196], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "收钱码支持花呗付款吗", "pos": ["收钱码付款可以用花呗吗", "收钱码花呗能用吗"], "neg": ["如何用花呗付款", "花呗我为什么又履约了", "我想永久的关掉借呗", "我有蚂蚁借呗,现在给我停了", "全部还请花呗账单", "为什么借呗提前还了没有额度", "提醒花呗额度", "花呗在有些商家用不了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.124, 0.056, 0.132, 0.074, 0.193, 0.109, 0.087, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "蚂蚁借呗所属公司", "pos": ["蚂蚁借呗是哪个公司的产品", "蚂蚁借呗贷款公司是谁"], "neg": ["花呗晚一天还款怎么", "花呗收钱码哪里查看", "在蚂蚁借呗里我怎么没有信誉额度", "如何取消花呗代扣电话费", "我的商品退货了,淘宝显示款项已退回,为什么我的花呗还要还钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.057, 0.135, 0.103, 0.073, 0.067], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗借款期限选项", "pos": ["借呗只有***个月和***个月", "蚂蚁借呗***个月"], "neg": ["花呗怎么冲qb", "但是我银行卡支付的为什么退回花呗", "开通花呗,如果不用会产生费用吗", "我在花呗里面充值话费到现在没到账", "退货了,可是还有***元运费没有退到花呗怎么处理"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.123, 0.165, 0.159, 0.174], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账户转移方法", "pos": ["另一个账户开通了花呗,怎么转到这个账户", "我之前的手机号账户花呗已经不记得了,怎么转到现在这个手机号账户上"], "neg": ["积分可换花呗额度吗", "我想把这个花呗给注销了", "花呗的账期", "蚂蚁借呗怎么不能借一个月", "借呗超过***笔然后什么时候可以恢复", "***花呗,用了***,已经还了,现在怎么还是***", "是不是借呗和网商贷两个只能开一个", "花呗不是一个手机号码怎么改"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.131, 0.075, 0.159, 0.181, 0.133, 0.189, 0.194], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗被取消原因", "pos": ["为什么我蚂蚁借呗给我取消啦", "我又没有逾期过,怎么取消我的蚂蚁借贝"], "neg": ["花呗关闭还能再使用吗", "借呗申请提高额度,但是审核进度不能查询", "蚂蚁花呗能绑定到我的另一个支付宝账号吗", "借呗还款日期多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.168, 0.2, 0.077, 0.06], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗游戏充值限制", "pos": ["现在游戏内不可以用花呗充值点卷吗", "花呗不能用于游戏充值吗"], "neg": ["怎么才能支持花呗充话费", "我老公有营业执照,我可以开通花呗收款吗", "我找不到我的花呗是哪个号码了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.183, 0.088, 0.06], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗分期利息计算", "pos": ["花呗分期还款怎么算利息", "花呗分期付款 利息多少"], "neg": ["花呗分期客服", "借呗要求信用额度是什么", "我想知道为什么花呗分期暂不可用", "花呗怎么计息的", "借呗还款 银行卡只能转入一万", "明天是还款日,可是今天我的钱转不进余额宝?明天怎么还蚂蚁借呗的钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.088, 0.115, 0.132, 0.057, 0.105, 0.178], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "取消花呗分期方法", "pos": ["取消花呗分期吗", "我的花呗不小心点到分期了可以帮我取消分期了吗"], "neg": ["蚂蚁借呗提前还款可以得到蚂蚁庄园的饲料", "找不到花呗在哪里", "花呗立减十元", "ofo共享单车为什么不能用花呗", "花呗可用额度多少才可以分期", "为什么花呗自动还款失败,什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.075, 0.068, 0.054, 0.158, 0.147, 0.109], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款方式", "pos": ["花呗怎样还款", "花呗支付是怎么还款的"], "neg": ["为什么申请蚂蚁借呗要刷脸", "借呗借钱需要绑定银行卡嘛", "花呗是自动扣款是从哪里扣款的", "我的蚂蚁借呗为什么会被取消", "花呗还了最低还款,还需要分期吗", "借呗还钱在余额宝里吗扣", "这个花呗还款日可以调一下吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.123, 0.175, 0.195, 0.135, 0.167, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗分期还款选项", "pos": ["借呗分期还款可以", "我的借呗借到钱了,可以分期还不"], "neg": ["我的花呗可以提升吗", "我用花呗付的款,要什么时候还", "我的花呗还款日期可以改一下。推迟几天吗", "借呗 借***个月 什么时候开始还", "怎么我的支付宝没有花贝"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.087, 0.1, 0.128, 0.084], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "芝麻信用解除授权对花呗影响", "pos": ["我的芝麻信用解除了授权,对芝麻分和花呗有影响吗", "花呗和芝麻信用"], "neg": ["我的蚂蚁借呗咋不能用了", "我的花呗怎么别另一个手机号绑定了", "花呗分期的能提前还款吗", "花呗买飞机票"], "pos_scores": [1.0, 0.95], "neg_scores": [0.112, 0.092, 0.174, 0.127], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗额度调整选项消失", "pos": ["借呗右上角找不到额度调整", "借呗的额度调控不见了"], "neg": ["花呗分期可以提前还清吗", "为什么我的花呗***月是未出账,***月为什么还要还款", "为什么借呗不能超过***笔", "花呗分期签约", "花呗如何取消自动缴费", "蚂蚁借呗利息这么算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.189, 0.09, 0.186, 0.104, 0.106, 0.07], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗开通时间查询", "pos": ["什么时候能使用花呗", "我的花呗什么时候可以开阿"], "neg": ["借呗已还清,还停我花呗", "网商贷换借呗", "没有app 花呗怎么还款", "蚂蚁借呗 借***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.192, 0.189, 0.099, 0.124], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗短期借款方法", "pos": ["蚂蚁借呗如何短期借款", "蚂蚁借呗怎么选择短期的,比如借半月或者一个月"], "neg": ["就 我之前有个支付宝开了花呗但是我号不用了", "蚂蚁借呗用不了怎么办", "花呗会寄账单过来吗", "怎么解除花呗贷"], "pos_scores": [1.0, 0.95], "neg_scores": [0.051, 0.182, 0.168, 0.182], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "退款后花呗未恢复原因", "pos": ["我在小米商城买了东西,已经退款了,花呗的的钱怎么还没回来", "已经退款了 花呗怎么还欠款"], "neg": ["支持花呗付款我为什么付不了", "我借呗分两次借的账单怎么分开查", "借呗的借款日和还款日相差多久", "花呗多少信用度可以开通", "花呗返还卷", "我的花呗为什不能用了"], "pos_scores": [1.0, 0.95], "neg_scores": [0.142, 0.063, 0.09, 0.126, 0.074, 0.075], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款未更新问题", "pos": ["我银行卡已经被花呗扣款为什么打开支付宝花呗还显示没还款", "我的花呗还款银行卡钱已经扣了 花呗还是现实没有还"], "neg": ["我的花呗是否是开通状态", "花呗的额度包括手续费吗", "借呗还钱的日期是什么时候"], "pos_scores": [1.0, 0.95], "neg_scores": [0.05, 0.186, 0.069], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "支付宝没有花呗选项", "pos": ["点开右下角我的,没有花呗", "我手机上没有花呗"], "neg": ["我还了借呗的一部分钱,为什么额度没恢复", "蚂蚁花呗最低还款会降低信用", "我为什么花呗不能还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.161, 0.153, 0.081], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款操作步骤", "pos": ["怎么可以把蚂蚁花呗还了,我现在换不了呀", "我刚才点错了,我要还款给花呗,怎么还"], "neg": ["蚂蚁借呗如何取消最低还款", "花呗支付成功但订单显示未支付", "我怎样才能使用借呗贷款", "最近花呗无法提前还款", "你看我开通花呗了吗", "蚂蚁信用***分借呗额度多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.095, 0.081, 0.146, 0.087, 0.121, 0.099], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗逾期还款后果", "pos": ["花呗没有按期偿还", "花呗没有正常还款"], "neg": ["我现在是开通了花呗吗", "花呗分期可以提前还吗?还需要手续费吗", "为什么花呗我没花那么多钱,花呗让我还那么多那", "蚂蚁借呗可以在***个月后一次还清吗", "花呗付的话便宜吗", "如果将网商贷改为借呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.182, 0.14, 0.145, 0.127, 0.158, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗免分期券领取方法", "pos": ["花呗免分期券在哪可以领", "花呗***期分期券哪里领"], "neg": ["之前我看了可以用花呗买机票", "下线门店怎么设置花呗交易额度", "还差多少才能开借呗", "花呗买电影票退票", "为什么商品支持花呗但是不能用花呗付款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.107, 0.137, 0.068, 0.141, 0.07], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款验证码发送到旧手机号", "pos": ["花呗还款验证码还发在旧手机号上怎么办", "修改玩手机号为什么用花呗付款还是发短信到旧手机号码"], "neg": ["我想知道花呗提前还,怎么额度还降低了", "怎么开通商家花呗", "花呗吗。为什么提前不能还 还是等到***月份还", "我想提高花呗额度买东西", "我花呗账号忘记啦"], "pos_scores": [1.0, 0.95], "neg_scores": [0.171, 0.153, 0.079, 0.075, 0.144], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还清后无法使用原因", "pos": ["为什么还钱完了花呗却用不了", "所有账单已还清,花呗为什么还不能用"], "neg": ["花呗关闭了可以再开启吗", "支付宝里面查不到花呗", "花呗怎么样分期", "我的花呗双***之前有临时额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.166, 0.096, 0.165, 0.158], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗自动还款设置", "pos": ["花呗支付后。还款日自动扣款吗", "花呗还l款会自动扣除还款钱吗"], "neg": ["我要开通花呗 买手机", "花呗开通需要多少芝麻信用", "蚂蚁借呗分期还款可不可以提前还", "花呗最多可以透支多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.151, 0.181, 0.15, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗购买火车票支持情况", "pos": ["花呗支持飞天猪买火车票吗", "支付宝上买火车票支持花呗吗"], "neg": ["领取借呗提升额度为什么来输入密码", "花呗咋样才能还款", "我的花呗该还款了,但是我不会还,应该怎么还款", "花呗一般什么时候评估额度", "我想降低借呗额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.069, 0.081, 0.178, 0.066, 0.128], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "收钱码接收花呗付款", "pos": ["收钱码收不收对方的花呗", "收钱码可以收花呗点钱吗"], "neg": ["怎么样才可以换掉之前的花呗手机号", "展开 花呗付款的时候验证码怎么发到旧手机号了?这怎么改", "我现在还不了分期花呗吗", "花呗还清 额度不对", "花呗客服多少", "花呗未还清还可以用借呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.189, 0.198, 0.135, 0.077, 0.08, 0.125], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗账单金额不符", "pos": ["花呗的数额与实际不符", "我花呗与还款不符合"], "neg": ["我现在想分期买苹果***花呗不够怎么办", "对多花呗可以消费多少", "用余额宝自动向花贝借呗还款要收费吗", "为什么我的支付宝不显示借呗", "我现在想在蚂蚁借呗上借钱。但是我的电话收不到短信的验证码", "蚂蚁借呗怎么看还款记录"], "pos_scores": [1.0, 0.95], "neg_scores": [0.126, 0.121, 0.084, 0.112, 0.086, 0.054], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗欠款还款方法", "pos": ["怎样还花呗欠款", "花呗还钱?怎么换"], "neg": ["我刚买的东西还没收到货就要显示交易成功要还花呗", "共享单车的押金为什么不能用花呗付了", "花呗这个月用下个月还", "我要还蚂蚁花呗的款,怎么办,我弄不好怎么办", "花呗未还款可以申请借呗嘛", "花呗只能在淘宝购物了,其它地方不能用是怎么回事", "我想成为商家,花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.185, 0.062, 0.163, 0.061, 0.097, 0.106, 0.078], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗被关闭原因", "pos": ["借呗正在用怎么给我关闭了", "信用记录一直挺好 为什么现在借呗给我取消了"], "neg": ["使用花呗支付花呗红包没有自动扣除", "花呗可用于付帐么", "蚂蚁借呗最高额多是多少", "开通花呗刷脸一直不通过", "花呗和信用卡收款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.182, 0.112, 0.16, 0.156, 0.176], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗提前还款操作", "pos": ["花呗申请提前还款", "我花呗要提前还款"], "neg": ["退款***快多,为什么我的花呗没有显示", "余额宝可以直接扣蚂蚁借呗吗", "本身积分发放的功能是受到花呗限制影响的,这个无法恢复,如果您是对花呗的限制有疑问,可以帮您zhuanj", "借呗还款日是每个月的几号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.129, 0.058, 0.163, 0.146], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "取消花呗方法", "pos": ["是我点错了,怎么取消花呗", "花呗怎么取消,我取消嘞"], "neg": ["摇一摇花呗临时额度", "怎么兑换花呗免息券", "过了***点怎么借呗还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.18, 0.099, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款未到账问题", "pos": ["账单已还清,但是退款到花呗的钱没收到", "退款退回花呗,可是花呗并没有收到那退款的钱"], "neg": ["借呗一千一个月还多少钱", "我的花呗返现", "借呗当天还最迟几点", "花呗可以定餐吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.168, 0.133, 0.055, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗多扣款解决方法", "pos": ["花呗多收了我的钱", "花呗为什么多让我还了***"], "neg": ["多少芝麻信用开通花呗", "借呗还款日未自动还款", "淘宝买手机朋友代付可以用花呗吗", "我已经还了花呗,那*** 退额度是怎么算的", "花呗下月还款日期", "蚂蚁借呗多少分才能借", "花呗免息优惠券在哪里", "怎么取消朋友带还花呗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.145, 0.165, 0.195, 0.157, 0.148, 0.144, 0.05, 0.07], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "蚂蚁借呗入口位置", "pos": ["借呗点哪里", "从哪找蚂蚁借呗"], "neg": ["我房贷需要借呗的结清证明怎么弄", "网商银行就是借呗吗", "可以还下下个月的花呗吗", "代收可以使用花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.088, 0.19, 0.187], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "开通借呗方法", "pos": ["我要怎么才能使用借呗", "怎么能用借呗"], "neg": ["我的花呗账号怎么跟我登陆的账号不一样", "花呗逾期利息和罚款怎么算", "为什么我游戏充值时无法使用花呗支付", "借呗分期能分几个月", "花呗分期还有多少期", "你妈的逼我为什么没有花呗", "花呗使用临时额度后,还款会优先还临时额度吗", "飞猪买机票为什么不能用花呗支付"], "pos_scores": [1.0, 0.95], "neg_scores": [0.187, 0.187, 0.191, 0.175, 0.174, 0.18, 0.164, 0.063], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗账单消失原因", "pos": ["为什么蚂蚁借呗看不到账单了,变成重新申请了", "我的借呗账单看不到了"], "neg": ["在花呗哪一个里面", "为什么打开花呗,没有分期还款", "我要怎样才能用蚂蚁借呗", "我之前可以用借呗现在为什么又不能用了", "我有花呗为什么查账,都显沒有邦定银行卡,支付宝 上的银行卡不就是的吗", "我支付宝是现在的手机号,怎么花呗还是原来的手机号?怎么修改", "我想看一下我的花呗之前的分期还有多少期", "为什么蚂蚁借呗是暂无额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.052, 0.071, 0.195, 0.156, 0.133, 0.187, 0.158, 0.089], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "商户开通花呗收款步骤", "pos": ["怎么开通商户的花呗收钱", "商户如何开通花呗支付"], "neg": ["手机花呗 如何分期", "花呗临时额度下个月必须还清吗", "还花呗从哪扣款", "为什么借呗不能借钱"], "pos_scores": [1.0, 0.95], "neg_scores": [0.196, 0.137, 0.154, 0.195], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款手机号更换方法", "pos": ["我的花呗,我想还钱,但是预留手机号换了,怎么办", "我的花呗还款手机号不用怎么改"], "neg": ["借呗活动网址", "使用花呗和来分期可以增加芝麻分", "如何减低借呗额度", "蚂蚁借呗首月免息"], "pos_scores": [1.0, 0.95], "neg_scores": [0.147, 0.064, 0.088, 0.058], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗放款时间", "pos": ["申请借呗后多久放款", "借呗借后多久到账"], "neg": ["借呗冻结存在那些风险", "为什么这么久还没帮我提", "要用多少信用额度才能开通花呗", "我***月花呗还***,这部分钱怎么算的", "我不用花呗行关掉", "花呗怎么买qb"], "pos_scores": [1.0, 0.95], "neg_scores": [0.067, 0.097, 0.111, 0.151, 0.073, 0.082], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗登录淘宝方法", "pos": ["花呗怎么登录淘宝", "淘宝的蚂蚁花呗"], "neg": ["我用着来分期 可以用蚂蚁借呗吗", "借呗还款日在", "商铺可以用花呗付,但我拍下后花呗不能用", "为什么我已经完成大学生认证但花呗分三期还是要手续费"], "pos_scores": [1.0, 0.95], "neg_scores": [0.175, 0.069, 0.121, 0.066], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗身份信息完善问题", "pos": ["借呗完善不了信息", "蚂蚁借呗不能完善身份信息"], "neg": ["怎样能够恢复蚂蚁借呗的功能", "借呗入口不显示额度", "使用购物津贴,还能用花呗吗", "花呗支付退款如何处理"], "pos_scores": [1.0, 0.95], "neg_scores": [0.187, 0.117, 0.069, 0.059], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗购买手机物流查询", "pos": ["花呗上买的手机去哪里查有没有发货", "花呗上买的手机怎么查快递"], "neg": ["什么是花呗冻结额度", "用了购物津贴蚂蚁花呗用不了", "我的花呗总了消费了多少"], "pos_scores": [1.0, 0.95], "neg_scores": [0.12, 0.067, 0.137], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗退款后仍需还款原因", "pos": ["为什么退款回花呗了,还显示要我还款", "花呗已经提示退款成功,可为什么还是显示帐单"], "neg": ["花呗晚两天还款,会对信用有影响吗", "花呗上个月可以交电费,这个月为什么不行", "蚂蚁借呗最多一次还多少", "借呗最多可分多少期"], "pos_scores": [1.0, 0.95], "neg_scores": [0.065, 0.083, 0.14, 0.12], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款日期查询", "pos": ["花呗使用什么时候还款", "我要是今天用花呗啥时候还款"], "neg": ["蚂蚁借呗还款提醒设置", "不想用花呗付款,又怎么换付款方式", "我能帮助朋友还花呗吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.059, 0.099, 0.198], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗无法借款原因", "pos": ["现在不能借呗了吗", "借呗下不了单"], "neg": ["花呗的余付款怎么能还上", "信用额度怎样才能蚂蚁借呗的条件", "蚂蚁借呗入口怎么找不到", "花呗可以在京东上购物吗", "我是商家,如何关闭客户用花呗付款", "怎样主动申请蚂蚁借呗提额", "怎么挂失花呗账号"], "pos_scores": [1.0, 0.95], "neg_scores": [0.151, 0.121, 0.087, 0.198, 0.167, 0.062, 0.129], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗和花呗区别", "pos": ["蚂蚁借呗,和花呗不样吗", "借呗花呗有冲突吗"], "neg": ["蚂蚁借呗可用花呗还吗", "我已经还完花呗的款了,为什么还显示未还款", "花呗你查不到次笔交易", "花呗被冻结咋办", "花呗也还,为什么额度不更新", "今天是借呗还款日,可以超过中午***点"], "pos_scores": [1.0, 0.95], "neg_scores": [0.061, 0.05, 0.109, 0.146, 0.18, 0.073], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收款码设置方法", "pos": ["怎么样能花呗收款", "怎么设置花呗收钱码"], "neg": ["花呗还款日是***号,***号当天还是逾期吗", "借呗分期还款账单", "取消蚂蚁借呗提现", "给我恢复借呗", "借呗跟花呗的联系", "花呗还款为什么额外扣了手续费", "有没有直接转账给花呗的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.076, 0.077, 0.186, 0.136, 0.157, 0.079, 0.159], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款困难解决方法", "pos": ["花呗还不了整么办", "花呗还不到这么多怎么办"], "neg": ["更改电话号码之后,借呗显示没有额度,是不是关闭了", "不是可以花呗分期吗", "我的花呗为什么不能用了 什么原因"], "pos_scores": [1.0, 0.95], "neg_scores": [0.097, 0.085, 0.155], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗有额度但无法使用", "pos": ["我什么我不能用花呗", "为什么我的花呗不能用,显示我的花呗还有***可用余额"], "neg": ["我想看看我***月的借呗还款记录", "借呗活动页面", "我的花呗之前退回了***"], "pos_scores": [1.0, 0.95], "neg_scores": [0.068, 0.154, 0.104], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗还款记录查询", "pos": ["如果查询借呗还款", "看蚂蚁借呗还款记录"], "neg": ["借呗没有临时额度吗", "花呗临时额度查得到吗", "蚂蚁会员和蚂蚁借呗的区别", "我花呗可以先提前还一部分吗", "花呗还款扣了银行卡的钱但显示没有这笔账单", "退款成功怎么花呗没入帐"], "pos_scores": [1.0, 0.95], "neg_scores": [0.083, 0.115, 0.17, 0.095, 0.167, 0.055], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "重新开通花呗方法", "pos": ["之前把花呗关闭了 现在怎么开通", "不小心关闭花呗啦,怎么开通"], "neg": ["花呗额度***千元为什么不能用分期付", "花呗额度分期是要商品的全额吗", "我说怎么删除花呗消费记录", "去饭店吃饭可以用花呗那吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.078, 0.091, 0.087, 0.156], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗人脸识别失败解决方法", "pos": ["借呗无法人脸识别", "借呗密码输了,人脸识别没用怎么办"], "neg": ["蚂蚁花呗不到下个月为什么不能提前还款", "借呗关了还能开么", "借呗还款日期不是固定每月四号吗", "借呗的钱没有到银行卡怎么办"], "pos_scores": [1.0, 0.95], "neg_scores": [0.14, 0.129, 0.109, 0.061], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "借呗分期期限选项", "pos": ["借呗!可以分***个月还吗", "蚂蚁借呗不是可以分***期么"], "neg": ["分期商品占用花呗额度吗", "借呗还款日利息", "我的花呗这个月的还款退货退款了,为什么还显示让我还款", "花呗已还款,银行卡里的钱都扣了,为什么还显示没还", "花呗里面怎么一次还请", "借呗停用了,什么时候能恢复", "能帮我关闭花呗吗", "花呗被关闭了怎么开启"], "pos_scores": [1.0, 0.95], "neg_scores": [0.071, 0.166, 0.185, 0.186, 0.057, 0.062, 0.191, 0.116], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗登录问题解决方法", "pos": ["花呗显示可登陆支付宝账号使用花呗,为什么我登了还是用不了", "登陆支付宝账号都对,可是花呗登不上"], "neg": ["不开通花呗看不到可用额度", "蚂蚁借呗还款时间没***个月了", "退款回复花呗额度怎么做", "解除花呗关联", "怎么主动申请蚂蚁借呗的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.197, 0.178, 0.165, 0.134, 0.072], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗开通方法", "pos": ["如何开通“花呗”服务", "花呗哪里开通"], "neg": ["花呗是哪个银行给的额度", "花呗授权什么用", "花呗分期以后 提前还款", "透支花呗额度", "本月还款借呗还能开通吗", "为何蚂蚁借呗还完后没有信用额度", "花呗最低还款的利息咋算"], "pos_scores": [1.0, 0.95], "neg_scores": [0.091, 0.064, 0.071, 0.194, 0.149, 0.05, 0.153], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗还款日修改方法", "pos": ["花呗还款日能不能改到***号", "花呗还款日期可以改到每月***号吗"], "neg": ["另一个账号用花呗,这个账号可以还吗", "我的花呗额度原来是***的", "还能重新开通花呗吗", "花呗还款用银行卡号支付密码在哪设置", "收钱码花呗怎么不能收大额的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.154, 0.092, 0.151, 0.051, 0.168], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗ETC缴费支持", "pos": ["花呗是否可以开道etc高速免费通道", "用花呗可以走高速etc缴费吗"], "neg": ["我刚用了借呗", "怎么样能提升蚂蚁借呗额度", "如何取消蚂蚁借呗这一项,不是关闭,是完全的取消"], "pos_scores": [1.0, 0.95], "neg_scores": [0.182, 0.1, 0.167], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "更换手机号后花呗无法使用", "pos": ["手机号码更换,无法使用花呗", "换了手机号花呗用不了"], "neg": ["换手机号码 对花呗有没有影响", "我关闭了花呗,想用另一个号开通怎么办", "花呗还最低还款额会影响信誉吗", "我要查询是哪次花呗产生的罚款", "我刚刚已经还了花呗了可是怎么还是显示没还,银行卡也扣钱了", "花呗已经还款了,怎么还有未出账", "我的花呗开通没有成功怎么办", "蚂蚁借呗之前能用为什么现在说没额度"], "pos_scores": [1.0, 0.95], "neg_scores": [0.123, 0.146, 0.142, 0.162, 0.123, 0.082, 0.109, 0.185], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗临时额度获取方法", "pos": ["花呗的临时额度我怎么没有", "临时花呗没有"], "neg": ["为什么不能用花呗支付hello单车押金", "花呗逾期用不了 多久才能在用", "当初用蚂蚁花呗付的款,现在花呗已经还完了,要退货,钱退在哪里", "怎么查花呗临时额度有效期", "借呗还款跟借款记录怎么删掉", "蚂蚁借呗最快到账", "借呗逾期了自动扣款", "花呗可以生活交费吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.079, 0.142, 0.177, 0.081, 0.121, 0.134, 0.191, 0.165], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗购买机票支持情况", "pos": ["我上个月用蚂蚁花呗买了机票", "花呗支持买机票吗"], "neg": ["借呗钱以还清?额度什么时候恢复", "花呗冻结了,怎样才可以开", "花呗号码和支付宝号码不一致怎么可以统一", "花呗支付了怎么没到账", "怎么取消消费优先选花呗", "什么时候才能再用花呗", "借呗额度自行消失怎么回事"], "pos_scores": [1.0, 0.95], "neg_scores": [0.19, 0.17, 0.115, 0.087, 0.075, 0.094, 0.083], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗最低还款影响信用吗", "pos": ["最低还花呗会影响信用度吗", "花呗最低还款后会不会影响信誉"], "neg": ["花呗付款成功却显示没有付款", "蚂蚁借呗没有去还款", "我得借呗怎么给停了", "借呗到还款日为什么不自动还款"], "pos_scores": [1.0, 0.95], "neg_scores": [0.074, 0.183, 0.089, 0.139], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "花呗收款码办理方法", "pos": ["怎么样来办理花呗收款码", "怎么蚂蚁花呗收款"], "neg": ["蚂蚁借呗可以按天借款吗", "我想退款说什么双***红包如果被使用就扣花呗什么意思", "可以花呗付款 怎么显示不支持花呗付款", "花呗能支付宝购物吗"], "pos_scores": [1.0, 0.95], "neg_scores": [0.142, 0.084, 0.087, 0.186], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
{"query": "商家花呗收款手续费", "pos": ["花呗 像商家收取手续费么", "我是商家,用花呗收款,需要收费吗"], "neg": ["借呗提前还款,本金的利息会相应减少吗", "主动申请蚂蚁借呗后又占无使用资格", "借呗不能一年还款么", "怎么花呗也扣了,银行卡里也出了的"], "pos_scores": [1.0, 0.95], "neg_scores": [0.054, 0.065, 0.103, 0.179], "prompt": "为此查询生成表示:", "type": "normal"}
|
||||
5
data/datasets/examples/embedding_data.jsonl
Normal file
5
data/datasets/examples/embedding_data.jsonl
Normal file
@ -0,0 +1,5 @@
|
||||
{"query":"什么是深度学习?","pos":["深度学习是机器学习的一个子领域,使用多层神经网络模拟人脑处理信息的方式","深度学习(Deep Learning)是人工智能和机器学习的重要分支,通过构建深层神经网络来学习数据表示"],"neg":["机器学习包含多种算法,如决策树、支持向量机、神经网络等","人工智能是计算机科学的一个分支,目标是创建能够执行智能任务的系统","数据挖掘是从大型数据集中提取模式和知识的过程"],"pos_scores":[1.0,0.95],"neg_scores":[0.6,0.4,0.3],"prompt":"为此查询生成表示:","type":"normal"}
|
||||
{"query":"如何制作红烧肉?","pos":["红烧肉制作步骤:选五花肉切块,焯水去腥,热锅炒糖色,下肉翻炒,加生抽老抽料酒,小火炖煮30-40分钟至软烂","红烧肉是经典上海菜,用五花肉、冰糖、生抽、老抽、料酒等,关键在于炒糖色和火候控制"],"neg":["红烧肉是中国传统菜肴,口感香甜软糯,深受大众喜爱","五花肉含有丰富的蛋白质和脂肪,营养价值较高","上海菜以红烧见长,口味偏甜,适合南方人饮食习惯"],"pos_scores":[1.0,0.9],"neg_scores":[0.5,0.3,0.2],"prompt":"为此查询生成表示:","type":"recipe"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","pos":["Python编程入门需要掌握:基本语法、数据类型(字符串、列表、字典)、控制结构(if/for/while)、函数定义、面向对象编程基础","Python新手学习路线:先学变量和数据类型,然后是条件语句和循环,接着学习函数和模块,最后是类和对象"],"neg":["Python是一种简单易学的编程语言,语法清晰,适合初学者","编程思维比语法更重要,需要培养逻辑思维能力","计算机科学包含算法、数据结构、操作系统等多个分支"],"pos_scores":[1.0,0.95],"neg_scores":[0.4,0.3,0.2],"prompt":"为此查询生成表示:","type":"programming"}
|
||||
{"query":"北京有哪些必去的旅游景点?","pos":["北京必去景点推荐:故宫(紫禁城)、天安门广场、万里长城、颐和园、天坛、圆明园、北海公园、什刹海等","故宫是明清两代皇宫,占地72万平方米,是世界最大的古代宫殿建筑群,必须提前预约参观"],"neg":["北京是中华人民共和国首都,有3000多年建城史","中国拥有众多世界文化遗产,旅游资源丰富","春秋季节是北京旅游的最佳时间,气候宜人"],"pos_scores":[1.0,0.9],"neg_scores":[0.3,0.25,0.4],"prompt":"为此查询生成表示:","type":"travel"}
|
||||
{"query":"机器学习和深度学习有什么区别?","pos":["机器学习是更广义的概念,深度学习是机器学习的一个子集,主要区别在于深度学习使用深层神经网络进行特征自动提取","深度学习使用多层神经网络自动学习特征,而传统机器学习通常需要人工设计特征,深度学习在图像和语音识别方面表现更优"],"neg":["机器学习算法包括监督学习、无监督学习和强化学习三大类","人工智能技术发展迅速,在各行各业都有广泛应用","神经网络是深度学习的基础,模拟人脑神经元连接"],"pos_scores":[1.0,0.95],"neg_scores":[0.4,0.3,0.5],"prompt":"为此查询生成表示:","type":"tech_comparison"}
|
||||
580
data/datasets/examples/format_usage_examples.py
Normal file
580
data/datasets/examples/format_usage_examples.py
Normal file
@ -0,0 +1,580 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive usage examples for the refactored BGE data pipeline
|
||||
Demonstrates all four supported formats: triplets, pairs, candidates, and legacy nested
|
||||
|
||||
Data Formats Supported:
|
||||
1. Triplets: Query-triplets format (query, pos, neg) for BGE-M3 embedding training
|
||||
2. Pairs: Flat pairs format (query, passage, label) for BGE-reranker training
|
||||
3. Candidates: Unified format (query, candidates) with threshold-based label assignment
|
||||
4. Legacy Nested: Legacy nested format with automatic conversion to flat pairs
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data.dataset import BGEDataset, BGEM3Dataset, BGERerankerDataset
|
||||
from data.optimization import optimize_data
|
||||
|
||||
|
||||
def create_sample_data():
|
||||
"""Create sample data files for all four formats"""
|
||||
|
||||
# Create temporary directory
|
||||
temp_dir = tempfile.mkdtemp(prefix="bge_examples_")
|
||||
print(f"📁 Creating sample data in: {temp_dir}")
|
||||
|
||||
# 1. Triplets format (for BGE-M3 embedding training)
|
||||
triplets_file = os.path.join(temp_dir, "triplets_data.jsonl")
|
||||
triplets_data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": [
|
||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"ML algorithms use statistical techniques to learn patterns from data and make predictions."
|
||||
],
|
||||
"neg": [
|
||||
"Weather forecasting uses meteorological data to predict atmospheric conditions.",
|
||||
"Cooking recipes require specific ingredients and step-by-step instructions."
|
||||
],
|
||||
"pos_scores": [0.95, 0.88],
|
||||
"neg_scores": [0.15, 0.08],
|
||||
"prompt": "为此查询生成表示:",
|
||||
"type": "definition"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"pos": [
|
||||
"Deep learning uses neural networks with multiple layers to learn complex patterns.",
|
||||
"Deep neural networks automatically extract features from raw data through backpropagation."
|
||||
],
|
||||
"neg": [
|
||||
"Baking bread requires yeast, flour, and proper kneading techniques.",
|
||||
"Solar panels convert sunlight into electrical energy through photovoltaic cells."
|
||||
],
|
||||
"pos_scores": [0.92, 0.89],
|
||||
"neg_scores": [0.12, 0.18]
|
||||
}
|
||||
]
|
||||
|
||||
with open(triplets_file, 'w', encoding='utf-8') as f:
|
||||
for item in triplets_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# 2. Pairs format (for BGE-reranker training)
|
||||
pairs_file = os.path.join(temp_dir, "pairs_data.jsonl")
|
||||
pairs_data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
},
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Weather forecasting uses meteorological data to predict atmospheric conditions.",
|
||||
"label": 0,
|
||||
"score": 0.15,
|
||||
"qid": "q1",
|
||||
"pid": "p2"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Deep learning uses neural networks with multiple layers to learn complex patterns.",
|
||||
"label": 1,
|
||||
"score": 0.92,
|
||||
"qid": "q2",
|
||||
"pid": "p3"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Baking bread requires yeast, flour, and proper kneading techniques.",
|
||||
"label": 0,
|
||||
"score": 0.12,
|
||||
"qid": "q2",
|
||||
"pid": "p4"
|
||||
}
|
||||
]
|
||||
|
||||
with open(pairs_file, 'w', encoding='utf-8') as f:
|
||||
for item in pairs_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# 3. Candidates format (unified format with threshold-based labels)
|
||||
candidates_file = os.path.join(temp_dir, "candidates_data.jsonl")
|
||||
candidates_data = [
|
||||
{
|
||||
"query": "What is artificial intelligence?",
|
||||
"qid": "q3",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "Artificial intelligence is the simulation of human intelligence in machines.",
|
||||
"score": 0.94,
|
||||
"pid": "c1"
|
||||
},
|
||||
{
|
||||
"text": "AI systems can perform tasks that typically require human intelligence.",
|
||||
"score": 0.87,
|
||||
"pid": "c2"
|
||||
},
|
||||
{
|
||||
"text": "Mountain climbing requires proper equipment and training.",
|
||||
"score": 0.23,
|
||||
"pid": "c3"
|
||||
},
|
||||
{
|
||||
"text": "Chess is a strategic board game played between two players.",
|
||||
"score": 0.31,
|
||||
"pid": "c4"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Explain neural networks",
|
||||
"qid": "q4",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "Neural networks are computing systems inspired by biological neural networks.",
|
||||
"score": 0.91,
|
||||
"pid": "c5"
|
||||
},
|
||||
{
|
||||
"text": "They consist of interconnected nodes that process information in layers.",
|
||||
"score": 0.85,
|
||||
"pid": "c6"
|
||||
},
|
||||
{
|
||||
"text": "Gardening involves planting seeds and watering plants regularly.",
|
||||
"score": 0.19,
|
||||
"pid": "c7"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
with open(candidates_file, 'w', encoding='utf-8') as f:
|
||||
for item in candidates_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# 4. Legacy nested format (for backward compatibility)
|
||||
legacy_file = os.path.join(temp_dir, "legacy_nested_data.jsonl")
|
||||
legacy_data = [
|
||||
{
|
||||
"query": "What is natural language processing?",
|
||||
"pos": [
|
||||
"Natural language processing is a field of AI that helps computers understand human language.",
|
||||
"NLP combines computational linguistics with machine learning to process text and speech."
|
||||
],
|
||||
"neg": [
|
||||
"Automobile manufacturing involves assembly lines and quality control processes.",
|
||||
"Photography captures light through camera lenses to create images."
|
||||
],
|
||||
"pos_scores": [0.93, 0.86],
|
||||
"neg_scores": [0.14, 0.22]
|
||||
},
|
||||
{
|
||||
"query": "How do transformers work in NLP?",
|
||||
"pos": [
|
||||
"Transformers use self-attention mechanisms to process sequential data in parallel.",
|
||||
"The transformer architecture revolutionized NLP with attention-based models like BERT and GPT."
|
||||
],
|
||||
"neg": [
|
||||
"Electric vehicles use batteries to power electric motors for transportation.",
|
||||
"Coffee brewing requires proper water temperature and grinding consistency."
|
||||
],
|
||||
"pos_scores": [0.89, 0.91],
|
||||
"neg_scores": [0.16, 0.13]
|
||||
}
|
||||
]
|
||||
|
||||
with open(legacy_file, 'w', encoding='utf-8') as f:
|
||||
for item in legacy_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
return {
|
||||
'triplets': triplets_file,
|
||||
'pairs': pairs_file,
|
||||
'candidates': candidates_file,
|
||||
'legacy_nested': legacy_file,
|
||||
'temp_dir': temp_dir
|
||||
}
|
||||
|
||||
|
||||
def demonstrate_fine_tuning_loader():
|
||||
"""Demonstrate fine-tuning data loader with all formats"""
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🎯 FINE-TUNING DATA LOADER EXAMPLES")
|
||||
print("="*80)
|
||||
|
||||
# Create sample data
|
||||
data_files = create_sample_data()
|
||||
|
||||
# Initialize tokenizer (use a simple one for demo)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# 1. Triplets format with BGE-M3 Dataset
|
||||
print("\n📊 1. TRIPLETS FORMAT (BGE-M3 Embedding Training)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
triplets_dataset = BGEM3Dataset(
|
||||
data_path=data_files['triplets'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(triplets_dataset)} triplet examples")
|
||||
print(f" Detected format: {triplets_dataset.detected_format}")
|
||||
|
||||
# Get a sample
|
||||
sample = triplets_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Query shape: {sample['query_input_ids'].shape}")
|
||||
print(f" Passages shape: {sample['passage_input_ids'].shape}")
|
||||
print(f" Labels: {sample['labels']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
# 2. Pairs format with BGE-Reranker Dataset
|
||||
print("\n📊 2. PAIRS FORMAT (BGE-Reranker Training)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
pairs_dataset = BGERerankerDataset(
|
||||
data_path=data_files['pairs'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True,
|
||||
legacy_support=False
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(pairs_dataset)} pair examples")
|
||||
print(f" Detected format: {pairs_dataset.detected_format}")
|
||||
|
||||
# Get a sample
|
||||
sample = pairs_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Input shape: {sample['input_ids'].shape}")
|
||||
print(f" Label: {sample['labels']}")
|
||||
print(f" Query: {sample['query_text'][:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
# 3. Candidates format with automatic conversion
|
||||
print("\n📊 3. CANDIDATES FORMAT (Threshold-based Conversion)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
candidates_dataset = BGEDataset(
|
||||
data_path=data_files['candidates'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True,
|
||||
format_type="candidates",
|
||||
candidates_threshold=0.5 # Scores >= 0.5 become label=1
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(candidates_dataset)} exploded examples")
|
||||
print(f" Detected format: {candidates_dataset.detected_format}")
|
||||
print(f" Note: Each query with multiple candidates exploded into separate pairs")
|
||||
|
||||
# Get a sample
|
||||
sample = candidates_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Format after processing: {sample['format']}")
|
||||
print(f" Source: {sample.get('source', 'N/A')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
# 4. Legacy nested format with automatic conversion
|
||||
print("\n📊 4. LEGACY NESTED FORMAT (Automatic Conversion)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
legacy_dataset = BGERerankerDataset(
|
||||
data_path=data_files['legacy_nested'],
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=256,
|
||||
is_train=True,
|
||||
legacy_support=True
|
||||
)
|
||||
|
||||
print(f"✅ Loaded {len(legacy_dataset)} converted examples")
|
||||
print(f" Detected format: {legacy_dataset.detected_format}")
|
||||
print(f" Note: Legacy nested format automatically converted to flat pairs")
|
||||
|
||||
# Get a sample
|
||||
sample = legacy_dataset[0]
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Format after processing: {sample['format']}")
|
||||
print(f" Source: {sample.get('source', 'N/A')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
return data_files
|
||||
|
||||
|
||||
def demonstrate_optimization_pipeline(data_files):
|
||||
"""Demonstrate optimization pipeline with all formats"""
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("⚡ OPTIMIZATION PIPELINE EXAMPLES")
|
||||
print("="*80)
|
||||
|
||||
cache_dir = os.path.join(data_files['temp_dir'], 'cache')
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# 1. BGE-M3 optimization with triplets format
|
||||
print("\n🚀 1. BGE-M3 OPTIMIZATION (Triplets Format)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-m3',
|
||||
input_files=[data_files['triplets']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
hard_negative_ratio=0.7,
|
||||
max_triplets_per_query=8,
|
||||
difficulty_threshold=0.1
|
||||
)
|
||||
|
||||
print(f"✅ M3 optimization complete: {len(cached_files)} files")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ M3 optimization error: {e}")
|
||||
|
||||
# 2. BGE-Reranker optimization with pairs format
|
||||
print("\n🚀 2. BGE-RERANKER OPTIMIZATION (Pairs Format)")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-reranker',
|
||||
input_files=[data_files['pairs']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
output_format='flat'
|
||||
)
|
||||
|
||||
print(f"✅ Reranker optimization complete: {len(cached_files)} files")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Reranker optimization error: {e}")
|
||||
|
||||
# 3. Candidates format optimization (converts to pairs for reranker)
|
||||
print("\n🚀 3. CANDIDATES FORMAT OPTIMIZATION")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-reranker',
|
||||
input_files=[data_files['candidates']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
output_format='flat'
|
||||
)
|
||||
|
||||
print(f"✅ Candidates optimization complete: {len(cached_files)} files")
|
||||
print(" Note: Candidates exploded into pairs with threshold-based labels")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Candidates optimization error: {e}")
|
||||
|
||||
# 4. Candidates to M3 triplets conversion
|
||||
print("\n🚀 4. CANDIDATES TO M3 TRIPLETS CONVERSION")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-m3',
|
||||
input_files=[data_files['candidates']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
hard_negative_ratio=0.8
|
||||
)
|
||||
|
||||
print(f"✅ Candidates-to-M3 optimization complete: {len(cached_files)} files")
|
||||
print(" Note: Candidates split by threshold into pos/neg for triplets")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Candidates-to-M3 optimization error: {e}")
|
||||
|
||||
# 5. Legacy nested format optimization
|
||||
print("\n🚀 5. LEGACY NESTED FORMAT OPTIMIZATION")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
cached_files = optimize_data(
|
||||
model_type='bge-reranker',
|
||||
input_files=[data_files['legacy_nested']],
|
||||
output_dir=cache_dir,
|
||||
tokenizer_path='bert-base-uncased',
|
||||
force=True,
|
||||
output_format='nested' # Keep nested format
|
||||
)
|
||||
|
||||
print(f"✅ Legacy optimization complete: {len(cached_files)} files")
|
||||
print(" Note: Legacy nested format preserved for compatibility")
|
||||
for file in cached_files:
|
||||
print(f" 📦 {file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Legacy optimization error: {e}")
|
||||
|
||||
|
||||
def show_data_format_schemas():
|
||||
"""Show clean data format schemas without inline comments"""
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("📋 DATA FORMAT SCHEMAS")
|
||||
print("="*80)
|
||||
|
||||
print("\n1. 🎯 TRIPLETS FORMAT (BGE-M3 Embedding Training)")
|
||||
print("-" * 60)
|
||||
triplets_schema = {
|
||||
"query": "What is machine learning?",
|
||||
"pos": [
|
||||
"Machine learning is a subset of artificial intelligence...",
|
||||
"ML algorithms learn patterns from data..."
|
||||
],
|
||||
"neg": [
|
||||
"Weather forecasting uses meteorological data...",
|
||||
"Cooking recipes require specific ingredients..."
|
||||
],
|
||||
"pos_scores": [0.95, 0.88],
|
||||
"neg_scores": [0.15, 0.08],
|
||||
"prompt": "为此查询生成表示:",
|
||||
"type": "definition"
|
||||
}
|
||||
print(json.dumps(triplets_schema, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n2. 🔍 PAIRS FORMAT (BGE-Reranker Training)")
|
||||
print("-" * 60)
|
||||
pairs_schema = {
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence...",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
}
|
||||
print(json.dumps(pairs_schema, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n3. 🎲 CANDIDATES FORMAT (Threshold-based Conversion)")
|
||||
print("-" * 60)
|
||||
candidates_schema = {
|
||||
"query": "What is artificial intelligence?",
|
||||
"qid": "q1",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "AI is the simulation of human intelligence in machines.",
|
||||
"score": 0.94,
|
||||
"pid": "c1"
|
||||
},
|
||||
{
|
||||
"text": "Mountain climbing requires proper equipment.",
|
||||
"score": 0.23,
|
||||
"pid": "c2"
|
||||
}
|
||||
]
|
||||
}
|
||||
print(json.dumps(candidates_schema, indent=2, ensure_ascii=False))
|
||||
print("📝 Note: Scores >= threshold become label=1, others become label=0")
|
||||
|
||||
print("\n4. 🔄 LEGACY NESTED FORMAT (Backward Compatibility)")
|
||||
print("-" * 60)
|
||||
legacy_schema = {
|
||||
"query": "What is natural language processing?",
|
||||
"pos": [
|
||||
"NLP is a field of AI that helps computers understand human language.",
|
||||
"NLP combines computational linguistics with machine learning..."
|
||||
],
|
||||
"neg": [
|
||||
"Automobile manufacturing involves assembly lines...",
|
||||
"Photography captures light through camera lenses..."
|
||||
],
|
||||
"pos_scores": [0.93, 0.86],
|
||||
"neg_scores": [0.14, 0.22]
|
||||
}
|
||||
print(json.dumps(legacy_schema, indent=2, ensure_ascii=False))
|
||||
print("📝 Note: Automatically converted to flat pairs for reranker training")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main demonstration function"""
|
||||
|
||||
print("🚀 BGE Data Pipeline Format Examples")
|
||||
print("="*80)
|
||||
print("Demonstrating all four supported formats:")
|
||||
print("1. Triplets (BGE-M3 embedding training)")
|
||||
print("2. Pairs (BGE-reranker training)")
|
||||
print("3. Candidates (unified format with threshold conversion)")
|
||||
print("4. Legacy nested (backward compatibility)")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Show data schemas
|
||||
show_data_format_schemas()
|
||||
|
||||
# Demonstrate fine-tuning loader
|
||||
data_files = demonstrate_fine_tuning_loader()
|
||||
|
||||
# Demonstrate optimization pipeline
|
||||
demonstrate_optimization_pipeline(data_files)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🎉 ALL EXAMPLES COMPLETED SUCCESSFULLY!")
|
||||
print("="*80)
|
||||
|
||||
print("\n📋 SUMMARY:")
|
||||
print("✅ Format Detection: Automatic detection of all four formats")
|
||||
print("✅ Conversion Logic: Seamless conversion between formats")
|
||||
print("✅ Threshold Assignment: Candidates → pairs with configurable threshold")
|
||||
print("✅ Legacy Support: Automatic migration of nested → flat pairs")
|
||||
print("✅ Optimization: Pre-tokenization and caching for 10-15x speedup")
|
||||
print("✅ Backward Compatibility: Full support for existing datasets")
|
||||
|
||||
print(f"\n📁 Sample data and cache files created in: {data_files['temp_dir']}")
|
||||
print(" You can examine the generated files to see the format conversions")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error during demonstration: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
20
data/datasets/examples/reranker_data.jsonl
Normal file
20
data/datasets/examples/reranker_data.jsonl
Normal file
@ -0,0 +1,20 @@
|
||||
{"query":"什么是深度学习?","passage":"深度学习是机器学习的一个子领域,使用多层神经网络模拟人脑处理信息的方式","label":1,"score":1.0,"qid":"Q001","pid":"P001"}
|
||||
{"query":"什么是深度学习?","passage":"深度学习(Deep Learning)是人工智能和机器学习的重要分支,通过构建深层神经网络来学习数据表示","label":1,"score":0.95,"qid":"Q001","pid":"P002"}
|
||||
{"query":"什么是深度学习?","passage":"机器学习包含多种算法,如决策树、支持向量机、神经网络等","label":0,"score":0.6,"qid":"Q001","pid":"N001"}
|
||||
{"query":"什么是深度学习?","passage":"人工智能是计算机科学的一个分支,目标是创建能够执行智能任务的系统","label":0,"score":0.4,"qid":"Q001","pid":"N002"}
|
||||
{"query":"什么是深度学习?","passage":"数据挖掘是从大型数据集中提取模式和知识的过程","label":0,"score":0.3,"qid":"Q001","pid":"N003"}
|
||||
{"query":"如何制作红烧肉?","passage":"红烧肉制作步骤:选五花肉切块,焯水去腥,热锅炒糖色,下肉翻炒,加生抽老抽料酒,小火炖煮30-40分钟至软烂","label":1,"score":1.0,"qid":"Q002","pid":"P003"}
|
||||
{"query":"如何制作红烧肉?","passage":"红烧肉是经典上海菜,用五花肉、冰糖、生抽、老抽、料酒等,关键在于炒糖色和火候控制","label":1,"score":0.9,"qid":"Q002","pid":"P004"}
|
||||
{"query":"如何制作红烧肉?","passage":"红烧肉是中国传统菜肴,口感香甜软糯,深受大众喜爱","label":0,"score":0.5,"qid":"Q002","pid":"N004"}
|
||||
{"query":"如何制作红烧肉?","passage":"五花肉含有丰富的蛋白质和脂肪,营养价值较高","label":0,"score":0.3,"qid":"Q002","pid":"N005"}
|
||||
{"query":"如何制作红烧肉?","passage":"上海菜以红烧见长,口味偏甜,适合南方人饮食习惯","label":0,"score":0.2,"qid":"Q002","pid":"N006"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"Python编程入门需要掌握:基本语法、数据类型(字符串、列表、字典)、控制结构(if/for/while)、函数定义、面向对象编程基础","label":1,"score":1.0,"qid":"Q003","pid":"P005"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"Python新手学习路线:先学变量和数据类型,然后是条件语句和循环,接着学习函数和模块,最后是类和对象","label":1,"score":0.95,"qid":"Q003","pid":"P006"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"Python是一种简单易学的编程语言,语法清晰,适合初学者","label":0,"score":0.4,"qid":"Q003","pid":"N007"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"编程思维比语法更重要,需要培养逻辑思维能力","label":0,"score":0.3,"qid":"Q003","pid":"N008"}
|
||||
{"query":"Python编程入门应该学习哪些内容?","passage":"计算机科学包含算法、数据结构、操作系统等多个分支","label":0,"score":0.2,"qid":"Q003","pid":"N009"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"北京必去景点推荐:故宫(紫禁城)、天安门广场、万里长城、颐和园、天坛、圆明园、北海公园、什刹海等","label":1,"score":1.0,"qid":"Q004","pid":"P007"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"故宫是明清两代皇宫,占地72万平方米,是世界最大的古代宫殿建筑群,必须提前预约参观","label":1,"score":0.9,"qid":"Q004","pid":"P008"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"北京是中华人民共和国首都,有3000多年建城史","label":0,"score":0.3,"qid":"Q004","pid":"N010"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"中国拥有众多世界文化遗产,旅游资源丰富","label":0,"score":0.25,"qid":"Q004","pid":"N011"}
|
||||
{"query":"北京有哪些必去的旅游景点?","passage":"春秋季节是北京旅游的最佳时间,气候宜人","label":0,"score":0.4,"qid":"Q004","pid":"N012"}
|
||||
2776
data/datasets/三国演义/m3_train.jsonl
Normal file
2776
data/datasets/三国演义/m3_train.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
2776
data/datasets/三国演义/m3_train_converted.jsonl
Normal file
2776
data/datasets/三国演义/m3_train_converted.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
4310
data/datasets/三国演义/reranker_train.jsonl
Normal file
4310
data/datasets/三国演义/reranker_train.jsonl
Normal file
File diff suppressed because one or more lines are too long
534
data/hard_negative_mining.py
Normal file
534
data/hard_negative_mining.py
Normal file
@ -0,0 +1,534 @@
|
||||
"""
|
||||
ANCE-style hard negative mining for BGE models
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import queue
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
import faiss
|
||||
|
||||
# Optional imports for text processing
|
||||
try:
|
||||
import jieba
|
||||
JIEBA_AVAILABLE = True
|
||||
except ImportError:
|
||||
JIEBA_AVAILABLE = False
|
||||
jieba = None
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
BM25_AVAILABLE = True
|
||||
except ImportError:
|
||||
BM25_AVAILABLE = False
|
||||
BM25Okapi = None
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ANCEHardNegativeMiner:
|
||||
"""
|
||||
Approximate Nearest Neighbor Negative Contrastive Learning
|
||||
Dynamically mines hard negatives using evolving model
|
||||
|
||||
Notes:
|
||||
- All parameters (embedding_dim, num_hard_negatives, negative_range, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
num_hard_negatives: int = 15,
|
||||
negative_range: Tuple[int, int] = (10, 100),
|
||||
margin: float = 0.1,
|
||||
index_refresh_interval: int = 1000,
|
||||
index_type: str = "IVF",
|
||||
nlist: int = 100,
|
||||
use_gpu: bool = True,
|
||||
max_index_size: int = 1000000
|
||||
):
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_hard_negatives = num_hard_negatives
|
||||
self.negative_range = negative_range
|
||||
self.margin = margin
|
||||
self.index_refresh_interval = index_refresh_interval
|
||||
self.use_gpu = use_gpu and torch.cuda.is_available()
|
||||
self.max_index_size = max_index_size
|
||||
|
||||
# Initialize FAISS index
|
||||
self.index = self._create_index(index_type, nlist)
|
||||
self.passage_map: Dict[int, str] = {} # Maps index ID to passage text
|
||||
self.query_map: Dict[str, List[str]] = defaultdict(list) # Maps query to positive passages
|
||||
|
||||
# For asynchronous index updates
|
||||
self.update_queue = queue.Queue()
|
||||
self.update_thread = None
|
||||
self.stop_update_thread = threading.Event()
|
||||
self.current_step = 0
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'total_queries': 0,
|
||||
'total_negatives_mined': 0,
|
||||
'index_updates': 0,
|
||||
'cache_hits': 0,
|
||||
'cache_misses': 0
|
||||
}
|
||||
|
||||
def _create_index(self, index_type: str, nlist: int) -> faiss.Index:
|
||||
"""Create FAISS index, warn if falling back to CPU."""
|
||||
if index_type == "IVF":
|
||||
quantizer = faiss.IndexFlatIP(self.embedding_dim)
|
||||
index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT)
|
||||
elif index_type == "HNSW":
|
||||
index = faiss.IndexHNSWFlat(self.embedding_dim, 32, faiss.METRIC_INNER_PRODUCT)
|
||||
else:
|
||||
index = faiss.IndexFlatIP(self.embedding_dim)
|
||||
if self.use_gpu:
|
||||
try:
|
||||
res = faiss.StandardGpuResources() # type: ignore[attr-defined]
|
||||
index = faiss.index_cpu_to_gpu(res, 0, index) # type: ignore[attr-defined]
|
||||
logger.info("Using GPU for FAISS index")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to use GPU for FAISS: {e}. Using CPU.")
|
||||
self.use_gpu = False
|
||||
else:
|
||||
logger.warning("FAISS is running on CPU. This may be slower than GPU.")
|
||||
return index
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters"""
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
def start_async_updates(self):
|
||||
"""Start background thread for asynchronous index updates"""
|
||||
if self.update_thread is None or not self.update_thread.is_alive():
|
||||
self.stop_update_thread.clear()
|
||||
self.update_thread = threading.Thread(target=self._update_worker)
|
||||
self.update_thread.daemon = True
|
||||
self.update_thread.start()
|
||||
logger.info("Started asynchronous index update thread")
|
||||
|
||||
def stop_async_updates(self):
|
||||
"""Stop background update thread and ensure clean shutdown."""
|
||||
if self.update_thread and self.update_thread.is_alive():
|
||||
self.stop_update_thread.set()
|
||||
self.update_thread.join(timeout=5)
|
||||
logger.info("Stopped asynchronous index update thread")
|
||||
|
||||
def _update_worker(self):
|
||||
"""Background worker for index updates"""
|
||||
while not self.stop_update_thread.is_set():
|
||||
try:
|
||||
# Get update task with timeout
|
||||
task = self.update_queue.get(timeout=1)
|
||||
|
||||
if task['type'] == 'add_embeddings':
|
||||
self._add_to_index(task['embeddings'], task['passages'])
|
||||
elif task['type'] == 'rebuild_index':
|
||||
self._rebuild_index(task['model'])
|
||||
|
||||
self.stats['index_updates'] += 1
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in update worker: {e}")
|
||||
|
||||
def _add_to_index(self, embeddings: np.ndarray, passages: List[str]):
|
||||
"""Add embeddings to index"""
|
||||
try:
|
||||
# Normalize embeddings for inner product search
|
||||
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
|
||||
# Add to index
|
||||
start_idx = len(self.passage_map)
|
||||
|
||||
# Debug log for embeddings
|
||||
logger.debug(f"embeddings shape: {getattr(embeddings, 'shape', None)}, type: {type(embeddings)}")
|
||||
|
||||
if hasattr(self.index, 'is_trained') and not self.index.is_trained:
|
||||
if embeddings is not None and len(embeddings) > 0:
|
||||
if hasattr(self.index, 'nlist') and len(embeddings) >= getattr(self.index, 'nlist', 0):
|
||||
# Safe to access self.index.nlist and call train
|
||||
logger.info("Training FAISS index...")
|
||||
self.index.train(embeddings) # type: ignore[call-arg]
|
||||
elif hasattr(self.index, 'nlist'):
|
||||
logger.warning(f"Not enough data to train index. Need {getattr(self.index, 'nlist', 0)}, got {len(embeddings)}")
|
||||
return
|
||||
else:
|
||||
# For index types without nlist/train, skip training
|
||||
pass
|
||||
else:
|
||||
logger.warning("Embeddings is None or empty, skipping training and addition to index.")
|
||||
return
|
||||
|
||||
if embeddings is not None and len(embeddings) > 0:
|
||||
self.index.add(embeddings) # type: ignore[call-arg]
|
||||
else:
|
||||
logger.warning("Embeddings is None or empty, skipping addition to index.")
|
||||
|
||||
# Update passage map
|
||||
for i, passage in enumerate(passages):
|
||||
self.passage_map[start_idx + i] = passage
|
||||
|
||||
logger.info(f"Added {len(passages)} passages to index. Total: {self.index.ntotal}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding embeddings to index: {e}")
|
||||
|
||||
def _rebuild_index(self, model):
|
||||
"""Rebuild entire index with current model"""
|
||||
try:
|
||||
logger.info("Rebuilding entire index...")
|
||||
|
||||
# Clear current index
|
||||
self.index.reset()
|
||||
old_passage_map = self.passage_map.copy()
|
||||
self.passage_map.clear()
|
||||
|
||||
# Re-encode all passages
|
||||
all_passages = list(set(sum(self.query_map.values(), [])))
|
||||
if not all_passages:
|
||||
logger.warning("No passages to rebuild index with")
|
||||
return
|
||||
|
||||
embeddings = []
|
||||
batch_size = 32
|
||||
|
||||
for i in tqdm(range(0, len(all_passages), batch_size), desc="Encoding passages"):
|
||||
batch = all_passages[i:i + batch_size]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
batch_embeddings = model.encode(batch, convert_to_numpy=True, batch_size=len(batch))
|
||||
embeddings.append(batch_embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding batch {i // batch_size}: {e}")
|
||||
continue
|
||||
|
||||
if embeddings:
|
||||
embeddings = np.vstack(embeddings)
|
||||
self._add_to_index(embeddings, all_passages)
|
||||
logger.info(f"Rebuilt index with {len(all_passages)} passages")
|
||||
else:
|
||||
# Restore old passage map if rebuild failed
|
||||
self.passage_map = old_passage_map
|
||||
logger.error("Failed to rebuild index - no embeddings generated")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rebuilding index: {e}")
|
||||
|
||||
def mine_hard_negatives(
|
||||
self,
|
||||
queries: List[str],
|
||||
query_embeddings: np.ndarray,
|
||||
positive_passages: List[List[str]],
|
||||
corpus_embeddings: Optional[np.ndarray] = None,
|
||||
corpus_passages: Optional[List[str]] = None
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Mine hard negatives for queries using ANCE approach
|
||||
|
||||
Args:
|
||||
queries: List of query texts
|
||||
query_embeddings: Query embeddings [num_queries, embedding_dim]
|
||||
positive_passages: List of positive passages for each query
|
||||
corpus_embeddings: Optional corpus embeddings to add to index
|
||||
corpus_passages: Optional corpus passages
|
||||
|
||||
Returns:
|
||||
List of hard negative passages for each query
|
||||
"""
|
||||
# Add corpus to index if provided
|
||||
if corpus_embeddings is not None and corpus_passages is not None:
|
||||
self.update_queue.put({
|
||||
'type': 'add_embeddings',
|
||||
'embeddings': corpus_embeddings,
|
||||
'passages': corpus_passages
|
||||
})
|
||||
|
||||
# Update query map
|
||||
for query, positives in zip(queries, positive_passages):
|
||||
self.query_map[query].extend(positives)
|
||||
|
||||
# Normalize query embeddings
|
||||
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
|
||||
|
||||
# Search for hard negatives
|
||||
hard_negatives = []
|
||||
k = min(self.negative_range[1], self.index.ntotal)
|
||||
|
||||
if k == 0:
|
||||
logger.warning("Index is empty, returning empty hard negatives")
|
||||
return [[] for _ in queries]
|
||||
|
||||
similarities, indices = self.index.search(query_embeddings, k) # type: ignore[call-arg]
|
||||
|
||||
for i, (query, positives) in enumerate(zip(queries, positive_passages)):
|
||||
query_hard_negs = []
|
||||
candidates = []
|
||||
|
||||
# Collect candidates within negative range
|
||||
for j in range(self.negative_range[0], min(k, self.negative_range[1])):
|
||||
if indices[i, j] >= 0: # Valid index
|
||||
passage_idx = indices[i, j]
|
||||
if passage_idx in self.passage_map:
|
||||
passage = self.passage_map[passage_idx]
|
||||
score = similarities[i, j]
|
||||
|
||||
# Skip if it's a positive passage
|
||||
if passage not in positives:
|
||||
candidates.append((passage, score))
|
||||
|
||||
# Select hard negatives with margin
|
||||
if candidates:
|
||||
# Sort by score (descending)
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Select with margin constraint
|
||||
selected = []
|
||||
last_score = float('inf')
|
||||
for passage, score in candidates:
|
||||
if last_score - score >= self.margin or len(selected) == 0:
|
||||
selected.append(passage)
|
||||
last_score = score
|
||||
if len(selected) >= self.num_hard_negatives:
|
||||
break
|
||||
|
||||
query_hard_negs = selected
|
||||
|
||||
# Pad with random negatives if needed
|
||||
if len(query_hard_negs) < self.num_hard_negatives:
|
||||
available = [p for p in self.passage_map.values()
|
||||
if p not in positives and p not in query_hard_negs]
|
||||
if available:
|
||||
additional_count = min(len(available), self.num_hard_negatives - len(query_hard_negs))
|
||||
additional = np.random.choice(available, size=additional_count, replace=False).tolist()
|
||||
query_hard_negs.extend(additional)
|
||||
|
||||
hard_negatives.append(query_hard_negs)
|
||||
|
||||
self.stats['total_queries'] += len(queries)
|
||||
self.stats['total_negatives_mined'] += sum(len(negs) for negs in hard_negatives)
|
||||
|
||||
return hard_negatives
|
||||
|
||||
def update_index_if_needed(self, step: int, model=None):
|
||||
"""Check if index needs update and trigger if necessary"""
|
||||
self.current_step = step
|
||||
|
||||
if step > 0 and step % self.index_refresh_interval == 0 and model is not None:
|
||||
logger.info(f"Triggering index rebuild at step {step}")
|
||||
self.update_queue.put({
|
||||
'type': 'rebuild_index',
|
||||
'model': model
|
||||
})
|
||||
|
||||
def save_mined_data(self, output_path: str, original_data_path: str):
|
||||
"""Save data with mined hard negatives"""
|
||||
logger.info(f"Saving mined data to {output_path}")
|
||||
|
||||
# Load original data
|
||||
original_data = []
|
||||
with open(original_data_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
original_data.append(json.loads(line))
|
||||
|
||||
# This is a simplified implementation - in practice you'd batch mine negatives
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for item in tqdm(original_data, desc="Saving mined data"):
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
logger.info(f"Saved {len(original_data)} examples")
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
"""Get mining statistics"""
|
||||
self.stats['index_size'] = self.index.ntotal
|
||||
self.stats['unique_passages'] = len(self.passage_map)
|
||||
return self.stats.copy()
|
||||
|
||||
|
||||
class HardNegativeDataAugmenter:
|
||||
"""
|
||||
Augment training data with various hard negative strategies
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.strategies = config.get('strategies', ['ance', 'bm25', 'semantic'])
|
||||
self.bm25_index = None
|
||||
self.corpus_passages = None
|
||||
|
||||
def build_bm25_index(self, corpus: List[str]):
|
||||
"""Build BM25 index for hard negative mining"""
|
||||
try:
|
||||
logger.info("Building BM25 index...")
|
||||
self.corpus_passages = corpus
|
||||
|
||||
# Tokenize corpus
|
||||
tokenized_corpus = []
|
||||
for doc in tqdm(corpus, desc="Tokenizing corpus"):
|
||||
if self._is_chinese(doc) and JIEBA_AVAILABLE:
|
||||
tokens = list(jieba.cut(doc))
|
||||
else:
|
||||
tokens = doc.lower().split()
|
||||
tokenized_corpus.append(tokens)
|
||||
|
||||
if BM25_AVAILABLE:
|
||||
self.bm25_index = BM25Okapi(tokenized_corpus)
|
||||
logger.info(f"Built BM25 index with {len(corpus)} documents")
|
||||
else:
|
||||
logger.warning("BM25Okapi not available, skipping BM25 index")
|
||||
self.bm25_index = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building BM25 index: {e}")
|
||||
self.bm25_index = None
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters"""
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
def augment_data(
|
||||
self,
|
||||
data: List[Dict],
|
||||
model=None,
|
||||
corpus: Optional[List[str]] = None
|
||||
) -> List[Dict]:
|
||||
"""Apply multiple hard negative augmentation strategies"""
|
||||
|
||||
# Build BM25 index if corpus provided
|
||||
if corpus and 'bm25' in self.strategies:
|
||||
self.build_bm25_index(corpus)
|
||||
|
||||
augmented_data = []
|
||||
|
||||
for item in tqdm(data, desc="Augmenting data"):
|
||||
query = item['query']
|
||||
positives = item.get('pos', [])
|
||||
negatives = item.get('neg', [])
|
||||
|
||||
# Apply different strategies
|
||||
if 'bm25' in self.strategies and self.bm25_index:
|
||||
bm25_negs = self._mine_bm25_negatives(query, positives)
|
||||
negatives.extend(bm25_negs[:5]) # Add top 5 BM25 negatives
|
||||
|
||||
if 'semantic' in self.strategies and model:
|
||||
semantic_negs = self._mine_semantic_negatives(query, positives, model)
|
||||
negatives.extend(semantic_negs[:5])
|
||||
|
||||
if 'random_swap' in self.strategies:
|
||||
swap_negs = self._create_swap_negatives(positives)
|
||||
negatives.extend(swap_negs[:3])
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_negatives = []
|
||||
for neg in negatives:
|
||||
if neg not in seen and neg not in positives:
|
||||
seen.add(neg)
|
||||
unique_negatives.append(neg)
|
||||
|
||||
augmented_item = item.copy()
|
||||
augmented_item['neg'] = unique_negatives[:self.config.get('max_negatives', 30)]
|
||||
augmented_data.append(augmented_item)
|
||||
|
||||
return augmented_data
|
||||
|
||||
def _mine_bm25_negatives(self, query: str, positives: List[str]) -> List[str]:
|
||||
"""Mine hard negatives using BM25"""
|
||||
if not self.bm25_index or not self.corpus_passages or not BM25_AVAILABLE:
|
||||
return []
|
||||
|
||||
# Tokenize query
|
||||
if self._is_chinese(query) and JIEBA_AVAILABLE:
|
||||
query_tokens = list(jieba.cut(query))
|
||||
else:
|
||||
query_tokens = query.lower().split()
|
||||
|
||||
# Get BM25 scores
|
||||
scores = self.bm25_index.get_scores(query_tokens)
|
||||
|
||||
# Get top passages excluding positives
|
||||
top_indices = np.argsort(scores)[::-1]
|
||||
negatives = []
|
||||
|
||||
for idx in top_indices:
|
||||
if len(negatives) >= 10:
|
||||
break
|
||||
passage = self.corpus_passages[idx]
|
||||
if passage not in positives:
|
||||
negatives.append(passage)
|
||||
|
||||
return negatives
|
||||
|
||||
def _mine_semantic_negatives(self, query: str, positives: List[str], model) -> List[str]:
|
||||
"""
|
||||
Mine semantically similar but irrelevant passages using the model's encode method.
|
||||
Returns top-k most similar passages (excluding positives).
|
||||
"""
|
||||
try:
|
||||
if not self.corpus_passages or not model:
|
||||
logger.warning("No corpus or model provided for semantic negative mining.")
|
||||
return []
|
||||
|
||||
batch_size = self.config.get('semantic_batch_size', 64)
|
||||
num_negatives = self.config.get('semantic_num_negatives', 10)
|
||||
device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Encode query
|
||||
query_embedding = model.encode([query], convert_to_numpy=True, device=device)
|
||||
# Encode all corpus passages in batches
|
||||
passage_embeddings = []
|
||||
for i in range(0, len(self.corpus_passages), batch_size):
|
||||
batch = self.corpus_passages[i:i+batch_size]
|
||||
batch_emb = model.encode(batch, convert_to_numpy=True, device=device)
|
||||
passage_embeddings.append(batch_emb)
|
||||
passage_embeddings = np.vstack(passage_embeddings)
|
||||
|
||||
# Compute similarity
|
||||
scores = np.dot(passage_embeddings, query_embedding[0])
|
||||
# Exclude positives
|
||||
candidates = [(idx, score) for idx, (score, passage) in enumerate(zip(scores, self.corpus_passages)) if passage not in positives]
|
||||
# Sort and select top-k
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
top_indices = [idx for idx, _ in candidates[:num_negatives]]
|
||||
negatives = [self.corpus_passages[idx] for idx in top_indices]
|
||||
logger.info(f"Mined {len(negatives)} semantic negatives for query: {query}")
|
||||
return negatives
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error mining semantic negatives: {e}")
|
||||
return []
|
||||
|
||||
def _create_swap_negatives(self, positives: List[str]) -> List[str]:
|
||||
"""Create negatives by swapping words in positive passages"""
|
||||
negatives = []
|
||||
|
||||
for pos in positives[:2]: # Only use first 2 positives
|
||||
words = pos.split()
|
||||
if len(words) > 5:
|
||||
# Swap random words
|
||||
import random
|
||||
swapped = words.copy()
|
||||
for _ in range(min(3, len(words) // 3)):
|
||||
i, j = random.sample(range(len(words)), 2)
|
||||
swapped[i], swapped[j] = swapped[j], swapped[i]
|
||||
negatives.append(' '.join(swapped))
|
||||
|
||||
return negatives
|
||||
687
data/preprocessing.py
Normal file
687
data/preprocessing.py
Normal file
@ -0,0 +1,687 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
import csv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataPreprocessor:
|
||||
"""
|
||||
Comprehensive data preprocessor for various input formats
|
||||
Supports JSON, JSONL, CSV, TSV, XML, and custom formats
|
||||
|
||||
Notes:
|
||||
- All parameters (tokenizer, max_query_length, max_passage_length, min_passage_length, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Optional[AutoTokenizer] = None,
|
||||
max_query_length: int = 512,
|
||||
max_passage_length: int = 8192,
|
||||
# For production, set min_passage_length to 10
|
||||
min_passage_length: int = 3
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_query_length = max_query_length
|
||||
self.max_passage_length = max_passage_length
|
||||
self.min_passage_length = min_passage_length
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
file_format: str = "auto",
|
||||
validation_split: float = 0.1
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
Preprocess input file and convert to standard JSONL format.
|
||||
Args:
|
||||
input_path: Path to input file
|
||||
output_path: Path to output JSONL file
|
||||
file_format: Input file format (auto, json, jsonl, csv, tsv, xml, dureader, msmarco)
|
||||
validation_split: Fraction of data to use for validation
|
||||
Returns:
|
||||
Tuple of (train_path, val_path)
|
||||
"""
|
||||
# Auto-detect format
|
||||
if file_format == "auto":
|
||||
file_format = self._detect_format(input_path)
|
||||
|
||||
logger.info(f"Processing {input_path} as {file_format} format")
|
||||
|
||||
# Load data based on format
|
||||
if file_format == "jsonl":
|
||||
data = self._load_jsonl(input_path)
|
||||
elif file_format == "json":
|
||||
data = self._load_json(input_path)
|
||||
elif file_format in ["csv", "tsv"]:
|
||||
data = self._load_csv(input_path, delimiter="," if file_format == "csv" else "\t")
|
||||
elif file_format == "xml":
|
||||
data = self._load_xml(input_path)
|
||||
elif file_format == "dureader":
|
||||
data = self._load_dureader(input_path)
|
||||
elif file_format == "msmarco":
|
||||
data = self._load_msmarco(input_path)
|
||||
elif file_format == "beir":
|
||||
data = self._load_beir(input_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {file_format}")
|
||||
|
||||
# Validate and clean data
|
||||
data = self._validate_and_clean(data)
|
||||
|
||||
# Split into train/val
|
||||
if validation_split > 0:
|
||||
train_data, val_data = self._split_data(data, validation_split)
|
||||
|
||||
# Save training data
|
||||
train_path = output_path.replace(".jsonl", "_train.jsonl")
|
||||
self._save_jsonl(train_data, train_path)
|
||||
|
||||
# Save validation data
|
||||
val_path = output_path.replace(".jsonl", "_val.jsonl")
|
||||
self._save_jsonl(val_data, val_path)
|
||||
|
||||
logger.info(f"Saved {len(train_data)} training and {len(val_data)} validation examples")
|
||||
return train_path, val_path
|
||||
else:
|
||||
self._save_jsonl(data, output_path)
|
||||
logger.info(f"Saved {len(data)} examples to {output_path}")
|
||||
return output_path, None
|
||||
|
||||
def _detect_format(self, file_path: str) -> str:
|
||||
"""Auto-detect file format based on extension and content"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
# First try by extension
|
||||
if ext == ".jsonl":
|
||||
return "jsonl"
|
||||
elif ext == ".json":
|
||||
return "json"
|
||||
elif ext == ".csv":
|
||||
return "csv"
|
||||
elif ext == ".tsv":
|
||||
return "tsv"
|
||||
elif ext == ".xml":
|
||||
return "xml"
|
||||
|
||||
# Try to detect by content
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
first_lines = [f.readline().strip() for _ in range(5)]
|
||||
first_content = '\n'.join(first_lines)
|
||||
|
||||
# Check for JSON Lines
|
||||
if all(line.startswith("{") and line.endswith("}") for line in first_lines if line):
|
||||
return "jsonl"
|
||||
|
||||
# Check for JSON
|
||||
if first_content.strip().startswith("[") or first_content.strip().startswith("{"):
|
||||
return "json"
|
||||
|
||||
# Check for XML
|
||||
if first_content.strip().startswith("<?xml") or first_content.strip().startswith("<"):
|
||||
return "xml"
|
||||
|
||||
# Check for TSV (tabs)
|
||||
if any("\t" in line for line in first_lines):
|
||||
return "tsv"
|
||||
|
||||
# Check for CSV (commas)
|
||||
if any("," in line for line in first_lines):
|
||||
return "csv"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not auto-detect format for {file_path}: {e}")
|
||||
|
||||
# Default fallback
|
||||
return "jsonl"
|
||||
|
||||
def _load_jsonl(self, file_path: str) -> List[Dict]:
|
||||
"""Load JSONL file with robust error handling and detailed logging."""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading JSONL"), 1):
|
||||
line = line.strip()
|
||||
if line:
|
||||
total_lines += 1
|
||||
try:
|
||||
data.append(json.loads(line))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e} | Content: {line[:100]}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping malformed line at {file_path}:{line_num}: {e} | Content: {line[:100]}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JSONL file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_json(self, file_path: str) -> List[Dict]:
|
||||
"""Load JSON file"""
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Convert to list format if needed
|
||||
if isinstance(data, dict):
|
||||
if 'data' in data:
|
||||
data = data['data']
|
||||
else:
|
||||
data = [data]
|
||||
elif not isinstance(data, list):
|
||||
data = [data]
|
||||
total_lines = len(data)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in file {file_path}: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JSON file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_csv(self, file_path: str, delimiter: str = ",") -> List[Dict]:
|
||||
"""Load CSV/TSV file with robust parsing"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
# Try different encodings
|
||||
encodings = ['utf-8', 'gb2312', 'gbk', 'big5', 'latin-1']
|
||||
df = None
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
df = pd.read_csv(file_path, delimiter=delimiter, encoding=encoding)
|
||||
logger.info(f"Successfully loaded CSV with {encoding} encoding")
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Error with encoding {encoding} in {file_path}: {e}")
|
||||
continue
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"Could not load CSV file {file_path} with any encoding")
|
||||
|
||||
# Process each row
|
||||
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CSV/TSV"):
|
||||
total_lines += 1
|
||||
try:
|
||||
item = {}
|
||||
# Map common column names to standard format
|
||||
row_dict = {str(k).lower().strip(): v for k, v in row.items() if pd.notna(v)}
|
||||
# Extract query
|
||||
query_fields = ['query', 'question', 'q', 'text', 'sentence']
|
||||
for field in query_fields:
|
||||
if field in row_dict:
|
||||
item['query'] = str(row_dict[field]).strip()
|
||||
break
|
||||
# Extract positive passages
|
||||
pos_fields = ['positive', 'pos', 'passage', 'document', 'answer', 'response']
|
||||
positives = []
|
||||
for field in pos_fields:
|
||||
if field in row_dict:
|
||||
pos_text = str(row_dict[field]).strip()
|
||||
if '|||' in pos_text:
|
||||
positives.extend([p.strip() for p in pos_text.split('|||') if p.strip()])
|
||||
else:
|
||||
positives.append(pos_text)
|
||||
break
|
||||
if 'passage' in row_dict and 'label' in row_dict:
|
||||
try:
|
||||
label = int(float(row_dict['label']))
|
||||
passage = str(row_dict['passage']).strip()
|
||||
if label == 1 and passage:
|
||||
positives.append(passage)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if positives:
|
||||
item['pos'] = positives
|
||||
# Extract negative passages
|
||||
neg_fields = ['negative', 'neg', 'hard_negative']
|
||||
negatives = []
|
||||
for field in neg_fields:
|
||||
if field in row_dict:
|
||||
neg_text = str(row_dict[field]).strip()
|
||||
if '|||' in neg_text:
|
||||
negatives.extend([n.strip() for n in neg_text.split('|||') if n.strip()])
|
||||
else:
|
||||
negatives.append(neg_text)
|
||||
break
|
||||
if 'passage' in row_dict and 'label' in row_dict:
|
||||
try:
|
||||
label = int(float(row_dict['label']))
|
||||
passage = str(row_dict['passage']).strip()
|
||||
if label == 0 and passage:
|
||||
negatives.append(passage)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if negatives:
|
||||
item['neg'] = negatives
|
||||
# Extract scores if available
|
||||
score_fields = ['score', 'relevance', 'rating']
|
||||
for field in score_fields:
|
||||
if field in row_dict:
|
||||
try:
|
||||
score = float(row_dict[field])
|
||||
if 'pos' in item:
|
||||
item['pos_scores'] = [score] * len(item['pos'])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# Only add if we have minimum required fields
|
||||
if 'query' in item and ('pos' in item or 'neg' in item):
|
||||
data.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping malformed row in {file_path} at index {idx}: {e} | Row: {row}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CSV file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_tsv(self, file_path: str) -> List[Dict]:
|
||||
"""Load TSV file (wrapper around _load_csv)."""
|
||||
return self._load_csv(file_path, delimiter="\t")
|
||||
|
||||
def _load_xml(self, file_path: str) -> List[Dict]:
|
||||
"""Load XML file (e.g., TREC format)"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
tree = ET.parse(file_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Handle different XML structures
|
||||
if root.tag == 'questions' or root.tag == 'queries':
|
||||
# TREC-style format
|
||||
for question in tqdm(root.findall('.//question') or root.findall('.//query'), desc="Loading XML"):
|
||||
item = {}
|
||||
|
||||
# Extract query
|
||||
query_elem = question.find('text') or question.find('title') or question.find('query')
|
||||
if query_elem is not None and query_elem.text is not None:
|
||||
item['query'] = query_elem.text.strip()
|
||||
|
||||
# Extract passages/documents
|
||||
passages = []
|
||||
for doc in question.findall('.//document') or question.findall('.//passage'):
|
||||
if doc.text:
|
||||
passages.append(doc.text.strip())
|
||||
|
||||
if passages:
|
||||
item['pos'] = passages
|
||||
|
||||
if 'query' in item:
|
||||
data.append(item)
|
||||
|
||||
else:
|
||||
# Generic XML - try to extract text content
|
||||
for elem in tqdm(root.iter(), desc="Processing XML"):
|
||||
if elem.text and len(elem.text.strip()) > 20: # Minimum length check
|
||||
item = {
|
||||
'query': elem.text.strip()[:100], # First part as query
|
||||
'pos': [elem.text.strip()] # Full text as passage
|
||||
}
|
||||
data.append(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading XML file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_dureader(self, file_path: str) -> List[Dict]:
|
||||
"""Load DuReader format data"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading DuReader"), 1):
|
||||
if line.strip():
|
||||
total_lines += 1
|
||||
try:
|
||||
item = json.loads(line)
|
||||
|
||||
# Convert DuReader format
|
||||
processed = {
|
||||
'query': item.get('question', '').strip(),
|
||||
'pos': [],
|
||||
'neg': []
|
||||
}
|
||||
|
||||
# Extract positive passages from answers
|
||||
if 'answers' in item:
|
||||
for answer in item['answers']:
|
||||
if isinstance(answer, dict) and 'text' in answer:
|
||||
processed['pos'].append(answer['text'].strip())
|
||||
elif isinstance(answer, str):
|
||||
processed['pos'].append(answer.strip())
|
||||
|
||||
# Extract documents as potential negatives
|
||||
if 'documents' in item:
|
||||
for doc in item['documents']:
|
||||
if isinstance(doc, dict):
|
||||
if 'paragraphs' in doc:
|
||||
for para in doc['paragraphs']:
|
||||
para_text = para.strip() if isinstance(para, str) else str(para).strip()
|
||||
if para_text and para_text not in processed['pos']:
|
||||
processed['neg'].append(para_text)
|
||||
elif 'content' in doc:
|
||||
content = doc['content'].strip()
|
||||
if content and content not in processed['pos']:
|
||||
processed['neg'].append(content)
|
||||
|
||||
if processed['query'] and processed['pos']:
|
||||
data.append(processed)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON line in DuReader: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading DuReader file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_msmarco(self, file_path: str) -> List[Dict]:
|
||||
"""Load MS MARCO format data"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading MS MARCO"), 1):
|
||||
line = line.strip()
|
||||
if line:
|
||||
total_lines += 1
|
||||
parts = line.split('\t')
|
||||
if len(parts) >= 2:
|
||||
item = {
|
||||
'query': parts[0].strip(),
|
||||
'pos': [parts[1].strip()],
|
||||
'neg': [p.strip() for p in parts[2:] if p.strip()]
|
||||
}
|
||||
data.append(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading MS MARCO file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_beir(self, file_path: str) -> List[Dict]:
|
||||
"""Load BEIR format data"""
|
||||
data = []
|
||||
total_lines = 0
|
||||
try:
|
||||
# BEIR format is typically JSONL with specific structure
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(tqdm(f, desc="Loading BEIR"), 1):
|
||||
if line.strip():
|
||||
total_lines += 1
|
||||
try:
|
||||
item = json.loads(line)
|
||||
|
||||
# BEIR format usually has 'query', 'positive_passages', 'negative_passages'
|
||||
processed = {}
|
||||
|
||||
if 'query' in item:
|
||||
processed['query'] = item['query'].strip()
|
||||
elif 'text' in item:
|
||||
processed['query'] = item['text'].strip()
|
||||
|
||||
if 'positive_passages' in item:
|
||||
processed['pos'] = [p.strip() for p in item['positive_passages'] if p.strip()]
|
||||
elif 'pos' in item:
|
||||
processed['pos'] = [p.strip() for p in item['pos'] if p.strip()]
|
||||
|
||||
if 'negative_passages' in item:
|
||||
processed['neg'] = [n.strip() for n in item['negative_passages'] if n.strip()]
|
||||
elif 'neg' in item:
|
||||
processed['neg'] = [n.strip() for n in item['neg'] if n.strip()]
|
||||
|
||||
if 'query' in processed and 'pos' in processed:
|
||||
data.append(processed)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON line in BEIR: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading BEIR file {file_path}: {e}")
|
||||
raise
|
||||
if total_lines > 0 and (len(data) < 10 or len(data) < 0.01 * total_lines):
|
||||
logger.warning(
|
||||
f"Too few valid data lines loaded from {file_path}: {len(data)} out of {total_lines}."
|
||||
)
|
||||
return data
|
||||
|
||||
def _validate_and_clean(self, data: List[Dict]) -> List[Dict]:
|
||||
"""Validate and clean data"""
|
||||
cleaned = []
|
||||
|
||||
for item in tqdm(data, desc="Validating data"):
|
||||
try:
|
||||
# Ensure required fields
|
||||
if 'query' not in item or not item['query'].strip():
|
||||
continue
|
||||
|
||||
# Clean query
|
||||
if self.tokenizer:
|
||||
encoding = self.tokenizer(item['query'], add_special_tokens=False) # type: ignore[operator]
|
||||
input_ids = encoding['input_ids']
|
||||
if len(input_ids) > self.max_query_length:
|
||||
input_ids = input_ids[:self.max_query_length]
|
||||
item['query'] = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
|
||||
else:
|
||||
if len(item['query']) > self.max_query_length:
|
||||
item['query'] = item['query'][: self.max_query_length]
|
||||
|
||||
# Ensure at least one positive
|
||||
if 'pos' not in item or not item['pos']:
|
||||
continue
|
||||
|
||||
# Clean positives
|
||||
item['pos'] = self._filter_passages_by_length(item['pos'])
|
||||
if not item['pos']:
|
||||
continue
|
||||
|
||||
# Clean negatives
|
||||
if 'neg' in item:
|
||||
item['neg'] = self._filter_passages_by_length(item['neg'])
|
||||
else:
|
||||
item['neg'] = []
|
||||
|
||||
# Filter by length if tokenizer available
|
||||
if self.tokenizer:
|
||||
# Check query length
|
||||
# Check passage lengths
|
||||
item['pos'] = self._filter_passages_by_length(item['pos'])
|
||||
item['neg'] = self._filter_passages_by_length(item['neg'])
|
||||
|
||||
# Ensure we still have data after filtering
|
||||
if item['pos']:
|
||||
cleaned.append(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing item: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Cleaned {len(cleaned)} examples from {len(data)} original")
|
||||
return cleaned
|
||||
|
||||
def _filter_passages_by_length(self, passages: List[str]) -> List[str]:
|
||||
"""Filter passages by token length"""
|
||||
if not self.tokenizer:
|
||||
return passages
|
||||
|
||||
filtered = []
|
||||
|
||||
for passage in passages:
|
||||
try:
|
||||
encoding = self.tokenizer(passage, add_special_tokens=False) # type: ignore[operator]
|
||||
input_ids = encoding['input_ids']
|
||||
|
||||
if len(input_ids) < self.min_passage_length:
|
||||
continue
|
||||
|
||||
if len(input_ids) > self.max_passage_length:
|
||||
# Truncate to max length
|
||||
input_ids = input_ids[:self.max_passage_length]
|
||||
passage = self.tokenizer.decode(input_ids, skip_special_tokens=True) # type: ignore[attr-defined]
|
||||
|
||||
filtered.append(passage)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing passage: {e}")
|
||||
continue
|
||||
|
||||
return filtered
|
||||
|
||||
def _split_data(
|
||||
self,
|
||||
data: List[Dict],
|
||||
validation_split: float
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Split data into train/validation sets"""
|
||||
# Shuffle data
|
||||
shuffled_data = data.copy()
|
||||
random.shuffle(shuffled_data)
|
||||
|
||||
# Calculate split point
|
||||
split_idx = int(len(shuffled_data) * (1 - validation_split))
|
||||
|
||||
train_data = shuffled_data[:split_idx]
|
||||
val_data = shuffled_data[split_idx:]
|
||||
|
||||
return train_data, val_data
|
||||
|
||||
def _save_jsonl(self, data: List[Dict], output_path: str):
|
||||
"""Save data in JSONL format"""
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving JSONL file {output_path}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class DataAugmenter:
|
||||
"""
|
||||
Data augmentation strategies for retrieval training
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: Optional[AutoTokenizer] = None):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def augment_with_paraphrases(
|
||||
self,
|
||||
data: List[Dict],
|
||||
num_paraphrases: int = 2
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Augment queries with paraphrases using rule-based approach
|
||||
"""
|
||||
augmented = []
|
||||
|
||||
for item in tqdm(data, desc="Augmenting with paraphrases"):
|
||||
# Add original
|
||||
augmented.append(item)
|
||||
|
||||
# Generate paraphrases
|
||||
for i in range(num_paraphrases):
|
||||
para_item = item.copy()
|
||||
para_query = self._generate_paraphrase(item['query'], i)
|
||||
if para_query != item['query']:
|
||||
para_item['query'] = para_query
|
||||
para_item['is_augmented'] = True
|
||||
para_item['augmentation_type'] = 'paraphrase'
|
||||
augmented.append(para_item)
|
||||
|
||||
return augmented
|
||||
|
||||
def _generate_paraphrase(self, query: str, variant: int) -> str:
|
||||
"""Generate paraphrase using rule-based patterns"""
|
||||
if self._is_chinese(query):
|
||||
# Chinese paraphrase patterns
|
||||
patterns = [
|
||||
("什么是", "请解释一下"),
|
||||
("如何", "怎样才能"),
|
||||
("为什么", "什么原因导致"),
|
||||
("在哪里", "什么地方"),
|
||||
("是什么", "指的是什么"),
|
||||
("有什么", "存在哪些"),
|
||||
("怎么办", "如何处理"),
|
||||
("多少", "数量是多少")
|
||||
]
|
||||
|
||||
# Apply patterns based on variant
|
||||
if variant == 0:
|
||||
for original, replacement in patterns:
|
||||
if original in query:
|
||||
return query.replace(original, replacement, 1)
|
||||
else:
|
||||
# Add question prefixes
|
||||
prefixes = ["请问", "我想知道", "能否告诉我", "想了解"]
|
||||
return f"{prefixes[variant % len(prefixes)]}{query}"
|
||||
else:
|
||||
# English paraphrase patterns
|
||||
patterns = [
|
||||
("what is", "what does"),
|
||||
("how to", "how can I"),
|
||||
("why does", "what causes"),
|
||||
("where is", "what location"),
|
||||
("when did", "at what time")
|
||||
]
|
||||
|
||||
query_lower = query.lower()
|
||||
for original, replacement in patterns:
|
||||
if original in query_lower:
|
||||
return query_lower.replace(original, replacement, 1)
|
||||
|
||||
# Add question prefixes
|
||||
prefixes = ["Please explain", "Could you tell me", "I need to know", "Help me understand"]
|
||||
return f"{prefixes[variant % len(prefixes)]} {query.lower()}"
|
||||
|
||||
return query
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters"""
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
254
docs/data_formats.md
Normal file
254
docs/data_formats.md
Normal file
@ -0,0 +1,254 @@
|
||||
# BGE Data Pipeline Format Specifications
|
||||
|
||||
The refactored BGE data pipeline supports four distinct input formats with automatic detection and conversion:
|
||||
|
||||
## 📊 Format Overview
|
||||
|
||||
| Format | Use Case | Detection Key | Output |
|
||||
|--------|----------|---------------|---------|
|
||||
| **Triplets** | BGE-M3 embedding training | `pos` + `neg` | Contrastive learning batches |
|
||||
| **Pairs** | BGE-reranker training | `passage` + `label` | Cross-encoder pairs |
|
||||
| **Candidates** | Unified threshold conversion | `candidates` | Auto-exploded pairs |
|
||||
| **Legacy Nested** | Backward compatibility | `pos` only | Converted to pairs |
|
||||
|
||||
---
|
||||
|
||||
## 1. 🎯 Triplets Format (BGE-M3 Embedding)
|
||||
|
||||
For contrastive learning with multiple positives and negatives per query.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": [
|
||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"ML algorithms use statistical techniques to learn patterns from data and make predictions."
|
||||
],
|
||||
"neg": [
|
||||
"Weather forecasting uses meteorological data to predict atmospheric conditions.",
|
||||
"Cooking recipes require specific ingredients and step-by-step instructions."
|
||||
],
|
||||
"pos_scores": [0.95, 0.88],
|
||||
"neg_scores": [0.15, 0.08],
|
||||
"prompt": "为此查询生成表示:",
|
||||
"type": "definition"
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `pos`: List of positive passages
|
||||
|
||||
### Optional Fields
|
||||
- `neg`: List of negative passages
|
||||
- `pos_scores`: Relevance scores for positives
|
||||
- `neg_scores`: Relevance scores for negatives
|
||||
- `prompt`: Query instruction prefix
|
||||
- `type`: Query type classification
|
||||
|
||||
---
|
||||
|
||||
## 2. 🔍 Pairs Format (BGE-Reranker)
|
||||
|
||||
For cross-encoder training with individual query-passage pairs.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `passage`: Passage text
|
||||
- `label`: Relevance label (0 or 1)
|
||||
|
||||
### Optional Fields
|
||||
- `score`: Relevance score (0.0-1.0)
|
||||
- `qid`: Query identifier
|
||||
- `pid`: Passage identifier
|
||||
|
||||
---
|
||||
|
||||
## 3. 🎲 Candidates Format (Unified)
|
||||
|
||||
For datasets with multiple candidates per query. Each candidate is automatically exploded into a separate pair based on score threshold.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is artificial intelligence?",
|
||||
"qid": "q1",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "Artificial intelligence is the simulation of human intelligence in machines.",
|
||||
"score": 0.94,
|
||||
"pid": "c1"
|
||||
},
|
||||
{
|
||||
"text": "AI systems can perform tasks that typically require human intelligence.",
|
||||
"score": 0.87,
|
||||
"pid": "c2"
|
||||
},
|
||||
{
|
||||
"text": "Mountain climbing requires proper equipment and training.",
|
||||
"score": 0.23,
|
||||
"pid": "c3"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `candidates`: List of candidate objects
|
||||
|
||||
### Candidate Object Fields
|
||||
- `text`: Candidate passage text
|
||||
- `score`: Relevance score (optional, defaults to 1.0)
|
||||
- `pid`: Passage identifier (optional, auto-generated)
|
||||
|
||||
### Processing Logic
|
||||
- **Threshold Assignment**: Candidates with `score >= threshold` become `label=1`, others become `label=0`
|
||||
- **Default Threshold**: 0.5 (configurable)
|
||||
- **Output**: Multiple pairs, one per candidate
|
||||
|
||||
---
|
||||
|
||||
## 4. 🔄 Legacy Nested Format (Backward Compatibility)
|
||||
|
||||
Legacy format automatically converted to flat pairs for reranker training.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is natural language processing?",
|
||||
"pos": [
|
||||
"Natural language processing is a field of AI that helps computers understand human language.",
|
||||
"NLP combines computational linguistics with machine learning to process text and speech."
|
||||
],
|
||||
"neg": [
|
||||
"Automobile manufacturing involves assembly lines and quality control processes.",
|
||||
"Photography captures light through camera lenses to create images."
|
||||
],
|
||||
"pos_scores": [0.93, 0.86],
|
||||
"neg_scores": [0.14, 0.22]
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `pos`: List of positive passages
|
||||
|
||||
### Optional Fields
|
||||
- `neg`: List of negative passages
|
||||
- `pos_scores`: Scores for positives
|
||||
- `neg_scores`: Scores for negatives
|
||||
|
||||
### Processing Logic
|
||||
- Each `pos[i]` becomes `(query, passage, label=1)`
|
||||
- Each `neg[j]` becomes `(query, passage, label=0)`
|
||||
- Maintains original scores
|
||||
- Tracks source as `legacy_nested`
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Usage Examples
|
||||
|
||||
### Fine-tuning Data Loader
|
||||
|
||||
```python
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset, BGEDataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Triplets for BGE-M3
|
||||
triplets_dataset = BGEM3Dataset(
|
||||
data_path="triplets_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
format_type="triplets"
|
||||
)
|
||||
|
||||
# Pairs for BGE-reranker
|
||||
pairs_dataset = BGERerankerDataset(
|
||||
data_path="pairs_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
format_type="pairs"
|
||||
)
|
||||
|
||||
# Candidates with threshold conversion
|
||||
candidates_dataset = BGEDataset(
|
||||
data_path="candidates_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
format_type="candidates",
|
||||
candidates_threshold=0.6 # Custom threshold
|
||||
)
|
||||
|
||||
# Legacy with automatic conversion
|
||||
legacy_dataset = BGERerankerDataset(
|
||||
data_path="legacy_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
legacy_support=True
|
||||
)
|
||||
```
|
||||
|
||||
### Optimization Pipeline
|
||||
|
||||
```bash
|
||||
# BGE-M3 optimization (triplets → cached triplets)
|
||||
python -m data.optimization bge-m3 triplets_data.jsonl \
|
||||
--output_dir ./cache/data \
|
||||
--hard_negative_ratio 0.8 \
|
||||
--difficulty_threshold 0.2
|
||||
|
||||
# BGE-reranker optimization (pairs → cached pairs)
|
||||
python -m data.optimization bge-reranker pairs_data.jsonl \
|
||||
--output_dir ./cache/data \
|
||||
--output_format flat
|
||||
|
||||
# Candidates optimization (candidates → cached pairs)
|
||||
python -m data.optimization bge-reranker candidates_data.jsonl \
|
||||
--output_dir ./cache/data
|
||||
|
||||
# Legacy optimization (nested → cached pairs)
|
||||
python -m data.optimization bge-reranker legacy_data.jsonl \
|
||||
--output_dir ./cache/data \
|
||||
--output_format nested
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚡ Performance Features
|
||||
|
||||
### Enhanced Processing
|
||||
- **10-15x faster loading** through pre-tokenization
|
||||
- **Automatic format detection** based on key presence
|
||||
- **Score-weighted sampling** for better training quality
|
||||
- **Hard negative mining** with difficulty thresholds
|
||||
|
||||
### Conversion Utilities
|
||||
- **Candidates → Pairs**: Threshold-based label assignment
|
||||
- **Legacy → Pairs**: Automatic nested-to-flat conversion
|
||||
- **Pairs → Triplets**: Grouping by query for embedding training
|
||||
- **Format validation**: Comprehensive error checking
|
||||
|
||||
### Backward Compatibility
|
||||
- **Legacy support**: Full compatibility with existing datasets
|
||||
- **Gradual migration**: Convert formats incrementally
|
||||
- **Source tracking**: Know the original format of converted data
|
||||
- **Flexible thresholds**: Configurable conversion parameters
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Key Benefits
|
||||
|
||||
1. **Unified Pipeline**: Single codebase handles all formats
|
||||
2. **Automatic Detection**: No manual format specification needed
|
||||
3. **Seamless Conversion**: Transparent format transformations
|
||||
4. **Performance Optimized**: Significant speedup through caching
|
||||
5. **Backward Compatible**: Works with existing datasets
|
||||
6. **Flexible Configuration**: Customizable thresholds and parameters
|
||||
159
docs/directory_structure.md
Normal file
159
docs/directory_structure.md
Normal file
@ -0,0 +1,159 @@
|
||||
# BGE Fine-tuning Directory Structure
|
||||
|
||||
This document clarifies the directory structure and cache usage in the BGE fine-tuning project.
|
||||
|
||||
## 📂 Project Directory Structure
|
||||
|
||||
```
|
||||
bge_finetune/
|
||||
├── 📁 cache/ # Cache directories
|
||||
│ ├── 📁 data/ # Preprocessed dataset cache
|
||||
│ │ ├── 📄 *.jsonl # Preprocessed JSONL files
|
||||
│ │ └── 📄 *.json # Preprocessed JSON files
|
||||
│ └── 📁 models/ # Downloaded model cache (HuggingFace/ModelScope)
|
||||
│ ├── 📁 tokenizers/ # Cached tokenizer files
|
||||
│ ├── 📁 config/ # Cached model config files
|
||||
│ └── 📁 downloads/ # Temporary download files
|
||||
│
|
||||
├── 📁 models/ # Local model storage
|
||||
│ ├── 📁 bge-m3/ # Local BGE-M3 model files
|
||||
│ └── 📁 bge-reranker-base/ # Local BGE-reranker model files
|
||||
│
|
||||
├── 📁 data/ # Raw training data
|
||||
│ ├── 📄 train.jsonl # Input training data
|
||||
│ ├── 📄 eval.jsonl # Input evaluation data
|
||||
│ └── 📁 processed/ # Preprocessed (but not cached) data
|
||||
│
|
||||
├── 📁 output/ # Training outputs
|
||||
│ ├── 📁 bge-m3/ # M3 training results
|
||||
│ ├── 📁 bge-reranker/ # Reranker training results
|
||||
│ └── 📁 checkpoints/ # Model checkpoints
|
||||
│
|
||||
└── 📁 logs/ # Training logs
|
||||
├── 📁 tensorboard/ # TensorBoard logs
|
||||
└── 📄 training.log # Text logs
|
||||
```
|
||||
|
||||
## 🎯 Directory Purposes
|
||||
|
||||
### 1. **`./cache/models/`** - Model Downloads Cache
|
||||
- **Purpose**: HuggingFace/ModelScope model download cache
|
||||
- **Contents**: Tokenizers, config files, temporary downloads
|
||||
- **Usage**: Automatic caching of remote model downloads
|
||||
- **Config**: `model_paths.cache_dir = "./cache/models"`
|
||||
|
||||
### 2. **`./cache/data/`** - Processed Dataset Cache
|
||||
- **Purpose**: Preprocessed and cleaned dataset files
|
||||
- **Contents**: Validated JSONL/JSON files
|
||||
- **Usage**: Store cleaned and validated training data
|
||||
- **Config**: `data.cache_dir = "./cache/data"`
|
||||
|
||||
### 3. **`./models/`** - Local Model Storage
|
||||
- **Purpose**: Complete local model files (when available)
|
||||
- **Contents**: Full model directories with all files
|
||||
- **Usage**: Direct loading without downloads
|
||||
- **Config**: `model_paths.bge_m3 = "./models/bge-m3"`
|
||||
|
||||
### 4. **`./data/`** - Raw Input Data
|
||||
- **Purpose**: Original training/evaluation datasets
|
||||
- **Contents**: `.jsonl` files in various formats
|
||||
- **Usage**: Input to preprocessing and optimization
|
||||
- **Config**: User-specified paths in training scripts
|
||||
|
||||
### 5. **`./output/`** - Training Results
|
||||
- **Purpose**: Fine-tuned models and training artifacts
|
||||
- **Contents**: Saved models, checkpoints, metrics
|
||||
- **Usage**: Results of training scripts
|
||||
- **Config**: `--output_dir` in training scripts
|
||||
|
||||
## ⚙️ Configuration Mapping
|
||||
|
||||
### Config File Settings
|
||||
```toml
|
||||
[model_paths]
|
||||
cache_dir = "./cache/models" # Model download cache
|
||||
bge_m3 = "./models/bge-m3" # Local BGE-M3 path
|
||||
bge_reranker = "./models/bge-reranker-base" # Local reranker path
|
||||
|
||||
[data]
|
||||
cache_dir = "./cache/data" # Dataset cache
|
||||
optimization_output_dir = "./cache/data" # Optimization default
|
||||
```
|
||||
|
||||
### Script Defaults
|
||||
```bash
|
||||
# Data optimization (processed datasets)
|
||||
python -m data.optimization bge-m3 data/train.jsonl --output_dir ./cache/data
|
||||
|
||||
# Model training (output models)
|
||||
python scripts/train_m3.py --output_dir ./output/bge-m3 --cache_dir ./cache/models
|
||||
|
||||
# Model evaluation (results)
|
||||
python scripts/evaluate.py --output_dir ./evaluation_results
|
||||
```
|
||||
|
||||
## 🔄 Data Flow
|
||||
|
||||
```
|
||||
Raw Data → Preprocessing → Optimization → Training → Output
|
||||
data/ → - → cache/data/ → models → output/
|
||||
↓ ↓ ↓
|
||||
Input Pre-tokenized Model Fine-tuned
|
||||
JSONL → .pkl files → Loading → Models
|
||||
```
|
||||
|
||||
## 🛠️ Usage Examples
|
||||
|
||||
### Model Cache (Download Cache)
|
||||
```python
|
||||
# Automatic model caching
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"BAAI/bge-m3",
|
||||
cache_dir="./cache/models" # Downloads cached here
|
||||
)
|
||||
```
|
||||
|
||||
### Data Cache (Optimization Cache)
|
||||
```bash
|
||||
# Create optimized datasets
|
||||
python -m data.optimization bge-m3 data/train.jsonl \
|
||||
--output_dir ./cache/data # .pkl files created here
|
||||
|
||||
# Use cached datasets in training
|
||||
python scripts/train_m3.py \
|
||||
--train_data ./cache/data/train_m3_cached.pkl
|
||||
```
|
||||
|
||||
### Local Models (No Downloads)
|
||||
```python
|
||||
# Use local model (no downloads)
|
||||
model = BGEM3Model.from_pretrained("./models/bge-m3")
|
||||
```
|
||||
|
||||
## 📋 Best Practices
|
||||
|
||||
1. **Separate Concerns**: Keep model cache, data cache, and outputs separate
|
||||
2. **Clear Naming**: Use descriptive suffixes (`.pkl` for cache, `_cached` for optimized)
|
||||
3. **Consistent Paths**: Use config.toml for centralized path management
|
||||
4. **Cache Management**:
|
||||
- Model cache: Can be shared across projects
|
||||
- Data cache: Project-specific optimized datasets
|
||||
- Output: Training results and checkpoints
|
||||
|
||||
## 🧹 Cleanup Commands
|
||||
|
||||
```bash
|
||||
# Clean data cache (re-optimization needed)
|
||||
rm -rf ./cache/data/
|
||||
|
||||
# Clean model cache (re-download needed)
|
||||
rm -rf ./cache/models/
|
||||
|
||||
# Clean training outputs
|
||||
rm -rf ./output/
|
||||
|
||||
# Clean all caches
|
||||
rm -rf ./cache/
|
||||
```
|
||||
|
||||
This structure ensures clear separation of concerns and efficient caching for both model downloads and processed datasets.
|
||||
298
docs/usage_guide.md
Normal file
298
docs/usage_guide.md
Normal file
@ -0,0 +1,298 @@
|
||||
# BGE Fine-tuning Usage Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide provides detailed instructions on fine-tuning BGE (BAAI General Embedding) models using our enhanced training framework. The framework supports both BGE-M3 (embedding) and BGE-reranker (cross-encoder) models with state-of-the-art training techniques.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Installation](#installation)
|
||||
2. [Quick Start](#quick-start)
|
||||
3. [Data Preparation](#data-preparation)
|
||||
4. [Training Scripts](#training-scripts)
|
||||
5. [Advanced Usage](#advanced-usage)
|
||||
6. [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/yourusername/bge-finetune.git
|
||||
cd bge-finetune
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Optional: Install Ascend NPU support
|
||||
# pip install torch-npu # If using Huawei Ascend NPUs
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Prepare Your Data
|
||||
|
||||
The framework supports multiple data formats:
|
||||
- **JSONL** (recommended): One JSON object per line
|
||||
- **CSV/TSV**: Tabular data with headers
|
||||
- **JSON**: Single JSON file with array of samples
|
||||
|
||||
Example JSONL format for embedding model:
|
||||
```json
|
||||
{"query": "What is machine learning?", "pos": ["ML is a subset of AI..."], "neg": ["The weather today..."]}
|
||||
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning uses..."]}
|
||||
```
|
||||
|
||||
### 2. Fine-tune BGE-M3 (Embedding Model)
|
||||
|
||||
```bash
|
||||
python scripts/train_m3.py \
|
||||
--model_name_or_path BAAI/bge-m3 \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir ./output/bge-m3-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--learning_rate 1e-5
|
||||
```
|
||||
|
||||
### 3. Fine-tune BGE-Reranker
|
||||
|
||||
```bash
|
||||
python scripts/train_reranker.py \
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--train_data data/train_pairs.jsonl \
|
||||
--output_dir ./output/bge-reranker-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
### Data Formats
|
||||
|
||||
#### 1. Embedding Model (Triplets Format)
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "query text",
|
||||
"pos": ["positive passage 1", "positive passage 2"],
|
||||
"neg": ["negative passage 1", "negative passage 2"],
|
||||
"pos_scores": [1.0, 0.9], // Optional: relevance scores
|
||||
"neg_scores": [0.1, 0.2] // Optional: relevance scores
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Reranker Model (Pairs Format)
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "query text",
|
||||
"passage": "passage text",
|
||||
"label": 1, // 1 for relevant, 0 for not relevant
|
||||
"score": 0.95 // Optional: relevance score
|
||||
}
|
||||
```
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
For small datasets (< 100k samples), enable in-memory caching:
|
||||
|
||||
```python
|
||||
# In your custom script
|
||||
dataset = BGEM3Dataset(
|
||||
data_path="train.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
cache_in_memory=True, # Pre-tokenize and cache in memory
|
||||
max_cache_size=100000 # Maximum samples to cache
|
||||
)
|
||||
```
|
||||
|
||||
## Training Scripts
|
||||
|
||||
### BGE-M3 Training Options
|
||||
|
||||
```bash
|
||||
python scripts/train_m3.py \
|
||||
--model_name_or_path BAAI/bge-m3 \
|
||||
--train_data data/train.jsonl \
|
||||
--eval_data data/eval.jsonl \
|
||||
--output_dir ./output/bge-m3-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--per_device_eval_batch_size 32 \
|
||||
--learning_rate 1e-5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--fp16 \ # Enable mixed precision training
|
||||
--train_group_size 8 \ # Number of passages per query
|
||||
--use_hard_negatives \ # Enable hard negative mining
|
||||
--temperature 0.02 \ # Contrastive loss temperature
|
||||
--save_steps 500 \
|
||||
--eval_steps 500 \
|
||||
--logging_steps 100
|
||||
```
|
||||
|
||||
### BGE-Reranker Training Options
|
||||
|
||||
```bash
|
||||
python scripts/train_reranker.py \
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--train_data data/train_pairs.jsonl \
|
||||
--eval_data data/eval_pairs.jsonl \
|
||||
--output_dir ./output/bge-reranker-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--max_length 512 \
|
||||
--train_group_size 16 \ # Pairs per query
|
||||
--gradient_checkpointing \ # Save memory
|
||||
--logging_steps 50
|
||||
```
|
||||
|
||||
### Joint Training (RocketQAv2 Approach)
|
||||
|
||||
Train retriever and reranker together:
|
||||
|
||||
```bash
|
||||
python scripts/train_joint.py \
|
||||
--retriever_model BAAI/bge-m3 \
|
||||
--reranker_model BAAI/bge-reranker-base \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir ./output/joint-training \
|
||||
--num_train_epochs 3 \
|
||||
--alternate_steps 100 # Switch between models every N steps
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Configuration Files
|
||||
|
||||
Use TOML configuration for complex setups:
|
||||
|
||||
```toml
|
||||
# config.toml
|
||||
[training]
|
||||
default_batch_size = 16
|
||||
default_num_epochs = 3
|
||||
gradient_accumulation_steps = 2
|
||||
|
||||
[m3]
|
||||
model_name_or_path = "BAAI/bge-m3"
|
||||
train_group_size = 8
|
||||
temperature = 0.02
|
||||
|
||||
[hardware]
|
||||
device_type = "cuda" # or "npu" for Ascend
|
||||
```
|
||||
|
||||
Run with config:
|
||||
```bash
|
||||
python scripts/train_m3.py --config_path config.toml --train_data data/train.jsonl
|
||||
```
|
||||
|
||||
### Multi-GPU Training
|
||||
|
||||
```bash
|
||||
# Single node, multiple GPUs
|
||||
torchrun --nproc_per_node=4 scripts/train_m3.py \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir ./output/bge-m3-multigpu
|
||||
|
||||
# Multiple nodes
|
||||
torchrun --nnodes=2 --nproc_per_node=4 \
|
||||
--rdzv_id=100 --rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
scripts/train_m3.py --train_data data/train.jsonl
|
||||
```
|
||||
|
||||
### Custom Data Preprocessing
|
||||
|
||||
```python
|
||||
from data.preprocessing import DataPreprocessor
|
||||
|
||||
# Preprocess custom format data
|
||||
preprocessor = DataPreprocessor(
|
||||
tokenizer=tokenizer,
|
||||
max_query_length=64,
|
||||
max_passage_length=512
|
||||
)
|
||||
|
||||
# Convert and clean data
|
||||
output_path, val_path = preprocessor.preprocess_file(
|
||||
input_path="raw_data.csv",
|
||||
output_path="processed_data.jsonl",
|
||||
file_format="csv",
|
||||
validation_split=0.1
|
||||
)
|
||||
```
|
||||
|
||||
### Model Evaluation
|
||||
|
||||
```bash
|
||||
python scripts/evaluate.py \
|
||||
--retriever_model ./output/bge-m3-finetuned \
|
||||
--reranker_model ./output/bge-reranker-finetuned \
|
||||
--eval_data data/test.jsonl \
|
||||
--output_dir ./evaluation_results \
|
||||
--metrics ndcg mrr recall \
|
||||
--top_k 10 20 50
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Out of Memory**
|
||||
- Reduce batch size
|
||||
- Enable gradient checkpointing: `--gradient_checkpointing`
|
||||
- Use fp16 training: `--fp16`
|
||||
- Enable in-memory caching for small datasets
|
||||
|
||||
2. **Slow Training**
|
||||
- Increase number of data loader workers: `--dataloader_num_workers 8`
|
||||
- Enable in-memory caching for datasets < 100k samples
|
||||
- Use SSD for data storage
|
||||
|
||||
3. **Poor Performance**
|
||||
- Check data quality and format
|
||||
- Adjust learning rate and warmup
|
||||
- Use hard negative mining: `--use_hard_negatives`
|
||||
- Increase training epochs
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
```bash
|
||||
# Enable debug logging
|
||||
export LOG_LEVEL=DEBUG
|
||||
python scripts/train_m3.py ...
|
||||
|
||||
# Profile training
|
||||
python scripts/benchmark.py \
|
||||
--model_type retriever \
|
||||
--model_path ./output/bge-m3-finetuned \
|
||||
--data_path data/test.jsonl
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Data Quality**
|
||||
- Ensure balanced positive/negative samples
|
||||
- Use relevance scores when available
|
||||
- Clean and normalize text data
|
||||
|
||||
2. **Training Strategy**
|
||||
- Start with small learning rate (1e-5 to 2e-5)
|
||||
- Use warmup (10% of steps)
|
||||
- Monitor evaluation metrics
|
||||
|
||||
3. **Resource Optimization**
|
||||
- Use gradient accumulation for large batches
|
||||
- Enable mixed precision training
|
||||
- Consider in-memory caching for small datasets
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Data Format Examples](data_formats.md)
|
||||
- [Model Architecture Details](../models/README.md)
|
||||
- [Evaluation Metrics Guide](../evaluation/README.md)
|
||||
- [API Reference](api_reference.md)
|
||||
432
docs/validation_guide.md
Normal file
432
docs/validation_guide.md
Normal file
@ -0,0 +1,432 @@
|
||||
# BGE Fine-tuning Validation Guide
|
||||
|
||||
This guide provides comprehensive instructions for validating your fine-tuned BGE models to ensure they actually perform better than the baseline models.
|
||||
|
||||
## 🎯 Overview
|
||||
|
||||
After fine-tuning BGE models, it's crucial to validate that your models have actually improved. This guide covers multiple validation approaches:
|
||||
|
||||
1. **Quick Validation** - Fast sanity checks
|
||||
2. **Comprehensive Validation** - Detailed performance analysis
|
||||
3. **Comparison Benchmarks** - Head-to-head baseline comparisons
|
||||
4. **Pipeline Validation** - End-to-end retrieval + reranking tests
|
||||
5. **Statistical Analysis** - Significance testing and detailed metrics
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Option 1: Run Full Validation Suite (Recommended)
|
||||
|
||||
The easiest way to validate your models is using the validation suite:
|
||||
|
||||
```bash
|
||||
# Auto-discover trained models and run complete validation
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Or specify models explicitly
|
||||
python scripts/run_validation_suite.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
```
|
||||
|
||||
This runs all validation tests and generates a comprehensive report.
|
||||
|
||||
### Option 2: Quick Validation Only
|
||||
|
||||
For a fast check if your models are working correctly:
|
||||
|
||||
```bash
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model \
|
||||
--test_pipeline
|
||||
```
|
||||
|
||||
## 📊 Validation Tools
|
||||
|
||||
### 1. Quick Validation (`quick_validation.py`)
|
||||
|
||||
**Purpose**: Fast sanity checks to verify models are working and show basic improvements.
|
||||
|
||||
**Features**:
|
||||
- Tests model loading and basic functionality
|
||||
- Uses predefined test queries in multiple languages
|
||||
- Compares relevance scoring between baseline and fine-tuned models
|
||||
- Tests pipeline performance (retrieval + reranking)
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Test both models
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model \
|
||||
--test_pipeline
|
||||
|
||||
# Test only retriever
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model
|
||||
|
||||
# Save results to file
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--output_file ./results.json
|
||||
```
|
||||
|
||||
**Output**: Console summary + optional JSON results file
|
||||
|
||||
### 2. Comprehensive Validation (`comprehensive_validation.py`)
|
||||
|
||||
**Purpose**: Detailed validation across multiple datasets with statistical analysis.
|
||||
|
||||
**Features**:
|
||||
- Tests on all available datasets automatically
|
||||
- Computes comprehensive metrics (Recall@K, Precision@K, MAP, MRR, NDCG)
|
||||
- Statistical significance testing
|
||||
- Performance degradation detection
|
||||
- Generates HTML reports with visualizations
|
||||
- Cross-dataset performance analysis
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Comprehensive validation with auto-discovered datasets
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_finetuned ./output/bge-reranker/final_model
|
||||
|
||||
# Use specific datasets
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--test_datasets data/datasets/examples/embedding_data.jsonl \
|
||||
data/datasets/Reranker_AFQMC/dev.json \
|
||||
--output_dir ./validation_results
|
||||
|
||||
# Custom evaluation settings
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--batch_size 16 \
|
||||
--k_values 1 3 5 10 20 \
|
||||
--retrieval_top_k 100 \
|
||||
--rerank_top_k 10
|
||||
```
|
||||
|
||||
**Output**:
|
||||
- JSON results file
|
||||
- HTML report with detailed analysis
|
||||
- Performance comparison tables
|
||||
- Statistical significance analysis
|
||||
|
||||
### 3. Model Comparison Scripts
|
||||
|
||||
**Purpose**: Head-to-head performance comparisons with baseline models.
|
||||
|
||||
#### Retriever Comparison (`compare_retriever.py`)
|
||||
|
||||
```bash
|
||||
python scripts/compare_retriever.py \
|
||||
--finetuned_model_path ./output/bge-m3-enhanced/final_model \
|
||||
--baseline_model_path BAAI/bge-m3 \
|
||||
--data_path data/datasets/examples/embedding_data.jsonl \
|
||||
--batch_size 16 \
|
||||
--max_samples 1000 \
|
||||
--output ./retriever_comparison.txt
|
||||
```
|
||||
|
||||
#### Reranker Comparison (`compare_reranker.py`)
|
||||
|
||||
```bash
|
||||
python scripts/compare_reranker.py \
|
||||
--finetuned_model_path ./output/bge-reranker/final_model \
|
||||
--baseline_model_path BAAI/bge-reranker-base \
|
||||
--data_path data/datasets/examples/reranker_data.jsonl \
|
||||
--batch_size 16 \
|
||||
--max_samples 1000 \
|
||||
--output ./reranker_comparison.txt
|
||||
```
|
||||
|
||||
**Output**: Tab-separated comparison tables with metrics and performance deltas.
|
||||
|
||||
### 4. Validation Suite Runner (`run_validation_suite.py`)
|
||||
|
||||
**Purpose**: Orchestrates all validation tests in a systematic way.
|
||||
|
||||
**Features**:
|
||||
- Auto-discovers trained models
|
||||
- Runs multiple validation approaches
|
||||
- Generates comprehensive reports
|
||||
- Provides final verdict on model quality
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Full auto-discovery and validation
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Specify models and validation types
|
||||
python scripts/run_validation_suite.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model \
|
||||
--validation_types quick comprehensive comparison
|
||||
|
||||
# Custom settings
|
||||
python scripts/run_validation_suite.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--batch_size 32 \
|
||||
--max_samples 500 \
|
||||
--output_dir ./my_validation_results
|
||||
```
|
||||
|
||||
**Output**:
|
||||
- Multiple result files from each validation type
|
||||
- HTML summary report
|
||||
- Final verdict and recommendations
|
||||
|
||||
### 5. Validation Utilities (`validation_utils.py`)
|
||||
|
||||
**Purpose**: Helper utilities for validation preparation and analysis.
|
||||
|
||||
**Features**:
|
||||
- Discover trained models in workspace
|
||||
- Analyze available test datasets
|
||||
- Create sample test data
|
||||
- Analyze validation results
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Discover trained models
|
||||
python scripts/validation_utils.py --discover-models
|
||||
|
||||
# Analyze available test datasets
|
||||
python scripts/validation_utils.py --analyze-datasets
|
||||
|
||||
# Create sample test data
|
||||
python scripts/validation_utils.py --create-sample-data --output-dir ./test_data
|
||||
|
||||
# Analyze validation results
|
||||
python scripts/validation_utils.py --analyze-results ./validation_results
|
||||
```
|
||||
|
||||
## 📁 Understanding Results
|
||||
|
||||
### Result Files
|
||||
|
||||
After running validation, you'll get several result files:
|
||||
|
||||
1. **`validation_suite_report.html`** - Main HTML report with visual summary
|
||||
2. **`validation_results.json`** - Complete detailed results in JSON format
|
||||
3. **`quick_validation_results.json`** - Quick validation results
|
||||
4. **`*_comparison_*.txt`** - Model comparison tables
|
||||
5. **Comprehensive validation folder** with detailed metrics
|
||||
|
||||
### Interpreting Results
|
||||
|
||||
#### Overall Verdict
|
||||
|
||||
- **🌟 EXCELLENT** - Significant improvements across most metrics
|
||||
- **✅ GOOD** - Clear improvements with minor degradations
|
||||
- **👌 FAIR** - Mixed results, some improvements
|
||||
- **⚠️ PARTIAL** - Some validation tests failed
|
||||
- **❌ POOR** - More degradations than improvements
|
||||
|
||||
#### Key Metrics to Watch
|
||||
|
||||
**For Retrievers**:
|
||||
- **Recall@K** - How many relevant documents are in top-K results
|
||||
- **MAP (Mean Average Precision)** - Overall ranking quality
|
||||
- **MRR (Mean Reciprocal Rank)** - Position of first relevant result
|
||||
- **NDCG@K** - Normalized ranking quality with graded relevance
|
||||
|
||||
**For Rerankers**:
|
||||
- **Accuracy** - Correct classification of relevant/irrelevant pairs
|
||||
- **MRR@K** - Mean reciprocal rank in top-K results
|
||||
- **NDCG@K** - Ranking quality with position discounting
|
||||
|
||||
**For Pipeline**:
|
||||
- **End-to-End Recall@K** - Overall system performance
|
||||
- **Pipeline Accuracy** - Correct top results after retrieval + reranking
|
||||
|
||||
#### Statistical Significance
|
||||
|
||||
The comprehensive validation includes statistical tests to determine if improvements are significant:
|
||||
|
||||
- **Improvement %** - Percentage improvement over baseline
|
||||
- **Statistical Significance** - Whether improvements are statistically meaningful
|
||||
- **Effect Size** - Magnitude of the improvement
|
||||
|
||||
## 🔧 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### 1. Models Not Found
|
||||
```
|
||||
❌ Retriever model not found: ./output/bge-m3-enhanced/final_model
|
||||
```
|
||||
**Solution**:
|
||||
- Check that training completed successfully
|
||||
- Verify the model path exists
|
||||
- Use `--discover-models` to find available models
|
||||
|
||||
#### 2. No Test Datasets Found
|
||||
```
|
||||
❌ No test datasets found!
|
||||
```
|
||||
**Solution**:
|
||||
- Place test data in `data/datasets/` directory
|
||||
- Use `--create-sample-data` to generate sample datasets
|
||||
- Specify datasets explicitly with `--test_datasets`
|
||||
|
||||
#### 3. Validation Failures
|
||||
```
|
||||
❌ Comprehensive validation failed
|
||||
```
|
||||
**Solution**:
|
||||
- Check model compatibility (tokenizer issues)
|
||||
- Reduce batch size if running out of memory
|
||||
- Verify dataset format is correct
|
||||
- Check logs for detailed error messages
|
||||
|
||||
#### 4. Poor Performance Results
|
||||
```
|
||||
⚠️ More degradations than improvements
|
||||
```
|
||||
**Analysis Steps**:
|
||||
1. Check training data quality and format
|
||||
2. Verify hyperparameters (learning rate might be too high)
|
||||
3. Ensure sufficient training data and epochs
|
||||
4. Check for data leakage or format mismatches
|
||||
5. Consider different base models
|
||||
|
||||
### Validation Best Practices
|
||||
|
||||
#### 1. Before Running Validation
|
||||
- ✅ Ensure training completed without errors
|
||||
- ✅ Verify model files exist and are complete
|
||||
- ✅ Check that test datasets are in correct format
|
||||
- ✅ Have baseline models available for comparison
|
||||
|
||||
#### 2. During Validation
|
||||
- 📊 Start with quick validation for fast feedback
|
||||
- 🔬 Run comprehensive validation on representative datasets
|
||||
- 📈 Use multiple metrics to get complete picture
|
||||
- ⏱️ Be patient - comprehensive validation takes time
|
||||
|
||||
#### 3. Analyzing Results
|
||||
- 🎯 Focus on metrics most relevant to your use case
|
||||
- 📊 Look at trends across multiple datasets
|
||||
- ⚖️ Consider statistical significance, not just raw improvements
|
||||
- 🔍 Investigate datasets where performance degraded
|
||||
|
||||
## 🎯 Validation Checklist
|
||||
|
||||
Use this checklist to ensure thorough validation:
|
||||
|
||||
### Pre-Validation Setup
|
||||
- [ ] Models trained and saved successfully
|
||||
- [ ] Test datasets available and formatted correctly
|
||||
- [ ] Baseline models accessible (BAAI/bge-m3, BAAI/bge-reranker-base)
|
||||
- [ ] Sufficient compute resources for validation
|
||||
|
||||
### Quick Validation (5-10 minutes)
|
||||
- [ ] Run quick validation script
|
||||
- [ ] Verify models load correctly
|
||||
- [ ] Check basic performance improvements
|
||||
- [ ] Test pipeline functionality (if both models available)
|
||||
|
||||
### Comprehensive Validation (30-60 minutes)
|
||||
- [ ] Run comprehensive validation script
|
||||
- [ ] Test on multiple datasets
|
||||
- [ ] Analyze statistical significance
|
||||
- [ ] Review HTML report for detailed insights
|
||||
|
||||
### Comparison Benchmarks (15-30 minutes)
|
||||
- [ ] Run retriever comparison (if applicable)
|
||||
- [ ] Run reranker comparison (if applicable)
|
||||
- [ ] Analyze performance deltas
|
||||
- [ ] Check throughput and memory usage
|
||||
|
||||
### Final Analysis
|
||||
- [ ] Review overall verdict and recommendations
|
||||
- [ ] Identify best and worst performing datasets
|
||||
- [ ] Check for any critical issues or degradations
|
||||
- [ ] Document findings and next steps
|
||||
|
||||
## 🚀 Advanced Usage
|
||||
|
||||
### Custom Test Datasets
|
||||
|
||||
Create your own test datasets for domain-specific validation:
|
||||
|
||||
```jsonl
|
||||
// For retriever testing (embedding_format)
|
||||
{"query": "Your query", "pos": ["relevant doc 1", "relevant doc 2"], "neg": ["irrelevant doc 1"]}
|
||||
|
||||
// For reranker testing (pairs format)
|
||||
{"query": "Your query", "passage": "Document text", "label": 1}
|
||||
{"query": "Your query", "passage": "Irrelevant text", "label": 0}
|
||||
```
|
||||
|
||||
### Cross-Domain Validation
|
||||
|
||||
Test model generalization across different domains:
|
||||
|
||||
```bash
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--test_datasets data/medical_qa.jsonl data/legal_docs.jsonl data/tech_support.jsonl
|
||||
```
|
||||
|
||||
### Performance Profiling
|
||||
|
||||
Monitor resource usage during validation:
|
||||
|
||||
```bash
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--batch_size 8 \ # Reduce if memory issues
|
||||
--max_samples 100 # Limit samples for speed
|
||||
```
|
||||
|
||||
### A/B Testing Multiple Models
|
||||
|
||||
Compare different fine-tuning approaches:
|
||||
|
||||
```bash
|
||||
# Test Model A
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/model_A/final_model \
|
||||
--output_dir ./validation_A
|
||||
|
||||
# Test Model B
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/model_B/final_model \
|
||||
--output_dir ./validation_B
|
||||
|
||||
# Compare results
|
||||
python scripts/validation_utils.py --analyze-results ./validation_A
|
||||
python scripts/validation_utils.py --analyze-results ./validation_B
|
||||
```
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- **BGE Model Documentation**: [BAAI/bge](https://github.com/FlagOpen/FlagEmbedding)
|
||||
- **Evaluation Metrics Guide**: See `evaluation/metrics.py` for detailed metric implementations
|
||||
- **Dataset Format Guide**: See `data/datasets/examples/format_usage_examples.py`
|
||||
- **Training Guide**: See main README for training instructions
|
||||
|
||||
## 💡 Tips for Better Validation
|
||||
|
||||
1. **Use Multiple Datasets**: Don't rely on a single test set
|
||||
2. **Check Statistical Significance**: Small improvements might not be meaningful
|
||||
3. **Monitor Resource Usage**: Ensure models are efficient enough for production
|
||||
4. **Validate Domain Generalization**: Test on data similar to production use case
|
||||
5. **Document Results**: Keep validation reports for future reference
|
||||
6. **Iterate Based on Results**: Use validation insights to improve training
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
If you encounter issues with validation:
|
||||
|
||||
1. Check the troubleshooting section above
|
||||
2. Review the logs for detailed error messages
|
||||
3. Use `--discover-models` and `--analyze-datasets` to debug setup issues
|
||||
4. Start with quick validation to isolate problems
|
||||
5. Check model and dataset formats are correct
|
||||
|
||||
Remember: Good validation is key to successful fine-tuning. Take time to thoroughly test your models before deploying them!
|
||||
20
evaluation/__init__.py
Normal file
20
evaluation/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""
|
||||
Evaluation modules and metrics
|
||||
"""
|
||||
from .evaluator import RetrievalEvaluator, InteractiveEvaluator
|
||||
from .metrics import (
|
||||
compute_retrieval_metrics,
|
||||
compute_reranker_metrics,
|
||||
evaluate_cross_lingual_retrieval,
|
||||
MetricsTracker
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'RetrievalEvaluator',
|
||||
'InteractiveEvaluator',
|
||||
'compute_retrieval_metrics',
|
||||
'compute_reranker_metrics',
|
||||
'evaluate_cross_lingual_retrieval',
|
||||
'MetricsTracker'
|
||||
]
|
||||
511
evaluation/evaluator.py
Normal file
511
evaluation/evaluator.py
Normal file
@ -0,0 +1,511 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from torch.utils.data import DataLoader
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from .metrics import (
|
||||
compute_retrieval_metrics,
|
||||
compute_reranker_metrics,
|
||||
evaluate_cross_lingual_retrieval,
|
||||
MetricsTracker
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalEvaluator:
|
||||
"""
|
||||
Comprehensive evaluator for retrieval models
|
||||
|
||||
Notes:
|
||||
- All parameters (model, tokenizer, device, metrics, k_values, etc.) should be passed from config-driven scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[BGEM3Model, BGERerankerModel],
|
||||
tokenizer,
|
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
|
||||
metrics: List[str] = ['recall', 'precision', 'map', 'mrr', 'ndcg'],
|
||||
k_values: List[int] = [1, 5, 10, 20, 100]
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.metrics = metrics
|
||||
self.k_values = k_values
|
||||
|
||||
# Move model to device
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Metrics tracker
|
||||
self.tracker = MetricsTracker()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
desc: str = "Evaluating",
|
||||
return_embeddings: bool = False
|
||||
) -> Dict[str, Union[float, list, np.ndarray]]:
|
||||
"""
|
||||
Evaluate model on dataset
|
||||
|
||||
Args:
|
||||
dataloader: DataLoader with evaluation data
|
||||
desc: Description for progress bar
|
||||
return_embeddings: Whether to return computed embeddings
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
self.model.eval()
|
||||
|
||||
all_query_embeddings = []
|
||||
all_passage_embeddings = []
|
||||
all_labels = []
|
||||
all_queries = []
|
||||
all_passages = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc=desc):
|
||||
# Move batch to device
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Get embeddings based on model type
|
||||
if isinstance(self.model, BGEM3Model):
|
||||
# Bi-encoder evaluation
|
||||
query_embeddings, passage_embeddings = self._evaluate_biencoder_batch(batch)
|
||||
|
||||
all_query_embeddings.append(query_embeddings.cpu())
|
||||
all_passage_embeddings.append(passage_embeddings.cpu())
|
||||
|
||||
else:
|
||||
# Cross-encoder evaluation
|
||||
scores = self._evaluate_crossencoder_batch(batch)
|
||||
# For cross-encoder, we store scores instead of embeddings
|
||||
all_query_embeddings.append(scores.cpu())
|
||||
|
||||
# Collect labels
|
||||
if 'labels' in batch:
|
||||
all_labels.append(batch['labels'].cpu())
|
||||
|
||||
# Collect text if available
|
||||
if 'query' in batch:
|
||||
all_queries.extend(batch['query'])
|
||||
if 'passages' in batch:
|
||||
all_passages.extend(batch['passages'])
|
||||
|
||||
# Concatenate results
|
||||
if isinstance(self.model, BGEM3Model):
|
||||
all_query_embeddings = torch.cat(all_query_embeddings, dim=0).numpy()
|
||||
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0).numpy()
|
||||
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
|
||||
|
||||
# Compute retrieval metrics
|
||||
if all_labels is not None:
|
||||
metrics = compute_retrieval_metrics(
|
||||
all_query_embeddings,
|
||||
all_passage_embeddings,
|
||||
all_labels,
|
||||
k_values=self.k_values
|
||||
)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
else:
|
||||
# Cross-encoder metrics
|
||||
all_scores = torch.cat(all_query_embeddings, dim=0).numpy()
|
||||
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
|
||||
|
||||
if all_labels is not None:
|
||||
# Determine if we have groups (listwise) or pairs
|
||||
if hasattr(dataloader.dataset, 'train_group_size'):
|
||||
groups = [getattr(dataloader.dataset, 'train_group_size')] * len(dataloader.dataset) # type: ignore[arg-type]
|
||||
else:
|
||||
groups = None
|
||||
|
||||
metrics = compute_reranker_metrics(
|
||||
all_scores,
|
||||
all_labels,
|
||||
groups=groups,
|
||||
k_values=self.k_values
|
||||
)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
# Update tracker
|
||||
self.tracker.update(metrics, step=0)
|
||||
|
||||
# Return embeddings if requested
|
||||
if return_embeddings and isinstance(self.model, BGEM3Model):
|
||||
metrics['query_embeddings'] = all_query_embeddings # type: ignore[assignment]
|
||||
metrics['passage_embeddings'] = all_passage_embeddings # type: ignore[assignment]
|
||||
metrics['queries'] = all_queries # type: ignore[assignment]
|
||||
metrics['passages'] = all_passages # type: ignore[assignment]
|
||||
|
||||
return metrics # type: ignore[return-value]
|
||||
|
||||
def _evaluate_biencoder_batch(self, batch: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Evaluate a batch for bi-encoder model"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
query_embeddings = query_outputs['dense']
|
||||
|
||||
# Get passage embeddings
|
||||
# Handle multiple passages per query
|
||||
if batch['passage_input_ids'].dim() == 3:
|
||||
# Flatten batch dimension
|
||||
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
|
||||
passage_input_ids = batch['passage_input_ids'].view(-1, seq_len)
|
||||
passage_attention_mask = batch['passage_attention_mask'].view(-1, seq_len)
|
||||
else:
|
||||
passage_input_ids = batch['passage_input_ids']
|
||||
passage_attention_mask = batch['passage_attention_mask']
|
||||
|
||||
passage_outputs = self.model(
|
||||
input_ids=passage_input_ids,
|
||||
attention_mask=passage_attention_mask,
|
||||
return_dense=True
|
||||
)
|
||||
passage_embeddings = passage_outputs['dense']
|
||||
|
||||
return query_embeddings, passage_embeddings
|
||||
|
||||
def _evaluate_crossencoder_batch(self, batch: Dict) -> torch.Tensor:
|
||||
"""Evaluate a batch for cross-encoder model"""
|
||||
outputs = self.model(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
return outputs['logits']
|
||||
|
||||
def evaluate_retrieval_pipeline(
|
||||
self,
|
||||
queries: List[str],
|
||||
corpus: List[str],
|
||||
relevant_docs: Dict[str, List[int]],
|
||||
batch_size: int = 32
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate full retrieval pipeline
|
||||
|
||||
Args:
|
||||
queries: List of queries
|
||||
corpus: List of all documents
|
||||
relevant_docs: Dict mapping query to relevant doc indices
|
||||
batch_size: Batch size for encoding
|
||||
|
||||
Returns:
|
||||
Evaluation metrics
|
||||
"""
|
||||
logger.info(f"Evaluating retrieval pipeline with {len(queries)} queries over {len(corpus)} documents")
|
||||
|
||||
# Encode all documents
|
||||
logger.info("Encoding corpus...")
|
||||
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
|
||||
|
||||
# Encode queries and evaluate
|
||||
all_metrics = defaultdict(list)
|
||||
|
||||
for i in tqdm(range(0, len(queries), batch_size), desc="Evaluating queries"):
|
||||
batch_queries = queries[i:i + batch_size]
|
||||
batch_relevant = [relevant_docs.get(q, []) for q in batch_queries]
|
||||
|
||||
# Encode queries
|
||||
query_embeddings = self._encode_texts(batch_queries, batch_size, show_progress=False)
|
||||
|
||||
# Create labels matrix
|
||||
labels = np.zeros((len(batch_queries), len(corpus)))
|
||||
for j, relevant_indices in enumerate(batch_relevant):
|
||||
for idx in relevant_indices:
|
||||
if idx < len(corpus):
|
||||
labels[j, idx] = 1
|
||||
|
||||
# Compute metrics
|
||||
batch_metrics = compute_retrieval_metrics(
|
||||
query_embeddings,
|
||||
corpus_embeddings,
|
||||
labels,
|
||||
k_values=self.k_values
|
||||
)
|
||||
|
||||
# Accumulate metrics
|
||||
for metric, value in batch_metrics.items():
|
||||
if metric != 'scores':
|
||||
all_metrics[metric].append(value)
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {
|
||||
metric: np.mean(values) for metric, values in all_metrics.items()
|
||||
}
|
||||
|
||||
return final_metrics # type: ignore[return-value]
|
||||
|
||||
def _encode_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int,
|
||||
desc: str = "Encoding",
|
||||
show_progress: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Encode texts to embeddings"""
|
||||
if isinstance(self.model, BGEM3Model):
|
||||
# Use model's encode method
|
||||
embeddings = self.model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
device=self.device
|
||||
)
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
else:
|
||||
# For cross-encoder, this doesn't make sense
|
||||
raise ValueError("Cannot encode texts with cross-encoder model")
|
||||
|
||||
return embeddings # type: ignore[return-value]
|
||||
|
||||
def evaluate_cross_lingual(
|
||||
self,
|
||||
queries: Dict[str, List[str]],
|
||||
corpus: List[str],
|
||||
relevant_docs: Dict[str, Dict[str, List[int]]],
|
||||
batch_size: int = 32
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Evaluate cross-lingual retrieval
|
||||
|
||||
Args:
|
||||
queries: Dict mapping language to queries
|
||||
corpus: List of documents (assumed to be in source language)
|
||||
relevant_docs: Nested dict of language -> query -> relevant indices
|
||||
batch_size: Batch size
|
||||
|
||||
Returns:
|
||||
Metrics per language
|
||||
"""
|
||||
# Encode corpus once
|
||||
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
|
||||
|
||||
# Evaluate each language
|
||||
results = {}
|
||||
|
||||
for lang, lang_queries in queries.items():
|
||||
logger.info(f"Evaluating {lang} queries...")
|
||||
|
||||
# Encode queries
|
||||
query_embeddings = self._encode_texts(lang_queries, batch_size, desc=f"Encoding {lang} queries")
|
||||
|
||||
# Create labels
|
||||
labels = np.zeros((len(lang_queries), len(corpus)))
|
||||
for i, query in enumerate(lang_queries):
|
||||
if query in relevant_docs.get(lang, {}):
|
||||
for idx in relevant_docs[lang][query]:
|
||||
if idx < len(corpus):
|
||||
labels[i, idx] = 1
|
||||
|
||||
# Compute metrics
|
||||
metrics = compute_retrieval_metrics(
|
||||
query_embeddings,
|
||||
corpus_embeddings,
|
||||
labels,
|
||||
k_values=self.k_values
|
||||
)
|
||||
|
||||
results[lang] = metrics
|
||||
|
||||
# Compute average across languages
|
||||
if len(results) > 1:
|
||||
avg_metrics = {}
|
||||
for metric in results[list(results.keys())[0]]:
|
||||
if metric != 'scores':
|
||||
values = [results[lang][metric] for lang in results]
|
||||
avg_metrics[metric] = np.mean(values)
|
||||
results['average'] = avg_metrics
|
||||
|
||||
return results
|
||||
|
||||
def save_evaluation_results(
|
||||
self,
|
||||
metrics: Dict[str, float],
|
||||
output_dir: str,
|
||||
prefix: str = "eval"
|
||||
):
|
||||
"""Save evaluation results"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save metrics
|
||||
metrics_file = os.path.join(output_dir, f"{prefix}_metrics.json")
|
||||
with open(metrics_file, 'w') as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
# Save summary
|
||||
summary_file = os.path.join(output_dir, f"{prefix}_summary.txt")
|
||||
with open(summary_file, 'w') as f:
|
||||
f.write(f"Evaluation Results - {prefix}\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
for metric, value in sorted(metrics.items()):
|
||||
if isinstance(value, float):
|
||||
f.write(f"{metric:.<30} {value:.4f}\n")
|
||||
|
||||
logger.info(f"Saved evaluation results to {output_dir}")
|
||||
|
||||
|
||||
class InteractiveEvaluator:
|
||||
"""
|
||||
Interactive evaluation for qualitative analysis
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: BGEM3Model,
|
||||
reranker: Optional[BGERerankerModel],
|
||||
tokenizer,
|
||||
corpus: List[str],
|
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
):
|
||||
self.retriever = retriever
|
||||
self.reranker = reranker
|
||||
self.tokenizer = tokenizer
|
||||
self.corpus = corpus
|
||||
self.device = device
|
||||
|
||||
# Pre-encode corpus
|
||||
logger.info("Pre-encoding corpus...")
|
||||
self.corpus_embeddings = self._encode_corpus()
|
||||
logger.info(f"Encoded {len(self.corpus)} documents")
|
||||
|
||||
def _encode_corpus(self) -> np.ndarray:
|
||||
"""Pre-encode all corpus documents"""
|
||||
embeddings = self.retriever.encode(
|
||||
self.corpus,
|
||||
batch_size=32,
|
||||
device=self.device
|
||||
)
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
return embeddings # type: ignore[return-value]
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
rerank: bool = True
|
||||
) -> List[Tuple[int, float, str]]:
|
||||
"""
|
||||
Search for query in corpus
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
top_k: Number of results to return
|
||||
rerank: Whether to use reranker
|
||||
|
||||
Returns:
|
||||
List of (index, score, text) tuples
|
||||
"""
|
||||
# Encode query
|
||||
query_embedding = self.retriever.encode(
|
||||
[query],
|
||||
device=self.device
|
||||
)
|
||||
if isinstance(query_embedding, torch.Tensor):
|
||||
query_embedding = query_embedding.cpu().numpy()
|
||||
|
||||
# Compute similarities
|
||||
similarities = np.dot(query_embedding, self.corpus_embeddings.T).squeeze() # type: ignore[arg-type]
|
||||
|
||||
# Get top-k
|
||||
top_indices = np.argsort(similarities)[::-1][:top_k * 2 if rerank else top_k]
|
||||
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
results.append((int(idx), float(similarities[idx]), self.corpus[idx]))
|
||||
|
||||
# Rerank if requested
|
||||
if rerank and self.reranker:
|
||||
reranked_indices = self.reranker.rerank(
|
||||
query,
|
||||
[self.corpus[idx] for idx in top_indices],
|
||||
return_scores=False
|
||||
)
|
||||
|
||||
# Reorder results
|
||||
reranked_results = []
|
||||
for i, rerank_idx in enumerate(reranked_indices[:top_k]):
|
||||
original_idx = top_indices[rerank_idx] # type: ignore[index]
|
||||
reranked_results.append((
|
||||
int(original_idx),
|
||||
float(i + 1), # Rank as score
|
||||
self.corpus[original_idx]
|
||||
))
|
||||
results = reranked_results
|
||||
|
||||
return results[:top_k]
|
||||
|
||||
def evaluate_query_interactively(
|
||||
self,
|
||||
query: str,
|
||||
relevant_indices: List[int],
|
||||
top_k: int = 10
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a single query interactively
|
||||
|
||||
Args:
|
||||
query: Query to evaluate
|
||||
relevant_indices: Ground truth relevant document indices
|
||||
top_k: Number of results to consider
|
||||
|
||||
Returns:
|
||||
Metrics for this query
|
||||
"""
|
||||
# Get search results
|
||||
results = self.search(query, top_k=top_k, rerank=True)
|
||||
retrieved_indices = [r[0] for r in results]
|
||||
|
||||
# Calculate metrics
|
||||
metrics = {}
|
||||
|
||||
# Precision@k
|
||||
relevant_retrieved = sum(1 for idx in retrieved_indices if idx in relevant_indices)
|
||||
metrics[f'precision@{top_k}'] = relevant_retrieved / top_k
|
||||
|
||||
# Recall@k
|
||||
if relevant_indices:
|
||||
metrics[f'recall@{top_k}'] = relevant_retrieved / len(relevant_indices)
|
||||
else:
|
||||
metrics[f'recall@{top_k}'] = 0.0
|
||||
|
||||
# MRR
|
||||
for i, idx in enumerate(retrieved_indices):
|
||||
if idx in relevant_indices:
|
||||
metrics['mrr'] = 1.0 / (i + 1)
|
||||
break
|
||||
else:
|
||||
metrics['mrr'] = 0.0
|
||||
|
||||
# Print results for inspection
|
||||
print(f"\nQuery: {query}")
|
||||
print(f"Relevant docs: {relevant_indices}")
|
||||
print(f"\nTop {top_k} results:")
|
||||
for i, (idx, score, text) in enumerate(results):
|
||||
relevance = "✓" if idx in relevant_indices else "✗"
|
||||
print(f"{i + 1}. [{relevance}] (idx={idx}, score={score:.3f}) {text[:100]}...")
|
||||
|
||||
print(f"\nMetrics: {metrics}")
|
||||
|
||||
return metrics
|
||||
542
evaluation/metrics.py
Normal file
542
evaluation/metrics.py
Normal file
@ -0,0 +1,542 @@
|
||||
import numpy as np
|
||||
from typing import List, Dict, Tuple, Optional, Union
|
||||
import torch
|
||||
from sklearn.metrics import average_precision_score, ndcg_score
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_retrieval_metrics(
|
||||
query_embeddings: np.ndarray,
|
||||
passage_embeddings: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
k_values: List[int] = [1, 5, 10, 20, 100],
|
||||
return_all_scores: bool = False
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute comprehensive retrieval metrics
|
||||
|
||||
Args:
|
||||
query_embeddings: Query embeddings [num_queries, embedding_dim]
|
||||
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
|
||||
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
|
||||
k_values: List of k values for metrics@k
|
||||
return_all_scores: Whether to return all similarity scores
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
# Compute similarity scores
|
||||
scores = np.matmul(query_embeddings, passage_embeddings.T)
|
||||
|
||||
# Get number of queries
|
||||
num_queries = query_embeddings.shape[0]
|
||||
|
||||
# Convert labels to binary matrix if needed
|
||||
if labels.ndim == 1:
|
||||
# Assume labels contains relevant passage indices
|
||||
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
|
||||
for i, idx in enumerate(labels):
|
||||
if idx < passage_embeddings.shape[0]:
|
||||
binary_labels[i, idx] = 1
|
||||
labels = binary_labels
|
||||
|
||||
# Compute metrics for each k
|
||||
for k in k_values:
|
||||
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
|
||||
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
|
||||
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
|
||||
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
|
||||
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
|
||||
|
||||
# Overall metrics
|
||||
metrics['map'] = compute_mean_average_precision(scores, labels)
|
||||
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
|
||||
|
||||
if return_all_scores:
|
||||
metrics['scores'] = scores
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def compute_retrieval_metrics_comprehensive(
|
||||
query_embeddings: np.ndarray,
|
||||
passage_embeddings: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
k_values: List[int] = [1, 5, 10, 20, 100],
|
||||
return_all_scores: bool = False,
|
||||
compute_additional_metrics: bool = True
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute comprehensive retrieval metrics with additional statistics
|
||||
|
||||
Args:
|
||||
query_embeddings: Query embeddings [num_queries, embedding_dim]
|
||||
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
|
||||
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
|
||||
k_values: List of k values for metrics@k
|
||||
return_all_scores: Whether to return all similarity scores
|
||||
compute_additional_metrics: Whether to compute additional metrics like success rate
|
||||
|
||||
Returns:
|
||||
Dictionary of comprehensive metrics
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
# Validate inputs
|
||||
if query_embeddings.shape[0] == 0 or passage_embeddings.shape[0] == 0:
|
||||
raise ValueError("Empty embeddings provided")
|
||||
|
||||
# Normalize embeddings for cosine similarity
|
||||
query_embeddings = query_embeddings / (np.linalg.norm(query_embeddings, axis=1, keepdims=True) + 1e-8)
|
||||
passage_embeddings = passage_embeddings / (np.linalg.norm(passage_embeddings, axis=1, keepdims=True) + 1e-8)
|
||||
|
||||
# Compute similarity scores
|
||||
scores = np.matmul(query_embeddings, passage_embeddings.T)
|
||||
|
||||
# Get number of queries
|
||||
num_queries = query_embeddings.shape[0]
|
||||
|
||||
# Convert labels to binary matrix if needed
|
||||
if labels.ndim == 1:
|
||||
# Assume labels contains relevant passage indices
|
||||
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
|
||||
for i, idx in enumerate(labels):
|
||||
if idx < passage_embeddings.shape[0]:
|
||||
binary_labels[i, idx] = 1
|
||||
labels = binary_labels
|
||||
|
||||
# Compute metrics for each k
|
||||
for k in k_values:
|
||||
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
|
||||
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
|
||||
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
|
||||
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
|
||||
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
|
||||
|
||||
if compute_additional_metrics:
|
||||
# Success rate (queries with at least one relevant doc in top-k)
|
||||
success_count = 0
|
||||
for i in range(num_queries):
|
||||
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
||||
if np.any(labels[i][top_k_indices] > 0):
|
||||
success_count += 1
|
||||
metrics[f'success@{k}'] = success_count / num_queries if num_queries > 0 else 0.0
|
||||
|
||||
# Overall metrics
|
||||
metrics['map'] = compute_mean_average_precision(scores, labels)
|
||||
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
|
||||
|
||||
if compute_additional_metrics:
|
||||
# Average number of relevant documents per query
|
||||
metrics['avg_relevant_per_query'] = np.mean(np.sum(labels > 0, axis=1))
|
||||
|
||||
# Coverage: fraction of relevant documents that appear in any top-k list
|
||||
max_k = max(k_values)
|
||||
covered_relevant = set()
|
||||
for i in range(num_queries):
|
||||
top_k_indices = np.argsort(scores[i])[::-1][:max_k]
|
||||
relevant_indices = np.where(labels[i] > 0)[0]
|
||||
for idx in relevant_indices:
|
||||
if idx in top_k_indices:
|
||||
covered_relevant.add((i, idx))
|
||||
|
||||
total_relevant = np.sum(labels > 0)
|
||||
metrics['coverage'] = len(covered_relevant) / total_relevant if total_relevant > 0 else 0.0
|
||||
|
||||
if return_all_scores:
|
||||
metrics['scores'] = scores
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def compute_recall_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
||||
"""
|
||||
Compute Recall@K
|
||||
|
||||
Args:
|
||||
scores: Similarity scores [num_queries, num_passages]
|
||||
labels: Binary relevance labels
|
||||
k: Top-k to consider
|
||||
|
||||
Returns:
|
||||
Average recall@k
|
||||
"""
|
||||
num_queries = scores.shape[0]
|
||||
recall_scores = []
|
||||
|
||||
for i in range(num_queries):
|
||||
# Get top-k indices
|
||||
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
||||
|
||||
# Count relevant documents
|
||||
num_relevant = labels[i].sum()
|
||||
if num_relevant == 0:
|
||||
continue
|
||||
|
||||
# Count relevant in top-k
|
||||
num_relevant_in_topk = labels[i][top_k_indices].sum()
|
||||
|
||||
recall_scores.append(num_relevant_in_topk / num_relevant)
|
||||
|
||||
return float(np.mean(recall_scores)) if recall_scores else 0.0
|
||||
|
||||
|
||||
def compute_precision_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
||||
"""
|
||||
Compute Precision@K
|
||||
"""
|
||||
num_queries = scores.shape[0]
|
||||
precision_scores = []
|
||||
|
||||
for i in range(num_queries):
|
||||
# Get top-k indices
|
||||
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
||||
|
||||
# Count relevant in top-k
|
||||
num_relevant_in_topk = labels[i][top_k_indices].sum()
|
||||
|
||||
precision_scores.append(num_relevant_in_topk / k)
|
||||
|
||||
return float(np.mean(precision_scores))
|
||||
|
||||
|
||||
def compute_map_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
||||
"""
|
||||
Compute Mean Average Precision@K
|
||||
"""
|
||||
num_queries = scores.shape[0]
|
||||
ap_scores = []
|
||||
|
||||
for i in range(num_queries):
|
||||
# Get top-k indices
|
||||
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
||||
|
||||
# Compute average precision
|
||||
num_relevant = 0
|
||||
sum_precision = 0
|
||||
|
||||
for j, idx in enumerate(top_k_indices):
|
||||
if labels[i][idx] == 1:
|
||||
num_relevant += 1
|
||||
sum_precision += num_relevant / (j + 1)
|
||||
|
||||
if labels[i].sum() > 0:
|
||||
ap_scores.append(sum_precision / min(labels[i].sum(), k))
|
||||
|
||||
return float(np.mean(ap_scores)) if ap_scores else 0.0
|
||||
|
||||
|
||||
def compute_mrr_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
||||
"""
|
||||
Compute Mean Reciprocal Rank@K with proper error handling
|
||||
|
||||
Args:
|
||||
scores: Similarity scores [num_queries, num_passages]
|
||||
labels: Binary relevance labels
|
||||
k: Top-k to consider
|
||||
|
||||
Returns:
|
||||
MRR@k score
|
||||
"""
|
||||
if scores.shape != labels.shape:
|
||||
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
|
||||
|
||||
num_queries = scores.shape[0]
|
||||
rr_scores = []
|
||||
|
||||
for i in range(num_queries):
|
||||
# Skip queries with no relevant documents
|
||||
if labels[i].sum() == 0:
|
||||
continue
|
||||
|
||||
# Get top-k indices
|
||||
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
||||
|
||||
# Find first relevant document
|
||||
for j, idx in enumerate(top_k_indices):
|
||||
if labels[i][idx] > 0: # Support graded relevance
|
||||
rr_scores.append(1.0 / (j + 1))
|
||||
break
|
||||
else:
|
||||
# No relevant document in top-k
|
||||
rr_scores.append(0.0)
|
||||
|
||||
return float(np.mean(rr_scores)) if rr_scores else 0.0
|
||||
|
||||
|
||||
def compute_ndcg_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
||||
"""
|
||||
Compute Normalized Discounted Cumulative Gain@K with robust error handling
|
||||
|
||||
Args:
|
||||
scores: Similarity scores [num_queries, num_passages]
|
||||
labels: Binary or graded relevance labels
|
||||
k: Top-k to consider
|
||||
|
||||
Returns:
|
||||
NDCG@k score
|
||||
"""
|
||||
if scores.shape != labels.shape:
|
||||
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
|
||||
|
||||
num_queries = scores.shape[0]
|
||||
ndcg_scores = []
|
||||
|
||||
for i in range(num_queries):
|
||||
# Skip queries with no relevant documents
|
||||
if labels[i].sum() == 0:
|
||||
continue
|
||||
|
||||
# Get scores and labels for this query
|
||||
query_scores = scores[i]
|
||||
query_labels = labels[i]
|
||||
|
||||
# Get top-k indices by score
|
||||
top_k_indices = np.argsort(query_scores)[::-1][:k]
|
||||
|
||||
# Compute DCG@k
|
||||
dcg = 0.0
|
||||
for j, idx in enumerate(top_k_indices):
|
||||
if query_labels[idx] > 0:
|
||||
# Use log2(j+2) as denominator (j+1 for rank, +1 to avoid log(1)=0)
|
||||
dcg += query_labels[idx] / np.log2(j + 2)
|
||||
|
||||
# Compute ideal DCG@k
|
||||
ideal_labels = np.sort(query_labels)[::-1][:k]
|
||||
idcg = 0.0
|
||||
for j, label in enumerate(ideal_labels):
|
||||
if label > 0:
|
||||
idcg += label / np.log2(j + 2)
|
||||
|
||||
# Compute NDCG
|
||||
if idcg > 0:
|
||||
ndcg_scores.append(dcg / idcg)
|
||||
else:
|
||||
ndcg_scores.append(0.0)
|
||||
|
||||
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
|
||||
|
||||
|
||||
def compute_mean_average_precision(scores: np.ndarray, labels: np.ndarray) -> float:
|
||||
"""
|
||||
Compute Mean Average Precision (MAP) over all positions
|
||||
"""
|
||||
num_queries = scores.shape[0]
|
||||
ap_scores = []
|
||||
|
||||
for i in range(num_queries):
|
||||
if labels[i].sum() > 0:
|
||||
ap = average_precision_score(labels[i], scores[i])
|
||||
ap_scores.append(ap)
|
||||
|
||||
return float(np.mean(ap_scores)) if ap_scores else 0.0
|
||||
|
||||
|
||||
def compute_mean_reciprocal_rank(scores: np.ndarray, labels: np.ndarray) -> float:
|
||||
"""
|
||||
Compute Mean Reciprocal Rank (MRR) over all positions
|
||||
"""
|
||||
num_queries = scores.shape[0]
|
||||
rr_scores = []
|
||||
|
||||
for i in range(num_queries):
|
||||
# Get ranking
|
||||
ranked_indices = np.argsort(scores[i])[::-1]
|
||||
|
||||
# Find first relevant document
|
||||
for j, idx in enumerate(ranked_indices):
|
||||
if labels[i][idx] == 1:
|
||||
rr_scores.append(1.0 / (j + 1))
|
||||
break
|
||||
else:
|
||||
rr_scores.append(0.0)
|
||||
|
||||
return float(np.mean(rr_scores))
|
||||
|
||||
|
||||
def compute_reranker_metrics(
|
||||
scores: Union[np.ndarray, torch.Tensor],
|
||||
labels: Union[np.ndarray, torch.Tensor],
|
||||
groups: Optional[List[int]] = None,
|
||||
k_values: List[int] = [1, 5, 10]
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute metrics specifically for reranker evaluation
|
||||
|
||||
Args:
|
||||
scores: Relevance scores from reranker
|
||||
labels: Binary relevance labels
|
||||
groups: Group sizes for each query (for listwise evaluation)
|
||||
k_values: List of k values
|
||||
|
||||
Returns:
|
||||
Dictionary of reranker-specific metrics
|
||||
"""
|
||||
# Convert to numpy if needed
|
||||
if isinstance(scores, torch.Tensor):
|
||||
scores = scores.cpu().numpy()
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = labels.cpu().numpy()
|
||||
|
||||
metrics = {}
|
||||
|
||||
if groups is None:
|
||||
# Assume pairwise evaluation
|
||||
# Binary classification metrics
|
||||
predictions = (scores > 0.5).astype(int) # type: ignore[attr-defined]
|
||||
|
||||
# Accuracy
|
||||
metrics['accuracy'] = np.mean(predictions == labels)
|
||||
|
||||
# Compute AUC if we have both positive and negative examples
|
||||
if len(np.unique(labels)) > 1:
|
||||
from sklearn.metrics import roc_auc_score
|
||||
metrics['auc'] = roc_auc_score(labels, scores)
|
||||
|
||||
else:
|
||||
# Listwise evaluation
|
||||
start_idx = 0
|
||||
mrr_scores = []
|
||||
ndcg_scores = defaultdict(list)
|
||||
|
||||
for group_size in groups:
|
||||
# Get scores and labels for this group
|
||||
group_scores = scores[start_idx:start_idx + group_size]
|
||||
group_labels = labels[start_idx:start_idx + group_size]
|
||||
|
||||
# Sort by scores
|
||||
sorted_indices = np.argsort(group_scores)[::-1]
|
||||
|
||||
# MRR
|
||||
for i, idx in enumerate(sorted_indices):
|
||||
if group_labels[idx] == 1:
|
||||
mrr_scores.append(1.0 / (i + 1))
|
||||
break
|
||||
else:
|
||||
mrr_scores.append(0.0)
|
||||
|
||||
# NDCG@k
|
||||
for k in k_values:
|
||||
if group_size >= k:
|
||||
try:
|
||||
ndcg = ndcg_score(
|
||||
group_labels.reshape(1, -1),
|
||||
group_scores.reshape(1, -1),
|
||||
k=k
|
||||
)
|
||||
ndcg_scores[k].append(ndcg)
|
||||
except:
|
||||
pass
|
||||
|
||||
start_idx += group_size
|
||||
|
||||
# Aggregate metrics
|
||||
metrics['mrr'] = np.mean(mrr_scores)
|
||||
|
||||
for k in k_values:
|
||||
if ndcg_scores[k]:
|
||||
metrics[f'ndcg@{k}'] = np.mean(ndcg_scores[k])
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_cross_lingual_retrieval(
|
||||
query_embeddings: Dict[str, np.ndarray],
|
||||
passage_embeddings: np.ndarray,
|
||||
labels: Dict[str, np.ndarray],
|
||||
languages: List[str] = ['en', 'zh']
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Evaluate cross-lingual retrieval performance
|
||||
|
||||
Args:
|
||||
query_embeddings: Dict mapping language to query embeddings
|
||||
passage_embeddings: Passage embeddings
|
||||
labels: Dict mapping language to relevance labels
|
||||
languages: List of languages to evaluate
|
||||
|
||||
Returns:
|
||||
Nested dict of metrics per language
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for lang in languages:
|
||||
if lang in query_embeddings and lang in labels:
|
||||
logger.info(f"Evaluating {lang} queries")
|
||||
|
||||
results[lang] = compute_retrieval_metrics(
|
||||
query_embeddings[lang],
|
||||
passage_embeddings,
|
||||
labels[lang]
|
||||
)
|
||||
|
||||
# Compute cross-lingual metrics if we have multiple languages
|
||||
if len(languages) > 1 and all(lang in results for lang in languages):
|
||||
# Average performance across languages
|
||||
avg_metrics = {}
|
||||
|
||||
for metric in results[languages[0]].keys():
|
||||
if metric != 'scores':
|
||||
values = [results[lang][metric] for lang in languages]
|
||||
avg_metrics[metric] = np.mean(values)
|
||||
|
||||
results['average'] = avg_metrics
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MetricsTracker:
|
||||
"""
|
||||
Track and aggregate metrics during training
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = defaultdict(list)
|
||||
self.best_metrics = {}
|
||||
self.best_steps = {}
|
||||
|
||||
def update(self, metrics: Dict[str, float], step: int):
|
||||
"""Update metrics for current step"""
|
||||
for key, value in metrics.items():
|
||||
self.metrics[key].append((step, value))
|
||||
|
||||
# Track best metrics
|
||||
if key not in self.best_metrics or value > self.best_metrics[key]:
|
||||
self.best_metrics[key] = value
|
||||
self.best_steps[key] = step
|
||||
|
||||
def get_current(self, metric: str) -> Optional[float]:
|
||||
"""Get most recent value for a metric"""
|
||||
if metric in self.metrics and self.metrics[metric]:
|
||||
return self.metrics[metric][-1][1]
|
||||
return None
|
||||
|
||||
def get_best(self, metric: str) -> Optional[Tuple[int, float]]:
|
||||
"""Get best value and step for a metric"""
|
||||
if metric in self.best_metrics:
|
||||
return self.best_steps[metric], self.best_metrics[metric]
|
||||
return None
|
||||
|
||||
def get_history(self, metric: str) -> List[Tuple[int, float]]:
|
||||
"""Get full history for a metric"""
|
||||
return self.metrics.get(metric, [])
|
||||
|
||||
def summary(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Get summary of all metrics"""
|
||||
summary = {
|
||||
'current': {},
|
||||
'best': {},
|
||||
'best_steps': {}
|
||||
}
|
||||
|
||||
for metric in self.metrics:
|
||||
summary['current'][metric] = self.get_current(metric)
|
||||
summary['best'][metric] = self.best_metrics.get(metric)
|
||||
summary['best_steps'][metric] = self.best_steps.get(metric)
|
||||
|
||||
return summary
|
||||
25
models/__init__.py
Normal file
25
models/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""
|
||||
Model implementations for BGE fine-tuning
|
||||
"""
|
||||
from .bge_m3 import BGEM3Model
|
||||
from .bge_reranker import BGERerankerModel
|
||||
from .losses import (
|
||||
InfoNCELoss,
|
||||
M3KnowledgeDistillationLoss,
|
||||
ListwiseCrossEntropyLoss,
|
||||
PairwiseMarginLoss,
|
||||
DynamicListwiseDistillationLoss,
|
||||
MultiTaskLoss
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BGEM3Model',
|
||||
'BGERerankerModel',
|
||||
'InfoNCELoss',
|
||||
'M3KnowledgeDistillationLoss',
|
||||
'ListwiseCrossEntropyLoss',
|
||||
'PairwiseMarginLoss',
|
||||
'DynamicListwiseDistillationLoss',
|
||||
'MultiTaskLoss'
|
||||
]
|
||||
768
models/bge_m3.py
Normal file
768
models/bge_m3.py
Normal file
@ -0,0 +1,768 @@
|
||||
"""
|
||||
BGE-M3 model wrapper for multi-functionality embedding
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
from utils.config_loader import get_config
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(model_name_or_path, cache_dir=None):
|
||||
"""
|
||||
Load model and tokenizer with comprehensive error handling and fallback strategies
|
||||
|
||||
Args:
|
||||
model_name_or_path: Model identifier or local path
|
||||
cache_dir: Cache directory for downloaded models
|
||||
|
||||
Returns:
|
||||
Tuple of (model, tokenizer, config)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails after all attempts
|
||||
"""
|
||||
config = get_config()
|
||||
model_source = config.get('model_paths', 'source', 'huggingface')
|
||||
|
||||
# Track loading attempts for debugging
|
||||
attempts = []
|
||||
|
||||
# Try primary loading method
|
||||
try:
|
||||
if model_source == 'huggingface':
|
||||
return _load_from_huggingface(model_name_or_path, cache_dir)
|
||||
elif model_source == 'modelscope':
|
||||
return _load_from_modelscope(model_name_or_path, cache_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown model source: {model_source}")
|
||||
except Exception as e:
|
||||
attempts.append(f"{model_source}: {str(e)}")
|
||||
logger.warning(f"Primary loading from {model_source} failed: {e}")
|
||||
|
||||
# Fallback strategies
|
||||
fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope']
|
||||
|
||||
for fallback in fallback_sources:
|
||||
try:
|
||||
logger.info(f"Attempting fallback loading from {fallback}")
|
||||
if fallback == 'huggingface':
|
||||
return _load_from_huggingface(model_name_or_path, cache_dir)
|
||||
elif fallback == 'modelscope':
|
||||
return _load_from_modelscope(model_name_or_path, cache_dir)
|
||||
except Exception as e:
|
||||
attempts.append(f"{fallback}: {str(e)}")
|
||||
logger.warning(f"Fallback loading from {fallback} failed: {e}")
|
||||
|
||||
# Try loading from local path if it exists
|
||||
if os.path.exists(model_name_or_path):
|
||||
try:
|
||||
logger.info(f"Attempting to load from local path: {model_name_or_path}")
|
||||
return _load_from_local(model_name_or_path, cache_dir)
|
||||
except Exception as e:
|
||||
attempts.append(f"local: {str(e)}")
|
||||
logger.warning(f"Local loading failed: {e}")
|
||||
|
||||
# All attempts failed
|
||||
error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def _load_from_huggingface(model_name_or_path, cache_dir=None):
|
||||
"""Load model from HuggingFace with error handling"""
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
|
||||
# Load config first to validate model
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}")
|
||||
return model, tokenizer, model_config
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Transformers library not available: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"HuggingFace loading error: {e}")
|
||||
|
||||
|
||||
def _load_from_modelscope(model_name_or_path, cache_dir=None):
|
||||
"""Load model from ModelScope with error handling"""
|
||||
try:
|
||||
from modelscope.models import Model as MSModel
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
|
||||
# Download model if needed
|
||||
local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
||||
|
||||
# Load model
|
||||
model = MSModel.from_pretrained(local_path)
|
||||
|
||||
# Try to load tokenizer
|
||||
tokenizer = None
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load tokenizer from ModelScope: {e}")
|
||||
|
||||
# Try to get config
|
||||
config = None
|
||||
config_path = os.path.join(local_path, 'config.json')
|
||||
if os.path.exists(config_path):
|
||||
import json
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}")
|
||||
return model, tokenizer, config
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(f"ModelScope library not available: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"ModelScope loading error: {e}")
|
||||
|
||||
|
||||
def _load_from_local(model_path, cache_dir=None):
|
||||
"""Load model from local directory with error handling"""
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
|
||||
if not os.path.isdir(model_path):
|
||||
raise ValueError(f"Local path is not a directory: {model_path}")
|
||||
|
||||
# Check for required files
|
||||
required_files = ['config.json']
|
||||
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
|
||||
if missing_files:
|
||||
raise FileNotFoundError(f"Missing required files: {missing_files}")
|
||||
|
||||
# Load components
|
||||
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
logger.info(f"Successfully loaded model from local path: {model_path}")
|
||||
return model, tokenizer, model_config
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Local loading error: {e}")
|
||||
|
||||
|
||||
class BGEM3Model(nn.Module):
|
||||
"""
|
||||
BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval
|
||||
|
||||
Notes:
|
||||
- All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "BAAI/bge-m3",
|
||||
use_dense: bool = True,
|
||||
use_sparse: bool = True,
|
||||
use_colbert: bool = True,
|
||||
pooling_method: str = "cls",
|
||||
normalize_embeddings: bool = True,
|
||||
sentence_pooling_method: str = "cls",
|
||||
colbert_dim: int = -1,
|
||||
sparse_top_k: int = 100,
|
||||
cache_dir: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
gradient_checkpointing: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.use_dense = use_dense
|
||||
self.use_sparse = use_sparse
|
||||
self.use_colbert = use_colbert
|
||||
self.pooling_method = pooling_method
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.colbert_dim = colbert_dim
|
||||
self.sparse_top_k = sparse_top_k
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
# Initialize model with comprehensive error handling
|
||||
try:
|
||||
logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...")
|
||||
self.model, self.tokenizer, self.config = load_model_and_tokenizer(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
logger.info("BGE-M3 model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load BGE-M3 model: {e}")
|
||||
raise RuntimeError(f"Model initialization failed: {e}")
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if self.gradient_checkpointing:
|
||||
self.enable_gradient_checkpointing()
|
||||
|
||||
# Ensure we have a valid config
|
||||
if self.config is None:
|
||||
logger.warning("Model config is None, creating a basic config")
|
||||
# Create a basic config object
|
||||
class BasicConfig:
|
||||
def __init__(self):
|
||||
self.hidden_size = 768 # Default BERT hidden size
|
||||
self.config = BasicConfig()
|
||||
|
||||
# Initialize heads for different functionalities
|
||||
hidden_size = getattr(self.config, 'hidden_size', 768)
|
||||
|
||||
if self.use_dense:
|
||||
# Dense retrieval doesn't need additional layers (uses CLS token)
|
||||
self.dense_linear = nn.Identity()
|
||||
|
||||
if self.use_sparse:
|
||||
# Sparse retrieval head (SPLADE-style)
|
||||
self.sparse_linear = nn.Linear(hidden_size, 1)
|
||||
self.sparse_activation = nn.ReLU()
|
||||
|
||||
if self.use_colbert:
|
||||
# ColBERT projection layer
|
||||
colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size
|
||||
self.colbert_linear = nn.Linear(hidden_size, colbert_dim)
|
||||
|
||||
# Set device with enhanced detection
|
||||
if device:
|
||||
self.device = torch.device(device)
|
||||
else:
|
||||
self.device = self._detect_device()
|
||||
|
||||
try:
|
||||
self.to(self.device)
|
||||
logger.info(f"BGE-M3 model moved to device: {self.device}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move model to {self.device}: {e}, using CPU")
|
||||
self.device = torch.device('cpu')
|
||||
self.to(self.device)
|
||||
|
||||
# Cache for tokenizer
|
||||
self._tokenizer = None
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
"""Enable gradient checkpointing for memory efficiency"""
|
||||
if hasattr(self.model, 'gradient_checkpointing_enable'):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
logger.info("Gradient checkpointing enabled for BGE-M3")
|
||||
else:
|
||||
logger.warning("Model does not support gradient checkpointing")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
"""Disable gradient checkpointing"""
|
||||
if hasattr(self.model, 'gradient_checkpointing_disable'):
|
||||
self.model.gradient_checkpointing_disable()
|
||||
logger.info("Gradient checkpointing disabled for BGE-M3")
|
||||
|
||||
def _detect_device(self) -> torch.device:
|
||||
"""Enhanced device detection with proper error handling"""
|
||||
# Check Ascend NPU first
|
||||
try:
|
||||
import torch_npu
|
||||
if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined]
|
||||
logger.info("Using Ascend NPU")
|
||||
return torch.device("npu")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"NPU detection failed: {e}")
|
||||
pass
|
||||
|
||||
# Check CUDA
|
||||
if torch.cuda.is_available():
|
||||
logger.info("Using CUDA GPU")
|
||||
return torch.device("cuda")
|
||||
|
||||
# Check MPS (Apple Silicon)
|
||||
try:
|
||||
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
||||
logger.info("Using Apple MPS")
|
||||
return torch.device("mps")
|
||||
except Exception as e:
|
||||
logger.debug(f"MPS detection failed: {e}")
|
||||
pass
|
||||
|
||||
# Default to CPU
|
||||
logger.info("Using CPU")
|
||||
return torch.device("cpu")
|
||||
|
||||
def _get_tokenizer(self):
|
||||
"""Lazy load tokenizer with error handling"""
|
||||
if self._tokenizer is None:
|
||||
try:
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
logger.info("Tokenizer loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tokenizer: {e}")
|
||||
raise RuntimeError(f"Tokenizer loading failed: {e}")
|
||||
return self._tokenizer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
return_dense: bool = True,
|
||||
return_sparse: bool = False,
|
||||
return_colbert: bool = False,
|
||||
return_dict: bool = True
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads.
|
||||
Args:
|
||||
input_ids: Token IDs [batch_size, seq_length]
|
||||
attention_mask: Attention mask [batch_size, seq_length]
|
||||
token_type_ids: Token type IDs (optional)
|
||||
return_dense: Return dense embeddings
|
||||
return_sparse: Return sparse embeddings
|
||||
return_colbert: Return ColBERT embeddings
|
||||
return_dict: Return dictionary of outputs
|
||||
Returns:
|
||||
Embeddings or dictionary of embeddings
|
||||
"""
|
||||
# Validate inputs
|
||||
if input_ids is None or attention_mask is None:
|
||||
raise ValueError("input_ids and attention_mask are required")
|
||||
if input_ids.shape != attention_mask.shape:
|
||||
raise ValueError("input_ids and attention_mask must have the same shape")
|
||||
|
||||
# Determine what to return based on model configuration and parameters
|
||||
return_dense = return_dense and self.use_dense
|
||||
return_sparse = return_sparse and self.use_sparse
|
||||
return_colbert = return_colbert and self.use_colbert
|
||||
|
||||
if not (return_dense or return_sparse or return_colbert):
|
||||
logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.")
|
||||
return_dense = True # Default to dense if nothing selected
|
||||
|
||||
try:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
return_dict=True
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
|
||||
except Exception as e:
|
||||
logger.error(f"Model forward pass failed: {e}")
|
||||
raise RuntimeError(f"Forward pass failed: {e}")
|
||||
|
||||
results = {}
|
||||
|
||||
# Dense embeddings (CLS or mean pooling)
|
||||
if return_dense:
|
||||
if self.pooling_method == "cls" or self.sentence_pooling_method == "cls":
|
||||
dense_vecs = last_hidden_state[:, 0]
|
||||
elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean":
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape)
|
||||
masked_embeddings = last_hidden_state * expanded_mask
|
||||
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
||||
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
|
||||
dense_vecs = sum_embeddings / sum_mask
|
||||
else:
|
||||
raise ValueError(f"Unknown pooling method: {self.pooling_method}")
|
||||
|
||||
dense_vecs = self.dense_linear(dense_vecs)
|
||||
# Ensure embeddings are float type
|
||||
dense_vecs = dense_vecs.float()
|
||||
if self.normalize_embeddings:
|
||||
dense_vecs = F.normalize(dense_vecs, p=2, dim=-1)
|
||||
results['dense'] = dense_vecs
|
||||
|
||||
# SPLADE+ sparse head
|
||||
if return_sparse:
|
||||
# SPLADE+ uses log1p(ReLU(.))
|
||||
sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len]
|
||||
sparse_activations = torch.log1p(F.relu(sparse_logits))
|
||||
|
||||
# Mask out padding tokens
|
||||
sparse_activations = sparse_activations * attention_mask.float()
|
||||
|
||||
# Optionally normalize sparse activations
|
||||
if self.normalize_embeddings:
|
||||
sparse_activations = F.normalize(sparse_activations, p=2, dim=-1)
|
||||
|
||||
results['sparse'] = sparse_activations
|
||||
|
||||
# FLOPS regularization (for loss, not forward, but add as attribute)
|
||||
self.splade_flops = torch.sum(sparse_activations)
|
||||
|
||||
# ColBERT head with late interaction (maxsim)
|
||||
if return_colbert:
|
||||
colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim]
|
||||
|
||||
# Mask out padding tokens
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape)
|
||||
colbert_vecs = colbert_vecs * expanded_mask
|
||||
|
||||
if self.normalize_embeddings:
|
||||
colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1)
|
||||
|
||||
results['colbert'] = colbert_vecs
|
||||
|
||||
if return_dict:
|
||||
return results
|
||||
|
||||
# If only one head requested, return that
|
||||
if return_dense and len(results) == 1:
|
||||
return results['dense']
|
||||
if return_sparse and len(results) == 1:
|
||||
return results['sparse']
|
||||
if return_colbert and len(results) == 1:
|
||||
return results['colbert']
|
||||
|
||||
return results
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sentences: Union[str, List[str]],
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
convert_to_numpy: bool = True,
|
||||
convert_to_tensor: bool = False,
|
||||
show_progress_bar: bool = False,
|
||||
normalize_embeddings: bool = True,
|
||||
return_dense: bool = True,
|
||||
return_sparse: bool = False,
|
||||
return_colbert: bool = False,
|
||||
device: Optional[str] = None
|
||||
) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]:
|
||||
"""
|
||||
Complete encode implementation with comprehensive error handling
|
||||
|
||||
Args:
|
||||
sentences: Input sentences
|
||||
batch_size: Batch size for encoding
|
||||
max_length: Maximum sequence length
|
||||
convert_to_numpy: Convert to numpy arrays
|
||||
convert_to_tensor: Keep as tensors
|
||||
show_progress_bar: Show progress bar
|
||||
normalize_embeddings: Normalize embeddings
|
||||
return_dense/sparse/colbert: Which embeddings to return
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Embeddings
|
||||
"""
|
||||
# Input validation
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
if not sentences:
|
||||
raise ValueError("Empty sentences list provided")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
if max_length <= 0:
|
||||
raise ValueError("max_length must be positive")
|
||||
|
||||
# Get tokenizer
|
||||
try:
|
||||
tokenizer = self._get_tokenizer()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
||||
|
||||
# Determine target device
|
||||
target_device = device if device else self.device
|
||||
|
||||
# Initialize storage for embeddings
|
||||
all_embeddings = {
|
||||
'dense': [] if return_dense else None,
|
||||
'sparse': [] if return_sparse else None,
|
||||
'colbert': [] if return_colbert else None
|
||||
}
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.eval()
|
||||
|
||||
try:
|
||||
# Import tqdm if show_progress_bar is True
|
||||
if show_progress_bar:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
iterator = tqdm(
|
||||
range(0, len(sentences), batch_size),
|
||||
desc="Encoding",
|
||||
total=(len(sentences) + batch_size - 1) // batch_size
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("tqdm not available, disabling progress bar")
|
||||
iterator = range(0, len(sentences), batch_size)
|
||||
else:
|
||||
iterator = range(0, len(sentences), batch_size)
|
||||
|
||||
# Process in batches
|
||||
for i in iterator:
|
||||
batch_sentences = sentences[i:i + batch_size]
|
||||
|
||||
# Tokenize with proper error handling
|
||||
try:
|
||||
encoded = tokenizer(
|
||||
batch_sentences,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}")
|
||||
|
||||
# Move to device
|
||||
try:
|
||||
encoded = {k: v.to(target_device) for k, v in encoded.items()}
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to move batch to device {target_device}: {e}")
|
||||
|
||||
# Get embeddings
|
||||
with torch.no_grad():
|
||||
try:
|
||||
outputs = self.forward(
|
||||
**encoded,
|
||||
return_dense=return_dense,
|
||||
return_sparse=return_sparse,
|
||||
return_colbert=return_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}")
|
||||
|
||||
# Collect embeddings
|
||||
for key in ['dense', 'sparse', 'colbert']:
|
||||
if all_embeddings[key] is not None and key in outputs:
|
||||
all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type]
|
||||
|
||||
# Concatenate and convert results
|
||||
final_embeddings = {}
|
||||
for key in ['dense', 'sparse', 'colbert']:
|
||||
if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type]
|
||||
try:
|
||||
embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type]
|
||||
|
||||
if normalize_embeddings and key == 'dense':
|
||||
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
||||
|
||||
if convert_to_numpy:
|
||||
embeddings = embeddings.numpy()
|
||||
elif not convert_to_tensor:
|
||||
embeddings = embeddings
|
||||
|
||||
final_embeddings[key] = embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {key} embeddings: {e}")
|
||||
raise
|
||||
|
||||
# Return format
|
||||
if len(final_embeddings) == 1:
|
||||
return list(final_embeddings.values())[0]
|
||||
|
||||
if not final_embeddings:
|
||||
raise RuntimeError("No embeddings were generated")
|
||||
|
||||
return final_embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Encoding failed: {e}")
|
||||
raise
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
"""Save model to directory with error handling, including custom heads and config."""
|
||||
try:
|
||||
import os
|
||||
import json
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save the base model
|
||||
self.model.save_pretrained(save_path)
|
||||
|
||||
# Save additional configuration (including [m3] section)
|
||||
from utils.config_loader import get_config
|
||||
config_dict = get_config().get_section('m3')
|
||||
config = {
|
||||
'use_dense': self.use_dense,
|
||||
'use_sparse': self.use_sparse,
|
||||
'use_colbert': self.use_colbert,
|
||||
'pooling_method': self.pooling_method,
|
||||
'normalize_embeddings': self.normalize_embeddings,
|
||||
'colbert_dim': self.colbert_dim,
|
||||
'sparse_top_k': self.sparse_top_k,
|
||||
'model_name_or_path': self.model_name_or_path,
|
||||
'm3_config': config_dict
|
||||
}
|
||||
with open(f"{save_path}/m3_config.json", 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save custom heads if needed
|
||||
if self.use_sparse:
|
||||
torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt'))
|
||||
if self.use_colbert:
|
||||
torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt'))
|
||||
|
||||
# Save tokenizer if available
|
||||
if self._tokenizer is not None:
|
||||
self._tokenizer.save_pretrained(save_path)
|
||||
|
||||
logger.info(f"Model and config saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save model: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, **kwargs):
|
||||
"""Load model from directory, restoring custom heads and config."""
|
||||
import os
|
||||
import json
|
||||
try:
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "m3_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# Instantiate model
|
||||
model = cls(
|
||||
model_name_or_path=config.get('model_name_or_path', model_path),
|
||||
use_dense=config.get('use_dense', True),
|
||||
use_sparse=config.get('use_sparse', True),
|
||||
use_colbert=config.get('use_colbert', True),
|
||||
pooling_method=config.get('pooling_method', 'cls'),
|
||||
normalize_embeddings=config.get('normalize_embeddings', True),
|
||||
colbert_dim=config.get('colbert_dim', -1),
|
||||
sparse_top_k=config.get('sparse_top_k', 100),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Load base model weights
|
||||
model.model = model.model.from_pretrained(model_path)
|
||||
|
||||
# Load custom heads if present
|
||||
if model.use_sparse:
|
||||
sparse_path = os.path.join(model_path, 'sparse_linear.pt')
|
||||
if os.path.exists(sparse_path):
|
||||
model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu'))
|
||||
if model.use_colbert:
|
||||
colbert_path = os.path.join(model_path, 'colbert_linear.pt')
|
||||
if os.path.exists(colbert_path):
|
||||
model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu'))
|
||||
|
||||
logger.info(f"Loaded BGE-M3 model from {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model from {model_path}: {e}")
|
||||
raise
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
queries: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
passages: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
mode: str = "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute similarity scores between queries and passages with error handling
|
||||
|
||||
Args:
|
||||
queries: Query embeddings
|
||||
passages: Passage embeddings
|
||||
mode: Similarity mode ('dense', 'sparse', 'colbert')
|
||||
|
||||
Returns:
|
||||
Similarity scores [num_queries, num_passages]
|
||||
"""
|
||||
try:
|
||||
# Extract embeddings based on mode
|
||||
if isinstance(queries, dict):
|
||||
if mode not in queries:
|
||||
raise ValueError(f"Mode '{mode}' not found in query embeddings")
|
||||
query_emb = queries[mode]
|
||||
else:
|
||||
query_emb = queries
|
||||
|
||||
if isinstance(passages, dict):
|
||||
if mode not in passages:
|
||||
raise ValueError(f"Mode '{mode}' not found in passage embeddings")
|
||||
passage_emb = passages[mode]
|
||||
else:
|
||||
passage_emb = passages
|
||||
|
||||
# Validate tensor shapes
|
||||
if query_emb.dim() < 2 or passage_emb.dim() < 2:
|
||||
raise ValueError("Embeddings must be at least 2-dimensional")
|
||||
|
||||
if mode == "dense":
|
||||
# Simple dot product for dense
|
||||
return torch.matmul(query_emb, passage_emb.t())
|
||||
|
||||
elif mode == "sparse":
|
||||
# Dot product for sparse (simplified)
|
||||
return torch.matmul(query_emb, passage_emb.t())
|
||||
|
||||
elif mode == "colbert":
|
||||
# MaxSim for ColBERT
|
||||
if query_emb.dim() != 3 or passage_emb.dim() != 3:
|
||||
raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]")
|
||||
|
||||
scores = []
|
||||
for q in query_emb:
|
||||
# q: [query_len, dim]
|
||||
# passage_emb: [num_passages, passage_len, dim]
|
||||
passage_scores = []
|
||||
for p in passage_emb:
|
||||
# Compute similarity matrix
|
||||
sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len]
|
||||
# Max over passage tokens for each query token
|
||||
max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len]
|
||||
# Sum over query tokens
|
||||
score = max_sim_scores.sum()
|
||||
passage_scores.append(score)
|
||||
scores.append(torch.stack(passage_scores))
|
||||
return torch.stack(scores)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown similarity mode: {mode}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Similarity computation failed: {e}")
|
||||
raise
|
||||
|
||||
def get_sentence_embedding_dimension(self) -> int:
|
||||
"""Get the dimension of sentence embeddings"""
|
||||
return self.config.hidden_size # type: ignore[attr-defined]
|
||||
|
||||
def get_word_embedding_dimension(self) -> int:
|
||||
"""Get the dimension of word embeddings"""
|
||||
return self.config.hidden_size # type: ignore[attr-defined]
|
||||
483
models/bge_reranker.py
Normal file
483
models/bge_reranker.py
Normal file
@ -0,0 +1,483 @@
|
||||
"""
|
||||
BGE-Reranker cross-encoder model wrapper
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerOutput:
|
||||
"""Output class for reranker predictions"""
|
||||
scores: torch.Tensor
|
||||
logits: torch.Tensor
|
||||
embeddings: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1):
|
||||
"""Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml"""
|
||||
config = get_config()
|
||||
|
||||
# Try to get model source from config, with fallback to huggingface
|
||||
try:
|
||||
model_source = config.get('model_paths', 'source', 'huggingface')
|
||||
except Exception:
|
||||
# Fallback if config loading fails
|
||||
logger.warning("Failed to load model source from config, defaulting to 'huggingface'")
|
||||
model_source = 'huggingface'
|
||||
|
||||
logger.info(f"Loading reranker model from source: {model_source}")
|
||||
|
||||
if model_source == 'huggingface':
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
try:
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
use_base_model = False
|
||||
except:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
use_base_model = True
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
return model, tokenizer, model_config, use_base_model
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope.models import Model as MSModel
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
||||
model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
tokenizer = None # ModelScope tokenizer loading can be added if needed
|
||||
model_config = None
|
||||
use_base_model = False
|
||||
return model, tokenizer, model_config, use_base_model
|
||||
except ImportError:
|
||||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load reranker model from ModelScope: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||||
raise RuntimeError(f"Unknown model source: {model_source}")
|
||||
|
||||
|
||||
class BGERerankerModel(nn.Module):
|
||||
"""
|
||||
BGE-Reranker cross-encoder model for passage reranking
|
||||
|
||||
Notes:
|
||||
- All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "BAAI/bge-reranker-base",
|
||||
num_labels: int = 1,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_fp16: bool = False,
|
||||
gradient_checkpointing: bool = False,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize BGE-Reranker model.
|
||||
Args:
|
||||
model_name_or_path: Path or identifier for pretrained model
|
||||
num_labels: Number of output labels
|
||||
cache_dir: Directory for model cache
|
||||
use_fp16: Use mixed precision
|
||||
gradient_checkpointing: Enable gradient checkpointing
|
||||
device: Device to use
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.num_labels = num_labels
|
||||
self.use_fp16 = use_fp16
|
||||
# Load configuration and model
|
||||
try:
|
||||
logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...")
|
||||
self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
num_labels=num_labels
|
||||
)
|
||||
logger.info("BGE-Reranker model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load BGE-Reranker model: {e}")
|
||||
raise RuntimeError(f"Model initialization failed: {e}")
|
||||
# Add classification head
|
||||
if self.use_base_model:
|
||||
self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined]
|
||||
self.dropout = nn.Dropout(
|
||||
self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined]
|
||||
# Enable gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
try:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
except Exception:
|
||||
logger.warning("Gradient checkpointing requested but not supported by this model.")
|
||||
# Set device
|
||||
if device:
|
||||
try:
|
||||
self.device = torch.device(device)
|
||||
self.to(self.device)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move model to {device}: {e}, using CPU")
|
||||
self.device = torch.device('cpu')
|
||||
self.to(self.device)
|
||||
else:
|
||||
self.device = next(self.parameters()).device
|
||||
# Mixed precision
|
||||
if self.use_fp16:
|
||||
try:
|
||||
self.half()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set model to half precision: {e}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True
|
||||
) -> Union[torch.Tensor, RerankerOutput]:
|
||||
"""
|
||||
Forward pass for cross-encoder reranking
|
||||
|
||||
Args:
|
||||
input_ids: Token IDs [batch_size, seq_length]
|
||||
attention_mask: Attention mask [batch_size, seq_length]
|
||||
token_type_ids: Token type IDs (optional)
|
||||
labels: Labels for training (optional)
|
||||
return_dict: Return RerankerOutput object
|
||||
|
||||
Returns:
|
||||
Scores or RerankerOutput
|
||||
"""
|
||||
if self.use_base_model:
|
||||
# Use base model + custom head
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get pooled output (CLS token)
|
||||
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
||||
pooled_output = outputs.pooler_output
|
||||
else:
|
||||
# Manual pooling if pooler not available
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
pooled_output = last_hidden_state[:, 0]
|
||||
|
||||
# Apply dropout and classifier
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
embeddings = pooled_output
|
||||
else:
|
||||
# Use sequence classification model
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=labels,
|
||||
return_dict=True
|
||||
)
|
||||
logits = outputs.logits
|
||||
embeddings = None
|
||||
|
||||
# Convert logits to scores (sigmoid for binary classification)
|
||||
if self.num_labels == 1:
|
||||
scores = torch.sigmoid(logits.squeeze(-1))
|
||||
else:
|
||||
scores = F.softmax(logits, dim=-1)
|
||||
|
||||
if not return_dict:
|
||||
return scores
|
||||
|
||||
return RerankerOutput(
|
||||
scores=scores,
|
||||
logits=logits,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
queries: Union[str, List[str]],
|
||||
passages: Union[str, List[str], List[List[str]]],
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
show_progress_bar: bool = False,
|
||||
convert_to_numpy: bool = True,
|
||||
normalize_scores: bool = False
|
||||
) -> Union[np.ndarray, torch.Tensor, List[float]]:
|
||||
"""
|
||||
Predict reranking scores for query-passage pairs
|
||||
|
||||
Args:
|
||||
queries: Query or list of queries
|
||||
passages: Passage(s) or list of passages per query
|
||||
batch_size: Batch size for inference
|
||||
max_length: Maximum sequence length
|
||||
show_progress_bar: Show progress bar
|
||||
convert_to_numpy: Convert to numpy array
|
||||
normalize_scores: Normalize scores to [0, 1]
|
||||
|
||||
Returns:
|
||||
Reranking scores
|
||||
"""
|
||||
# Handle input formats
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
|
||||
if isinstance(passages, str):
|
||||
passages = [[passages]]
|
||||
elif isinstance(passages[0], str):
|
||||
passages = [passages] # type: ignore[assignment]
|
||||
|
||||
# Ensure queries and passages match
|
||||
if len(queries) == 1 and len(passages) > 1:
|
||||
queries = queries * len(passages)
|
||||
|
||||
# Create pairs
|
||||
pairs = []
|
||||
for query, passage_list in zip(queries, passages):
|
||||
if isinstance(passage_list, str):
|
||||
passage_list = [passage_list]
|
||||
for passage in passage_list:
|
||||
pairs.append((query, passage))
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
# Score in batches
|
||||
all_scores = []
|
||||
|
||||
if show_progress_bar:
|
||||
from tqdm import tqdm
|
||||
pbar = tqdm(total=len(pairs), desc="Scoring")
|
||||
|
||||
for i in range(0, len(pairs), batch_size):
|
||||
batch_pairs = pairs[i:i + batch_size]
|
||||
|
||||
# Tokenize
|
||||
batch_queries = [pair[0] for pair in batch_pairs]
|
||||
batch_passages = [pair[1] for pair in batch_pairs]
|
||||
|
||||
encoded = tokenizer(
|
||||
batch_queries,
|
||||
batch_passages,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
# Move to device
|
||||
encoded = {k: v.to(self.device) for k, v in encoded.items()}
|
||||
|
||||
# Get scores
|
||||
with torch.no_grad():
|
||||
outputs = self.forward(**encoded, return_dict=True)
|
||||
scores = outputs.scores # type: ignore[attr-defined]
|
||||
|
||||
all_scores.append(scores)
|
||||
|
||||
if show_progress_bar:
|
||||
pbar.update(len(batch_pairs))
|
||||
|
||||
if show_progress_bar:
|
||||
pbar.close()
|
||||
|
||||
# Concatenate scores
|
||||
all_scores = torch.cat(all_scores, dim=0)
|
||||
|
||||
# Normalize if requested
|
||||
if normalize_scores:
|
||||
all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8)
|
||||
|
||||
# Convert format
|
||||
if convert_to_numpy:
|
||||
all_scores = all_scores.cpu().numpy()
|
||||
elif not isinstance(all_scores, torch.Tensor):
|
||||
all_scores = all_scores.cpu()
|
||||
|
||||
# Reshape if single query
|
||||
if len(queries) == 1 and len(passages) == 1:
|
||||
return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined]
|
||||
|
||||
return all_scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
passages: List[str],
|
||||
top_k: Optional[int] = None,
|
||||
return_scores: bool = False,
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512
|
||||
) -> Union[List[str], Tuple[List[str], List[float]]]:
|
||||
"""
|
||||
Rerank passages for a query
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
passages: List of passages to rerank
|
||||
top_k: Return only top-k passages
|
||||
return_scores: Also return scores
|
||||
batch_size: Batch size for scoring
|
||||
max_length: Maximum sequence length
|
||||
|
||||
Returns:
|
||||
Reranked passages (and optionally scores)
|
||||
"""
|
||||
# Get scores
|
||||
scores = self.predict(
|
||||
query,
|
||||
passages,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
convert_to_numpy=True
|
||||
)
|
||||
|
||||
# Sort by scores (descending)
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
# Apply top_k if specified
|
||||
if top_k is not None:
|
||||
sorted_indices = sorted_indices[:top_k]
|
||||
|
||||
# Get reranked passages
|
||||
reranked_passages = [passages[i] for i in sorted_indices]
|
||||
|
||||
if return_scores:
|
||||
reranked_scores = [float(scores[i]) for i in sorted_indices]
|
||||
return reranked_passages, reranked_scores
|
||||
|
||||
return reranked_passages
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
"""Save model to directory, including custom heads and config."""
|
||||
try:
|
||||
import os
|
||||
import json
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save the base model
|
||||
self.model.save_pretrained(save_path)
|
||||
|
||||
# Save additional configuration (including [reranker] section)
|
||||
from utils.config_loader import get_config
|
||||
config_dict = get_config().get_section('reranker')
|
||||
config = {
|
||||
'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False,
|
||||
'num_labels': getattr(self, 'num_labels', 1),
|
||||
'reranker_config': config_dict
|
||||
}
|
||||
with open(f"{save_path}/reranker_config.json", 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save custom heads if present
|
||||
if hasattr(self, 'classifier'):
|
||||
torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt'))
|
||||
if hasattr(self, 'dropout'):
|
||||
torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt'))
|
||||
|
||||
# Save tokenizer if available
|
||||
if hasattr(self, '_tokenizer') and self._tokenizer is not None:
|
||||
self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
|
||||
|
||||
logger.info(f"Reranker model and config saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save reranker model: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, **kwargs):
|
||||
"""Load model from directory, restoring custom heads and config."""
|
||||
import os
|
||||
import json
|
||||
try:
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "reranker_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# Instantiate model
|
||||
model = cls(
|
||||
model_name_or_path=model_path,
|
||||
num_labels=config.get('num_labels', 1),
|
||||
use_fp16=config.get('use_fp16', False),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Load base model weights
|
||||
model.model = model.model.from_pretrained(model_path)
|
||||
|
||||
# Load custom heads if present
|
||||
classifier_path = os.path.join(model_path, 'classifier.pt')
|
||||
if hasattr(model, 'classifier') and os.path.exists(classifier_path):
|
||||
model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu'))
|
||||
dropout_path = os.path.join(model_path, 'dropout.pt')
|
||||
if hasattr(model, 'dropout') and os.path.exists(dropout_path):
|
||||
model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu'))
|
||||
|
||||
logger.info(f"Loaded BGE-Reranker model from {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load reranker model from {model_path}: {e}")
|
||||
raise
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
scores: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
loss_type: str = "binary_cross_entropy"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss for training
|
||||
|
||||
Args:
|
||||
scores: Model scores
|
||||
labels: Ground truth labels
|
||||
loss_type: Type of loss function
|
||||
|
||||
Returns:
|
||||
Loss value
|
||||
"""
|
||||
if loss_type == "binary_cross_entropy":
|
||||
return F.binary_cross_entropy(scores, labels.float())
|
||||
elif loss_type == "cross_entropy":
|
||||
return F.cross_entropy(scores, labels.long())
|
||||
elif loss_type == "mse":
|
||||
return F.mse_loss(scores, labels.float())
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
507
models/losses.py
Normal file
507
models/losses.py
Normal file
@ -0,0 +1,507 @@
|
||||
"""
|
||||
Loss functions for BGE fine-tuning
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
class InfoNCELoss(nn.Module):
|
||||
"""
|
||||
InfoNCE loss for contrastive learning
|
||||
As used in BGE-M3 paper with temperature τ=0.02
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, reduction, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.02, reduction: str = 'mean'):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.reduction = reduction
|
||||
self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_embeddings: torch.Tensor,
|
||||
positive_embeddings: torch.Tensor,
|
||||
negative_embeddings: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute InfoNCE loss with proper normalization and numerical stability.
|
||||
|
||||
Args:
|
||||
query_embeddings: [batch_size, embedding_dim]
|
||||
positive_embeddings: [batch_size, embedding_dim]
|
||||
negative_embeddings: [batch_size, num_negatives, embedding_dim] or None
|
||||
"""
|
||||
batch_size = query_embeddings.size(0)
|
||||
device = query_embeddings.device
|
||||
|
||||
# Normalize embeddings (L2 normalization)
|
||||
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
||||
positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1)
|
||||
|
||||
# Compute positive scores (query-positive similarity)
|
||||
# Shape: [batch_size]
|
||||
pos_scores = torch.sum(query_embeddings * positive_embeddings, dim=1)
|
||||
pos_logits = pos_scores / self.temperature
|
||||
|
||||
# Compute negative scores if provided
|
||||
if negative_embeddings is not None and negative_embeddings.numel() > 0:
|
||||
# Normalize negative embeddings
|
||||
negative_embeddings = F.normalize(negative_embeddings, p=2, dim=-1)
|
||||
# Compute negative scores
|
||||
# Shape: [batch_size, num_negatives]
|
||||
neg_scores = torch.matmul(
|
||||
query_embeddings.unsqueeze(1),
|
||||
negative_embeddings.transpose(1, 2)
|
||||
).squeeze(1)
|
||||
neg_logits = neg_scores / self.temperature
|
||||
|
||||
# Concatenate positive and negative logits
|
||||
# Shape: [batch_size, 1 + num_negatives]
|
||||
logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
|
||||
|
||||
# Labels: positive is always at index 0
|
||||
labels = torch.zeros(batch_size, device=device, dtype=torch.long)
|
||||
else:
|
||||
# Use in-batch negatives as fallback
|
||||
# Compute similarity matrix for in-batch negatives
|
||||
similarity_matrix = torch.matmul(query_embeddings, positive_embeddings.t())
|
||||
logits = similarity_matrix / self.temperature
|
||||
|
||||
# Labels: positive pairs are on the diagonal
|
||||
labels = torch.arange(batch_size, device=device, dtype=torch.long)
|
||||
|
||||
# Compute cross-entropy loss
|
||||
loss = self.cross_entropy(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class M3KnowledgeDistillationLoss(nn.Module):
|
||||
"""
|
||||
Knowledge distillation loss for BGE-M3's multi-functionality training
|
||||
Distills knowledge from dense, sparse, and multi-vector representations
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, alpha, distill_types, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: float = 4.0,
|
||||
alpha: float = 0.5,
|
||||
distill_types: List[str] = ['dense', 'sparse', 'colbert']
|
||||
):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.distill_types = distill_types
|
||||
self.kl_div = nn.KLDivLoss(reduction='batchmean')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
student_outputs: Dict[str, torch.Tensor],
|
||||
teacher_outputs: Dict[str, torch.Tensor],
|
||||
labels: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Args:
|
||||
student_outputs: Dictionary with 'dense', 'sparse', 'colbert' outputs
|
||||
teacher_outputs: Dictionary with teacher model outputs
|
||||
labels: Optional labels for supervised loss
|
||||
|
||||
Returns:
|
||||
total_loss: Combined distillation loss
|
||||
loss_dict: Dictionary of individual losses
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=next(iter(student_outputs.values())).device if student_outputs else 'cpu')
|
||||
loss_dict = {}
|
||||
|
||||
# Distillation loss for each representation type
|
||||
for repr_type in self.distill_types:
|
||||
if repr_type in student_outputs and repr_type in teacher_outputs:
|
||||
student_logits = student_outputs[repr_type]
|
||||
teacher_logits = teacher_outputs[repr_type]
|
||||
|
||||
# Apply temperature scaling
|
||||
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
|
||||
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
|
||||
|
||||
# KL divergence loss
|
||||
kl_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)
|
||||
total_loss = total_loss + self.alpha * kl_loss
|
||||
loss_dict[f'{repr_type}_distill_loss'] = kl_loss.detach().cpu().item()
|
||||
|
||||
# Add supervised loss if labels provided
|
||||
if labels is not None and 'dense' in student_outputs:
|
||||
ce_loss = F.cross_entropy(student_outputs['dense'], labels)
|
||||
total_loss = total_loss + (1 - self.alpha) * ce_loss
|
||||
loss_dict['supervised_loss'] = ce_loss.detach().cpu().item()
|
||||
|
||||
loss_dict['total_loss'] = total_loss.detach().cpu().item()
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class ListwiseCrossEntropyLoss(nn.Module):
|
||||
"""
|
||||
Listwise cross-entropy loss for reranker training
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 1.0):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Compute listwise cross-entropy loss with proper normalization.
|
||||
|
||||
Args:
|
||||
scores: Relevance scores [total_passages]
|
||||
labels: Binary relevance labels [total_passages]
|
||||
groups: List of group sizes (passages per query)
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=scores.device, dtype=scores.dtype)
|
||||
offset = 0
|
||||
|
||||
for group_size in groups:
|
||||
if group_size == 0:
|
||||
continue
|
||||
|
||||
# Get scores and labels for this group
|
||||
group_scores = scores[offset:offset + group_size]
|
||||
group_labels = labels[offset:offset + group_size]
|
||||
|
||||
# Skip if no positive labels
|
||||
if group_labels.sum() == 0:
|
||||
offset += group_size
|
||||
continue
|
||||
|
||||
# Apply temperature scaling
|
||||
scaled_scores = group_scores / self.temperature
|
||||
|
||||
# Apply log_softmax for numerical stability
|
||||
log_probs = F.log_softmax(scaled_scores, dim=0)
|
||||
|
||||
# Normalize labels to create a probability distribution
|
||||
label_probs = group_labels.float() / group_labels.sum()
|
||||
|
||||
# Compute cross-entropy: -sum(p_true * log(p_pred))
|
||||
loss = -(label_probs * log_probs).sum()
|
||||
|
||||
total_loss = total_loss + loss
|
||||
offset += group_size
|
||||
|
||||
# Average over number of groups
|
||||
num_valid_groups = sum(1 for g in groups if g > 0)
|
||||
if num_valid_groups > 0:
|
||||
return total_loss / num_valid_groups
|
||||
else:
|
||||
return total_loss
|
||||
|
||||
|
||||
class PairwiseMarginLoss(nn.Module):
|
||||
"""
|
||||
Pairwise margin ranking loss for reranker
|
||||
|
||||
Notes:
|
||||
- All parameters (margin, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, margin: float = 1.0):
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||||
|
||||
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
scores: Relevance scores [total_passages]
|
||||
labels: Binary relevance labels [total_passages]
|
||||
groups: List of group sizes
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=scores.device)
|
||||
num_pairs = 0
|
||||
offset = 0
|
||||
|
||||
for group_size in groups:
|
||||
group_scores = scores[offset:offset + group_size]
|
||||
group_labels = labels[offset:offset + group_size]
|
||||
|
||||
# Find positive and negative indices
|
||||
pos_indices = torch.where(group_labels == 1)[0]
|
||||
neg_indices = torch.where(group_labels == 0)[0]
|
||||
|
||||
if len(pos_indices) > 0 and len(neg_indices) > 0:
|
||||
# Create all positive-negative pairs
|
||||
for pos_idx in pos_indices:
|
||||
for neg_idx in neg_indices:
|
||||
pos_score = group_scores[pos_idx].unsqueeze(0)
|
||||
neg_score = group_scores[neg_idx].unsqueeze(0)
|
||||
target = torch.ones(1, device=scores.device)
|
||||
|
||||
loss = self.ranking_loss(pos_score, neg_score, target)
|
||||
total_loss = total_loss + loss
|
||||
num_pairs += 1
|
||||
|
||||
offset += group_size
|
||||
|
||||
return total_loss / max(num_pairs, 1)
|
||||
|
||||
|
||||
class DynamicListwiseDistillationLoss(nn.Module):
|
||||
"""
|
||||
Dynamic listwise distillation loss for RocketQAv2-style joint training
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, dynamic_weight, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 3.0, dynamic_weight: bool = True):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.dynamic_weight = dynamic_weight
|
||||
self.kl_div = nn.KLDivLoss(reduction='none')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
student_scores: torch.Tensor,
|
||||
teacher_scores: torch.Tensor,
|
||||
groups: List[int],
|
||||
confidence_scores: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
student_scores: Student model scores [total_passages]
|
||||
teacher_scores: Teacher model scores [total_passages]
|
||||
groups: List of group sizes
|
||||
confidence_scores: Optional confidence for dynamic weighting
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=student_scores.device)
|
||||
offset = 0
|
||||
|
||||
for i, group_size in enumerate(groups):
|
||||
student_group = student_scores[offset:offset + group_size] / self.temperature
|
||||
teacher_group = teacher_scores[offset:offset + group_size] / self.temperature
|
||||
|
||||
# Convert to distributions
|
||||
student_probs = F.log_softmax(student_group, dim=0)
|
||||
teacher_probs = F.softmax(teacher_group, dim=0)
|
||||
|
||||
# Compute KL divergence
|
||||
kl_loss = self.kl_div(student_probs, teacher_probs).sum()
|
||||
|
||||
# Apply dynamic weighting based on confidence
|
||||
if self.dynamic_weight and confidence_scores is not None:
|
||||
weight = confidence_scores[i]
|
||||
kl_loss = kl_loss * weight
|
||||
|
||||
total_loss = total_loss + kl_loss * (self.temperature ** 2)
|
||||
offset += group_size
|
||||
|
||||
return total_loss / len(groups)
|
||||
|
||||
|
||||
class MultiTaskLoss(nn.Module):
|
||||
"""
|
||||
Multi-task loss for combining multiple objectives
|
||||
|
||||
Notes:
|
||||
- All parameters (loss_weights, adaptive_weights, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loss_weights: Optional[Dict[str, float]] = None,
|
||||
adaptive_weights: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_weights = loss_weights or {
|
||||
'dense': 1.0,
|
||||
'sparse': 0.3,
|
||||
'colbert': 0.5,
|
||||
'distillation': 1.0
|
||||
}
|
||||
self.adaptive_weights = adaptive_weights
|
||||
|
||||
if adaptive_weights:
|
||||
# Learnable loss weights (uncertainty weighting)
|
||||
self.log_vars = nn.Parameter(torch.zeros(len(self.loss_weights)))
|
||||
|
||||
def forward(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Args:
|
||||
losses: Dictionary of individual losses
|
||||
|
||||
Returns:
|
||||
total_loss: Combined loss
|
||||
loss_dict: Dictionary of individual loss values
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device if losses else 'cpu')
|
||||
loss_dict = {}
|
||||
|
||||
if self.adaptive_weights:
|
||||
# Uncertainty-based weighting
|
||||
for i, (loss_name, loss_value) in enumerate(losses.items()):
|
||||
if loss_name in self.loss_weights:
|
||||
precision = torch.exp(-self.log_vars[i])
|
||||
weighted_loss = precision * loss_value + self.log_vars[i]
|
||||
total_loss = total_loss + weighted_loss
|
||||
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
|
||||
loss_dict[f'{loss_name}_weight'] = precision.detach().cpu().item() if isinstance(precision, torch.Tensor) else float(precision)
|
||||
else:
|
||||
# Fixed weighting
|
||||
for loss_name, loss_value in losses.items():
|
||||
if loss_name in self.loss_weights:
|
||||
weight = self.loss_weights[loss_name]
|
||||
weighted_loss = weight * loss_value
|
||||
total_loss = total_loss + weighted_loss
|
||||
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
|
||||
loss_dict[f'{loss_name}_weight'] = weight
|
||||
|
||||
loss_dict['total_loss'] = total_loss.detach().cpu().item() if isinstance(total_loss, torch.Tensor) else float(total_loss)
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class SparseLoss(nn.Module):
|
||||
"""
|
||||
SPLADE+ sparse loss with log1p(ReLU(.)) and FLOPS regularization.
|
||||
For BGE-M3's token-position based sparse representations.
|
||||
Args:
|
||||
flops_loss_weight: Weight for FLOPS regularization
|
||||
sparse_reg_weight: Weight for sparsity regularization
|
||||
temperature: Temperature for similarity computation
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
flops_loss_weight: float = 1e-3,
|
||||
sparse_reg_weight: float = 1e-4,
|
||||
temperature: float = 0.02
|
||||
):
|
||||
super().__init__()
|
||||
self.flops_loss_weight = flops_loss_weight
|
||||
self.sparse_reg_weight = sparse_reg_weight
|
||||
self.temperature = temperature
|
||||
self.cross_entropy = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_sparse: torch.Tensor,
|
||||
doc_sparse: torch.Tensor,
|
||||
labels: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Args:
|
||||
query_sparse: [batch, query_seq_len] - Query sparse embeddings (token-position based)
|
||||
doc_sparse: [batch, doc_seq_len] - Document sparse embeddings (token-position based)
|
||||
labels: [batch] - Labels for contrastive learning
|
||||
Returns:
|
||||
Tuple of (loss, loss_dict)
|
||||
"""
|
||||
batch_size = query_sparse.size(0)
|
||||
device = query_sparse.device
|
||||
|
||||
# SPLADE+ uses log1p(ReLU(.))
|
||||
query_splade = torch.log1p(F.relu(query_sparse))
|
||||
doc_splade = torch.log1p(F.relu(doc_sparse))
|
||||
|
||||
# For BGE-M3's token-position sparse embeddings, compute similarity using aggregated scores
|
||||
# Since we have different sequence lengths, we need to aggregate to comparable representations
|
||||
|
||||
# Sum pooling to get document-level sparse scores
|
||||
query_scores = query_splade.sum(dim=1, keepdim=True) # [batch, 1]
|
||||
doc_scores = doc_splade.sum(dim=1, keepdim=True) # [batch, 1]
|
||||
|
||||
# Compute similarity matrix for contrastive learning
|
||||
# Use broadcasting to create [batch, batch] similarity matrix
|
||||
similarity_matrix = query_scores * doc_scores.t() # [batch, batch]
|
||||
|
||||
# Alternative: use cosine similarity approach
|
||||
# Normalize the aggregated scores
|
||||
query_norm = F.normalize(query_scores, p=2, dim=1)
|
||||
doc_norm = F.normalize(doc_scores, p=2, dim=1)
|
||||
similarity_matrix = torch.matmul(query_norm, doc_norm.t())
|
||||
|
||||
# Scale by temperature for contrastive loss
|
||||
logits = similarity_matrix / self.temperature
|
||||
|
||||
# Labels for in-batch contrastive learning (diagonal elements are positive pairs)
|
||||
contrastive_labels = torch.arange(batch_size, device=device, dtype=torch.long)
|
||||
|
||||
# Contrastive loss
|
||||
main_loss = self.cross_entropy(logits, contrastive_labels)
|
||||
|
||||
# FLOPS regularization (encourage sparsity)
|
||||
flops = (query_splade.sum(dim=-1) + doc_splade.sum(dim=-1)).mean()
|
||||
|
||||
# L1 sparsity regularization
|
||||
sparsity_loss = (query_splade.sum() + doc_splade.sum()) / (query_splade.numel() + doc_splade.numel())
|
||||
|
||||
# Combine losses
|
||||
total_loss = main_loss + self.flops_loss_weight * flops + self.sparse_reg_weight * sparsity_loss
|
||||
|
||||
# Ensure all losses are tensors
|
||||
if not isinstance(total_loss, torch.Tensor):
|
||||
total_loss = torch.tensor(total_loss, device=device)
|
||||
if not isinstance(main_loss, torch.Tensor):
|
||||
main_loss = torch.tensor(main_loss, device=device)
|
||||
if not isinstance(flops, torch.Tensor):
|
||||
flops = torch.tensor(flops, device=device)
|
||||
if not isinstance(sparsity_loss, torch.Tensor):
|
||||
sparsity_loss = torch.tensor(sparsity_loss, device=device)
|
||||
|
||||
loss_dict = {
|
||||
'main_loss': main_loss.detach().cpu().item(),
|
||||
'flops': flops.detach().cpu().item(),
|
||||
'sparsity_loss': sparsity_loss.detach().cpu().item(),
|
||||
'total_loss': total_loss.detach().cpu().item()
|
||||
}
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class ColBERTLoss(nn.Module):
|
||||
"""
|
||||
ColBERT late interaction (maxsim) loss.
|
||||
Args:
|
||||
temperature: Softmax temperature
|
||||
"""
|
||||
def __init__(self, temperature: float = 0.02):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
def forward(
|
||||
self,
|
||||
query_embeds: torch.Tensor,
|
||||
doc_embeds: torch.Tensor,
|
||||
labels: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
query_embeds: [batch, seq_len, dim]
|
||||
doc_embeds: [batch, seq_len, dim]
|
||||
labels: [batch]
|
||||
Returns:
|
||||
Loss value
|
||||
"""
|
||||
# Late interaction: maxsim over doc tokens for each query token
|
||||
# [batch, query_len, doc_len]
|
||||
sim_matrix = torch.einsum('bqd,bkd->bqk', query_embeds, doc_embeds) / self.temperature
|
||||
maxsim = sim_matrix.max(dim=2)[0] # [batch, query_len]
|
||||
# Aggregate over query tokens (mean)
|
||||
scores = maxsim.mean(dim=1) # [batch]
|
||||
# Binary cross-entropy loss
|
||||
loss = F.binary_cross_entropy_with_logits(scores, labels.float())
|
||||
return loss
|
||||
154
readme_new.md
Normal file
154
readme_new.md
Normal file
@ -0,0 +1,154 @@
|
||||
# BGE Fine-tuning for Chinese Book Retrieval
|
||||
|
||||
Robust, production-ready tooling to fine-tune the BAAI **BGE-M3** embedding model and **BGE-Reranker** cross-encoder for high-performance retrieval of Chinese book content (with English query support).
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
1. Overview
|
||||
2. Requirements & Compatibility
|
||||
3. Core Features
|
||||
4. Installation
|
||||
5. Quick Start
|
||||
6. Configuration System
|
||||
7. Training & Evaluation
|
||||
8. Distributed & Ascend NPU Support
|
||||
9. Troubleshooting & FAQ
|
||||
10. References
|
||||
11. Roadmap
|
||||
|
||||
---
|
||||
|
||||
## 1. Overview
|
||||
* Unified repository for dense, sparse (SPLADE-style) and multi-vector (ColBERT-style) retrieval.
|
||||
* Joint training (RocketQAv2) and ANCE-style hard-negative mining included.
|
||||
* Built-in evaluation, benchmarking and statistical significance testing.
|
||||
|
||||
---
|
||||
|
||||
## 2. Requirements & Compatibility
|
||||
* **Python ≥ 3.10**
|
||||
* **PyTorch ≥ 2.2.0** (nightly recommended for CUDA 12.8+)
|
||||
* **CUDA ≥ 12.8** *(RTX 5090 / A100 / H100)* or **Torch-NPU** *(Ascend 910 / 910B)*
|
||||
|
||||
See `requirements.txt` for the full dependency list.
|
||||
|
||||
---
|
||||
|
||||
## 3. Core Features
|
||||
* **InfoNCE (τ = 0.02) & Listwise CE** losses with self-distillation.
|
||||
* **ANCE-style Hard-Negative Mining** with FAISS index refresh.
|
||||
* **Atomic Checkpointing** & robust failure recovery.
|
||||
* **Automatic Mixed Precision** (fp16 / bf16) & gradient checkpointing.
|
||||
* **Comprehensive Metrics**: Recall@k, MRR@k, NDCG@k, MAP, Success-Rate.
|
||||
* **Cross-hardware Support**: CUDA (NCCL), Ascend NPU (HCCL), CPU (Gloo).
|
||||
|
||||
---
|
||||
|
||||
## 4. Installation
|
||||
```bash
|
||||
# 1. Install PyTorch nightly (CUDA 12.8)
|
||||
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
|
||||
# 2. Install project dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 3. (Optional) Ascend NPU backend
|
||||
pip install torch_npu # then edit config.toml [ascend]
|
||||
```
|
||||
Models auto-download on first use, or fetch manually:
|
||||
```bash
|
||||
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
|
||||
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Quick Start
|
||||
### Fine-tune BGE-M3 (Embedding)
|
||||
```bash
|
||||
python scripts/train_m3.py \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir output/bge-m3 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--learning_rate 1e-5
|
||||
```
|
||||
|
||||
### Fine-tune BGE-Reranker (Cross-encoder)
|
||||
```bash
|
||||
python scripts/train_reranker.py \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir output/bge-reranker \
|
||||
--learning_rate 6e-5 \
|
||||
--train_group_size 16
|
||||
```
|
||||
|
||||
### Joint Training (RocketQAv2-style)
|
||||
```bash
|
||||
python scripts/train_joint.py \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir output/joint-training \
|
||||
--bf16 --gradient_checkpointing
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Configuration System
|
||||
Configuration is **TOML-driven** with clear precedence:
|
||||
1. Command-line → 2. Model (`[m3]` / `[reranker]`) → 3. Hardware (`[hardware]` / `[ascend]`) → 4. Training defaults (`[training]`) → 5. Dataclass field defaults.
|
||||
|
||||
Load & validate:
|
||||
```python
|
||||
from config.base_config import BaseConfig
|
||||
cfg = BaseConfig.from_toml("config.toml")
|
||||
```
|
||||
|
||||
Key sections: `[training]`, `[hardware]`, `[distributed]`, `[m3]`, `[reranker]`, `[augmentation]`, `[checkpoint]`, `[logging]`.
|
||||
|
||||
---
|
||||
|
||||
## 7. Training & Evaluation
|
||||
* **Scripts**: `train_m3.py`, `train_reranker.py`, `train_joint.py`
|
||||
* **Evaluation**: `scripts/evaluate.py` supports retriever, reranker or full pipeline.
|
||||
* **Benchmarking**: `scripts/benchmark.py` for latency / throughput; `compare_*` for baseline comparisons.
|
||||
|
||||
Example evaluation:
|
||||
```bash
|
||||
python scripts/evaluate.py \
|
||||
--retriever_model output/bge-m3 \
|
||||
--eval_data data/test.jsonl \
|
||||
--k_values 1 5 10
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Distributed & Ascend NPU Support
|
||||
### Multi-GPU (NVIDIA)
|
||||
```bash
|
||||
torchrun --nproc_per_node=4 scripts/train_m3.py --train_data data/train.jsonl --output_dir output/distributed
|
||||
```
|
||||
Backend auto-selects **NCCL**. Use `[distributed]` in `config.toml` for fine-tuning.
|
||||
|
||||
### Huawei Ascend NPU
|
||||
```bash
|
||||
pip install torch_npu
|
||||
python scripts/train_m3.py --device npu --train_data data/train.jsonl --output_dir output/ascend
|
||||
```
|
||||
Backend auto-selects **HCCL**.
|
||||
|
||||
---
|
||||
|
||||
## 9. Troubleshooting & FAQ
|
||||
* **CUDA mismatch**: install the matching PyTorch + CUDA build (`nvidia-smi` to verify).
|
||||
* **Ascend setup**: ensure `torch_npu` and `[ascend]` config.
|
||||
* **NCCL errors**: set `NCCL_DEBUG=INFO` and disable IB if needed (`NCCL_IB_DISABLE=1`).
|
||||
* **Memory issues**: enable `--gradient_checkpointing` & mixed precision.
|
||||
|
||||
---
|
||||
|
||||
## 10. References
|
||||
1. BGE-M3 Embeddings – [arXiv:2402.03216](https://arxiv.org/abs/2402.03216)
|
||||
2. RocketQAv2 – ACL 2021
|
||||
3. ANCE – [arXiv:2007.00808](https://arxiv.org/abs/2007.00808)
|
||||
4. ColBERT – SIGIR 2020
|
||||
5. SPLADE – [arXiv:2109.10086](https://arxiv.org/abs/2109.10086)
|
||||
49
requirements.txt
Normal file
49
requirements.txt
Normal file
@ -0,0 +1,49 @@
|
||||
# Core dependencies
|
||||
# IMPORTANT: Install PyTorch with CUDA 12.8 support first using:
|
||||
# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
# Then install remaining dependencies with: pip install -r requirements.txt
|
||||
|
||||
# PyTorch - version will be satisfied by pre-installed nightly build
|
||||
# Minimum version for CUDA 12.8 support (will use nightly/2.3.0+ in practice)
|
||||
torch>=2.2.0
|
||||
torchvision>=0.17.0
|
||||
torchaudio>=2.2.0
|
||||
# HuggingFace Transformers
|
||||
transformers>=4.38.0
|
||||
# HuggingFace Datasets
|
||||
datasets>=2.14.0
|
||||
# HuggingFace Tokenizers
|
||||
tokenizers>=0.15.0
|
||||
numpy>=1.23.0
|
||||
scipy>=1.10.0
|
||||
scikit-learn>=1.3.0
|
||||
pandas>=1.5.0
|
||||
FlagEmbedding>=1.2.0
|
||||
# For CPU use faiss-cpu
|
||||
faiss-gpu>=1.7.4
|
||||
psutil>=5.9.0
|
||||
omegaconf>=2.3.0
|
||||
toml>=0.10.2
|
||||
tqdm>=4.64.0
|
||||
pyyaml>=6.0
|
||||
jsonlines>=3.1.0
|
||||
huggingface-hub>=0.19.0
|
||||
jieba>=0.42.1
|
||||
requests>=2.28.0
|
||||
# Optional/dev dependencies
|
||||
# For Huawei Ascend NPU support (install only if needed)
|
||||
# torch-npu>=1.11.0.post1; platform_system=="Linux" and platform_machine=="aarch64"
|
||||
# ascend-compiler; platform_system=="Linux" and platform_machine=="aarch64"
|
||||
# For experiment tracking (optional)
|
||||
wandb # optional, for Weights & Biases logging
|
||||
# For visualization (optional)
|
||||
tensorboard # optional, for TensorBoard logging
|
||||
# For testing and development
|
||||
pytest # optional, for running tests
|
||||
# For ModelScope support (optional)
|
||||
modelscope # optional, for ModelScope model hub
|
||||
sentence-transformers>=2.2.2
|
||||
accelerate>=0.25.0
|
||||
rank-bm25>=0.2.2
|
||||
# For logging and distributed
|
||||
loguru>=0.7.0
|
||||
144
scripts/benchmark.py
Normal file
144
scripts/benchmark.py
Normal file
@ -0,0 +1,144 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("benchmark")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for benchmarking."""
|
||||
parser = argparse.ArgumentParser(description="Benchmark BGE fine-tuned models.")
|
||||
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
|
||||
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current memory usage of the process."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def benchmark_retriever(model_path, data_path, batch_size, max_samples, device):
|
||||
"""Benchmark the BGE retriever model."""
|
||||
logger.info(f"Loading retriever model from {model_path}")
|
||||
model = BGEM3Model(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
|
||||
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
# Only dense embedding for speed
|
||||
_ = model(
|
||||
batch['query_input_ids'].to(device),
|
||||
batch['query_attention_mask'].to(device),
|
||||
return_dense=True, return_sparse=False, return_colbert=False
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n_samples += batch['query_input_ids'].shape[0]
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
logger.info(f"Retriever throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
|
||||
return throughput, latency, mem_mb
|
||||
|
||||
def benchmark_reranker(model_path, data_path, batch_size, max_samples, device):
|
||||
"""Benchmark the BGE reranker model."""
|
||||
logger.info(f"Loading reranker model from {model_path}")
|
||||
model = BGERerankerModel(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
|
||||
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
_ = model.model(
|
||||
input_ids=batch['input_ids'].to(device),
|
||||
attention_mask=batch['attention_mask'].to(device)
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n_samples += batch['input_ids'].shape[0]
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
logger.info(f"Reranker throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
|
||||
return throughput, latency, mem_mb
|
||||
|
||||
def main():
|
||||
"""Run benchmark with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Benchmark BGE models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
|
||||
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
if args.model_type == 'retriever':
|
||||
throughput, latency, mem_mb = benchmark_retriever(
|
||||
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
else:
|
||||
throughput, latency, mem_mb = benchmark_reranker(
|
||||
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
# Save results
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Model: {args.model_path}\n")
|
||||
f.write(f"Data: {args.data_path}\n")
|
||||
f.write(f"Batch size: {args.batch_size}\n")
|
||||
f.write(f"Throughput: {throughput:.2f} samples/sec\n")
|
||||
f.write(f"Latency: {latency*1000:.2f} ms/sample\n")
|
||||
f.write(f"Peak memory: {mem_mb:.2f} MB\n")
|
||||
logger.info(f"Benchmark results saved to {args.output}")
|
||||
logging.info(f"Benchmark completed on device: {device}")
|
||||
except Exception as e:
|
||||
logging.error(f"Benchmark failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
138
scripts/compare_reranker.py
Normal file
138
scripts/compare_reranker.py
Normal file
@ -0,0 +1,138 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGERerankerDataset
|
||||
from transformers import AutoTokenizer
|
||||
from evaluation.metrics import compute_reranker_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("compare_reranker")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command-line arguments for reranker comparison."""
|
||||
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline reranker models.")
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current memory usage of the process."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def run_reranker(model_path, data_path, batch_size, max_samples, device):
|
||||
"""
|
||||
Run inference on a reranker model and measure throughput, latency, and memory.
|
||||
Args:
|
||||
model_path (str): Path to the reranker model.
|
||||
data_path (str): Path to the evaluation data.
|
||||
batch_size (int): Batch size for inference.
|
||||
max_samples (int): Maximum number of samples to process.
|
||||
device (str): Device to use (cuda, cpu, npu).
|
||||
Returns:
|
||||
tuple: (throughput, latency, mem_mb, all_scores, all_labels)
|
||||
"""
|
||||
model = BGERerankerModel(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
outputs = model.model(
|
||||
input_ids=batch['input_ids'].to(device),
|
||||
attention_mask=batch['attention_mask'].to(device)
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n = batch['input_ids'].shape[0]
|
||||
n_samples += n
|
||||
# For metrics: collect scores and labels
|
||||
scores = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
|
||||
all_scores.append(scores.cpu().numpy())
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
# Prepare for metrics
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_labels = np.concatenate(all_labels, axis=0)
|
||||
return throughput, latency, mem_mb, all_scores, all_labels
|
||||
|
||||
def main():
|
||||
"""Run reranker comparison with robust error handling and config-driven defaults."""
|
||||
parser = argparse.ArgumentParser(description="Compare reranker models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
# Fine-tuned model
|
||||
logger.info("Running fine-tuned reranker...")
|
||||
ft_throughput, ft_latency, ft_mem, ft_scores, ft_labels = run_reranker(
|
||||
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
ft_metrics = compute_reranker_metrics(ft_scores, ft_labels, k_values=args.k_values)
|
||||
# Baseline model
|
||||
logger.info("Running baseline reranker...")
|
||||
bl_throughput, bl_latency, bl_mem, bl_scores, bl_labels = run_reranker(
|
||||
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
bl_metrics = compute_reranker_metrics(bl_scores, bl_labels, k_values=args.k_values)
|
||||
# Output comparison
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
|
||||
for k in args.k_values:
|
||||
for metric in [f'accuracy@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
for metric in ['map', 'mrr']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
|
||||
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
|
||||
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
|
||||
logger.info(f"Comparison results saved to {args.output}")
|
||||
print(f"\nComparison complete. See {args.output} for details.")
|
||||
except Exception as e:
|
||||
logging.error(f"Reranker comparison failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
144
scripts/compare_retriever.py
Normal file
144
scripts/compare_retriever.py
Normal file
@ -0,0 +1,144 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset
|
||||
from transformers import AutoTokenizer
|
||||
from evaluation.metrics import compute_retrieval_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("compare_retriever")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for retriever comparison."""
|
||||
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline retriever models.")
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current process memory usage in MB."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def run_retriever(model_path, data_path, batch_size, max_samples, device):
|
||||
"""
|
||||
Run inference on a retriever model and measure throughput, latency, and memory.
|
||||
Args:
|
||||
model_path (str): Path to the retriever model.
|
||||
data_path (str): Path to the evaluation data.
|
||||
batch_size (int): Batch size for inference.
|
||||
max_samples (int): Maximum number of samples to process.
|
||||
device (str): Device to use (cuda, cpu, npu).
|
||||
Returns:
|
||||
tuple: throughput (samples/sec), latency (ms/sample), peak memory (MB),
|
||||
query_embeds (np.ndarray), passage_embeds (np.ndarray), labels (np.ndarray).
|
||||
"""
|
||||
model = BGEM3Model(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
query_embeds = []
|
||||
passage_embeds = []
|
||||
labels = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
out = model(
|
||||
batch['query_input_ids'].to(device),
|
||||
batch['query_attention_mask'].to(device),
|
||||
return_dense=True, return_sparse=False, return_colbert=False
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n = batch['query_input_ids'].shape[0]
|
||||
n_samples += n
|
||||
# For metrics: collect embeddings and labels
|
||||
query_embeds.append(out['query_embeds'].cpu().numpy())
|
||||
passage_embeds.append(out['passage_embeds'].cpu().numpy())
|
||||
labels.append(batch['labels'].cpu().numpy())
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
# Prepare for metrics
|
||||
query_embeds = np.concatenate(query_embeds, axis=0)
|
||||
passage_embeds = np.concatenate(passage_embeds, axis=0)
|
||||
labels = np.concatenate(labels, axis=0)
|
||||
return throughput, latency, mem_mb, query_embeds, passage_embeds, labels
|
||||
|
||||
def main():
|
||||
"""Run retriever comparison with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Compare retriever models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
# Fine-tuned model
|
||||
logger.info("Running fine-tuned retriever...")
|
||||
ft_throughput, ft_latency, ft_mem, ft_qe, ft_pe, ft_labels = run_retriever(
|
||||
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
ft_metrics = compute_retrieval_metrics(ft_qe, ft_pe, ft_labels, k_values=args.k_values)
|
||||
# Baseline model
|
||||
logger.info("Running baseline retriever...")
|
||||
bl_throughput, bl_latency, bl_mem, bl_qe, bl_pe, bl_labels = run_retriever(
|
||||
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
bl_metrics = compute_retrieval_metrics(bl_qe, bl_pe, bl_labels, k_values=args.k_values)
|
||||
# Output comparison
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
|
||||
for k in args.k_values:
|
||||
for metric in [f'recall@{k}', f'precision@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
for metric in ['map', 'mrr']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
|
||||
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
|
||||
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
|
||||
logger.info(f"Comparison results saved to {args.output}")
|
||||
print(f"\nComparison complete. See {args.output} for details.")
|
||||
except Exception as e:
|
||||
logging.error(f"Retriever comparison failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
990
scripts/comprehensive_validation.py
Normal file
990
scripts/comprehensive_validation.py
Normal file
@ -0,0 +1,990 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Validation Script for BGE Fine-tuned Models
|
||||
|
||||
This script provides a complete validation framework to determine if fine-tuned models
|
||||
actually perform better than baseline models across multiple datasets and metrics.
|
||||
|
||||
Features:
|
||||
- Tests both M3 retriever and reranker models
|
||||
- Compares against baseline models on multiple datasets
|
||||
- Provides statistical significance testing
|
||||
- Tests pipeline performance (retrieval + reranking)
|
||||
- Generates comprehensive reports
|
||||
- Includes performance degradation detection
|
||||
- Supports cross-dataset validation
|
||||
|
||||
Usage:
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_finetuned ./output/bge-reranker/final_model \\
|
||||
--output_dir ./validation_results
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from scipy import stats
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from evaluation.metrics import (
|
||||
compute_retrieval_metrics, compute_reranker_metrics,
|
||||
compute_retrieval_metrics_comprehensive
|
||||
)
|
||||
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from utils.logging import setup_logging, get_logger
|
||||
from utils.config_loader import get_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationConfig:
|
||||
"""Configuration for comprehensive validation"""
|
||||
# Model paths
|
||||
retriever_baseline: str = "BAAI/bge-m3"
|
||||
reranker_baseline: str = "BAAI/bge-reranker-base"
|
||||
retriever_finetuned: Optional[str] = None
|
||||
reranker_finetuned: Optional[str] = None
|
||||
|
||||
# Test datasets
|
||||
test_datasets: List[str] = None
|
||||
|
||||
# Evaluation settings
|
||||
batch_size: int = 32
|
||||
max_length: int = 512
|
||||
k_values: List[int] = None
|
||||
significance_level: float = 0.05
|
||||
min_samples_for_significance: int = 30
|
||||
|
||||
# Pipeline settings
|
||||
retrieval_top_k: int = 100
|
||||
rerank_top_k: int = 10
|
||||
|
||||
# Output settings
|
||||
output_dir: str = "./validation_results"
|
||||
save_embeddings: bool = False
|
||||
create_report: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.test_datasets is None:
|
||||
self.test_datasets = [
|
||||
"data/datasets/examples/embedding_data.jsonl",
|
||||
"data/datasets/examples/reranker_data.jsonl",
|
||||
"data/datasets/Reranker_AFQMC/dev.json",
|
||||
"data/datasets/三国演义/reranker_train.jsonl"
|
||||
]
|
||||
if self.k_values is None:
|
||||
self.k_values = [1, 3, 5, 10, 20]
|
||||
|
||||
|
||||
class ComprehensiveValidator:
|
||||
"""Comprehensive validation framework for BGE models"""
|
||||
|
||||
def __init__(self, config: ValidationConfig):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
self.results = {}
|
||||
self.device = self._get_device()
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(config.output_dir, exist_ok=True)
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=config.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
"""Get optimal device for evaluation"""
|
||||
try:
|
||||
config = get_config()
|
||||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||||
|
||||
if device_type == 'npu' and torch.cuda.is_available():
|
||||
try:
|
||||
import torch_npu
|
||||
device = torch.device("npu:0")
|
||||
self.logger.info("Using NPU device")
|
||||
return device
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
self.logger.info("Using CUDA device")
|
||||
return device
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error detecting optimal device: {e}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
self.logger.info("Using CPU device")
|
||||
return device
|
||||
|
||||
def validate_all(self) -> Dict[str, Any]:
|
||||
"""Run comprehensive validation on all models and datasets"""
|
||||
self.logger.info("🚀 Starting Comprehensive BGE Model Validation")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# 1. Validate retriever models
|
||||
if self.config.retriever_finetuned:
|
||||
self.logger.info("📊 Validating Retriever Models...")
|
||||
retriever_results = self._validate_retrievers()
|
||||
self.results['retriever'] = retriever_results
|
||||
|
||||
# 2. Validate reranker models
|
||||
if self.config.reranker_finetuned:
|
||||
self.logger.info("📊 Validating Reranker Models...")
|
||||
reranker_results = self._validate_rerankers()
|
||||
self.results['reranker'] = reranker_results
|
||||
|
||||
# 3. Validate pipeline (if both models available)
|
||||
if self.config.retriever_finetuned and self.config.reranker_finetuned:
|
||||
self.logger.info("📊 Validating Pipeline Performance...")
|
||||
pipeline_results = self._validate_pipeline()
|
||||
self.results['pipeline'] = pipeline_results
|
||||
|
||||
# 4. Generate comprehensive report
|
||||
if self.config.create_report:
|
||||
self._generate_report()
|
||||
|
||||
# 5. Statistical analysis
|
||||
self._perform_statistical_analysis()
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
self.logger.info("✅ Comprehensive validation completed!")
|
||||
self.logger.info(f"⏱️ Total validation time: {duration:.2f} seconds")
|
||||
|
||||
return self.results
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Validation failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _validate_retrievers(self) -> Dict[str, Any]:
|
||||
"""Validate retriever models on all suitable datasets"""
|
||||
results = {
|
||||
'baseline_metrics': {},
|
||||
'finetuned_metrics': {},
|
||||
'comparisons': {},
|
||||
'datasets_tested': []
|
||||
}
|
||||
|
||||
for dataset_path in self.config.test_datasets:
|
||||
if not os.path.exists(dataset_path):
|
||||
self.logger.warning(f"Dataset not found: {dataset_path}")
|
||||
continue
|
||||
|
||||
# Skip reranker-only datasets for retriever testing
|
||||
if 'reranker' in dataset_path.lower() and 'AFQMC' in dataset_path:
|
||||
continue
|
||||
|
||||
dataset_name = Path(dataset_path).stem
|
||||
self.logger.info(f"Testing retriever on dataset: {dataset_name}")
|
||||
|
||||
try:
|
||||
# Test baseline model
|
||||
baseline_metrics = self._test_retriever_on_dataset(
|
||||
model_path=self.config.retriever_baseline,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=True
|
||||
)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_metrics = self._test_retriever_on_dataset(
|
||||
model_path=self.config.retriever_finetuned,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=False
|
||||
)
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "retriever")
|
||||
|
||||
results['baseline_metrics'][dataset_name] = baseline_metrics
|
||||
results['finetuned_metrics'][dataset_name] = finetuned_metrics
|
||||
results['comparisons'][dataset_name] = comparison
|
||||
results['datasets_tested'].append(dataset_name)
|
||||
|
||||
self.logger.info(f"✅ Retriever validation complete for {dataset_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Failed to test retriever on {dataset_name}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _validate_rerankers(self) -> Dict[str, Any]:
|
||||
"""Validate reranker models on all suitable datasets"""
|
||||
results = {
|
||||
'baseline_metrics': {},
|
||||
'finetuned_metrics': {},
|
||||
'comparisons': {},
|
||||
'datasets_tested': []
|
||||
}
|
||||
|
||||
for dataset_path in self.config.test_datasets:
|
||||
if not os.path.exists(dataset_path):
|
||||
self.logger.warning(f"Dataset not found: {dataset_path}")
|
||||
continue
|
||||
|
||||
dataset_name = Path(dataset_path).stem
|
||||
self.logger.info(f"Testing reranker on dataset: {dataset_name}")
|
||||
|
||||
try:
|
||||
# Test baseline model
|
||||
baseline_metrics = self._test_reranker_on_dataset(
|
||||
model_path=self.config.reranker_baseline,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=True
|
||||
)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_metrics = self._test_reranker_on_dataset(
|
||||
model_path=self.config.reranker_finetuned,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=False
|
||||
)
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "reranker")
|
||||
|
||||
results['baseline_metrics'][dataset_name] = baseline_metrics
|
||||
results['finetuned_metrics'][dataset_name] = finetuned_metrics
|
||||
results['comparisons'][dataset_name] = comparison
|
||||
results['datasets_tested'].append(dataset_name)
|
||||
|
||||
self.logger.info(f"✅ Reranker validation complete for {dataset_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Failed to test reranker on {dataset_name}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _validate_pipeline(self) -> Dict[str, Any]:
|
||||
"""Validate end-to-end retrieval + reranking pipeline"""
|
||||
results = {
|
||||
'baseline_pipeline': {},
|
||||
'finetuned_pipeline': {},
|
||||
'comparisons': {},
|
||||
'datasets_tested': []
|
||||
}
|
||||
|
||||
for dataset_path in self.config.test_datasets:
|
||||
if not os.path.exists(dataset_path):
|
||||
continue
|
||||
|
||||
# Only test on datasets suitable for pipeline evaluation
|
||||
if 'embedding_data' not in dataset_path:
|
||||
continue
|
||||
|
||||
dataset_name = Path(dataset_path).stem
|
||||
self.logger.info(f"Testing pipeline on dataset: {dataset_name}")
|
||||
|
||||
try:
|
||||
# Test baseline pipeline
|
||||
baseline_metrics = self._test_pipeline_on_dataset(
|
||||
retriever_path=self.config.retriever_baseline,
|
||||
reranker_path=self.config.reranker_baseline,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=True
|
||||
)
|
||||
|
||||
# Test fine-tuned pipeline
|
||||
finetuned_metrics = self._test_pipeline_on_dataset(
|
||||
retriever_path=self.config.retriever_finetuned,
|
||||
reranker_path=self.config.reranker_finetuned,
|
||||
dataset_path=dataset_path,
|
||||
is_baseline=False
|
||||
)
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "pipeline")
|
||||
|
||||
results['baseline_pipeline'][dataset_name] = baseline_metrics
|
||||
results['finetuned_pipeline'][dataset_name] = finetuned_metrics
|
||||
results['comparisons'][dataset_name] = comparison
|
||||
results['datasets_tested'].append(dataset_name)
|
||||
|
||||
self.logger.info(f"✅ Pipeline validation complete for {dataset_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Failed to test pipeline on {dataset_name}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _test_retriever_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
||||
"""Test a retriever model on a specific dataset"""
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=self.config.max_length,
|
||||
passage_max_length=self.config.max_length,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# Create evaluator
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=self.device,
|
||||
k_values=self.config.k_values
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
metrics = evaluator.evaluate(
|
||||
dataloader=dataloader,
|
||||
desc=f"{'Baseline' if is_baseline else 'Fine-tuned'} Retriever"
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error testing retriever: {e}")
|
||||
return {}
|
||||
|
||||
def _test_reranker_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
||||
"""Test a reranker model on a specific dataset"""
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Create dataset
|
||||
dataset = BGERerankerDataset(
|
||||
data_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=self.config.max_length,
|
||||
passage_max_length=self.config.max_length,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# Collect predictions and labels
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
# Move to device
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
attention_mask = batch['attention_mask'].to(self.device)
|
||||
labels = batch['labels'].cpu().numpy()
|
||||
|
||||
# Get predictions
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
scores = outputs.logits.cpu().numpy()
|
||||
|
||||
all_scores.append(scores)
|
||||
all_labels.append(labels)
|
||||
|
||||
# Concatenate results
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
# Compute metrics
|
||||
metrics = compute_reranker_metrics(
|
||||
scores=all_scores,
|
||||
labels=all_labels,
|
||||
k_values=self.config.k_values
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error testing reranker: {e}")
|
||||
return {}
|
||||
|
||||
def _test_pipeline_on_dataset(self, retriever_path: str, reranker_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
||||
"""Test end-to-end pipeline on a dataset"""
|
||||
try:
|
||||
# Load models
|
||||
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
|
||||
retriever = BGEM3Model.from_pretrained(retriever_path)
|
||||
retriever.to(self.device)
|
||||
retriever.eval()
|
||||
|
||||
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
|
||||
reranker = BGERerankerModel.from_pretrained(reranker_path)
|
||||
reranker.to(self.device)
|
||||
reranker.eval()
|
||||
|
||||
# Load and preprocess data
|
||||
data = []
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
item = json.loads(line.strip())
|
||||
if 'query' in item and 'pos' in item and 'neg' in item:
|
||||
data.append(item)
|
||||
except:
|
||||
continue
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Test pipeline performance
|
||||
total_queries = 0
|
||||
pipeline_metrics = {f'recall@{k}': [] for k in self.config.k_values}
|
||||
pipeline_metrics.update({f'mrr@{k}': [] for k in self.config.k_values})
|
||||
|
||||
for item in data[:100]: # Limit to 100 queries for speed
|
||||
query = item['query']
|
||||
pos_docs = item['pos'][:2] # Take first 2 positive docs
|
||||
neg_docs = item['neg'][:8] # Take first 8 negative docs
|
||||
|
||||
# Create candidate pool
|
||||
candidates = pos_docs + neg_docs
|
||||
labels = [1] * len(pos_docs) + [0] * len(neg_docs)
|
||||
|
||||
if len(candidates) < 3: # Need minimum candidates
|
||||
continue
|
||||
|
||||
try:
|
||||
# Step 1: Retrieval (encode query and candidates)
|
||||
query_inputs = retriever_tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.config.max_length
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
query_outputs = retriever(**query_inputs, return_dense=True)
|
||||
query_emb = query_outputs['dense']
|
||||
|
||||
# Encode candidates
|
||||
candidate_embs = []
|
||||
for candidate in candidates:
|
||||
cand_inputs = retriever_tokenizer(
|
||||
candidate,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.config.max_length
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
cand_outputs = retriever(**cand_inputs, return_dense=True)
|
||||
candidate_embs.append(cand_outputs['dense'])
|
||||
|
||||
candidate_embs = torch.cat(candidate_embs, dim=0)
|
||||
|
||||
# Compute similarities
|
||||
similarities = torch.cosine_similarity(query_emb, candidate_embs, dim=1)
|
||||
|
||||
# Get top-k for retrieval
|
||||
top_k_retrieval = min(self.config.retrieval_top_k, len(candidates))
|
||||
retrieval_indices = torch.argsort(similarities, descending=True)[:top_k_retrieval]
|
||||
|
||||
# Step 2: Reranking
|
||||
rerank_candidates = [candidates[i] for i in retrieval_indices]
|
||||
rerank_labels = [labels[i] for i in retrieval_indices]
|
||||
|
||||
# Score pairs with reranker
|
||||
rerank_scores = []
|
||||
for candidate in rerank_candidates:
|
||||
pair_text = f"{query} [SEP] {candidate}"
|
||||
pair_inputs = reranker_tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.config.max_length
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pair_outputs = reranker(**pair_inputs)
|
||||
score = pair_outputs.logits.item()
|
||||
rerank_scores.append(score)
|
||||
|
||||
# Final ranking
|
||||
final_indices = np.argsort(rerank_scores)[::-1]
|
||||
final_labels = [rerank_labels[i] for i in final_indices]
|
||||
|
||||
# Compute metrics
|
||||
for k in self.config.k_values:
|
||||
if k <= len(final_labels):
|
||||
# Recall@k
|
||||
relevant_in_topk = sum(final_labels[:k])
|
||||
total_relevant = sum(rerank_labels)
|
||||
recall = relevant_in_topk / total_relevant if total_relevant > 0 else 0
|
||||
pipeline_metrics[f'recall@{k}'].append(recall)
|
||||
|
||||
# MRR@k
|
||||
for i, label in enumerate(final_labels[:k]):
|
||||
if label == 1:
|
||||
pipeline_metrics[f'mrr@{k}'].append(1.0 / (i + 1))
|
||||
break
|
||||
else:
|
||||
pipeline_metrics[f'mrr@{k}'].append(0.0)
|
||||
|
||||
total_queries += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error processing query: {e}")
|
||||
continue
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {}
|
||||
for metric, values in pipeline_metrics.items():
|
||||
if values:
|
||||
final_metrics[metric] = np.mean(values)
|
||||
else:
|
||||
final_metrics[metric] = 0.0
|
||||
|
||||
final_metrics['total_queries'] = total_queries
|
||||
|
||||
return final_metrics
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error testing pipeline: {e}")
|
||||
return {}
|
||||
|
||||
def _compare_metrics(self, baseline: Dict[str, float], finetuned: Dict[str, float], model_type: str) -> Dict[str, Any]:
|
||||
"""Compare baseline vs fine-tuned metrics with statistical analysis"""
|
||||
comparison = {
|
||||
'improvements': {},
|
||||
'degradations': {},
|
||||
'statistical_significance': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
# Compare common metrics
|
||||
common_metrics = set(baseline.keys()) & set(finetuned.keys())
|
||||
|
||||
improvements = 0
|
||||
degradations = 0
|
||||
significant_improvements = 0
|
||||
|
||||
for metric in common_metrics:
|
||||
if isinstance(baseline[metric], (int, float)) and isinstance(finetuned[metric], (int, float)):
|
||||
baseline_val = baseline[metric]
|
||||
finetuned_val = finetuned[metric]
|
||||
|
||||
# Calculate improvement
|
||||
if baseline_val != 0:
|
||||
improvement_pct = ((finetuned_val - baseline_val) / baseline_val) * 100
|
||||
else:
|
||||
improvement_pct = float('inf') if finetuned_val > 0 else 0
|
||||
|
||||
# Classify improvement/degradation
|
||||
if finetuned_val > baseline_val:
|
||||
comparison['improvements'][metric] = {
|
||||
'baseline': baseline_val,
|
||||
'finetuned': finetuned_val,
|
||||
'improvement_pct': improvement_pct,
|
||||
'absolute_improvement': finetuned_val - baseline_val
|
||||
}
|
||||
improvements += 1
|
||||
elif finetuned_val < baseline_val:
|
||||
comparison['degradations'][metric] = {
|
||||
'baseline': baseline_val,
|
||||
'finetuned': finetuned_val,
|
||||
'degradation_pct': -improvement_pct,
|
||||
'absolute_degradation': baseline_val - finetuned_val
|
||||
}
|
||||
degradations += 1
|
||||
|
||||
# Summary statistics
|
||||
comparison['summary'] = {
|
||||
'total_metrics': len(common_metrics),
|
||||
'improvements': improvements,
|
||||
'degradations': degradations,
|
||||
'unchanged': len(common_metrics) - improvements - degradations,
|
||||
'improvement_ratio': improvements / len(common_metrics) if common_metrics else 0,
|
||||
'net_improvement': improvements - degradations
|
||||
}
|
||||
|
||||
return comparison
|
||||
|
||||
def _perform_statistical_analysis(self):
|
||||
"""Perform statistical analysis on validation results"""
|
||||
self.logger.info("📈 Performing Statistical Analysis...")
|
||||
|
||||
analysis = {
|
||||
'overall_summary': {},
|
||||
'significance_tests': {},
|
||||
'confidence_intervals': {},
|
||||
'effect_sizes': {}
|
||||
}
|
||||
|
||||
# Analyze each model type
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type not in self.results:
|
||||
continue
|
||||
|
||||
model_results = self.results[model_type]
|
||||
|
||||
# Collect all improvements across datasets
|
||||
all_improvements = []
|
||||
significant_improvements = 0
|
||||
|
||||
for dataset, comparison in model_results.get('comparisons', {}).items():
|
||||
for metric, improvement_data in comparison.get('improvements', {}).items():
|
||||
improvement_pct = improvement_data.get('improvement_pct', 0)
|
||||
if improvement_pct > 0:
|
||||
all_improvements.append(improvement_pct)
|
||||
|
||||
# Simple significance test (can be enhanced)
|
||||
if improvement_pct > 5.0: # 5% improvement threshold
|
||||
significant_improvements += 1
|
||||
|
||||
if all_improvements:
|
||||
analysis[model_type] = {
|
||||
'mean_improvement_pct': np.mean(all_improvements),
|
||||
'median_improvement_pct': np.median(all_improvements),
|
||||
'std_improvement_pct': np.std(all_improvements),
|
||||
'min_improvement_pct': np.min(all_improvements),
|
||||
'max_improvement_pct': np.max(all_improvements),
|
||||
'significant_improvements': significant_improvements,
|
||||
'total_improvements': len(all_improvements)
|
||||
}
|
||||
|
||||
self.results['statistical_analysis'] = analysis
|
||||
|
||||
def _generate_report(self):
|
||||
"""Generate comprehensive validation report"""
|
||||
self.logger.info("📝 Generating Comprehensive Report...")
|
||||
|
||||
report_path = os.path.join(self.config.output_dir, "validation_report.html")
|
||||
|
||||
html_content = self._create_html_report()
|
||||
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
# Also save JSON results
|
||||
json_path = os.path.join(self.config.output_dir, "validation_results.json")
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.results, f, indent=2, default=str)
|
||||
|
||||
self.logger.info(f"📊 Report saved to: {report_path}")
|
||||
self.logger.info(f"📊 JSON results saved to: {json_path}")
|
||||
|
||||
def _create_html_report(self) -> str:
|
||||
"""Create HTML report with comprehensive validation results"""
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>BGE Model Validation Report</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.header {{ background: #f0f8ff; padding: 20px; border-radius: 10px; }}
|
||||
.section {{ margin: 20px 0; }}
|
||||
.metrics-table {{ border-collapse: collapse; width: 100%; }}
|
||||
.metrics-table th, .metrics-table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
||||
.metrics-table th {{ background-color: #f2f2f2; }}
|
||||
.improvement {{ color: green; font-weight: bold; }}
|
||||
.degradation {{ color: red; font-weight: bold; }}
|
||||
.summary-box {{ background: #f9f9f9; padding: 15px; border-left: 4px solid #4CAF50; margin: 10px 0; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>🚀 BGE Model Comprehensive Validation Report</h1>
|
||||
<p><strong>Generated:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
<p><strong>Configuration:</strong></p>
|
||||
<ul>
|
||||
<li>Retriever Baseline: {self.config.retriever_baseline}</li>
|
||||
<li>Retriever Fine-tuned: {self.config.retriever_finetuned or 'Not tested'}</li>
|
||||
<li>Reranker Baseline: {self.config.reranker_baseline}</li>
|
||||
<li>Reranker Fine-tuned: {self.config.reranker_finetuned or 'Not tested'}</li>
|
||||
</ul>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add executive summary
|
||||
html += self._create_executive_summary()
|
||||
|
||||
# Add detailed results for each model type
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in self.results:
|
||||
html += self._create_model_section(model_type)
|
||||
|
||||
html += """
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html
|
||||
|
||||
def _create_executive_summary(self) -> str:
|
||||
"""Create executive summary section"""
|
||||
html = """
|
||||
<div class="section">
|
||||
<h2>📊 Executive Summary</h2>
|
||||
"""
|
||||
|
||||
# Overall verdict
|
||||
total_improvements = 0
|
||||
total_degradations = 0
|
||||
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in self.results and 'comparisons' in self.results[model_type]:
|
||||
for dataset, comparison in self.results[model_type]['comparisons'].items():
|
||||
total_improvements += len(comparison.get('improvements', {}))
|
||||
total_degradations += len(comparison.get('degradations', {}))
|
||||
|
||||
if total_improvements > total_degradations:
|
||||
verdict = "✅ <strong>FINE-TUNING SUCCESSFUL</strong> - Models show overall improvement"
|
||||
verdict_color = "green"
|
||||
elif total_improvements == total_degradations:
|
||||
verdict = "⚠️ <strong>MIXED RESULTS</strong> - Equal improvements and degradations"
|
||||
verdict_color = "orange"
|
||||
else:
|
||||
verdict = "❌ <strong>FINE-TUNING ISSUES</strong> - More degradations than improvements"
|
||||
verdict_color = "red"
|
||||
|
||||
html += f"""
|
||||
<div class="summary-box" style="border-left-color: {verdict_color};">
|
||||
<h3>{verdict}</h3>
|
||||
<ul>
|
||||
<li>Total Metric Improvements: {total_improvements}</li>
|
||||
<li>Total Metric Degradations: {total_degradations}</li>
|
||||
<li>Net Improvement: {total_improvements - total_degradations}</li>
|
||||
</ul>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</div>
|
||||
"""
|
||||
return html
|
||||
|
||||
def _create_model_section(self, model_type: str) -> str:
|
||||
"""Create detailed section for a model type"""
|
||||
html = f"""
|
||||
<div class="section">
|
||||
<h2>📈 {model_type.title()} Model Results</h2>
|
||||
"""
|
||||
|
||||
model_results = self.results[model_type]
|
||||
|
||||
for dataset in model_results.get('datasets_tested', []):
|
||||
if dataset in model_results['comparisons']:
|
||||
comparison = model_results['comparisons'][dataset]
|
||||
|
||||
html += f"""
|
||||
<h3>Dataset: {dataset}</h3>
|
||||
<table class="metrics-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Metric</th>
|
||||
<th>Baseline</th>
|
||||
<th>Fine-tuned</th>
|
||||
<th>Change</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
# Add improvements
|
||||
for metric, data in comparison.get('improvements', {}).items():
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{metric}</td>
|
||||
<td>{data['baseline']:.4f}</td>
|
||||
<td>{data['finetuned']:.4f}</td>
|
||||
<td class="improvement">+{data['improvement_pct']:.2f}%</td>
|
||||
<td class="improvement">✅ Improved</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
# Add degradations
|
||||
for metric, data in comparison.get('degradations', {}).items():
|
||||
html += f"""
|
||||
<tr>
|
||||
<td>{metric}</td>
|
||||
<td>{data['baseline']:.4f}</td>
|
||||
<td>{data['finetuned']:.4f}</td>
|
||||
<td class="degradation">-{data['degradation_pct']:.2f}%</td>
|
||||
<td class="degradation">❌ Degraded</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
html += """
|
||||
</div>
|
||||
"""
|
||||
return html
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Comprehensive validation of BGE fine-tuned models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Validate both retriever and reranker
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_finetuned ./output/bge-reranker/final_model
|
||||
|
||||
# Validate only retriever
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model
|
||||
|
||||
# Custom datasets and output
|
||||
python scripts/comprehensive_validation.py \\
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
||||
--test_datasets data/my_test1.jsonl data/my_test2.jsonl \\
|
||||
--output_dir ./my_validation_results
|
||||
"""
|
||||
)
|
||||
|
||||
# Model paths
|
||||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||||
help="Path to baseline retriever model")
|
||||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to baseline reranker model")
|
||||
parser.add_argument("--retriever_finetuned", type=str, default=None,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_finetuned", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
|
||||
# Test data
|
||||
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
|
||||
help="Paths to test datasets (JSONL format)")
|
||||
|
||||
# Evaluation settings
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 3, 5, 10, 20],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Output settings
|
||||
parser.add_argument("--output_dir", type=str, default="./validation_results",
|
||||
help="Directory to save validation results")
|
||||
parser.add_argument("--save_embeddings", action="store_true",
|
||||
help="Save embeddings for further analysis")
|
||||
parser.add_argument("--no_report", action="store_true",
|
||||
help="Skip HTML report generation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for retrieval stage")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking stage")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main validation function"""
|
||||
args = parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if not args.retriever_finetuned and not args.reranker_finetuned:
|
||||
print("❌ Error: At least one fine-tuned model must be specified")
|
||||
print(" Use --retriever_finetuned or --reranker_finetuned")
|
||||
return 1
|
||||
|
||||
# Create configuration
|
||||
config = ValidationConfig(
|
||||
retriever_baseline=args.retriever_baseline,
|
||||
reranker_baseline=args.reranker_baseline,
|
||||
retriever_finetuned=args.retriever_finetuned,
|
||||
reranker_finetuned=args.reranker_finetuned,
|
||||
test_datasets=args.test_datasets,
|
||||
batch_size=args.batch_size,
|
||||
max_length=args.max_length,
|
||||
k_values=args.k_values,
|
||||
output_dir=args.output_dir,
|
||||
save_embeddings=args.save_embeddings,
|
||||
create_report=not args.no_report,
|
||||
retrieval_top_k=args.retrieval_top_k,
|
||||
rerank_top_k=args.rerank_top_k
|
||||
)
|
||||
|
||||
# Run validation
|
||||
try:
|
||||
validator = ComprehensiveValidator(config)
|
||||
results = validator.validate_all()
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("🎉 COMPREHENSIVE VALIDATION COMPLETED!")
|
||||
print("=" * 80)
|
||||
print(f"📊 Results saved to: {config.output_dir}")
|
||||
print("\n📈 Quick Summary:")
|
||||
|
||||
# Print quick summary
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in results:
|
||||
model_results = results[model_type]
|
||||
total_datasets = len(model_results.get('datasets_tested', []))
|
||||
total_improvements = sum(len(comp.get('improvements', {}))
|
||||
for comp in model_results.get('comparisons', {}).values())
|
||||
total_degradations = sum(len(comp.get('degradations', {}))
|
||||
for comp in model_results.get('comparisons', {}).values())
|
||||
|
||||
status = "✅" if total_improvements > total_degradations else "⚠️" if total_improvements == total_degradations else "❌"
|
||||
print(f" {status} {model_type.title()}: {total_improvements} improvements, {total_degradations} degradations across {total_datasets} datasets")
|
||||
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n❌ Validation interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Validation failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
291
scripts/convert_m3_dataset.py
Normal file
291
scripts/convert_m3_dataset.py
Normal file
@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE-M3 Dataset Format Converter
|
||||
|
||||
Converts the nested dataset format where training data is stored inside a "data" field
|
||||
to the flat format expected by the BGE-M3 training pipeline.
|
||||
|
||||
Input format:
|
||||
{
|
||||
"query": "...",
|
||||
"data": "{\"pos\": [...], \"neg\": [...], \"pos_scores\": [...], ...}"
|
||||
}
|
||||
|
||||
Output format:
|
||||
{
|
||||
"query": "...",
|
||||
"pos": [...],
|
||||
"neg": [...],
|
||||
"pos_scores": [...],
|
||||
"neg_scores": [...],
|
||||
"prompt": "...",
|
||||
"type": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Any, Optional
|
||||
import re
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_json_from_data_field(data_field: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract JSON data from the "data" field which contains a JSON string
|
||||
wrapped in markdown code blocks.
|
||||
|
||||
Args:
|
||||
data_field: String containing JSON data, possibly wrapped in markdown
|
||||
|
||||
Returns:
|
||||
Parsed JSON dictionary or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Remove markdown code block markers if present
|
||||
cleaned_data = data_field.strip()
|
||||
|
||||
# Remove ```json and ``` markers
|
||||
if cleaned_data.startswith('```json'):
|
||||
cleaned_data = cleaned_data[7:] # Remove ```json
|
||||
elif cleaned_data.startswith('```'):
|
||||
cleaned_data = cleaned_data[3:] # Remove ```
|
||||
|
||||
if cleaned_data.endswith('```'):
|
||||
cleaned_data = cleaned_data[:-3] # Remove trailing ```
|
||||
|
||||
# Clean up any remaining whitespace
|
||||
cleaned_data = cleaned_data.strip()
|
||||
|
||||
# Parse the JSON
|
||||
return json.loads(cleaned_data)
|
||||
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
logger.warning(f"Failed to parse JSON from data field: {e}")
|
||||
logger.warning(f"Data content (first 200 chars): {data_field[:200]}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_dataset_format(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Convert dataset from nested format to flat format.
|
||||
|
||||
Args:
|
||||
input_file: Path to input JSONL file with nested format
|
||||
output_file: Path to output JSONL file with flat format
|
||||
"""
|
||||
logger.info(f"Converting dataset from {input_file} to {output_file}")
|
||||
|
||||
# Count total lines for progress bar
|
||||
total_lines = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
total_lines += 1
|
||||
|
||||
converted_count = 0
|
||||
error_count = 0
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as infile, \
|
||||
open(output_file, 'w', encoding='utf-8') as outfile:
|
||||
|
||||
for line_num, line in enumerate(tqdm(infile, total=total_lines, desc="Converting"), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse the input JSON line
|
||||
item = json.loads(line)
|
||||
|
||||
# Validate required fields
|
||||
if 'query' not in item:
|
||||
logger.warning(f"Line {line_num}: Missing 'query' field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
if 'data' not in item:
|
||||
logger.warning(f"Line {line_num}: Missing 'data' field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Extract data from the nested field
|
||||
data_content = extract_json_from_data_field(item['data'])
|
||||
if data_content is None:
|
||||
logger.warning(f"Line {line_num}: Failed to parse data field")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Create the flattened format
|
||||
flattened_item = {
|
||||
'query': item['query']
|
||||
}
|
||||
|
||||
# Add all fields from the data content
|
||||
expected_fields = ['pos', 'neg', 'pos_scores', 'neg_scores', 'prompt', 'type']
|
||||
for field in expected_fields:
|
||||
if field in data_content:
|
||||
flattened_item[field] = data_content[field]
|
||||
else:
|
||||
# Provide reasonable defaults
|
||||
if field == 'pos':
|
||||
flattened_item[field] = []
|
||||
elif field == 'neg':
|
||||
flattened_item[field] = []
|
||||
elif field == 'pos_scores':
|
||||
flattened_item[field] = [1.0] * len(data_content.get('pos', []))
|
||||
elif field == 'neg_scores':
|
||||
flattened_item[field] = [0.0] * len(data_content.get('neg', []))
|
||||
elif field == 'prompt':
|
||||
flattened_item[field] = '为此查询生成表示:'
|
||||
elif field == 'type':
|
||||
flattened_item[field] = 'normal'
|
||||
|
||||
# Validate the converted item
|
||||
if not flattened_item['pos'] and not flattened_item['neg']:
|
||||
logger.warning(f"Line {line_num}: No positive or negative passages found")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Ensure scores arrays match passage arrays
|
||||
if len(flattened_item['pos_scores']) != len(flattened_item['pos']):
|
||||
flattened_item['pos_scores'] = [1.0] * len(flattened_item['pos'])
|
||||
|
||||
if len(flattened_item['neg_scores']) != len(flattened_item['neg']):
|
||||
flattened_item['neg_scores'] = [0.0] * len(flattened_item['neg'])
|
||||
|
||||
# Write the flattened item
|
||||
outfile.write(json.dumps(flattened_item, ensure_ascii=False) + '\n')
|
||||
converted_count += 1
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
|
||||
error_count += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Line {line_num}: Unexpected error - {e}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
logger.info(f"Conversion complete!")
|
||||
logger.info(f"Successfully converted: {converted_count} items")
|
||||
logger.info(f"Errors encountered: {error_count} items")
|
||||
logger.info(f"Output saved to: {output_file}")
|
||||
|
||||
|
||||
def validate_converted_file(file_path: str, sample_size: int = 5) -> None:
|
||||
"""
|
||||
Validate the converted file by showing a few sample entries.
|
||||
|
||||
Args:
|
||||
file_path: Path to the converted JSONL file
|
||||
sample_size: Number of sample entries to display
|
||||
"""
|
||||
logger.info(f"Validating converted file: {file_path}")
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
total_lines = len(lines)
|
||||
logger.info(f"Total converted entries: {total_lines}")
|
||||
|
||||
if total_lines == 0:
|
||||
logger.warning("No entries found in converted file!")
|
||||
return
|
||||
|
||||
# Show sample entries
|
||||
sample_indices = list(range(min(sample_size, total_lines)))
|
||||
|
||||
for i, idx in enumerate(sample_indices):
|
||||
try:
|
||||
item = json.loads(lines[idx])
|
||||
logger.info(f"\nSample {i+1} (line {idx+1}):")
|
||||
logger.info(f" Query: {item['query'][:100]}...")
|
||||
logger.info(f" Positives: {len(item.get('pos', []))}")
|
||||
logger.info(f" Negatives: {len(item.get('neg', []))}")
|
||||
logger.info(f" Type: {item.get('type', 'N/A')}")
|
||||
logger.info(f" Prompt: {item.get('prompt', 'N/A')}")
|
||||
|
||||
# Show first positive and negative if available
|
||||
if item.get('pos'):
|
||||
logger.info(f" First positive: {item['pos'][0][:150]}...")
|
||||
if item.get('neg'):
|
||||
logger.info(f" First negative: {item['neg'][0][:150]}...")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON at line {idx+1}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating file: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert BGE-M3 dataset from nested to flat format",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python convert_m3_dataset.py data/datasets/三国演义/m3_train.jsonl data/datasets/三国演义/m3_train_converted.jsonl
|
||||
python convert_m3_dataset.py input.jsonl output.jsonl --validate
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'input_file',
|
||||
help='Path to input JSONL file with nested format'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'output_file',
|
||||
help='Path to output JSONL file with flat format'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--validate',
|
||||
action='store_true',
|
||||
help='Validate the converted file and show samples'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--sample-size',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Number of sample entries to show during validation (default: 5)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check if input file exists
|
||||
if not os.path.exists(args.input_file):
|
||||
logger.error(f"Input file not found: {args.input_file}")
|
||||
return 1
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = os.path.dirname(args.output_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
logger.info(f"Created output directory: {output_dir}")
|
||||
|
||||
try:
|
||||
# Convert the dataset
|
||||
convert_dataset_format(args.input_file, args.output_file)
|
||||
|
||||
# Validate if requested
|
||||
if args.validate:
|
||||
validate_converted_file(args.output_file, args.sample_size)
|
||||
|
||||
logger.info("Dataset conversion completed successfully!")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dataset conversion failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
584
scripts/evaluate.py
Normal file
584
scripts/evaluate.py
Normal file
@ -0,0 +1,584 @@
|
||||
"""
|
||||
Evaluation script for fine-tuned BGE models
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
import numpy as np
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from utils.logging import setup_logging, get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for evaluation."""
|
||||
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--device", type=str, default="auto",
|
||||
help="Device to use (auto, cpu, cuda)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validation
|
||||
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
|
||||
args.eval_pipeline = True # Default to pipeline evaluation
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def load_evaluation_data(data_path: str, data_format: str):
|
||||
"""Load evaluation data from a file."""
|
||||
preprocessor = DataPreprocessor()
|
||||
|
||||
if data_format == "auto":
|
||||
data_format = preprocessor._detect_format(data_path)
|
||||
|
||||
if data_format == "jsonl":
|
||||
data = preprocessor._load_jsonl(data_path)
|
||||
elif data_format == "json":
|
||||
data = preprocessor._load_json(data_path)
|
||||
elif data_format in ["csv", "tsv"]:
|
||||
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
|
||||
elif data_format == "dureader":
|
||||
data = preprocessor._load_dureader(data_path)
|
||||
elif data_format == "msmarco":
|
||||
data = preprocessor._load_msmarco(data_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {data_format}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
|
||||
"""Evaluate the retriever model."""
|
||||
logger.info("Evaluating retriever model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=retriever_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for evaluation
|
||||
queries = [item['query'] for item in eval_data]
|
||||
|
||||
# If we have passage data, create passage corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
# Create relevance mapping
|
||||
relevant_docs = {}
|
||||
for i, item in enumerate(eval_data):
|
||||
query = item['query']
|
||||
if 'pos' in item:
|
||||
relevant_indices = []
|
||||
for pos in item['pos']:
|
||||
if pos in passages:
|
||||
relevant_indices.append(passages.index(pos))
|
||||
if relevant_indices:
|
||||
relevant_docs[query] = relevant_indices
|
||||
|
||||
# Evaluate
|
||||
if passages and relevant_docs:
|
||||
metrics = evaluator.evaluate_retrieval_pipeline(
|
||||
queries=queries,
|
||||
corpus=passages,
|
||||
relevant_docs=relevant_docs,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
else:
|
||||
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
|
||||
# Create dummy evaluation
|
||||
metrics = {"note": "Limited evaluation due to data format"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the reranker model."""
|
||||
logger.info("Evaluating reranker model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for reranker evaluation
|
||||
total_queries = 0
|
||||
total_correct = 0
|
||||
total_mrr = 0
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or 'neg' not in item:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
positives = item['pos']
|
||||
negatives = item['neg']
|
||||
|
||||
if not positives or not negatives:
|
||||
continue
|
||||
|
||||
# Create candidates (1 positive + negatives)
|
||||
candidates = positives[:1] + negatives
|
||||
|
||||
# Score all candidates
|
||||
pairs = [(query, candidate) for candidate in candidates]
|
||||
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
|
||||
# Check if positive is ranked first
|
||||
best_idx = np.argmax(scores)
|
||||
if best_idx == 0: # First is positive
|
||||
total_correct += 1
|
||||
total_mrr += 1.0
|
||||
else:
|
||||
# Find rank of positive (it's always at index 0)
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
|
||||
total_mrr += 1.0 / positive_rank
|
||||
|
||||
total_queries += 1
|
||||
|
||||
if total_queries > 0:
|
||||
metrics = {
|
||||
'accuracy@1': total_correct / total_queries,
|
||||
'mrr': total_mrr / total_queries,
|
||||
'total_queries': total_queries
|
||||
}
|
||||
else:
|
||||
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the full retrieval + reranking pipeline."""
|
||||
logger.info("Evaluating full pipeline...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
return {"note": "No passages found for pipeline evaluation"}
|
||||
|
||||
# Pre-encode corpus with retriever
|
||||
logger.info("Encoding corpus...")
|
||||
corpus_embeddings = retriever_model.encode(
|
||||
passages,
|
||||
batch_size=args.batch_size,
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
total_queries = 0
|
||||
pipeline_metrics = {
|
||||
'retrieval_recall@10': [],
|
||||
'rerank_accuracy@1': [],
|
||||
'rerank_mrr': []
|
||||
}
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or not item['pos']:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
relevant_passages = item['pos']
|
||||
|
||||
# Step 1: Retrieval
|
||||
query_embedding = retriever_model.encode(
|
||||
[query],
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
# Compute similarities
|
||||
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
|
||||
|
||||
# Get top-k for retrieval
|
||||
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
|
||||
retrieved_passages = [passages[i] for i in top_k_indices]
|
||||
|
||||
# Check retrieval recall
|
||||
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
|
||||
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
|
||||
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
|
||||
|
||||
# Step 2: Reranking (if reranker available)
|
||||
if reranker_model is not None:
|
||||
# Rerank top candidates
|
||||
rerank_candidates = retrieved_passages[:args.rerank_top_k]
|
||||
pairs = [(query, candidate) for candidate in rerank_candidates]
|
||||
|
||||
if pairs:
|
||||
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
reranked_indices = np.argsort(rerank_scores)[::-1]
|
||||
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
|
||||
|
||||
# Check reranking performance
|
||||
if reranked_passages[0] in relevant_passages:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(1.0)
|
||||
pipeline_metrics['rerank_mrr'].append(1.0)
|
||||
else:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(0.0)
|
||||
# Find MRR
|
||||
for rank, passage in enumerate(reranked_passages, 1):
|
||||
if passage in relevant_passages:
|
||||
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
|
||||
break
|
||||
else:
|
||||
pipeline_metrics['rerank_mrr'].append(0.0)
|
||||
|
||||
total_queries += 1
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {}
|
||||
for metric, values in pipeline_metrics.items():
|
||||
if values:
|
||||
final_metrics[metric] = np.mean(values)
|
||||
|
||||
final_metrics['total_queries'] = total_queries
|
||||
|
||||
return final_metrics
|
||||
|
||||
|
||||
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Run interactive evaluation."""
|
||||
logger.info("Starting interactive evaluation...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
print("No passages found for interactive evaluation")
|
||||
return
|
||||
|
||||
# Create interactive evaluator
|
||||
evaluator = InteractiveEvaluator(
|
||||
retriever=retriever_model,
|
||||
reranker=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
corpus=passages,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
print("\n=== Interactive Evaluation ===")
|
||||
print("Enter queries to test the retrieval system.")
|
||||
print("Type 'quit' to exit, 'sample' to test with sample queries.")
|
||||
|
||||
# Load some sample queries
|
||||
sample_queries = [item['query'] for item in eval_data[:5]]
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\nEnter query: ").strip()
|
||||
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
break
|
||||
elif user_input.lower() == 'sample':
|
||||
print("\nSample queries:")
|
||||
for i, query in enumerate(sample_queries):
|
||||
print(f"{i + 1}. {query}")
|
||||
continue
|
||||
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
|
||||
query = sample_queries[int(user_input) - 1]
|
||||
else:
|
||||
query = user_input
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Search
|
||||
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
|
||||
|
||||
print(f"\nResults for: {query}")
|
||||
print("-" * 50)
|
||||
for i, (idx, score, text) in enumerate(results):
|
||||
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting interactive evaluation...")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run evaluation with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting model evaluation")
|
||||
logger.info(f"Arguments: {args}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Set device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
elif device == "npu":
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
device = "npu"
|
||||
else:
|
||||
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
|
||||
device = "cpu"
|
||||
else:
|
||||
device = device
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load models
|
||||
logger.info("Loading models...")
|
||||
from utils.config_loader import get_config
|
||||
model_source = get_config().get('model_paths.source', 'huggingface')
|
||||
# Load retriever
|
||||
try:
|
||||
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
|
||||
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
|
||||
retriever_model.to(device)
|
||||
retriever_model.eval()
|
||||
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load retriever: {e}")
|
||||
return
|
||||
|
||||
# Load reranker (optional)
|
||||
reranker_model = None
|
||||
reranker_tokenizer = None
|
||||
if args.reranker_model:
|
||||
try:
|
||||
reranker_model = BGERerankerModel(args.reranker_model)
|
||||
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
|
||||
reranker_model.to(device)
|
||||
reranker_model.eval()
|
||||
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load reranker: {e}")
|
||||
|
||||
# Load evaluation data
|
||||
logger.info("Loading evaluation data...")
|
||||
eval_data = load_evaluation_data(args.eval_data, args.data_format)
|
||||
logger.info(f"Loaded {len(eval_data)} evaluation examples")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Run evaluations
|
||||
all_results = {}
|
||||
|
||||
if args.eval_retriever:
|
||||
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
|
||||
all_results['retriever'] = retriever_metrics
|
||||
logger.info(f"Retriever metrics: {retriever_metrics}")
|
||||
|
||||
if args.eval_reranker and reranker_model:
|
||||
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
|
||||
all_results['reranker'] = reranker_metrics
|
||||
logger.info(f"Reranker metrics: {reranker_metrics}")
|
||||
|
||||
if args.eval_pipeline:
|
||||
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
all_results['pipeline'] = pipeline_metrics
|
||||
logger.info(f"Pipeline metrics: {pipeline_metrics}")
|
||||
|
||||
# Compare with baseline if requested
|
||||
if args.compare_baseline:
|
||||
logger.info("Loading baseline models for comparison...")
|
||||
try:
|
||||
base_retriever = BGEM3Model(args.base_retriever)
|
||||
base_retriever.to(device)
|
||||
base_retriever.eval()
|
||||
|
||||
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
|
||||
all_results['baseline_retriever'] = base_retriever_metrics
|
||||
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
|
||||
|
||||
if reranker_model:
|
||||
base_reranker = BGERerankerModel(args.base_reranker)
|
||||
base_reranker.to(device)
|
||||
base_reranker.eval()
|
||||
|
||||
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
|
||||
all_results['baseline_reranker'] = base_reranker_metrics
|
||||
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load baseline models: {e}")
|
||||
|
||||
# Save results
|
||||
results_file = os.path.join(args.output_dir, "evaluation_results.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
logger.info(f"Evaluation results saved to {results_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("EVALUATION SUMMARY")
|
||||
print("=" * 50)
|
||||
for model_type, metrics in all_results.items():
|
||||
print(f"\n{model_type.upper()}:")
|
||||
for metric, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.4f}")
|
||||
else:
|
||||
print(f" {metric}: {value}")
|
||||
|
||||
# Interactive evaluation
|
||||
if args.interactive:
|
||||
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
624
scripts/quick_validation.py
Normal file
624
scripts/quick_validation.py
Normal file
@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick Validation Script for BGE Fine-tuned Models
|
||||
|
||||
This script provides fast validation to quickly check if fine-tuned models
|
||||
are working correctly and performing better than baselines.
|
||||
|
||||
Usage:
|
||||
# Quick test both models
|
||||
python scripts/quick_validation.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
|
||||
# Test only retriever
|
||||
python scripts/quick_validation.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuickValidator:
|
||||
"""Quick validation for BGE models"""
|
||||
|
||||
def __init__(self):
|
||||
self.device = self._get_device()
|
||||
self.test_queries = self._get_test_queries()
|
||||
|
||||
def _get_device(self):
|
||||
"""Get optimal device"""
|
||||
try:
|
||||
config = get_config()
|
||||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||||
|
||||
if device_type == 'npu' and torch.cuda.is_available():
|
||||
try:
|
||||
import torch_npu
|
||||
return torch.device("npu:0")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
return torch.device("cpu")
|
||||
|
||||
def _get_test_queries(self):
|
||||
"""Get standard test queries for validation"""
|
||||
return [
|
||||
{
|
||||
"query": "什么是深度学习?",
|
||||
"relevant": [
|
||||
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
|
||||
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
|
||||
],
|
||||
"irrelevant": [
|
||||
"今天天气很好,阳光明媚",
|
||||
"我喜欢吃苹果和香蕉"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "如何制作红烧肉?",
|
||||
"relevant": [
|
||||
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
|
||||
"制作红烧肉的关键是掌握火候和糖色的调制"
|
||||
],
|
||||
"irrelevant": [
|
||||
"Python是一种编程语言",
|
||||
"机器学习需要大量的数据"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Python编程入门",
|
||||
"relevant": [
|
||||
"Python是一种易学易用的编程语言,适合初学者入门",
|
||||
"学习Python需要掌握基本语法、数据类型和控制结构"
|
||||
],
|
||||
"irrelevant": [
|
||||
"红烧肉是一道经典的中式菜肴",
|
||||
"深度学习使用神经网络进行训练"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
|
||||
"""Quick validation of retriever model"""
|
||||
print(f"\n🔍 Validating Retriever Model")
|
||||
print("=" * 50)
|
||||
|
||||
results = {
|
||||
'model_path': model_path,
|
||||
'baseline_path': baseline_path,
|
||||
'test_results': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Test fine-tuned model
|
||||
print("Testing fine-tuned model...")
|
||||
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
|
||||
|
||||
# Test baseline model
|
||||
print("Testing baseline model...")
|
||||
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
|
||||
|
||||
results['test_results'] = {
|
||||
'finetuned': ft_scores,
|
||||
'baseline': baseline_scores,
|
||||
'comparison': comparison
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print(f"\n📊 Retriever Validation Summary:")
|
||||
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
|
||||
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
|
||||
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
|
||||
|
||||
if comparison['relevance_improvement'] > 0:
|
||||
print(" ✅ Fine-tuned model shows improvement!")
|
||||
else:
|
||||
print(" ❌ Fine-tuned model shows degradation!")
|
||||
|
||||
results['summary'] = {
|
||||
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
|
||||
'improvement_pct': comparison['relevance_improvement']
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Retriever validation failed: {e}")
|
||||
return None
|
||||
|
||||
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
|
||||
"""Quick validation of reranker model"""
|
||||
print(f"\n🏆 Validating Reranker Model")
|
||||
print("=" * 50)
|
||||
|
||||
results = {
|
||||
'model_path': model_path,
|
||||
'baseline_path': baseline_path,
|
||||
'test_results': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Test fine-tuned model
|
||||
print("Testing fine-tuned reranker...")
|
||||
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
|
||||
|
||||
# Test baseline model
|
||||
print("Testing baseline reranker...")
|
||||
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
|
||||
|
||||
results['test_results'] = {
|
||||
'finetuned': ft_scores,
|
||||
'baseline': baseline_scores,
|
||||
'comparison': comparison
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print(f"\n📊 Reranker Validation Summary:")
|
||||
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
|
||||
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
|
||||
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
|
||||
|
||||
if comparison['accuracy_improvement'] > 0:
|
||||
print(" ✅ Fine-tuned reranker shows improvement!")
|
||||
else:
|
||||
print(" ❌ Fine-tuned reranker shows degradation!")
|
||||
|
||||
results['summary'] = {
|
||||
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
|
||||
'improvement_pct': comparison['accuracy_improvement']
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reranker validation failed: {e}")
|
||||
return None
|
||||
|
||||
def _test_retriever_model(self, model_path, model_type):
|
||||
"""Test retriever model with predefined queries"""
|
||||
try:
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
print(f" ✅ {model_type} model loaded successfully")
|
||||
|
||||
relevant_scores = []
|
||||
irrelevant_scores = []
|
||||
ranking_accuracy = []
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Encode query
|
||||
query_inputs = tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
query_outputs = model(**query_inputs, return_dense=True)
|
||||
query_emb = query_outputs['dense']
|
||||
|
||||
# Test relevant documents
|
||||
for doc in relevant_docs:
|
||||
doc_inputs = tokenizer(
|
||||
doc,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||||
doc_emb = doc_outputs['dense']
|
||||
|
||||
# Compute similarity
|
||||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
relevant_scores.append(similarity)
|
||||
|
||||
# Test irrelevant documents
|
||||
for doc in irrelevant_docs:
|
||||
doc_inputs = tokenizer(
|
||||
doc,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||||
doc_emb = doc_outputs['dense']
|
||||
|
||||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
irrelevant_scores.append(similarity)
|
||||
|
||||
# Check ranking accuracy (relevant should score higher than irrelevant)
|
||||
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||||
for doc in relevant_docs])
|
||||
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||||
for doc in irrelevant_docs])
|
||||
|
||||
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
|
||||
|
||||
return {
|
||||
'avg_relevant_score': np.mean(relevant_scores),
|
||||
'avg_irrelevant_score': np.mean(irrelevant_scores),
|
||||
'ranking_accuracy': np.mean(ranking_accuracy),
|
||||
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing {model_type} retriever: {e}")
|
||||
return None
|
||||
|
||||
def _test_reranker_model(self, model_path, model_type):
|
||||
"""Test reranker model with predefined query-document pairs"""
|
||||
try:
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
print(f" ✅ {model_type} reranker loaded successfully")
|
||||
|
||||
correct_rankings = 0
|
||||
total_pairs = 0
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Test relevant documents (positive examples)
|
||||
for doc in relevant_docs:
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
score = outputs.logits.item()
|
||||
all_scores.append(score)
|
||||
all_labels.append(1) # Relevant
|
||||
|
||||
# Check if score indicates relevance (positive score)
|
||||
if score > 0:
|
||||
correct_rankings += 1
|
||||
total_pairs += 1
|
||||
|
||||
# Test irrelevant documents (negative examples)
|
||||
for doc in irrelevant_docs:
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
score = outputs.logits.item()
|
||||
all_scores.append(score)
|
||||
all_labels.append(0) # Irrelevant
|
||||
|
||||
# Check if score indicates irrelevance (negative score)
|
||||
if score <= 0:
|
||||
correct_rankings += 1
|
||||
total_pairs += 1
|
||||
|
||||
return {
|
||||
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
|
||||
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
|
||||
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
|
||||
'total_pairs': total_pairs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing {model_type} reranker: {e}")
|
||||
return None
|
||||
|
||||
def _compare_retriever_results(self, baseline, finetuned):
|
||||
"""Compare retriever results"""
|
||||
if not baseline or not finetuned:
|
||||
return {}
|
||||
|
||||
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
|
||||
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
|
||||
|
||||
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
|
||||
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
|
||||
|
||||
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
|
||||
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
|
||||
|
||||
return {
|
||||
'relevance_improvement': relevance_improvement,
|
||||
'separation_improvement': separation_improvement,
|
||||
'ranking_improvement': ranking_improvement
|
||||
}
|
||||
|
||||
def _compare_reranker_results(self, baseline, finetuned):
|
||||
"""Compare reranker results"""
|
||||
if not baseline or not finetuned:
|
||||
return {}
|
||||
|
||||
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
|
||||
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
|
||||
|
||||
return {
|
||||
'accuracy_improvement': accuracy_improvement
|
||||
}
|
||||
|
||||
def validate_pipeline(self, retriever_path, reranker_path):
|
||||
"""Quick validation of full pipeline"""
|
||||
print(f"\n🚀 Validating Full Pipeline")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Load models
|
||||
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
|
||||
retriever = BGEM3Model.from_pretrained(retriever_path)
|
||||
retriever.to(self.device)
|
||||
retriever.eval()
|
||||
|
||||
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
|
||||
reranker = BGERerankerModel.from_pretrained(reranker_path)
|
||||
reranker.to(self.device)
|
||||
reranker.eval()
|
||||
|
||||
print(" ✅ Pipeline models loaded successfully")
|
||||
|
||||
total_queries = 0
|
||||
pipeline_accuracy = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Create candidate pool
|
||||
all_docs = relevant_docs + irrelevant_docs
|
||||
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
|
||||
|
||||
# Step 1: Retrieval
|
||||
query_inputs = retriever_tokenizer(
|
||||
query, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
query_emb = retriever(**query_inputs, return_dense=True)['dense']
|
||||
|
||||
# Score all documents with retriever
|
||||
doc_scores = []
|
||||
for doc in all_docs:
|
||||
doc_inputs = retriever_tokenizer(
|
||||
doc, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
|
||||
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
doc_scores.append(score)
|
||||
|
||||
# Get top-k for reranking (all docs in this case)
|
||||
retrieval_indices = sorted(range(len(doc_scores)),
|
||||
key=lambda i: doc_scores[i], reverse=True)
|
||||
|
||||
# Step 2: Reranking
|
||||
rerank_scores = []
|
||||
rerank_labels = []
|
||||
|
||||
for idx in retrieval_indices:
|
||||
doc = all_docs[idx]
|
||||
label = doc_labels[idx]
|
||||
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
pair_inputs = reranker_tokenizer(
|
||||
pair_text, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
|
||||
rerank_output = reranker(**pair_inputs)
|
||||
rerank_score = rerank_output.logits.item()
|
||||
|
||||
rerank_scores.append(rerank_score)
|
||||
rerank_labels.append(label)
|
||||
|
||||
# Check if top-ranked document is relevant
|
||||
best_idx = np.argmax(rerank_scores)
|
||||
if rerank_labels[best_idx] == 1:
|
||||
pipeline_accuracy += 1
|
||||
|
||||
total_queries += 1
|
||||
|
||||
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
|
||||
|
||||
print(f"\n📊 Pipeline Validation Summary:")
|
||||
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
|
||||
print(f" Queries Tested: {total_queries}")
|
||||
|
||||
if pipeline_acc > 0.5:
|
||||
print(" ✅ Pipeline performing well!")
|
||||
else:
|
||||
print(" ❌ Pipeline needs improvement!")
|
||||
|
||||
return {
|
||||
'accuracy': pipeline_acc,
|
||||
'total_queries': total_queries,
|
||||
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline validation failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Quick validation of BGE fine-tuned models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("--retriever_model", type=str, default=None,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||||
help="Path to baseline retriever model")
|
||||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to baseline reranker model")
|
||||
parser.add_argument("--test_pipeline", action="store_true",
|
||||
help="Test full retrieval + reranking pipeline")
|
||||
parser.add_argument("--output_file", type=str, default=None,
|
||||
help="Save results to JSON file")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main validation function"""
|
||||
args = parse_args()
|
||||
|
||||
if not args.retriever_model and not args.reranker_model:
|
||||
print("❌ Error: At least one model must be specified")
|
||||
print(" Use --retriever_model or --reranker_model")
|
||||
return 1
|
||||
|
||||
print("🚀 BGE Quick Validation")
|
||||
print("=" * 60)
|
||||
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
validator = QuickValidator()
|
||||
results = {}
|
||||
|
||||
# Validate retriever
|
||||
if args.retriever_model:
|
||||
if os.path.exists(args.retriever_model):
|
||||
retriever_results = validator.validate_retriever(
|
||||
args.retriever_model, args.retriever_baseline
|
||||
)
|
||||
if retriever_results:
|
||||
results['retriever'] = retriever_results
|
||||
else:
|
||||
print(f"❌ Retriever model not found: {args.retriever_model}")
|
||||
|
||||
# Validate reranker
|
||||
if args.reranker_model:
|
||||
if os.path.exists(args.reranker_model):
|
||||
reranker_results = validator.validate_reranker(
|
||||
args.reranker_model, args.reranker_baseline
|
||||
)
|
||||
if reranker_results:
|
||||
results['reranker'] = reranker_results
|
||||
else:
|
||||
print(f"❌ Reranker model not found: {args.reranker_model}")
|
||||
|
||||
# Validate pipeline
|
||||
if (args.test_pipeline and args.retriever_model and args.reranker_model and
|
||||
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
|
||||
pipeline_results = validator.validate_pipeline(
|
||||
args.retriever_model, args.reranker_model
|
||||
)
|
||||
if pipeline_results:
|
||||
results['pipeline'] = pipeline_results
|
||||
|
||||
# Print final summary
|
||||
print(f"\n{'='*60}")
|
||||
print("🎯 FINAL VALIDATION SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
|
||||
overall_status = "✅ SUCCESS"
|
||||
issues = []
|
||||
|
||||
for model_type, model_results in results.items():
|
||||
if 'summary' in model_results:
|
||||
status = model_results['summary'].get('status', 'unknown')
|
||||
improvement = model_results['summary'].get('improvement_pct', 0)
|
||||
|
||||
if status == 'improved':
|
||||
print(f"✅ {model_type.title()}: Improved by {improvement:.2f}%")
|
||||
elif status == 'degraded':
|
||||
print(f"❌ {model_type.title()}: Degraded by {abs(improvement):.2f}%")
|
||||
issues.append(model_type)
|
||||
overall_status = "⚠️ ISSUES DETECTED"
|
||||
elif model_type == 'pipeline' and 'status' in model_results:
|
||||
if model_results['status'] == 'good':
|
||||
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
|
||||
else:
|
||||
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
|
||||
issues.append('pipeline')
|
||||
overall_status = "⚠️ ISSUES DETECTED"
|
||||
|
||||
print(f"\n{overall_status}")
|
||||
if issues:
|
||||
print(f"⚠️ Issues detected in: {', '.join(issues)}")
|
||||
print("💡 Consider running comprehensive validation for detailed analysis")
|
||||
|
||||
# Save results if requested
|
||||
if args.output_file:
|
||||
results['timestamp'] = datetime.now().isoformat()
|
||||
results['overall_status'] = overall_status
|
||||
results['issues'] = issues
|
||||
|
||||
with open(args.output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"📊 Results saved to: {args.output_file}")
|
||||
|
||||
return 0 if overall_status == "✅ SUCCESS" else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
696
scripts/run_validation_suite.py
Normal file
696
scripts/run_validation_suite.py
Normal file
@ -0,0 +1,696 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE Validation Suite Runner
|
||||
|
||||
This script runs a complete validation suite for BGE fine-tuned models,
|
||||
including multiple validation approaches and generating comprehensive reports.
|
||||
|
||||
Usage:
|
||||
# Auto-discover models and run full suite
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Run with specific models
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
|
||||
# Run only specific validation types
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--validation_types quick comprehensive comparison
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import subprocess
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from scripts.validation_utils import ModelDiscovery, TestDataManager, ValidationReportAnalyzer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationSuiteRunner:
|
||||
"""Orchestrates comprehensive validation of BGE models"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.results_dir = Path(config.get('output_dir', './validation_suite_results'))
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize components
|
||||
self.model_discovery = ModelDiscovery(config.get('workspace_root', '.'))
|
||||
self.data_manager = TestDataManager(config.get('data_dir', 'data/datasets'))
|
||||
|
||||
# Suite results
|
||||
self.suite_results = {
|
||||
'start_time': datetime.now().isoformat(),
|
||||
'config': config,
|
||||
'validation_results': {},
|
||||
'summary': {},
|
||||
'issues': [],
|
||||
'recommendations': []
|
||||
}
|
||||
|
||||
def run_suite(self) -> Dict[str, Any]:
|
||||
"""Run the complete validation suite"""
|
||||
logger.info("🚀 Starting BGE Validation Suite")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Phase 1: Discovery and Setup
|
||||
self._phase_discovery()
|
||||
|
||||
# Phase 2: Quick Validation
|
||||
if 'quick' in self.config.get('validation_types', ['quick']):
|
||||
self._phase_quick_validation()
|
||||
|
||||
# Phase 3: Comprehensive Validation
|
||||
if 'comprehensive' in self.config.get('validation_types', ['comprehensive']):
|
||||
self._phase_comprehensive_validation()
|
||||
|
||||
# Phase 4: Comparison Benchmarks
|
||||
if 'comparison' in self.config.get('validation_types', ['comparison']):
|
||||
self._phase_comparison_benchmarks()
|
||||
|
||||
# Phase 5: Analysis and Reporting
|
||||
self._phase_analysis()
|
||||
|
||||
# Phase 6: Generate Final Report
|
||||
self._generate_final_report()
|
||||
|
||||
self.suite_results['end_time'] = datetime.now().isoformat()
|
||||
self.suite_results['status'] = 'completed'
|
||||
|
||||
logger.info("✅ Validation Suite Completed Successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Validation Suite Failed: {e}")
|
||||
self.suite_results['status'] = 'failed'
|
||||
self.suite_results['error'] = str(e)
|
||||
raise
|
||||
|
||||
return self.suite_results
|
||||
|
||||
def _phase_discovery(self):
|
||||
"""Phase 1: Discover models and datasets"""
|
||||
logger.info("📡 Phase 1: Discovery and Setup")
|
||||
|
||||
# Auto-discover models if requested
|
||||
if self.config.get('auto_discover', False):
|
||||
discovered_models = self.model_discovery.find_trained_models()
|
||||
|
||||
# Set models from discovery if not provided
|
||||
if not self.config.get('retriever_model'):
|
||||
if discovered_models['retrievers']:
|
||||
model_name = list(discovered_models['retrievers'].keys())[0]
|
||||
self.config['retriever_model'] = discovered_models['retrievers'][model_name]['path']
|
||||
logger.info(f" 📊 Auto-discovered retriever: {model_name}")
|
||||
|
||||
if not self.config.get('reranker_model'):
|
||||
if discovered_models['rerankers']:
|
||||
model_name = list(discovered_models['rerankers'].keys())[0]
|
||||
self.config['reranker_model'] = discovered_models['rerankers'][model_name]['path']
|
||||
logger.info(f" 🏆 Auto-discovered reranker: {model_name}")
|
||||
|
||||
self.suite_results['discovered_models'] = discovered_models
|
||||
|
||||
# Discover test datasets
|
||||
discovered_datasets = self.data_manager.find_test_datasets()
|
||||
self.suite_results['available_datasets'] = discovered_datasets
|
||||
|
||||
# Validate models exist
|
||||
issues = []
|
||||
if self.config.get('retriever_model'):
|
||||
if not os.path.exists(self.config['retriever_model']):
|
||||
issues.append(f"Retriever model not found: {self.config['retriever_model']}")
|
||||
|
||||
if self.config.get('reranker_model'):
|
||||
if not os.path.exists(self.config['reranker_model']):
|
||||
issues.append(f"Reranker model not found: {self.config['reranker_model']}")
|
||||
|
||||
if issues:
|
||||
for issue in issues:
|
||||
logger.error(f" ❌ {issue}")
|
||||
self.suite_results['issues'].extend(issues)
|
||||
return
|
||||
|
||||
logger.info(" ✅ Discovery phase completed")
|
||||
|
||||
def _phase_quick_validation(self):
|
||||
"""Phase 2: Quick validation tests"""
|
||||
logger.info("⚡ Phase 2: Quick Validation")
|
||||
|
||||
try:
|
||||
# Run quick validation script
|
||||
cmd = ['python', 'scripts/quick_validation.py']
|
||||
|
||||
if self.config.get('retriever_model'):
|
||||
cmd.extend(['--retriever_model', self.config['retriever_model']])
|
||||
|
||||
if self.config.get('reranker_model'):
|
||||
cmd.extend(['--reranker_model', self.config['reranker_model']])
|
||||
|
||||
if self.config.get('retriever_model') and self.config.get('reranker_model'):
|
||||
cmd.append('--test_pipeline')
|
||||
|
||||
# Save results to file
|
||||
quick_results_file = self.results_dir / 'quick_validation_results.json'
|
||||
cmd.extend(['--output_file', str(quick_results_file)])
|
||||
|
||||
logger.info(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(" ✅ Quick validation completed")
|
||||
|
||||
# Load results
|
||||
if quick_results_file.exists():
|
||||
with open(quick_results_file, 'r') as f:
|
||||
quick_results = json.load(f)
|
||||
self.suite_results['validation_results']['quick'] = quick_results
|
||||
else:
|
||||
logger.error(f" ❌ Quick validation failed: {result.stderr}")
|
||||
self.suite_results['issues'].append(f"Quick validation failed: {result.stderr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in quick validation: {e}")
|
||||
self.suite_results['issues'].append(f"Quick validation error: {e}")
|
||||
|
||||
def _phase_comprehensive_validation(self):
|
||||
"""Phase 3: Comprehensive validation"""
|
||||
logger.info("🔬 Phase 3: Comprehensive Validation")
|
||||
|
||||
try:
|
||||
# Run comprehensive validation script
|
||||
cmd = ['python', 'scripts/comprehensive_validation.py']
|
||||
|
||||
if self.config.get('retriever_model'):
|
||||
cmd.extend(['--retriever_finetuned', self.config['retriever_model']])
|
||||
|
||||
if self.config.get('reranker_model'):
|
||||
cmd.extend(['--reranker_finetuned', self.config['reranker_model']])
|
||||
|
||||
# Use specific datasets if provided
|
||||
if self.config.get('test_datasets'):
|
||||
cmd.extend(['--test_datasets'] + self.config['test_datasets'])
|
||||
|
||||
# Set output directory
|
||||
comprehensive_dir = self.results_dir / 'comprehensive'
|
||||
cmd.extend(['--output_dir', str(comprehensive_dir)])
|
||||
|
||||
logger.info(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(" ✅ Comprehensive validation completed")
|
||||
|
||||
# Load results
|
||||
results_file = comprehensive_dir / 'validation_results.json'
|
||||
if results_file.exists():
|
||||
with open(results_file, 'r') as f:
|
||||
comp_results = json.load(f)
|
||||
self.suite_results['validation_results']['comprehensive'] = comp_results
|
||||
else:
|
||||
logger.error(f" ❌ Comprehensive validation failed: {result.stderr}")
|
||||
self.suite_results['issues'].append(f"Comprehensive validation failed: {result.stderr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in comprehensive validation: {e}")
|
||||
self.suite_results['issues'].append(f"Comprehensive validation error: {e}")
|
||||
|
||||
def _phase_comparison_benchmarks(self):
|
||||
"""Phase 4: Run comparison benchmarks"""
|
||||
logger.info("📊 Phase 4: Comparison Benchmarks")
|
||||
|
||||
# Find suitable datasets for comparison
|
||||
datasets = self.data_manager.find_test_datasets()
|
||||
|
||||
try:
|
||||
# Run retriever comparison if model exists
|
||||
if self.config.get('retriever_model'):
|
||||
self._run_retriever_comparison(datasets)
|
||||
|
||||
# Run reranker comparison if model exists
|
||||
if self.config.get('reranker_model'):
|
||||
self._run_reranker_comparison(datasets)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in comparison benchmarks: {e}")
|
||||
self.suite_results['issues'].append(f"Comparison benchmarks error: {e}")
|
||||
|
||||
def _run_retriever_comparison(self, datasets: Dict[str, List[str]]):
|
||||
"""Run retriever comparison benchmarks"""
|
||||
logger.info(" 🔍 Running retriever comparison...")
|
||||
|
||||
# Find suitable datasets for retriever testing
|
||||
test_datasets = datasets.get('retriever_datasets', []) + datasets.get('general_datasets', [])
|
||||
|
||||
if not test_datasets:
|
||||
logger.warning(" ⚠️ No suitable datasets found for retriever comparison")
|
||||
return
|
||||
|
||||
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
|
||||
if not os.path.exists(dataset_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
'python', 'scripts/compare_retriever.py',
|
||||
'--finetuned_model_path', self.config['retriever_model'],
|
||||
'--baseline_model_path', self.config.get('retriever_baseline', 'BAAI/bge-m3'),
|
||||
'--data_path', dataset_path,
|
||||
'--batch_size', str(self.config.get('batch_size', 16)),
|
||||
'--max_samples', str(self.config.get('max_samples', 500)),
|
||||
'--output', str(self.results_dir / f'retriever_comparison_{Path(dataset_path).stem}.txt')
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f" ✅ Retriever comparison on {Path(dataset_path).name}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Retriever comparison failed on {Path(dataset_path).name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠️ Error comparing retriever on {dataset_path}: {e}")
|
||||
|
||||
def _run_reranker_comparison(self, datasets: Dict[str, List[str]]):
|
||||
"""Run reranker comparison benchmarks"""
|
||||
logger.info(" 🏆 Running reranker comparison...")
|
||||
|
||||
# Find suitable datasets for reranker testing
|
||||
test_datasets = datasets.get('reranker_datasets', []) + datasets.get('general_datasets', [])
|
||||
|
||||
if not test_datasets:
|
||||
logger.warning(" ⚠️ No suitable datasets found for reranker comparison")
|
||||
return
|
||||
|
||||
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
|
||||
if not os.path.exists(dataset_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
'python', 'scripts/compare_reranker.py',
|
||||
'--finetuned_model_path', self.config['reranker_model'],
|
||||
'--baseline_model_path', self.config.get('reranker_baseline', 'BAAI/bge-reranker-base'),
|
||||
'--data_path', dataset_path,
|
||||
'--batch_size', str(self.config.get('batch_size', 16)),
|
||||
'--max_samples', str(self.config.get('max_samples', 500)),
|
||||
'--output', str(self.results_dir / f'reranker_comparison_{Path(dataset_path).stem}.txt')
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f" ✅ Reranker comparison on {Path(dataset_path).name}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Reranker comparison failed on {Path(dataset_path).name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠️ Error comparing reranker on {dataset_path}: {e}")
|
||||
|
||||
def _phase_analysis(self):
|
||||
"""Phase 5: Analyze all results"""
|
||||
logger.info("📈 Phase 5: Results Analysis")
|
||||
|
||||
try:
|
||||
# Analyze comprehensive results if available
|
||||
comprehensive_dir = self.results_dir / 'comprehensive'
|
||||
if comprehensive_dir.exists():
|
||||
analyzer = ValidationReportAnalyzer(str(comprehensive_dir))
|
||||
analysis = analyzer.analyze_results()
|
||||
self.suite_results['analysis'] = analysis
|
||||
|
||||
logger.info(" ✅ Results analysis completed")
|
||||
else:
|
||||
logger.warning(" ⚠️ No comprehensive results found for analysis")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error in results analysis: {e}")
|
||||
self.suite_results['issues'].append(f"Analysis error: {e}")
|
||||
|
||||
def _generate_final_report(self):
|
||||
"""Generate final comprehensive report"""
|
||||
logger.info("📝 Phase 6: Generating Final Report")
|
||||
|
||||
try:
|
||||
# Create summary
|
||||
self._create_suite_summary()
|
||||
|
||||
# Save suite results
|
||||
suite_results_file = self.results_dir / 'validation_suite_results.json'
|
||||
with open(suite_results_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.suite_results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Generate HTML report
|
||||
self._generate_html_report()
|
||||
|
||||
logger.info(f" 📊 Reports saved to: {self.results_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ Error generating final report: {e}")
|
||||
self.suite_results['issues'].append(f"Report generation error: {e}")
|
||||
|
||||
def _create_suite_summary(self):
|
||||
"""Create suite summary"""
|
||||
summary = {
|
||||
'total_validations': 0,
|
||||
'successful_validations': 0,
|
||||
'failed_validations': 0,
|
||||
'overall_verdict': 'unknown',
|
||||
'key_findings': [],
|
||||
'critical_issues': []
|
||||
}
|
||||
|
||||
# Count validations
|
||||
for validation_type, results in self.suite_results.get('validation_results', {}).items():
|
||||
summary['total_validations'] += 1
|
||||
if results:
|
||||
summary['successful_validations'] += 1
|
||||
else:
|
||||
summary['failed_validations'] += 1
|
||||
|
||||
# Determine overall verdict
|
||||
if summary['successful_validations'] > 0 and summary['failed_validations'] == 0:
|
||||
if self.suite_results.get('analysis', {}).get('overall_status') == 'excellent':
|
||||
summary['overall_verdict'] = 'excellent'
|
||||
elif self.suite_results.get('analysis', {}).get('overall_status') in ['good', 'mixed']:
|
||||
summary['overall_verdict'] = 'good'
|
||||
else:
|
||||
summary['overall_verdict'] = 'fair'
|
||||
elif summary['successful_validations'] > summary['failed_validations']:
|
||||
summary['overall_verdict'] = 'partial'
|
||||
else:
|
||||
summary['overall_verdict'] = 'poor'
|
||||
|
||||
# Extract key findings from analysis
|
||||
if 'analysis' in self.suite_results:
|
||||
analysis = self.suite_results['analysis']
|
||||
if 'recommendations' in analysis:
|
||||
summary['key_findings'] = analysis['recommendations'][:5] # Top 5 recommendations
|
||||
|
||||
# Extract critical issues
|
||||
summary['critical_issues'] = self.suite_results.get('issues', [])
|
||||
|
||||
self.suite_results['summary'] = summary
|
||||
|
||||
def _generate_html_report(self):
|
||||
"""Generate HTML summary report"""
|
||||
html_file = self.results_dir / 'validation_suite_report.html'
|
||||
|
||||
summary = self.suite_results.get('summary', {})
|
||||
|
||||
verdict_colors = {
|
||||
'excellent': '#28a745',
|
||||
'good': '#17a2b8',
|
||||
'fair': '#ffc107',
|
||||
'partial': '#fd7e14',
|
||||
'poor': '#dc3545',
|
||||
'unknown': '#6c757d'
|
||||
}
|
||||
|
||||
verdict_color = verdict_colors.get(summary.get('overall_verdict', 'unknown'), '#6c757d')
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>BGE Validation Suite Report</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; }}
|
||||
.header {{ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white; padding: 30px; border-radius: 10px; text-align: center; }}
|
||||
.verdict {{ font-size: 24px; font-weight: bold; color: {verdict_color};
|
||||
margin: 20px 0; text-align: center; padding: 15px;
|
||||
border: 3px solid {verdict_color}; border-radius: 10px; }}
|
||||
.section {{ margin: 30px 0; }}
|
||||
.metrics {{ display: flex; justify-content: space-around; margin: 20px 0; }}
|
||||
.metric {{ text-align: center; padding: 15px; background: #f8f9fa;
|
||||
border-radius: 8px; min-width: 120px; }}
|
||||
.metric-value {{ font-size: 2em; font-weight: bold; color: #495057; }}
|
||||
.metric-label {{ color: #6c757d; }}
|
||||
.findings {{ background: #e3f2fd; padding: 20px; border-radius: 8px;
|
||||
border-left: 4px solid #2196f3; }}
|
||||
.issues {{ background: #ffebee; padding: 20px; border-radius: 8px;
|
||||
border-left: 4px solid #f44336; }}
|
||||
.timeline {{ margin: 20px 0; }}
|
||||
.timeline-item {{ margin: 10px 0; padding: 10px; background: #f8f9fa;
|
||||
border-left: 3px solid #007bff; border-radius: 0 5px 5px 0; }}
|
||||
ul {{ padding-left: 20px; }}
|
||||
li {{ margin: 8px 0; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>🚀 BGE Validation Suite Report</h1>
|
||||
<p>Comprehensive Model Performance Analysis</p>
|
||||
<p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
|
||||
<div class="verdict">
|
||||
Overall Verdict: {summary.get('overall_verdict', 'UNKNOWN').upper()}
|
||||
</div>
|
||||
|
||||
<div class="metrics">
|
||||
<div class="metric">
|
||||
<div class="metric-value">{summary.get('total_validations', 0)}</div>
|
||||
<div class="metric-label">Total Validations</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value">{summary.get('successful_validations', 0)}</div>
|
||||
<div class="metric-label">Successful</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value">{summary.get('failed_validations', 0)}</div>
|
||||
<div class="metric-label">Failed</div>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add key findings
|
||||
if summary.get('key_findings'):
|
||||
html_content += """
|
||||
<div class="section">
|
||||
<h2>🔍 Key Findings</h2>
|
||||
<div class="findings">
|
||||
<ul>
|
||||
"""
|
||||
for finding in summary['key_findings']:
|
||||
html_content += f" <li>{finding}</li>\n"
|
||||
|
||||
html_content += """
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add critical issues if any
|
||||
if summary.get('critical_issues'):
|
||||
html_content += """
|
||||
<div class="section">
|
||||
<h2>⚠️ Critical Issues</h2>
|
||||
<div class="issues">
|
||||
<ul>
|
||||
"""
|
||||
for issue in summary['critical_issues']:
|
||||
html_content += f" <li>{issue}</li>\n"
|
||||
|
||||
html_content += """
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add validation timeline
|
||||
html_content += f"""
|
||||
<div class="section">
|
||||
<h2>⏱️ Validation Timeline</h2>
|
||||
<div class="timeline">
|
||||
<div class="timeline-item">
|
||||
<strong>Suite Started:</strong> {self.suite_results.get('start_time', 'Unknown')}
|
||||
</div>
|
||||
"""
|
||||
|
||||
for validation_type in self.suite_results.get('validation_results', {}).keys():
|
||||
html_content += f"""
|
||||
<div class="timeline-item">
|
||||
<strong>{validation_type.title()} Validation:</strong> Completed
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content += f"""
|
||||
<div class="timeline-item">
|
||||
<strong>Suite Completed:</strong> {self.suite_results.get('end_time', 'In Progress')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📁 Detailed Results</h2>
|
||||
<p>For detailed validation results, check the following files in <code>{self.results_dir}</code>:</p>
|
||||
<ul>
|
||||
<li><code>validation_suite_results.json</code> - Complete suite results</li>
|
||||
<li><code>quick_validation_results.json</code> - Quick validation results</li>
|
||||
<li><code>comprehensive/</code> - Comprehensive validation reports</li>
|
||||
<li><code>*_comparison_*.txt</code> - Model comparison results</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
with open(html_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f" 📊 HTML report: {html_file}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run comprehensive BGE validation suite",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Auto-discover models and run full suite
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Run with specific models
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
|
||||
# Run only specific validation types
|
||||
python scripts/run_validation_suite.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--validation_types quick comprehensive
|
||||
"""
|
||||
)
|
||||
|
||||
# Model paths
|
||||
parser.add_argument("--retriever_model", type=str, default=None,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--auto_discover", action="store_true",
|
||||
help="Auto-discover trained models in workspace")
|
||||
|
||||
# Baseline models
|
||||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||||
help="Baseline retriever model")
|
||||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Baseline reranker model")
|
||||
|
||||
# Validation configuration
|
||||
parser.add_argument("--validation_types", type=str, nargs="+",
|
||||
default=["quick", "comprehensive", "comparison"],
|
||||
choices=["quick", "comprehensive", "comparison"],
|
||||
help="Types of validation to run")
|
||||
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
|
||||
help="Specific test datasets to use")
|
||||
|
||||
# Performance settings
|
||||
parser.add_argument("--batch_size", type=int, default=16,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_samples", type=int, default=1000,
|
||||
help="Maximum samples per dataset for speed")
|
||||
|
||||
# Directories
|
||||
parser.add_argument("--workspace_root", type=str, default=".",
|
||||
help="Workspace root directory")
|
||||
parser.add_argument("--data_dir", type=str, default="data/datasets",
|
||||
help="Test datasets directory")
|
||||
parser.add_argument("--output_dir", type=str, default="./validation_suite_results",
|
||||
help="Output directory for all results")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
args = parse_args()
|
||||
|
||||
# Validate input
|
||||
if not args.auto_discover and not args.retriever_model and not args.reranker_model:
|
||||
print("❌ Error: Must specify models or use --auto_discover")
|
||||
print(" Use --retriever_model, --reranker_model, or --auto_discover")
|
||||
return 1
|
||||
|
||||
# Create configuration
|
||||
config = {
|
||||
'retriever_model': args.retriever_model,
|
||||
'reranker_model': args.reranker_model,
|
||||
'auto_discover': args.auto_discover,
|
||||
'retriever_baseline': args.retriever_baseline,
|
||||
'reranker_baseline': args.reranker_baseline,
|
||||
'validation_types': args.validation_types,
|
||||
'test_datasets': args.test_datasets,
|
||||
'batch_size': args.batch_size,
|
||||
'max_samples': args.max_samples,
|
||||
'workspace_root': args.workspace_root,
|
||||
'data_dir': args.data_dir,
|
||||
'output_dir': args.output_dir
|
||||
}
|
||||
|
||||
# Run validation suite
|
||||
try:
|
||||
runner = ValidationSuiteRunner(config)
|
||||
results = runner.run_suite()
|
||||
|
||||
# Print final summary
|
||||
summary = results.get('summary', {})
|
||||
verdict = summary.get('overall_verdict', 'unknown')
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("🎯 VALIDATION SUITE SUMMARY")
|
||||
print("=" * 80)
|
||||
|
||||
verdict_emojis = {
|
||||
'excellent': '🌟',
|
||||
'good': '✅',
|
||||
'fair': '👌',
|
||||
'partial': '⚠️',
|
||||
'poor': '❌',
|
||||
'unknown': '❓'
|
||||
}
|
||||
|
||||
print(f"{verdict_emojis.get(verdict, '❓')} Overall Verdict: {verdict.upper()}")
|
||||
print(f"📊 Validations: {summary.get('successful_validations', 0)}/{summary.get('total_validations', 0)} successful")
|
||||
|
||||
if summary.get('key_findings'):
|
||||
print(f"\n🔍 Key Findings:")
|
||||
for i, finding in enumerate(summary['key_findings'][:3], 1):
|
||||
print(f" {i}. {finding}")
|
||||
|
||||
if summary.get('critical_issues'):
|
||||
print(f"\n⚠️ Critical Issues:")
|
||||
for issue in summary['critical_issues'][:3]:
|
||||
print(f" • {issue}")
|
||||
|
||||
print(f"\n📁 Detailed results: {config['output_dir']}")
|
||||
print("🌐 Open validation_suite_report.html in your browser for full report")
|
||||
|
||||
return 0 if verdict in ['excellent', 'good'] else 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n❌ Validation suite interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Validation suite failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
401
scripts/setup.py
Normal file
401
scripts/setup.py
Normal file
@ -0,0 +1,401 @@
|
||||
"""
|
||||
Setup script for BGE fine-tuning environment
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
import shutil
|
||||
import urllib.request
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_command(command, check=True):
|
||||
"""Run a shell command"""
|
||||
logger.info(f"Running: {command}")
|
||||
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||||
|
||||
if check and result.returncode != 0:
|
||||
logger.error(f"Command failed: {command}")
|
||||
logger.error(f"Error: {result.stderr}")
|
||||
sys.exit(1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_python_version():
|
||||
"""Check Python version"""
|
||||
if sys.version_info < (3, 8):
|
||||
logger.error("Python 3.8 or higher is required")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Python version: {sys.version}")
|
||||
|
||||
|
||||
def check_gpu():
|
||||
"""Check GPU availability"""
|
||||
try:
|
||||
result = run_command("nvidia-smi", check=False)
|
||||
if result.returncode == 0:
|
||||
logger.info("NVIDIA GPU detected")
|
||||
return "cuda"
|
||||
else:
|
||||
logger.info("No NVIDIA GPU detected")
|
||||
return "cpu"
|
||||
except:
|
||||
logger.info("nvidia-smi not available")
|
||||
return "cpu"
|
||||
|
||||
|
||||
def install_pytorch(device_type="cuda"):
|
||||
"""Install PyTorch"""
|
||||
logger.info("Installing PyTorch...")
|
||||
|
||||
if device_type == "cuda":
|
||||
# Install CUDA version for RTX 5090
|
||||
pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
|
||||
else:
|
||||
# Install CPU version
|
||||
pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
||||
run_command(pip_command)
|
||||
|
||||
# Verify installation
|
||||
try:
|
||||
import torch
|
||||
logger.info(f"PyTorch version: {torch.__version__}")
|
||||
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
logger.info(f"CUDA version: {torch.version.cuda}")
|
||||
logger.info(f"GPU count: {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
||||
except ImportError:
|
||||
logger.error("Failed to import PyTorch")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_requirements():
|
||||
"""Install requirements"""
|
||||
logger.info("Installing requirements...")
|
||||
|
||||
requirements_file = Path(__file__).parent.parent / "requirements.txt"
|
||||
|
||||
if requirements_file.exists():
|
||||
run_command(f"pip install -r {requirements_file}")
|
||||
else:
|
||||
logger.warning("requirements.txt not found, installing core packages...")
|
||||
|
||||
# Core packages
|
||||
packages = [
|
||||
"transformers>=4.36.0",
|
||||
"datasets>=2.14.0",
|
||||
"tokenizers>=0.15.0",
|
||||
"numpy>=1.24.0",
|
||||
"scipy>=1.10.0",
|
||||
"scikit-learn>=1.3.0",
|
||||
"pandas>=2.0.0",
|
||||
"FlagEmbedding>=1.2.0",
|
||||
"faiss-gpu>=1.7.4",
|
||||
"rank-bm25>=0.2.2",
|
||||
"accelerate>=0.25.0",
|
||||
"tensorboard>=2.14.0",
|
||||
"tqdm>=4.66.0",
|
||||
"pyyaml>=6.0",
|
||||
"jsonlines>=3.1.0",
|
||||
"huggingface-hub>=0.19.0",
|
||||
"jieba>=0.42.1",
|
||||
"toml>=0.10.2",
|
||||
"requests>=2.28.0"
|
||||
]
|
||||
|
||||
for package in packages:
|
||||
run_command(f"pip install {package}")
|
||||
|
||||
|
||||
def setup_directories():
|
||||
"""Setup project directories"""
|
||||
logger.info("Setting up directories...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
directories = [
|
||||
"models",
|
||||
"models/bge-m3",
|
||||
"models/bge-reranker-base",
|
||||
"data/raw",
|
||||
"data/processed",
|
||||
"output",
|
||||
"logs",
|
||||
"cache/data"
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
dir_path = base_dir / directory
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created directory: {dir_path}")
|
||||
|
||||
|
||||
def download_models(download_models_flag=False):
|
||||
"""Download base models from HuggingFace or ModelScope based on config.toml"""
|
||||
if not download_models_flag:
|
||||
logger.info("Skipping model download. Use --download-models to download base models.")
|
||||
return
|
||||
|
||||
logger.info("Downloading base models...")
|
||||
base_dir = Path(__file__).parent.parent
|
||||
config = get_config()
|
||||
model_source = config.get('model_paths.source', 'huggingface')
|
||||
|
||||
if model_source == 'huggingface':
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
# Download BGE-M3
|
||||
logger.info("Downloading BGE-M3 from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-m3",
|
||||
local_dir=base_dir / "models/bge-m3",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
# Download BGE-Reranker
|
||||
logger.info("Downloading BGE-Reranker from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-reranker-base",
|
||||
local_dir=base_dir / "models/bge-reranker-base",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
logger.info("Models downloaded successfully from HuggingFace!")
|
||||
except ImportError:
|
||||
logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download models from HuggingFace: {e}")
|
||||
logger.info("You can download models manually or use the Hugging Face CLI:")
|
||||
logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3")
|
||||
logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base")
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
# Download BGE-M3
|
||||
logger.info("Downloading BGE-M3 from ModelScope...")
|
||||
ms_snapshot_download(
|
||||
model_id="damo/nlp_corom_sentence-embedding_chinese-base",
|
||||
cache_dir=str(base_dir / "models/bge-m3")
|
||||
)
|
||||
# Download BGE-Reranker
|
||||
logger.info("Downloading BGE-Reranker from ModelScope...")
|
||||
ms_snapshot_download(
|
||||
model_id="damo/nlp_corom_cross-encoder_chinese-base",
|
||||
cache_dir=str(base_dir / "models/bge-reranker-base")
|
||||
)
|
||||
logger.info("Models downloaded successfully from ModelScope!")
|
||||
except ImportError:
|
||||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download models from ModelScope: {e}")
|
||||
logger.info("You can download models manually from ModelScope website or CLI.")
|
||||
else:
|
||||
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||||
|
||||
|
||||
def create_sample_config():
|
||||
"""Create sample configuration files"""
|
||||
logger.info("Creating sample configuration...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
config_file = base_dir / "config.toml"
|
||||
|
||||
if not config_file.exists():
|
||||
# Create default config (this would match our config.toml artifact)
|
||||
sample_config = """# Configuration file for BGE fine-tuning project
|
||||
|
||||
[api]
|
||||
# OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.)
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "your-api-key-here" # Replace with your actual API key
|
||||
model = "deepseek-chat"
|
||||
timeout = 30
|
||||
|
||||
[translation]
|
||||
# Translation settings
|
||||
enable_api = true
|
||||
target_languages = ["en", "zh", "ja", "fr"]
|
||||
max_retries = 3
|
||||
fallback_to_simulation = true
|
||||
|
||||
[training]
|
||||
# Training defaults
|
||||
default_batch_size = 4
|
||||
default_learning_rate = 1e-5
|
||||
default_epochs = 3
|
||||
default_warmup_ratio = 0.1
|
||||
|
||||
[model_paths]
|
||||
# Default model paths
|
||||
bge_m3 = "./models/bge-m3"
|
||||
bge_reranker = "./models/bge-reranker-base"
|
||||
|
||||
[data]
|
||||
# Data processing settings
|
||||
max_query_length = 512
|
||||
max_passage_length = 8192
|
||||
min_passage_length = 10
|
||||
validation_split = 0.1
|
||||
"""
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
f.write(sample_config)
|
||||
|
||||
logger.info(f"Created sample config: {config_file}")
|
||||
logger.info("Please edit config.toml to add your API keys if needed.")
|
||||
|
||||
|
||||
def create_sample_data():
|
||||
"""Create sample data files"""
|
||||
logger.info("Creating sample data...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
sample_file = base_dir / "data/raw/sample_data.jsonl"
|
||||
|
||||
if not sample_file.exists():
|
||||
sample_data = [
|
||||
{
|
||||
"query": "什么是机器学习?",
|
||||
"pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"],
|
||||
"neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"]
|
||||
},
|
||||
{
|
||||
"query": "How does neural network work?",
|
||||
"pos": [
|
||||
"Neural networks are computing systems inspired by biological neural networks that constitute the brain."],
|
||||
"neg": ["Machine learning algorithms can be supervised or unsupervised.",
|
||||
"Data preprocessing is an important step in ML."]
|
||||
},
|
||||
{
|
||||
"query": "Python编程语言的特点",
|
||||
"pos": ["Python是一种高级编程语言,具有简洁的语法和强大的功能,广泛用于数据科学和AI开发。"],
|
||||
"neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"]
|
||||
}
|
||||
]
|
||||
|
||||
with open(sample_file, 'w', encoding='utf-8') as f:
|
||||
for item in sample_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
logger.info(f"Created sample data: {sample_file}")
|
||||
|
||||
|
||||
def verify_installation():
|
||||
"""Verify installation"""
|
||||
logger.info("Verifying installation...")
|
||||
|
||||
try:
|
||||
# Test imports
|
||||
import torch
|
||||
import transformers
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger.info("✓ Core packages imported successfully")
|
||||
|
||||
# Test GPU
|
||||
if torch.cuda.is_available():
|
||||
logger.info("✓ CUDA is available")
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn(10, 10).to(device)
|
||||
logger.info("✓ GPU tensor operations work")
|
||||
else:
|
||||
logger.info("✓ CPU-only setup")
|
||||
|
||||
# Test model loading (if models exist)
|
||||
base_dir = Path(__file__).parent.parent
|
||||
if (base_dir / "models/bge-m3").exists():
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3"))
|
||||
logger.info("✓ BGE-M3 model files are accessible")
|
||||
except Exception as e:
|
||||
logger.warning(f"BGE-M3 model check failed: {e}")
|
||||
|
||||
logger.info("Installation verification completed!")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Import error: {e}")
|
||||
logger.error("Installation may be incomplete")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def print_usage_info():
|
||||
"""Print usage information"""
|
||||
print("\n" + "=" * 60)
|
||||
print("BGE FINE-TUNING SETUP COMPLETED!")
|
||||
print("=" * 60)
|
||||
print("\nQuick Start:")
|
||||
print("1. Edit config.toml to add your API keys (optional)")
|
||||
print("2. Prepare your training data in JSONL format")
|
||||
print("3. Run training:")
|
||||
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
|
||||
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
|
||||
print("4. Evaluate models:")
|
||||
print(" python scripts/evaluate.py --retriever_model output/model --eval_data data/raw/test.jsonl")
|
||||
print("\nSample data created in: data/raw/sample_data.jsonl")
|
||||
print("Configuration file: config.toml")
|
||||
print("\nFor more information, see the README.md file.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment")
|
||||
parser.add_argument("--download-models", action="store_true",
|
||||
help="Download base models from Hugging Face")
|
||||
parser.add_argument("--cpu-only", action="store_true",
|
||||
help="Install CPU-only version of PyTorch")
|
||||
parser.add_argument("--skip-pytorch", action="store_true",
|
||||
help="Skip PyTorch installation")
|
||||
parser.add_argument("--skip-requirements", action="store_true",
|
||||
help="Skip requirements installation")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Starting BGE fine-tuning environment setup...")
|
||||
|
||||
# Check Python version
|
||||
check_python_version()
|
||||
|
||||
# Detect device type
|
||||
if args.cpu_only:
|
||||
device_type = "cpu"
|
||||
else:
|
||||
device_type = check_gpu()
|
||||
|
||||
# Install PyTorch
|
||||
if not args.skip_pytorch:
|
||||
install_pytorch(device_type)
|
||||
|
||||
# Install requirements
|
||||
if not args.skip_requirements:
|
||||
install_requirements()
|
||||
|
||||
# Setup directories
|
||||
setup_directories()
|
||||
|
||||
# Download models
|
||||
download_models(args.download_models)
|
||||
|
||||
# Create sample files
|
||||
create_sample_config()
|
||||
create_sample_data()
|
||||
|
||||
# Verify installation
|
||||
verify_installation()
|
||||
|
||||
# Print usage info
|
||||
print_usage_info()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
91
scripts/test_converted_dataset.py
Normal file
91
scripts/test_converted_dataset.py
Normal file
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify the converted dataset can be loaded by BGE training system
|
||||
"""
|
||||
|
||||
from data.dataset import BGEDataset
|
||||
from transformers import AutoTokenizer
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_dataset_loading(dataset_path: str):
|
||||
"""Test that the dataset can be loaded properly"""
|
||||
|
||||
# Check if dataset file exists
|
||||
if not os.path.exists(dataset_path):
|
||||
logger.error(f"Dataset file not found: {dataset_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load tokenizer
|
||||
logger.info("Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('models/bge-m3')
|
||||
|
||||
# Test loading the converted dataset
|
||||
logger.info("Loading dataset...")
|
||||
dataset = BGEDataset(
|
||||
data_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=512,
|
||||
format_type='triplets'
|
||||
)
|
||||
|
||||
logger.info(f"✅ Dataset loaded successfully!")
|
||||
logger.info(f"📊 Total samples: {len(dataset)}")
|
||||
logger.info(f"🔍 Detected format: {dataset.detected_format}")
|
||||
|
||||
# Test getting a sample
|
||||
logger.info("Testing sample retrieval...")
|
||||
sample = dataset[0]
|
||||
|
||||
logger.info(f"🎯 Sample keys: {list(sample.keys())}")
|
||||
logger.info(f"🔤 Query shape: {sample['query_input_ids'].shape}")
|
||||
logger.info(f"📄 Passages shape: {sample['passage_input_ids'].shape}")
|
||||
logger.info(f"🏷️ Labels shape: {sample['labels'].shape}")
|
||||
logger.info(f"📊 Scores shape: {sample['scores'].shape}")
|
||||
|
||||
# Show sample content
|
||||
logger.info(f"🔍 Query text: {sample['query_text'][:100]}...")
|
||||
logger.info(f"📝 First passage: {sample['passage_texts'][0][:100]}...")
|
||||
logger.info(f"🏷️ Labels: {sample['labels']}")
|
||||
logger.info(f"📊 Scores: {sample['scores']}")
|
||||
|
||||
# Test a few more samples
|
||||
logger.info("Testing multiple samples...")
|
||||
for i in [1, 10, 100]:
|
||||
if i < len(dataset):
|
||||
sample_i = dataset[i]
|
||||
logger.info(f"Sample {i}: Query length={len(sample_i['query_text'])}, Passages={len([p for p in sample_i['passage_texts'] if p])}")
|
||||
|
||||
logger.info("✅ All tests passed! Dataset is ready for BGE-M3 training.")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Dataset loading failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
dataset_path = "data/datasets/三国演义/m3_train_converted.jsonl"
|
||||
|
||||
logger.info("🧪 Testing converted BGE-M3 dataset...")
|
||||
logger.info(f"📁 Dataset path: {dataset_path}")
|
||||
|
||||
success = test_dataset_loading(dataset_path)
|
||||
|
||||
if success:
|
||||
logger.info("🎉 Dataset is ready for training!")
|
||||
logger.info("💡 You can now use this dataset with:")
|
||||
logger.info(" python scripts/train_m3.py --data_path data/datasets/三国演义/m3_train_converted.jsonl")
|
||||
else:
|
||||
logger.error("💥 Dataset testing failed!")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
497
scripts/train_joint.py
Normal file
497
scripts/train_joint.py
Normal file
@ -0,0 +1,497 @@
|
||||
"""
|
||||
Joint training of BGE-M3 and BGE-Reranker using RocketQAv2 approach
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
# Set tokenizer parallelism to avoid warnings
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
import torch
|
||||
from transformers import AutoTokenizer, set_seed
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset, JointDataset
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from training.joint_trainer import JointTrainer
|
||||
from utils.logging import setup_logging, get_logger
|
||||
from utils.distributed import setup_distributed, is_main_process, set_random_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Joint training of BGE-M3 and BGE-Reranker")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, default="BAAI/bge-m3",
|
||||
help="Path to retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to reranker model")
|
||||
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
||||
help="Directory to cache downloaded models")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--train_data", type=str, required=True,
|
||||
help="Path to training data file")
|
||||
parser.add_argument("--eval_data", type=str, default=None,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./output/joint-training",
|
||||
help="Directory to save the models")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
|
||||
help="Batch size per GPU/CPU for training")
|
||||
parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
|
||||
help="Batch size per GPU/CPU for evaluation")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
|
||||
help="Number of updates steps to accumulate")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-5,
|
||||
help="Initial learning rate for retriever")
|
||||
parser.add_argument("--reranker_learning_rate", type=float, default=6e-5,
|
||||
help="Initial learning rate for reranker")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
||||
help="Ratio of warmup steps")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01,
|
||||
help="Weight decay")
|
||||
|
||||
# Joint training specific
|
||||
parser.add_argument("--joint_loss_weight", type=float, default=0.3,
|
||||
help="Weight for joint distillation loss")
|
||||
parser.add_argument("--update_retriever_steps", type=int, default=100,
|
||||
help="Update retriever every N steps")
|
||||
parser.add_argument("--use_dynamic_distillation", action="store_true",
|
||||
help="Use dynamic listwise distillation")
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument("--retriever_train_group_size", type=int, default=8,
|
||||
help="Number of passages per query for retriever")
|
||||
parser.add_argument("--reranker_train_group_size", type=int, default=16,
|
||||
help="Number of passages per query for reranker")
|
||||
parser.add_argument("--temperature", type=float, default=0.02,
|
||||
help="Temperature for InfoNCE loss")
|
||||
|
||||
# BGE-M3 specific
|
||||
parser.add_argument("--use_dense", action="store_true", default=True,
|
||||
help="Use dense embeddings")
|
||||
parser.add_argument("--use_sparse", action="store_true",
|
||||
help="Use sparse embeddings")
|
||||
parser.add_argument("--use_colbert", action="store_true",
|
||||
help="Use ColBERT embeddings")
|
||||
parser.add_argument("--query_max_length", type=int, default=512,
|
||||
help="Maximum query length")
|
||||
parser.add_argument("--passage_max_length", type=int, default=8192,
|
||||
help="Maximum passage length")
|
||||
|
||||
# Reranker specific
|
||||
parser.add_argument("--reranker_max_length", type=int, default=512,
|
||||
help="Maximum sequence length for reranker")
|
||||
|
||||
# Performance
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Use mixed precision training")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="Use gradient checkpointing")
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
||||
help="Number of workers for data loading")
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--save_steps", type=int, default=1000,
|
||||
help="Save checkpoint every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=1000,
|
||||
help="Evaluate every N steps")
|
||||
parser.add_argument("--logging_steps", type=int, default=100,
|
||||
help="Log every N steps")
|
||||
parser.add_argument("--save_total_limit", type=int, default=3,
|
||||
help="Maximum number of checkpoints to keep")
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint to resume from")
|
||||
|
||||
# Distributed training
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="Local rank for distributed training")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true",
|
||||
help="Overwrite the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
|
||||
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting joint training of BGE-M3 and BGE-Reranker")
|
||||
logger.info(f"Arguments: {args}")
|
||||
|
||||
# Set random seed
|
||||
set_random_seed(args.seed, deterministic=False)
|
||||
|
||||
# Setup distributed training if needed
|
||||
device = None
|
||||
if args.local_rank != -1:
|
||||
setup_distributed()
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu
|
||||
torch_npu.npu.set_device(args.local_rank)
|
||||
device = torch.device("npu", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Distributed training enabled: using CPU")
|
||||
else:
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu
|
||||
device = torch.device("npu")
|
||||
logger.info("Using Ascend NPU")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU")
|
||||
|
||||
# Create configurations
|
||||
retriever_config = BGEM3Config(
|
||||
model_name_or_path=args.retriever_model,
|
||||
cache_dir=args.cache_dir,
|
||||
train_data_path=args.train_data,
|
||||
eval_data_path=args.eval_data,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
output_dir=os.path.join(args.output_dir, "retriever"),
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
weight_decay=args.weight_decay,
|
||||
train_group_size=args.retriever_train_group_size,
|
||||
temperature=args.temperature,
|
||||
use_dense=args.use_dense,
|
||||
use_sparse=args.use_sparse,
|
||||
use_colbert=args.use_colbert,
|
||||
fp16=args.fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
dataloader_num_workers=args.dataloader_num_workers,
|
||||
save_steps=args.save_steps,
|
||||
eval_steps=args.eval_steps,
|
||||
logging_steps=args.logging_steps,
|
||||
save_total_limit=args.save_total_limit,
|
||||
seed=args.seed,
|
||||
local_rank=args.local_rank,
|
||||
device=str(device)
|
||||
)
|
||||
|
||||
reranker_config = BGERerankerConfig(
|
||||
model_name_or_path=args.reranker_model,
|
||||
cache_dir=args.cache_dir,
|
||||
train_data_path=args.train_data,
|
||||
eval_data_path=args.eval_data,
|
||||
max_seq_length=args.reranker_max_length,
|
||||
output_dir=os.path.join(args.output_dir, "reranker"),
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.reranker_learning_rate,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
weight_decay=args.weight_decay,
|
||||
train_group_size=args.reranker_train_group_size,
|
||||
joint_training=True, # type: ignore[call-arg]
|
||||
joint_loss_weight=args.joint_loss_weight, # type: ignore[call-arg]
|
||||
use_dynamic_listwise_distillation=args.use_dynamic_distillation, # type: ignore[call-arg]
|
||||
update_retriever_steps=args.update_retriever_steps, # type: ignore[call-arg]
|
||||
fp16=args.fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
dataloader_num_workers=args.dataloader_num_workers,
|
||||
save_steps=args.save_steps,
|
||||
eval_steps=args.eval_steps,
|
||||
logging_steps=args.logging_steps,
|
||||
save_total_limit=args.save_total_limit,
|
||||
seed=args.seed,
|
||||
local_rank=args.local_rank,
|
||||
device=str(device)
|
||||
)
|
||||
|
||||
# Create joint config (use retriever config as base)
|
||||
joint_config = retriever_config
|
||||
joint_config.output_dir = args.output_dir
|
||||
joint_config.joint_loss_weight = args.joint_loss_weight # type: ignore[attr-defined]
|
||||
joint_config.use_dynamic_listwise_distillation = args.use_dynamic_distillation # type: ignore[attr-defined]
|
||||
joint_config.update_retriever_steps = args.update_retriever_steps # type: ignore[attr-defined]
|
||||
|
||||
# Save configurations
|
||||
if is_main_process():
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
joint_config.save(os.path.join(args.output_dir, "joint_config.json"))
|
||||
retriever_config.save(os.path.join(args.output_dir, "retriever_config.json"))
|
||||
reranker_config.save(os.path.join(args.output_dir, "reranker_config.json"))
|
||||
|
||||
# Load tokenizers
|
||||
logger.info("Loading tokenizers...")
|
||||
retriever_tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.retriever_model,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
reranker_tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.reranker_model,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
# Prepare data
|
||||
logger.info("Preparing data...")
|
||||
preprocessor = DataPreprocessor(
|
||||
tokenizer=retriever_tokenizer, # Use retriever tokenizer for preprocessing
|
||||
max_query_length=args.query_max_length,
|
||||
max_passage_length=args.passage_max_length
|
||||
)
|
||||
|
||||
# Preprocess training data
|
||||
train_path, val_path = preprocessor.preprocess_file(
|
||||
input_path=args.train_data,
|
||||
output_path=os.path.join(args.output_dir, "processed_train.jsonl"),
|
||||
file_format=args.data_format,
|
||||
validation_split=0.1 if args.eval_data is None else 0.0
|
||||
)
|
||||
|
||||
# Use provided eval data or split validation
|
||||
eval_path = args.eval_data if args.eval_data else val_path
|
||||
|
||||
# Create datasets
|
||||
logger.info("Creating datasets...")
|
||||
|
||||
# Retriever dataset
|
||||
retriever_train_dataset = BGEM3Dataset(
|
||||
data_path=train_path,
|
||||
tokenizer=retriever_tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.retriever_train_group_size,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
# Reranker dataset
|
||||
reranker_train_dataset = BGERerankerDataset(
|
||||
data_path=train_path,
|
||||
tokenizer=reranker_tokenizer,
|
||||
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
|
||||
train_group_size=args.reranker_train_group_size,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
# Joint dataset
|
||||
joint_train_dataset = JointDataset( # type: ignore[call-arg]
|
||||
retriever_dataset=retriever_train_dataset, # type: ignore[call-arg]
|
||||
reranker_dataset=reranker_train_dataset # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Evaluation datasets
|
||||
joint_eval_dataset = None
|
||||
if eval_path:
|
||||
retriever_eval_dataset = BGEM3Dataset(
|
||||
data_path=eval_path,
|
||||
tokenizer=retriever_tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.retriever_train_group_size,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
reranker_eval_dataset = BGERerankerDataset(
|
||||
data_path=eval_path,
|
||||
tokenizer=reranker_tokenizer,
|
||||
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
|
||||
train_group_size=args.reranker_train_group_size,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
joint_eval_dataset = JointDataset( # type: ignore[call-arg]
|
||||
retriever_dataset=retriever_eval_dataset, # type: ignore[call-arg]
|
||||
reranker_dataset=reranker_eval_dataset # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Initialize models
|
||||
logger.info("Loading models...")
|
||||
retriever = BGEM3Model(
|
||||
model_name_or_path=args.retriever_model,
|
||||
use_dense=args.use_dense,
|
||||
use_sparse=args.use_sparse,
|
||||
use_colbert=args.use_colbert,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
reranker = BGERerankerModel(
|
||||
model_name_or_path=args.reranker_model,
|
||||
cache_dir=args.cache_dir,
|
||||
use_fp16=args.fp16
|
||||
)
|
||||
from utils.config_loader import get_config
|
||||
logger.info("Retriever model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
|
||||
logger.info("Reranker model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
|
||||
|
||||
# Initialize optimizers
|
||||
from torch.optim import AdamW
|
||||
|
||||
retriever_optimizer = AdamW(
|
||||
retriever.parameters(),
|
||||
lr=args.learning_rate,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
reranker_optimizer = AdamW(
|
||||
reranker.parameters(),
|
||||
lr=args.reranker_learning_rate,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
# Initialize joint trainer
|
||||
logger.info("Initializing joint trainer...")
|
||||
trainer = JointTrainer(
|
||||
config=joint_config,
|
||||
retriever=retriever,
|
||||
reranker=reranker,
|
||||
retriever_optimizer=retriever_optimizer,
|
||||
reranker_optimizer=reranker_optimizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.distributed import create_distributed_dataloader
|
||||
|
||||
if args.local_rank != -1:
|
||||
train_dataloader = create_distributed_dataloader(
|
||||
joint_train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if joint_eval_dataset:
|
||||
eval_dataloader = create_distributed_dataloader(
|
||||
joint_eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
seed=args.seed
|
||||
)
|
||||
else:
|
||||
train_dataloader = DataLoader(
|
||||
joint_train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if joint_eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
joint_eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume_from_checkpoint:
|
||||
trainer.load_checkpoint(args.resume_from_checkpoint)
|
||||
|
||||
# Train
|
||||
logger.info("Starting joint training...")
|
||||
try:
|
||||
trainer.train(train_dataloader, eval_dataloader, args.num_train_epochs)
|
||||
|
||||
# Save final models
|
||||
if is_main_process():
|
||||
logger.info("Saving final models...")
|
||||
final_dir = os.path.join(args.output_dir, "final_models")
|
||||
trainer.save_models(final_dir)
|
||||
|
||||
# Save individual models
|
||||
retriever_path = os.path.join(args.output_dir, "final_retriever")
|
||||
reranker_path = os.path.join(args.output_dir, "final_reranker")
|
||||
|
||||
os.makedirs(retriever_path, exist_ok=True)
|
||||
os.makedirs(reranker_path, exist_ok=True)
|
||||
|
||||
retriever.save_pretrained(retriever_path)
|
||||
retriever_tokenizer.save_pretrained(retriever_path)
|
||||
|
||||
reranker.model.save_pretrained(reranker_path)
|
||||
reranker_tokenizer.save_pretrained(reranker_path)
|
||||
|
||||
# Save training summary
|
||||
summary = {
|
||||
"training_args": vars(args),
|
||||
"retriever_config": retriever_config.to_dict(),
|
||||
"reranker_config": reranker_config.to_dict(),
|
||||
"joint_config": joint_config.to_dict(),
|
||||
"final_metrics": getattr(trainer, 'best_metric', None),
|
||||
"total_steps": getattr(trainer, 'global_step', 0),
|
||||
"training_completed": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
logger.info(f"Joint training completed! Models saved to {args.output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Joint training failed with error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
614
scripts/train_m3.py
Normal file
614
scripts/train_m3.py
Normal file
@ -0,0 +1,614 @@
|
||||
"""
|
||||
Enhanced BGE-M3 embedding model training script with integrated dataset pipeline
|
||||
Supports both legacy nested format and new embedding schema with automatic detection
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
# Set tokenizer parallelism to avoid warnings
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
import torch
|
||||
from transformers import AutoTokenizer, set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from transformers.optimization import get_linear_schedule_with_warmup
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
try:
|
||||
from config.m3_config import BGEM3Config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset, create_embedding_dataset, collate_embedding_batch, migrate_nested_to_flat
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from data.augmentation import QueryAugmenter, HardNegativeAugmenter
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from evaluation.evaluator import RetrievalEvaluator
|
||||
from utils.logging import setup_logging, get_logger
|
||||
from utils.distributed import setup_distributed, is_main_process, set_random_seed
|
||||
from utils.config_loader import get_config, load_config_for_training, resolve_model_path
|
||||
except ImportError as e:
|
||||
print(f"Import error: {e}")
|
||||
print("Please ensure you're running from the project root directory")
|
||||
sys.exit(1)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Enhanced BGE-M3 Embedding Model Training")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model_name_or_path", type=str, default=None,
|
||||
help="Path to pretrained model or model identifier")
|
||||
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
||||
help="Directory to cache downloaded models")
|
||||
parser.add_argument("--config_path", type=str, default="./config.toml",
|
||||
help="Path to configuration file")
|
||||
|
||||
# Data arguments - support both legacy and new schemas
|
||||
parser.add_argument("--train_data", type=str, required=True,
|
||||
help="Path to training data file (supports both embedding and legacy nested formats)")
|
||||
parser.add_argument("--eval_data", type=str, default=None,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "embedding", "legacy_nested", "jsonl", "json", "csv", "tsv"],
|
||||
help="Format of input data (auto-detects by default)")
|
||||
parser.add_argument("--legacy_data", type=str, default=None,
|
||||
help="Path to additional legacy nested data (for backward compatibility)")
|
||||
parser.add_argument("--migrate_legacy", action="store_true",
|
||||
help="Migrate legacy nested format to flat format")
|
||||
|
||||
# Sequence length arguments
|
||||
parser.add_argument("--max_seq_length", type=int, default=None,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--query_max_length", type=int, default=None,
|
||||
help="Maximum query length")
|
||||
parser.add_argument("--passage_max_length", type=int, default=None,
|
||||
help="Maximum passage length")
|
||||
parser.add_argument("--max_length", type=int, default=8192,
|
||||
help="Maximum total input length (multi-granularity)")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./output/bge-m3-enhanced",
|
||||
help="Directory to save the model")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=None,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=None,
|
||||
help="Batch size per GPU/CPU for training")
|
||||
parser.add_argument("--per_device_eval_batch_size", type=int, default=None,
|
||||
help="Batch size per GPU/CPU for evaluation")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass")
|
||||
parser.add_argument("--learning_rate", type=float, default=None,
|
||||
help="Initial learning rate")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=None,
|
||||
help="Warmup steps as a ratio of total steps")
|
||||
parser.add_argument("--weight_decay", type=float, default=None,
|
||||
help="Weight decay")
|
||||
|
||||
# Model specific arguments - enhanced features
|
||||
parser.add_argument("--train_group_size", type=int, default=8,
|
||||
help="Number of passages per query in training")
|
||||
parser.add_argument("--temperature", type=float, default=0.1,
|
||||
help="Temperature for InfoNCE loss")
|
||||
parser.add_argument("--use_hard_negatives", action="store_true",
|
||||
help="Enable hard negative mining")
|
||||
parser.add_argument("--use_self_distill", action="store_true",
|
||||
help="Enable self-knowledge distillation")
|
||||
parser.add_argument("--use_score_weighted_sampling", action="store_true", default=True,
|
||||
help="Enable score-weighted positive sampling")
|
||||
parser.add_argument("--query_instruction", type=str, default="",
|
||||
help="Query instruction prefix")
|
||||
|
||||
# Enhanced sampling features
|
||||
parser.add_argument("--multiple_positives", action="store_true", default=True,
|
||||
help="Use multiple positives per batch for richer contrastive signals")
|
||||
parser.add_argument("--hard_negative_ratio", type=float, default=0.7,
|
||||
help="Ratio of hard to random negatives")
|
||||
parser.add_argument("--difficulty_threshold", type=float, default=0.1,
|
||||
help="Minimum difficulty for hard negatives")
|
||||
|
||||
# Optimization and performance
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Use mixed precision training")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="Use gradient checkpointing to save memory")
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
||||
help="Number of workers for data loading")
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--save_steps", type=int, default=500,
|
||||
help="Save checkpoint every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=500,
|
||||
help="Evaluate every N steps")
|
||||
parser.add_argument("--logging_steps", type=int, default=100,
|
||||
help="Log every N steps")
|
||||
parser.add_argument("--save_total_limit", type=int, default=3,
|
||||
help="Maximum number of checkpoints to keep")
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint to resume from")
|
||||
|
||||
# Distributed training
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="Local rank for distributed training")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true",
|
||||
help="Overwrite the output directory")
|
||||
parser.add_argument("--log_level", type=str, default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging level")
|
||||
|
||||
# Backward compatibility flags
|
||||
parser.add_argument("--legacy_support", action="store_true", default=True,
|
||||
help="Enable legacy nested format support")
|
||||
parser.add_argument("--use_new_dataset_api", action="store_true", default=True,
|
||||
help="Use new integrated dataset API with format detection")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
|
||||
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def create_config_from_args(args):
|
||||
"""Create configuration from command line arguments and config file
|
||||
|
||||
Parameter Precedence (highest to lowest):
|
||||
1. Command-line arguments (CLI)
|
||||
2. Model-specific config ([m3], [reranker])
|
||||
3. [hardware] section in config.toml
|
||||
4. [training] section in config.toml
|
||||
5. Hardcoded defaults in code
|
||||
"""
|
||||
# Load base configuration from TOML file
|
||||
config_dict = load_config_for_training(
|
||||
config_path=args.config_path,
|
||||
override_params={}
|
||||
)
|
||||
|
||||
# Get the centralized config instance
|
||||
central_config = get_config()
|
||||
|
||||
# Create override parameters from command line arguments
|
||||
override_params = {}
|
||||
|
||||
# Model configuration - use new resolve_model_path function
|
||||
resolved_model_path, resolved_cache_dir = resolve_model_path(
|
||||
model_key='bge_m3',
|
||||
config_instance=central_config,
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
args.model_name_or_path = resolved_model_path
|
||||
args.cache_dir = resolved_cache_dir
|
||||
|
||||
# Data configuration
|
||||
data_config = config_dict.get('data', {})
|
||||
if args.max_seq_length is None:
|
||||
args.max_seq_length = data_config.get('max_seq_length', 512)
|
||||
if args.query_max_length is None:
|
||||
args.query_max_length = data_config.get('max_query_length', 64)
|
||||
if args.passage_max_length is None:
|
||||
args.passage_max_length = data_config.get('max_passage_length', 512)
|
||||
|
||||
# Get validation split from config
|
||||
validation_split = data_config.get('validation_split', 0.1)
|
||||
|
||||
# Training configuration
|
||||
training_config = config_dict.get('training', {})
|
||||
if args.num_train_epochs is None:
|
||||
args.num_train_epochs = training_config.get('default_num_epochs', 3)
|
||||
if args.per_device_train_batch_size is None:
|
||||
args.per_device_train_batch_size = training_config.get('default_batch_size', 4)
|
||||
if args.per_device_eval_batch_size is None:
|
||||
args.per_device_eval_batch_size = args.per_device_train_batch_size * 2
|
||||
if args.learning_rate is None:
|
||||
args.learning_rate = training_config.get('default_learning_rate', 1e-5)
|
||||
if args.warmup_ratio is None:
|
||||
args.warmup_ratio = training_config.get('default_warmup_ratio', 0.1)
|
||||
if args.weight_decay is None:
|
||||
args.weight_decay = training_config.get('default_weight_decay', 0.01)
|
||||
|
||||
# Model-specific (M3) configuration - highest precedence
|
||||
m3_config = config_dict.get('m3', {})
|
||||
train_group_size = m3_config.get('train_group_size', args.train_group_size)
|
||||
temperature = m3_config.get('temperature', args.temperature)
|
||||
use_hard_negatives = m3_config.get('use_hard_negatives', args.use_hard_negatives)
|
||||
use_self_distill = m3_config.get('use_self_distill', args.use_self_distill)
|
||||
query_instruction_for_retrieval = m3_config.get('query_instruction_for_retrieval', args.query_instruction)
|
||||
|
||||
# Hardware configuration
|
||||
hardware_config = config_dict.get('hardware', {})
|
||||
if args.per_device_train_batch_size is None:
|
||||
args.per_device_train_batch_size = hardware_config.get('per_device_train_batch_size', args.per_device_train_batch_size)
|
||||
if args.per_device_eval_batch_size is None:
|
||||
args.per_device_eval_batch_size = hardware_config.get('per_device_eval_batch_size', args.per_device_eval_batch_size)
|
||||
|
||||
# Create BGE-M3 configuration
|
||||
config = BGEM3Config(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
train_data_path=args.train_data,
|
||||
eval_data_path=args.eval_data,
|
||||
max_seq_length=args.max_seq_length,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
max_length=args.max_length,
|
||||
output_dir=args.output_dir,
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
weight_decay=args.weight_decay,
|
||||
train_group_size=train_group_size,
|
||||
temperature=temperature,
|
||||
use_hard_negatives=use_hard_negatives,
|
||||
use_self_distill=use_self_distill,
|
||||
query_instruction_for_retrieval=query_instruction_for_retrieval,
|
||||
fp16=args.fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
dataloader_num_workers=args.dataloader_num_workers,
|
||||
save_steps=args.save_steps,
|
||||
eval_steps=args.eval_steps,
|
||||
logging_steps=args.logging_steps,
|
||||
save_total_limit=args.save_total_limit,
|
||||
seed=args.seed,
|
||||
local_rank=args.local_rank,
|
||||
overwrite_output_dir=args.overwrite_output_dir
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(args):
|
||||
"""Setup BGE-M3 model and tokenizer"""
|
||||
logger.info(f"Loading BGE-M3 model from {args.model_name_or_path}")
|
||||
|
||||
# Detect offline mode if the path is local or HF_HUB_OFFLINE is set
|
||||
local_files_only = os.path.isdir(args.model_name_or_path) or bool(os.environ.get("HF_HUB_OFFLINE"))
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=True,
|
||||
local_files_only=local_files_only
|
||||
)
|
||||
|
||||
# Load model
|
||||
try:
|
||||
model = BGEM3Model(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
logger.info("Model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def setup_datasets(args, tokenizer):
|
||||
"""Setup datasets with enhanced format detection and migration support"""
|
||||
logger.info("Setting up datasets with enhanced format detection...")
|
||||
|
||||
# Handle legacy data migration if needed
|
||||
if args.legacy_data and args.migrate_legacy:
|
||||
logger.info(f"Migrating legacy data from {args.legacy_data}")
|
||||
legacy_base = os.path.splitext(args.legacy_data)[0]
|
||||
migrated_file = f"{legacy_base}_migrated_embedding.jsonl"
|
||||
|
||||
# For embedding, we don't need to migrate to flat pairs, but we can optimize the format
|
||||
# Copy the legacy data to work with if needed
|
||||
import shutil
|
||||
shutil.copy2(args.legacy_data, migrated_file)
|
||||
logger.info(f"Legacy data prepared for training: {migrated_file}")
|
||||
|
||||
# Setup training dataset
|
||||
if args.use_new_dataset_api:
|
||||
# Use new integrated dataset API with format detection
|
||||
train_dataset = create_embedding_dataset(
|
||||
data_path=args.train_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(train_dataset)} training examples from {args.train_data}")
|
||||
logger.info(f"Detected format: {train_dataset.detected_format}")
|
||||
|
||||
# Add legacy data if provided
|
||||
if args.legacy_data and args.legacy_support:
|
||||
logger.info(f"Loading additional legacy data from {args.legacy_data}")
|
||||
legacy_dataset = create_embedding_dataset(
|
||||
data_path=args.legacy_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
logger.info(f"Loaded {len(legacy_dataset)} legacy examples")
|
||||
|
||||
# Combine datasets
|
||||
train_dataset.data.extend(legacy_dataset.data)
|
||||
logger.info(f"Total training examples: {len(train_dataset)}")
|
||||
else:
|
||||
# Use original dataset API for backward compatibility
|
||||
train_dataset = BGEM3Dataset(
|
||||
data_path=args.train_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.train_group_size,
|
||||
is_train=True
|
||||
)
|
||||
logger.info(f"Loaded {len(train_dataset)} training examples (legacy API)")
|
||||
|
||||
# Setup evaluation dataset
|
||||
eval_dataset = None
|
||||
if args.eval_data:
|
||||
if args.use_new_dataset_api:
|
||||
eval_dataset = create_embedding_dataset(
|
||||
data_path=args.eval_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=False
|
||||
)
|
||||
else:
|
||||
eval_dataset = BGEM3Dataset(
|
||||
data_path=args.eval_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
train_group_size=args.train_group_size,
|
||||
is_train=False
|
||||
)
|
||||
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_data_loaders(args, train_dataset, eval_dataset):
|
||||
"""Setup data loaders with appropriate collate functions"""
|
||||
|
||||
# 4. Set up data loaders with appropriate collate function
|
||||
if hasattr(train_dataset, 'format') and train_dataset.format == 'triplets':
|
||||
# This might be a generic BGE dataset in triplets format
|
||||
collate_fn = collate_embedding_batch
|
||||
logger.info("Using embedding collate function for triplets format")
|
||||
else:
|
||||
collate_fn = collate_embedding_batch
|
||||
logger.info("Using standard embedding collate function")
|
||||
|
||||
# FastM3Dataset doesn't work well with multiprocessing due to pickle loading
|
||||
workers = args.dataloader_num_workers
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
# Evaluation data loader
|
||||
eval_dataloader = None
|
||||
if eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# Create configuration
|
||||
config = create_config_from_args(args)
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=getattr(logging, args.log_level),
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("🚀 Starting Enhanced BGE-M3 Fine-tuning")
|
||||
logger.info(f"Arguments: {args}")
|
||||
logger.info(f"Configuration: {config.to_dict()}")
|
||||
|
||||
# Set random seed
|
||||
set_random_seed(args.seed, deterministic=False)
|
||||
|
||||
# Setup distributed training if needed
|
||||
device = None
|
||||
if args.local_rank != -1:
|
||||
setup_distributed()
|
||||
# Device selection from config
|
||||
from utils.config_loader import get_config
|
||||
config_instance = get_config()
|
||||
device_type = config_instance.get('ascend.device_type', 'cuda')
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu # type: ignore
|
||||
torch_npu.npu.set_device(args.local_rank)
|
||||
device = torch.device("npu", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Distributed training enabled: using CPU")
|
||||
else:
|
||||
# Non-distributed device selection
|
||||
from utils.config_loader import get_config
|
||||
config_instance = get_config()
|
||||
device_type = config_instance.get('ascend.device_type', 'cuda')
|
||||
if device_type == 'npu':
|
||||
try:
|
||||
import torch_npu # type: ignore
|
||||
device = torch.device("npu")
|
||||
logger.info("Using Ascend NPU")
|
||||
except ImportError:
|
||||
logger.warning("torch_npu not installed, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU")
|
||||
|
||||
# Update config with device
|
||||
config.device = str(device)
|
||||
|
||||
# Save configuration
|
||||
if is_main_process():
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
config.save(os.path.join(args.output_dir, "config.json"))
|
||||
|
||||
# Setup model and tokenizer
|
||||
model, tokenizer = setup_model_and_tokenizer(args)
|
||||
|
||||
# Setup datasets with enhanced format detection
|
||||
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
|
||||
|
||||
# Setup data loaders
|
||||
train_dataloader, eval_dataloader = setup_data_loaders(args, train_dataset, eval_dataset)
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing enhanced trainer...")
|
||||
try:
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
logger.info("Trainer initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize trainer: {e}")
|
||||
raise
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume_from_checkpoint:
|
||||
logger.info(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
|
||||
try:
|
||||
trainer.load_checkpoint(args.resume_from_checkpoint)
|
||||
logger.info("Checkpoint loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
raise
|
||||
|
||||
# Train
|
||||
logger.info("🎯 Starting enhanced training...")
|
||||
try:
|
||||
# Handle type compatibility for trainer
|
||||
from typing import cast
|
||||
trainer_train_dataset = cast(BGEM3Dataset, train_dataset) if train_dataset else None
|
||||
trainer_eval_dataset = cast(BGEM3Dataset, eval_dataset) if eval_dataset else None
|
||||
|
||||
if not trainer_train_dataset:
|
||||
raise ValueError("Training dataset is required but not provided")
|
||||
|
||||
trainer.train(
|
||||
train_dataset=trainer_train_dataset,
|
||||
eval_dataset=trainer_eval_dataset,
|
||||
resume_from_checkpoint=args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
# Save final model
|
||||
if is_main_process():
|
||||
logger.info("Saving final model...")
|
||||
final_path = os.path.join(args.output_dir, "final_model")
|
||||
os.makedirs(final_path, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model.save_pretrained(final_path)
|
||||
tokenizer.save_pretrained(final_path)
|
||||
|
||||
# Save training summary
|
||||
summary = {
|
||||
"training_args": vars(args),
|
||||
"model_config": config.to_dict(),
|
||||
"final_metrics": getattr(trainer, 'best_metric', None),
|
||||
"total_steps": getattr(trainer, 'global_step', 0),
|
||||
"training_completed": datetime.now().isoformat(),
|
||||
"device_used": str(device),
|
||||
"dataset_format": getattr(train_dataset, 'detected_format', 'unknown'),
|
||||
"enhanced_features": {
|
||||
"use_new_dataset_api": args.use_new_dataset_api,
|
||||
"use_hard_negatives": args.use_hard_negatives,
|
||||
"use_self_distill": args.use_self_distill,
|
||||
"use_score_weighted_sampling": args.use_score_weighted_sampling,
|
||||
"multiple_positives": args.multiple_positives,
|
||||
"legacy_support": args.legacy_support
|
||||
}
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
logger.info(f"✅ Enhanced training completed! Model saved to {final_path}")
|
||||
|
||||
# Show training summary
|
||||
logger.info("📊 Training Summary:")
|
||||
logger.info(f" Total examples: {len(train_dataset)}")
|
||||
if hasattr(train_dataset, 'detected_format'):
|
||||
logger.info(f" Format detected: {train_dataset.detected_format}") # type: ignore[attr-defined]
|
||||
logger.info(f" Epochs: {args.num_train_epochs}")
|
||||
logger.info(f" Learning rate: {args.learning_rate}")
|
||||
logger.info(f" Batch size: {args.per_device_train_batch_size}")
|
||||
logger.info(f" Temperature: {args.temperature}")
|
||||
logger.info(f" Enhanced features enabled: {args.use_new_dataset_api}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed with error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if hasattr(trainer, 'hard_negative_miner') and trainer.hard_negative_miner:
|
||||
trainer.hard_negative_miner.stop_async_updates()
|
||||
|
||||
if args.local_rank != -1:
|
||||
from utils.distributed import cleanup_distributed
|
||||
cleanup_distributed()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
398
scripts/train_reranker.py
Normal file
398
scripts/train_reranker.py
Normal file
@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE-Reranker Cross-Encoder Model Training Script
|
||||
Uses the refactored dataset pipeline with flat pairs schema
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
||||
from transformers.optimization import AdamW
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from utils.logging import setup_logging, get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for reranker training"""
|
||||
parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to pretrained BGE-Reranker model")
|
||||
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
||||
help="Directory to cache downloaded models")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--reranker_data", type=str, required=True,
|
||||
help="Path to reranker training data (reranker_data.jsonl)")
|
||||
parser.add_argument("--eval_data", type=str, default=None,
|
||||
help="Path to evaluation data")
|
||||
parser.add_argument("--legacy_data", type=str, default=None,
|
||||
help="Path to legacy nested data (will be migrated)")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./output/bge-reranker",
|
||||
help="Directory to save the trained model")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
|
||||
help="Batch size per GPU/CPU for training")
|
||||
parser.add_argument("--per_device_eval_batch_size", type=int, default=16,
|
||||
help="Batch size per GPU/CPU for evaluation")
|
||||
parser.add_argument("--learning_rate", type=float, default=6e-5,
|
||||
help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01,
|
||||
help="Weight decay")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
||||
help="Warmup ratio")
|
||||
|
||||
# Model specific arguments
|
||||
parser.add_argument("--query_max_length", type=int, default=64,
|
||||
help="Maximum query length")
|
||||
parser.add_argument("--passage_max_length", type=int, default=448,
|
||||
help="Maximum passage length")
|
||||
parser.add_argument("--loss_type", type=str, default="cross_entropy",
|
||||
choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"],
|
||||
help="Loss function type")
|
||||
|
||||
# Enhanced features
|
||||
parser.add_argument("--use_score_weighting", action="store_true",
|
||||
help="Use score weighting in loss calculation")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0,
|
||||
help="Label smoothing for cross entropy loss")
|
||||
|
||||
# System arguments
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Use 16-bit floating point precision")
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
||||
help="Number of workers for data loading")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
parser.add_argument("--log_level", type=str, default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging level")
|
||||
|
||||
# Backward compatibility
|
||||
parser.add_argument("--legacy_support", action="store_true", default=True,
|
||||
help="Enable legacy nested format support")
|
||||
parser.add_argument("--migrate_legacy", action="store_true",
|
||||
help="Migrate legacy nested format to flat format")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(args):
|
||||
"""Setup BGE-Reranker model and tokenizer"""
|
||||
logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def migrate_legacy_data(args):
|
||||
"""Migrate legacy nested data to flat format if needed"""
|
||||
if args.legacy_data and args.migrate_legacy:
|
||||
logger.info("Migrating legacy nested data to flat format...")
|
||||
|
||||
# Create output file name
|
||||
legacy_base = os.path.splitext(args.legacy_data)[0]
|
||||
migrated_file = f"{legacy_base}_migrated_flat.jsonl"
|
||||
|
||||
# Migrate the data
|
||||
migrate_nested_to_flat(args.legacy_data, migrated_file)
|
||||
|
||||
logger.info(f"✅ Migrated legacy data to {migrated_file}")
|
||||
return migrated_file
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def setup_datasets(args, tokenizer):
|
||||
"""Setup reranker datasets"""
|
||||
logger.info("Setting up reranker datasets...")
|
||||
|
||||
# Migrate legacy data if needed
|
||||
migrated_file = migrate_legacy_data(args)
|
||||
|
||||
# Primary reranker data
|
||||
train_dataset = create_reranker_dataset(
|
||||
data_path=args.reranker_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True,
|
||||
legacy_support=args.legacy_support
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}")
|
||||
logger.info(f"Detected format: {train_dataset.format_type}")
|
||||
|
||||
# Add migrated legacy data if available
|
||||
if migrated_file:
|
||||
logger.info(f"Loading migrated legacy data from {migrated_file}")
|
||||
migrated_dataset = create_reranker_dataset(
|
||||
data_path=migrated_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True,
|
||||
legacy_support=False # Migrated data is in flat format
|
||||
)
|
||||
logger.info(f"Loaded {len(migrated_dataset)} migrated examples")
|
||||
|
||||
# Combine datasets
|
||||
train_dataset.data.extend(migrated_dataset.data)
|
||||
logger.info(f"Total training examples: {len(train_dataset)}")
|
||||
|
||||
# Evaluation dataset
|
||||
eval_dataset = None
|
||||
if args.eval_data:
|
||||
eval_dataset = create_reranker_dataset(
|
||||
data_path=args.eval_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=False,
|
||||
legacy_support=args.legacy_support
|
||||
)
|
||||
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_loss_function(args):
|
||||
"""Setup cross-encoder loss function"""
|
||||
if args.loss_type == "cross_entropy":
|
||||
if args.label_smoothing > 0:
|
||||
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
else:
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
elif args.loss_type == "binary_cross_entropy":
|
||||
loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
elif args.loss_type == "listwise_ce":
|
||||
# Custom listwise cross-entropy loss
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {args.loss_type}")
|
||||
|
||||
logger.info(f"Using loss function: {args.loss_type}")
|
||||
return loss_fn
|
||||
|
||||
|
||||
def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args):
|
||||
"""Train for one epoch"""
|
||||
model.train()
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# Move batch to device
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
attention_mask = batch['attention_mask'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
||||
|
||||
# Calculate loss
|
||||
if args.loss_type == "binary_cross_entropy":
|
||||
# For binary classification
|
||||
loss = loss_fn(logits.squeeze(), labels.float())
|
||||
else:
|
||||
# For multi-class classification
|
||||
loss = loss_fn(logits, labels)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Log progress
|
||||
if batch_idx % 100 == 0:
|
||||
logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
||||
|
||||
return total_loss / num_batches
|
||||
|
||||
|
||||
def evaluate(model, dataloader, loss_fn, device, args):
|
||||
"""Evaluate the model"""
|
||||
model.eval()
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
attention_mask = batch['attention_mask'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
||||
|
||||
# Calculate loss
|
||||
if args.loss_type == "binary_cross_entropy":
|
||||
loss = loss_fn(logits.squeeze(), labels.float())
|
||||
# Calculate accuracy for binary classification
|
||||
predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long()
|
||||
else:
|
||||
loss = loss_fn(logits, labels)
|
||||
# Calculate accuracy for multi-class classification
|
||||
predictions = torch.argmax(logits, dim=-1)
|
||||
|
||||
total_loss += loss.item()
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
return total_loss / len(dataloader), accuracy
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=args.log_level,
|
||||
log_to_file=True,
|
||||
log_to_console=True
|
||||
)
|
||||
|
||||
logger.info("🚀 Starting BGE-Reranker Training")
|
||||
logger.info(f"Arguments: {vars(args)}")
|
||||
|
||||
# Set random seed
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Setup device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Setup model and tokenizer
|
||||
model, tokenizer = setup_model_and_tokenizer(args)
|
||||
model.to(device)
|
||||
|
||||
# Setup datasets
|
||||
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
|
||||
|
||||
# Setup data loaders
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_reranker_batch,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_reranker_batch,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
# Setup optimizer and scheduler
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
total_steps = len(train_dataloader) * args.num_train_epochs
|
||||
warmup_steps = int(total_steps * args.warmup_ratio)
|
||||
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_steps
|
||||
)
|
||||
|
||||
# Setup loss function
|
||||
loss_fn = setup_loss_function(args)
|
||||
|
||||
# Training loop
|
||||
logger.info("🎯 Starting training...")
|
||||
|
||||
best_eval_loss = float('inf')
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}")
|
||||
|
||||
# Train
|
||||
train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args)
|
||||
logger.info(f"Training loss: {train_loss:.4f}")
|
||||
|
||||
# Evaluate
|
||||
if eval_dataloader:
|
||||
eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args)
|
||||
logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
|
||||
|
||||
# Save best model
|
||||
if eval_loss < best_eval_loss:
|
||||
best_eval_loss = eval_loss
|
||||
best_model_path = os.path.join(args.output_dir, "best_model")
|
||||
os.makedirs(best_model_path, exist_ok=True)
|
||||
model.save_pretrained(best_model_path)
|
||||
tokenizer.save_pretrained(best_model_path)
|
||||
logger.info(f"Saved best model to {best_model_path}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = os.path.join(args.output_dir, "final_model")
|
||||
os.makedirs(final_model_path, exist_ok=True)
|
||||
model.save_pretrained(final_model_path)
|
||||
tokenizer.save_pretrained(final_model_path)
|
||||
|
||||
logger.info(f"✅ Training completed! Final model saved to {final_model_path}")
|
||||
|
||||
# Show training summary
|
||||
logger.info("📊 Training Summary:")
|
||||
logger.info(f" Total examples: {len(train_dataset)}")
|
||||
logger.info(f" Format detected: {train_dataset.format_type}")
|
||||
logger.info(f" Epochs: {args.num_train_epochs}")
|
||||
logger.info(f" Learning rate: {args.learning_rate}")
|
||||
logger.info(f" Batch size: {args.per_device_train_batch_size}")
|
||||
logger.info(f" Loss function: {args.loss_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
scripts/validate_m3.py
Normal file
122
scripts/validate_m3.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE-M3 model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
|
||||
def test_model_embeddings(model_path, test_queries=None):
|
||||
"""Test if the model can generate embeddings"""
|
||||
|
||||
if test_queries is None:
|
||||
test_queries = [
|
||||
"什么是深度学习?", # Original training query
|
||||
"什么是机器学习?", # Related query
|
||||
"今天天气怎么样?" # Unrelated query
|
||||
]
|
||||
|
||||
print(f"🧪 Testing model from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Model and tokenizer loaded successfully")
|
||||
|
||||
# Test embeddings generation
|
||||
embeddings = []
|
||||
|
||||
for i, query in enumerate(test_queries):
|
||||
print(f"\n📝 Testing query {i+1}: '{query}'")
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
with torch.no_grad():
|
||||
outputs = model.forward(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if 'dense' in outputs:
|
||||
embedding = outputs['dense'].cpu().numpy() # type: ignore[index]
|
||||
embeddings.append(embedding)
|
||||
|
||||
print(f" ✅ Embedding shape: {embedding.shape}")
|
||||
print(f" 📊 Embedding norm: {np.linalg.norm(embedding):.4f}")
|
||||
print(f" 📈 Mean activation: {embedding.mean():.6f}")
|
||||
print(f" 📉 Std activation: {embedding.std():.6f}")
|
||||
else:
|
||||
print(f" ❌ No dense embedding returned")
|
||||
return False
|
||||
|
||||
# Compare embeddings
|
||||
if len(embeddings) >= 2:
|
||||
print(f"\n🔗 Similarity Analysis:")
|
||||
for i in range(len(embeddings)):
|
||||
for j in range(i+1, len(embeddings)):
|
||||
sim = np.dot(embeddings[i].flatten(), embeddings[j].flatten()) / (
|
||||
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
|
||||
)
|
||||
print(f" Query {i+1} vs Query {j+1}: {sim:.4f}")
|
||||
|
||||
print(f"\n🎉 Model validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Model validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_models(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned model"""
|
||||
print("🔄 Comparing original vs fine-tuned model...")
|
||||
|
||||
test_query = "什么是深度学习?"
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Model:")
|
||||
orig_success = test_model_embeddings(original_path, [test_query])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Model:")
|
||||
ft_success = test_model_embeddings(finetuned_path, [test_query])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE-M3 Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_m3_output/final_model"
|
||||
success = test_model_embeddings(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-m3"
|
||||
compare_models(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Fine-tuning validation: PASSED")
|
||||
print(" The model can generate embeddings successfully!")
|
||||
else:
|
||||
print("\n❌ Fine-tuning validation: FAILED")
|
||||
print(" The model has issues generating embeddings!")
|
||||
150
scripts/validate_reranker.py
Normal file
150
scripts/validate_reranker.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE Reranker model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
|
||||
def test_reranker_scoring(model_path, test_queries=None):
|
||||
"""Test if the reranker can score query-passage pairs"""
|
||||
|
||||
if test_queries is None:
|
||||
# Test query with relevant and irrelevant passages
|
||||
test_queries = [
|
||||
{
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域", # Relevant (should score high)
|
||||
"机器学习包含多种算法", # Somewhat relevant
|
||||
"今天天气很好", # Irrelevant (should score low)
|
||||
"深度学习使用神经网络进行学习" # Very relevant
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🧪 Testing reranker from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Reranker model and tokenizer loaded successfully")
|
||||
|
||||
for query_idx, test_case in enumerate(test_queries):
|
||||
query = test_case["query"]
|
||||
passages = test_case["passages"]
|
||||
|
||||
print(f"\n📝 Test Case {query_idx + 1}")
|
||||
print(f"Query: '{query}'")
|
||||
print(f"Passages to rank:")
|
||||
|
||||
scores = []
|
||||
|
||||
for i, passage in enumerate(passages):
|
||||
print(f" {i+1}. '{passage}'")
|
||||
|
||||
# Create query-passage pair
|
||||
sep_token = tokenizer.sep_token or '[SEP]'
|
||||
pair_text = f"{query} {sep_token} {passage}"
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get relevance score
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Extract score (logits)
|
||||
if hasattr(outputs, 'logits'):
|
||||
score = outputs.logits.item()
|
||||
elif isinstance(outputs, torch.Tensor):
|
||||
score = outputs.item()
|
||||
else:
|
||||
score = outputs[0].item()
|
||||
|
||||
scores.append((i, passage, score))
|
||||
|
||||
# Sort by relevance score (highest first)
|
||||
scores.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
print(f"\n🏆 Ranking Results:")
|
||||
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
|
||||
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
|
||||
|
||||
# Check if most relevant passages ranked highest
|
||||
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
|
||||
top_passages = [item[1] for item in scores[:2]]
|
||||
|
||||
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
|
||||
irrelevant_in_top = "今天天气很好" in top_passages
|
||||
|
||||
if relevant_in_top and not irrelevant_in_top:
|
||||
print(f" ✅ Good ranking - relevant passages ranked higher")
|
||||
elif relevant_in_top:
|
||||
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
|
||||
else:
|
||||
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
|
||||
|
||||
print(f"\n🎉 Reranker validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Reranker validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_rerankers(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned reranker"""
|
||||
print("🔄 Comparing original vs fine-tuned reranker...")
|
||||
|
||||
test_case = {
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域",
|
||||
"今天天气很好"
|
||||
]
|
||||
}
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Reranker:")
|
||||
orig_success = test_reranker_scoring(original_path, [test_case])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Reranker:")
|
||||
ft_success = test_reranker_scoring(finetuned_path, [test_case])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE Reranker Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_reranker_output/final_model"
|
||||
success = test_reranker_scoring(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-reranker-base"
|
||||
compare_rerankers(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Reranker validation: PASSED")
|
||||
print(" The reranker can score query-passage pairs successfully!")
|
||||
else:
|
||||
print("\n❌ Reranker validation: FAILED")
|
||||
print(" The reranker has issues with scoring!")
|
||||
641
scripts/validation_utils.py
Normal file
641
scripts/validation_utils.py
Normal file
@ -0,0 +1,641 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Validation Utilities for BGE Fine-tuning
|
||||
|
||||
This module provides utilities for model validation including:
|
||||
- Automatic model discovery
|
||||
- Test data preparation and validation
|
||||
- Results analysis and reporting
|
||||
- Benchmark dataset creation
|
||||
|
||||
Usage:
|
||||
python scripts/validation_utils.py --discover-models
|
||||
python scripts/validation_utils.py --prepare-test-data
|
||||
python scripts/validation_utils.py --analyze-results ./validation_results
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelDiscovery:
|
||||
"""Discover trained models in the workspace"""
|
||||
|
||||
def __init__(self, workspace_root: str = "."):
|
||||
self.workspace_root = Path(workspace_root)
|
||||
|
||||
def find_trained_models(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Find all trained BGE models in the workspace"""
|
||||
models = {
|
||||
'retrievers': {},
|
||||
'rerankers': {},
|
||||
'joint': {}
|
||||
}
|
||||
|
||||
# Common output directories to search
|
||||
search_patterns = [
|
||||
"output/*/final_model",
|
||||
"output/*/best_model",
|
||||
"output/*/*/final_model",
|
||||
"output/*/*/best_model",
|
||||
"test_*_output/final_model",
|
||||
"training_output/*/final_model"
|
||||
]
|
||||
|
||||
for pattern in search_patterns:
|
||||
for model_path in glob.glob(str(self.workspace_root / pattern)):
|
||||
model_path = Path(model_path)
|
||||
|
||||
if self._is_valid_model_directory(model_path):
|
||||
model_info = self._analyze_model_directory(model_path)
|
||||
|
||||
if model_info['type'] == 'retriever':
|
||||
models['retrievers'][model_info['name']] = model_info
|
||||
elif model_info['type'] == 'reranker':
|
||||
models['rerankers'][model_info['name']] = model_info
|
||||
elif model_info['type'] == 'joint':
|
||||
models['joint'][model_info['name']] = model_info
|
||||
|
||||
return models
|
||||
|
||||
def _is_valid_model_directory(self, path: Path) -> bool:
|
||||
"""Check if directory contains a valid trained model"""
|
||||
required_files = ['config.json', 'pytorch_model.bin']
|
||||
alternative_files = ['model.safetensors', 'tokenizer.json']
|
||||
|
||||
has_required = any((path / f).exists() for f in required_files)
|
||||
has_alternative = any((path / f).exists() for f in alternative_files)
|
||||
|
||||
return has_required or has_alternative
|
||||
|
||||
def _analyze_model_directory(self, path: Path) -> Dict[str, str]:
|
||||
"""Analyze model directory to determine type and metadata"""
|
||||
info = {
|
||||
'path': str(path),
|
||||
'name': self._generate_model_name(path),
|
||||
'type': 'unknown',
|
||||
'created': 'unknown',
|
||||
'training_info': {}
|
||||
}
|
||||
|
||||
# Try to determine model type from path
|
||||
path_str = str(path).lower()
|
||||
if 'm3' in path_str or 'retriever' in path_str:
|
||||
info['type'] = 'retriever'
|
||||
elif 'reranker' in path_str:
|
||||
info['type'] = 'reranker'
|
||||
elif 'joint' in path_str:
|
||||
info['type'] = 'joint'
|
||||
|
||||
# Get creation time
|
||||
if path.exists():
|
||||
info['created'] = datetime.fromtimestamp(path.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# Try to read training info
|
||||
training_summary_path = path.parent / 'training_summary.json'
|
||||
if training_summary_path.exists():
|
||||
try:
|
||||
with open(training_summary_path, 'r') as f:
|
||||
info['training_info'] = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Try to read config
|
||||
config_path = path / 'config.json'
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
if 'model_type' in config:
|
||||
info['type'] = config['model_type']
|
||||
except:
|
||||
pass
|
||||
|
||||
return info
|
||||
|
||||
def _generate_model_name(self, path: Path) -> str:
|
||||
"""Generate a readable name for the model"""
|
||||
parts = path.parts
|
||||
|
||||
# Try to find meaningful parts
|
||||
meaningful_parts = []
|
||||
for part in parts[-3:]: # Look at last 3 parts
|
||||
if part not in ['final_model', 'best_model', 'output']:
|
||||
meaningful_parts.append(part)
|
||||
|
||||
if meaningful_parts:
|
||||
return '_'.join(meaningful_parts)
|
||||
else:
|
||||
return f"model_{path.parent.name}"
|
||||
|
||||
def print_discovered_models(self):
|
||||
"""Print discovered models in a formatted way"""
|
||||
models = self.find_trained_models()
|
||||
|
||||
print("\n🔍 DISCOVERED TRAINED MODELS")
|
||||
print("=" * 60)
|
||||
|
||||
for model_type, model_dict in models.items():
|
||||
if model_dict:
|
||||
print(f"\n📊 {model_type.upper()}:")
|
||||
for name, info in model_dict.items():
|
||||
print(f" • {name}")
|
||||
print(f" Path: {info['path']}")
|
||||
print(f" Created: {info['created']}")
|
||||
if info['training_info']:
|
||||
training_info = info['training_info']
|
||||
if 'total_steps' in training_info:
|
||||
print(f" Training Steps: {training_info['total_steps']}")
|
||||
if 'final_metrics' in training_info and training_info['final_metrics']:
|
||||
print(f" Final Metrics: {training_info['final_metrics']}")
|
||||
|
||||
if not any(models.values()):
|
||||
print("❌ No trained models found!")
|
||||
print("💡 Make sure you have trained some models first.")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
class TestDataManager:
|
||||
"""Manage test datasets for validation"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/datasets"):
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
def find_test_datasets(self) -> Dict[str, List[str]]:
|
||||
"""Find available test datasets"""
|
||||
datasets = {
|
||||
'retriever_datasets': [],
|
||||
'reranker_datasets': [],
|
||||
'general_datasets': []
|
||||
}
|
||||
|
||||
# Search for JSONL files in data directory
|
||||
for jsonl_file in self.data_dir.rglob("*.jsonl"):
|
||||
file_path = str(jsonl_file)
|
||||
|
||||
# Categorize based on filename/path
|
||||
if 'reranker' in file_path.lower():
|
||||
datasets['reranker_datasets'].append(file_path)
|
||||
elif 'embedding' in file_path.lower() or 'm3' in file_path.lower():
|
||||
datasets['retriever_datasets'].append(file_path)
|
||||
else:
|
||||
datasets['general_datasets'].append(file_path)
|
||||
|
||||
# Also search for JSON files
|
||||
for json_file in self.data_dir.rglob("*.json"):
|
||||
file_path = str(json_file)
|
||||
if 'test' in file_path.lower() or 'dev' in file_path.lower():
|
||||
if 'reranker' in file_path.lower() or 'AFQMC' in file_path:
|
||||
datasets['reranker_datasets'].append(file_path)
|
||||
else:
|
||||
datasets['general_datasets'].append(file_path)
|
||||
|
||||
return datasets
|
||||
|
||||
def validate_dataset(self, dataset_path: str) -> Dict[str, Any]:
|
||||
"""Validate a dataset file and analyze its format"""
|
||||
validation_result = {
|
||||
'path': dataset_path,
|
||||
'exists': False,
|
||||
'format': 'unknown',
|
||||
'sample_count': 0,
|
||||
'has_queries': False,
|
||||
'has_passages': False,
|
||||
'has_labels': False,
|
||||
'issues': [],
|
||||
'sample_data': {}
|
||||
}
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
validation_result['issues'].append(f"File does not exist: {dataset_path}")
|
||||
return validation_result
|
||||
|
||||
validation_result['exists'] = True
|
||||
|
||||
try:
|
||||
# Read first few lines to analyze format
|
||||
samples = []
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= 5: # Only analyze first 5 lines
|
||||
break
|
||||
try:
|
||||
sample = json.loads(line.strip())
|
||||
samples.append(sample)
|
||||
except json.JSONDecodeError:
|
||||
validation_result['issues'].append(f"Invalid JSON at line {i+1}")
|
||||
|
||||
if samples:
|
||||
validation_result['sample_count'] = len(samples)
|
||||
validation_result['sample_data'] = samples[0] if samples else {}
|
||||
|
||||
# Analyze format
|
||||
first_sample = samples[0]
|
||||
|
||||
# Check for common fields
|
||||
if 'query' in first_sample:
|
||||
validation_result['has_queries'] = True
|
||||
|
||||
if any(field in first_sample for field in ['passage', 'pos', 'neg', 'candidates']):
|
||||
validation_result['has_passages'] = True
|
||||
|
||||
if any(field in first_sample for field in ['label', 'labels', 'score']):
|
||||
validation_result['has_labels'] = True
|
||||
|
||||
# Determine format
|
||||
if 'pos' in first_sample and 'neg' in first_sample:
|
||||
validation_result['format'] = 'triplets'
|
||||
elif 'candidates' in first_sample:
|
||||
validation_result['format'] = 'candidates'
|
||||
elif 'passage' in first_sample and 'label' in first_sample:
|
||||
validation_result['format'] = 'pairs'
|
||||
elif 'sentence1' in first_sample and 'sentence2' in first_sample:
|
||||
validation_result['format'] = 'sentence_pairs'
|
||||
else:
|
||||
validation_result['format'] = 'custom'
|
||||
|
||||
except Exception as e:
|
||||
validation_result['issues'].append(f"Error reading file: {e}")
|
||||
|
||||
return validation_result
|
||||
|
||||
def print_dataset_analysis(self):
|
||||
"""Print analysis of available datasets"""
|
||||
datasets = self.find_test_datasets()
|
||||
|
||||
print("\n📁 AVAILABLE TEST DATASETS")
|
||||
print("=" * 60)
|
||||
|
||||
for category, dataset_list in datasets.items():
|
||||
if dataset_list:
|
||||
print(f"\n📊 {category.upper().replace('_', ' ')}:")
|
||||
for dataset_path in dataset_list:
|
||||
validation = self.validate_dataset(dataset_path)
|
||||
status = "✅" if validation['exists'] and not validation['issues'] else "⚠️"
|
||||
print(f" {status} {Path(dataset_path).name}")
|
||||
print(f" Path: {dataset_path}")
|
||||
print(f" Format: {validation['format']}")
|
||||
if validation['issues']:
|
||||
for issue in validation['issues']:
|
||||
print(f" ⚠️ {issue}")
|
||||
|
||||
if not any(datasets.values()):
|
||||
print("❌ No test datasets found!")
|
||||
print("💡 Place your test data in the data/datasets/ directory")
|
||||
|
||||
return datasets
|
||||
|
||||
def create_sample_test_data(self, output_dir: str = "data/test_samples"):
|
||||
"""Create sample test datasets for validation"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Create sample retriever test data
|
||||
retriever_test = [
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"pos": [
|
||||
"人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
|
||||
"AI技术包括机器学习、深度学习、自然语言处理等多个领域"
|
||||
],
|
||||
"neg": [
|
||||
"今天的天气非常好,阳光明媚",
|
||||
"我喜欢听音乐和看电影"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"pos": [
|
||||
"学习编程需要选择合适的语言,如Python、Java等,然后通过实践项目提升技能",
|
||||
"编程学习的关键是多写代码、多调试、多思考算法逻辑"
|
||||
],
|
||||
"neg": [
|
||||
"烹饪是一门艺术,需要掌握火候和调味",
|
||||
"运动对健康很重要,每天应该锻炼30分钟"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Create sample reranker test data
|
||||
reranker_test = [
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"passage": "人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
|
||||
"label": 1
|
||||
},
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"passage": "今天的天气非常好,阳光明媚",
|
||||
"label": 0
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"passage": "学习编程需要选择合适的语言,如Python、Java等,然后通过实践项目提升技能",
|
||||
"label": 1
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"passage": "烹饪是一门艺术,需要掌握火候和调味",
|
||||
"label": 0
|
||||
}
|
||||
]
|
||||
|
||||
# Save retriever test data
|
||||
retriever_path = os.path.join(output_dir, "retriever_test.jsonl")
|
||||
with open(retriever_path, 'w', encoding='utf-8') as f:
|
||||
for item in retriever_test:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# Save reranker test data
|
||||
reranker_path = os.path.join(output_dir, "reranker_test.jsonl")
|
||||
with open(reranker_path, 'w', encoding='utf-8') as f:
|
||||
for item in reranker_test:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"✅ Sample test datasets created in: {output_dir}")
|
||||
print(f" 📁 Retriever test: {retriever_path}")
|
||||
print(f" 📁 Reranker test: {reranker_path}")
|
||||
|
||||
return {
|
||||
'retriever_test': retriever_path,
|
||||
'reranker_test': reranker_path
|
||||
}
|
||||
|
||||
|
||||
class ValidationReportAnalyzer:
|
||||
"""Analyze validation results and generate insights"""
|
||||
|
||||
def __init__(self, results_dir: str):
|
||||
self.results_dir = Path(results_dir)
|
||||
|
||||
def analyze_results(self) -> Dict[str, Any]:
|
||||
"""Analyze validation results"""
|
||||
analysis = {
|
||||
'overall_status': 'unknown',
|
||||
'model_performance': {},
|
||||
'recommendations': [],
|
||||
'summary_stats': {}
|
||||
}
|
||||
|
||||
# Look for results files
|
||||
json_results = self.results_dir / "validation_results.json"
|
||||
|
||||
if not json_results.exists():
|
||||
analysis['recommendations'].append("No validation results found. Run validation first.")
|
||||
return analysis
|
||||
|
||||
try:
|
||||
with open(json_results, 'r') as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Analyze each model type
|
||||
improvements = 0
|
||||
degradations = 0
|
||||
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in results and 'comparisons' in results[model_type]:
|
||||
model_analysis = self._analyze_model_results(results[model_type])
|
||||
analysis['model_performance'][model_type] = model_analysis
|
||||
|
||||
improvements += model_analysis['total_improvements']
|
||||
degradations += model_analysis['total_degradations']
|
||||
|
||||
# Overall status
|
||||
if improvements > degradations * 1.5:
|
||||
analysis['overall_status'] = 'excellent'
|
||||
elif improvements > degradations:
|
||||
analysis['overall_status'] = 'good'
|
||||
elif improvements == degradations:
|
||||
analysis['overall_status'] = 'mixed'
|
||||
else:
|
||||
analysis['overall_status'] = 'poor'
|
||||
|
||||
# Generate recommendations
|
||||
analysis['recommendations'] = self._generate_recommendations(analysis)
|
||||
|
||||
# Summary statistics
|
||||
analysis['summary_stats'] = {
|
||||
'total_improvements': improvements,
|
||||
'total_degradations': degradations,
|
||||
'net_improvement': improvements - degradations,
|
||||
'improvement_ratio': improvements / (improvements + degradations) if (improvements + degradations) > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
analysis['recommendations'].append(f"Error analyzing results: {e}")
|
||||
|
||||
return analysis
|
||||
|
||||
def _analyze_model_results(self, model_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Analyze results for a specific model type"""
|
||||
analysis = {
|
||||
'total_improvements': 0,
|
||||
'total_degradations': 0,
|
||||
'best_dataset': None,
|
||||
'worst_dataset': None,
|
||||
'avg_improvement': 0,
|
||||
'issues': []
|
||||
}
|
||||
|
||||
dataset_scores = {}
|
||||
|
||||
for dataset_name, comparison in model_results.get('comparisons', {}).items():
|
||||
improvements = len(comparison.get('improvements', {}))
|
||||
degradations = len(comparison.get('degradations', {}))
|
||||
|
||||
analysis['total_improvements'] += improvements
|
||||
analysis['total_degradations'] += degradations
|
||||
|
||||
# Score dataset (improvements - degradations)
|
||||
dataset_score = improvements - degradations
|
||||
dataset_scores[dataset_name] = dataset_score
|
||||
|
||||
if dataset_scores:
|
||||
analysis['best_dataset'] = max(dataset_scores.items(), key=lambda x: x[1])
|
||||
analysis['worst_dataset'] = min(dataset_scores.items(), key=lambda x: x[1])
|
||||
analysis['avg_improvement'] = sum(dataset_scores.values()) / len(dataset_scores)
|
||||
|
||||
# Identify issues
|
||||
if analysis['total_degradations'] > analysis['total_improvements']:
|
||||
analysis['issues'].append("More degradations than improvements")
|
||||
|
||||
if analysis['avg_improvement'] < 0:
|
||||
analysis['issues'].append("Negative average improvement")
|
||||
|
||||
return analysis
|
||||
|
||||
def _generate_recommendations(self, analysis: Dict[str, Any]) -> List[str]:
|
||||
"""Generate recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
overall_status = analysis['overall_status']
|
||||
|
||||
if overall_status == 'excellent':
|
||||
recommendations.append("🎉 Great job! Your fine-tuned models show significant improvements.")
|
||||
recommendations.append("Consider deploying these models to production.")
|
||||
|
||||
elif overall_status == 'good':
|
||||
recommendations.append("✅ Good progress! Models show overall improvement.")
|
||||
recommendations.append("Consider further fine-tuning on specific datasets where degradation occurred.")
|
||||
|
||||
elif overall_status == 'mixed':
|
||||
recommendations.append("⚠️ Mixed results. Some improvements, some degradations.")
|
||||
recommendations.append("Analyze which datasets/metrics are degrading and adjust training data.")
|
||||
recommendations.append("Consider different hyperparameters or longer training.")
|
||||
|
||||
elif overall_status == 'poor':
|
||||
recommendations.append("❌ Models show more degradations than improvements.")
|
||||
recommendations.append("Review training data quality and format.")
|
||||
recommendations.append("Check hyperparameters (learning rate might be too high).")
|
||||
recommendations.append("Consider starting with different base models.")
|
||||
|
||||
# Specific model recommendations
|
||||
for model_type, model_perf in analysis.get('model_performance', {}).items():
|
||||
if model_perf['issues']:
|
||||
recommendations.append(f"🔧 {model_type.title()} issues: {', '.join(model_perf['issues'])}")
|
||||
|
||||
if model_perf['best_dataset']:
|
||||
best_dataset, best_score = model_perf['best_dataset']
|
||||
recommendations.append(f"📊 {model_type.title()} performs best on: {best_dataset} (score: {best_score})")
|
||||
|
||||
if model_perf['worst_dataset']:
|
||||
worst_dataset, worst_score = model_perf['worst_dataset']
|
||||
recommendations.append(f"📉 {model_type.title()} needs improvement on: {worst_dataset} (score: {worst_score})")
|
||||
|
||||
return recommendations
|
||||
|
||||
def print_analysis(self):
|
||||
"""Print comprehensive analysis of validation results"""
|
||||
analysis = self.analyze_results()
|
||||
|
||||
print(f"\n📊 VALIDATION RESULTS ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
# Overall status
|
||||
status_emoji = {
|
||||
'excellent': '🌟',
|
||||
'good': '✅',
|
||||
'mixed': '⚠️',
|
||||
'poor': '❌',
|
||||
'unknown': '❓'
|
||||
}
|
||||
|
||||
print(f"{status_emoji.get(analysis['overall_status'], '❓')} Overall Status: {analysis['overall_status'].upper()}")
|
||||
|
||||
# Summary stats
|
||||
if 'summary_stats' in analysis and analysis['summary_stats']:
|
||||
stats = analysis['summary_stats']
|
||||
print(f"\n📈 Summary Statistics:")
|
||||
print(f" Total Improvements: {stats['total_improvements']}")
|
||||
print(f" Total Degradations: {stats['total_degradations']}")
|
||||
print(f" Net Improvement: {stats['net_improvement']:+d}")
|
||||
print(f" Improvement Ratio: {stats['improvement_ratio']:.2%}")
|
||||
|
||||
# Model performance
|
||||
if analysis['model_performance']:
|
||||
print(f"\n📊 Model Performance:")
|
||||
for model_type, perf in analysis['model_performance'].items():
|
||||
print(f" {model_type.title()}:")
|
||||
print(f" Improvements: {perf['total_improvements']}")
|
||||
print(f" Degradations: {perf['total_degradations']}")
|
||||
print(f" Average Score: {perf['avg_improvement']:.2f}")
|
||||
|
||||
if perf['best_dataset']:
|
||||
print(f" Best Dataset: {perf['best_dataset'][0]} (score: {perf['best_dataset'][1]})")
|
||||
|
||||
if perf['issues']:
|
||||
print(f" Issues: {', '.join(perf['issues'])}")
|
||||
|
||||
# Recommendations
|
||||
if analysis['recommendations']:
|
||||
print(f"\n💡 Recommendations:")
|
||||
for i, rec in enumerate(analysis['recommendations'], 1):
|
||||
print(f" {i}. {rec}")
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validation utilities for BGE models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("--discover-models", action="store_true",
|
||||
help="Discover trained models in workspace")
|
||||
parser.add_argument("--analyze-datasets", action="store_true",
|
||||
help="Analyze available test datasets")
|
||||
parser.add_argument("--create-sample-data", action="store_true",
|
||||
help="Create sample test datasets")
|
||||
parser.add_argument("--analyze-results", type=str, default=None,
|
||||
help="Analyze validation results in specified directory")
|
||||
parser.add_argument("--workspace-root", type=str, default=".",
|
||||
help="Root directory of workspace")
|
||||
parser.add_argument("--data-dir", type=str, default="data/datasets",
|
||||
help="Directory containing test datasets")
|
||||
parser.add_argument("--output-dir", type=str, default="data/test_samples",
|
||||
help="Output directory for sample data")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main utility function"""
|
||||
args = parse_args()
|
||||
|
||||
print("🔧 BGE Validation Utilities")
|
||||
print("=" * 50)
|
||||
|
||||
if args.discover_models:
|
||||
print("\n🔍 Discovering trained models...")
|
||||
discovery = ModelDiscovery(args.workspace_root)
|
||||
models = discovery.print_discovered_models()
|
||||
|
||||
if any(models.values()):
|
||||
print("\n💡 Use these model paths in your validation commands:")
|
||||
for model_type, model_dict in models.items():
|
||||
for name, info in model_dict.items():
|
||||
print(f" --{model_type[:-1]}_model {info['path']}")
|
||||
|
||||
if args.analyze_datasets:
|
||||
print("\n📁 Analyzing test datasets...")
|
||||
data_manager = TestDataManager(args.data_dir)
|
||||
datasets = data_manager.print_dataset_analysis()
|
||||
|
||||
if args.create_sample_data:
|
||||
print("\n📝 Creating sample test data...")
|
||||
data_manager = TestDataManager()
|
||||
sample_paths = data_manager.create_sample_test_data(args.output_dir)
|
||||
|
||||
print("\n💡 Use these sample datasets for testing:")
|
||||
print(f" --test_datasets {sample_paths['retriever_test']} {sample_paths['reranker_test']}")
|
||||
|
||||
if args.analyze_results:
|
||||
print(f"\n📊 Analyzing results from: {args.analyze_results}")
|
||||
analyzer = ValidationReportAnalyzer(args.analyze_results)
|
||||
analysis = analyzer.print_analysis()
|
||||
|
||||
if not any([args.discover_models, args.analyze_datasets, args.create_sample_data, args.analyze_results]):
|
||||
print("❓ No action specified. Use --help to see available options.")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
99
tests/test_data_edge_cases.py
Normal file
99
tests/test_data_edge_cases.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from data.augmentation import QueryAugmenter
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_jsonl_file():
|
||||
# Create a malformed JSONL file
|
||||
fd, path = tempfile.mkstemp(suffix='.jsonl')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
f.write('{"query": "valid", "pos": ["a"], "neg": ["b"]}\n')
|
||||
f.write('{invalid json}\n')
|
||||
f.write('')
|
||||
yield path
|
||||
os.remove(path)
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_csv_file():
|
||||
fd, path = tempfile.mkstemp(suffix='.csv')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
f.write('query,pos,neg\n')
|
||||
f.write('valid,"a","b"\n')
|
||||
f.write('malformed line\n')
|
||||
yield path
|
||||
os.remove(path)
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_tsv_file():
|
||||
fd, path = tempfile.mkstemp(suffix='.tsv')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
f.write('query\tpos\tneg\n')
|
||||
f.write('valid\ta\tb\n')
|
||||
f.write('malformed line\n')
|
||||
yield path
|
||||
os.remove(path)
|
||||
|
||||
def test_jsonl_malformed_lines(tmp_jsonl_file):
|
||||
preprocessor = DataPreprocessor(tokenizer=None)
|
||||
# Should skip malformed line and not raise
|
||||
data = preprocessor._load_jsonl(tmp_jsonl_file)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]['query'] == 'valid'
|
||||
|
||||
def test_csv_malformed_lines(tmp_csv_file):
|
||||
preprocessor = DataPreprocessor(tokenizer=None)
|
||||
# Should skip malformed line and not raise
|
||||
data = preprocessor._load_csv(tmp_csv_file)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]['query'] == 'valid'
|
||||
|
||||
def test_tsv_malformed_lines(tmp_tsv_file):
|
||||
preprocessor = DataPreprocessor(tokenizer=None)
|
||||
# Should skip malformed line and not raise
|
||||
data = preprocessor._load_tsv(tmp_tsv_file)
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]['query'] == 'valid'
|
||||
|
||||
def test_augmentation_methods_config():
|
||||
# Should use config-driven augmentation methods
|
||||
augmenter = QueryAugmenter(augmentation_methods=None, config_path='config.toml')
|
||||
assert isinstance(augmenter.augmentation_methods, list)
|
||||
assert 'paraphrase' in augmenter.augmentation_methods
|
||||
|
||||
def test_config_list_parsing():
|
||||
"""Test that config loader parses comma/semicolon-separated lists correctly."""
|
||||
from utils.config_loader import ConfigLoader
|
||||
loader = ConfigLoader()
|
||||
loader._config = {'section': {'list_param': 'a, b; c'}}
|
||||
value = loader.get('section', 'list_param', default=[])
|
||||
assert isinstance(value, list)
|
||||
assert set(value) == {'a', 'b', 'c'}
|
||||
|
||||
def test_config_param_source_logging(caplog):
|
||||
"""Test that config loader logs parameter source correctly."""
|
||||
from utils.config_loader import ConfigLoader
|
||||
loader = ConfigLoader()
|
||||
loader._config = {'section': {'param': 42}}
|
||||
with caplog.at_level('INFO'):
|
||||
value = loader.get('section', 'param', default=0)
|
||||
assert 'source: config.toml' in caplog.text
|
||||
value = loader.get('section', 'missing', default=99)
|
||||
assert 'source: default' in caplog.text
|
||||
|
||||
def test_config_fallback_default():
|
||||
"""Test that config loader falls back to default if key is missing."""
|
||||
from utils.config_loader import ConfigLoader
|
||||
loader = ConfigLoader()
|
||||
loader._config = {'section': {}}
|
||||
value = loader.get('section', 'not_found', default='fallback')
|
||||
assert value == 'fallback'
|
||||
121
tests/test_full_training.py
Normal file
121
tests/test_full_training.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script to validate full training pipeline with learning rate scheduling"""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import tempfile
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset, collate_embedding_batch
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def create_test_data():
|
||||
"""Create minimal test data"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": ["Machine learning is a subset of AI"],
|
||||
"neg": ["The weather is sunny today"]
|
||||
},
|
||||
{
|
||||
"query": "How does Python work?",
|
||||
"pos": ["Python is a programming language"],
|
||||
"neg": ["Cats are cute animals"]
|
||||
},
|
||||
{
|
||||
"query": "What is deep learning?",
|
||||
"pos": ["Deep learning uses neural networks"],
|
||||
"neg": ["Pizza is delicious food"]
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def main():
|
||||
print("🧪 Testing BGE-M3 training with learning rate scheduling...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_test_data()
|
||||
print(f"✅ Created test data: {train_file}")
|
||||
|
||||
# Initialize config with small dataset settings
|
||||
config = BGEM3Config(
|
||||
model_name_or_path="./models/bge-m3",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-lr-scheduler",
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1, # Small for testing
|
||||
learning_rate=1e-4, # Higher for testing
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
use_self_distill=False,
|
||||
use_hard_negatives=False,
|
||||
temperature=0.1 # Fixed temperature
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ Dataset size: {len(dataset)}")
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test training for a few steps
|
||||
print("\n🏋️ Starting training test...")
|
||||
trainer.train(dataset)
|
||||
|
||||
print("✅ Training completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Training failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
712
tests/test_installation.py
Normal file
712
tests/test_installation.py
Normal file
@ -0,0 +1,712 @@
|
||||
"""
|
||||
Comprehensive testing script for BGE fine-tuning project
|
||||
Tests all components, configurations, and integrations
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import tempfile
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
class ComprehensiveTestSuite:
|
||||
"""Comprehensive test suite for BGE fine-tuning project"""
|
||||
|
||||
def __init__(self):
|
||||
self.test_results = []
|
||||
self.failed_tests = []
|
||||
|
||||
def run_test(self, test_name, test_func, *args, **kwargs):
|
||||
"""Run a single test and record results"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing: {test_name}")
|
||||
print('='*60)
|
||||
|
||||
try:
|
||||
result = test_func(*args, **kwargs)
|
||||
if result:
|
||||
print(f"✅ {test_name} PASSED")
|
||||
self.test_results.append((test_name, "PASSED", None))
|
||||
return True
|
||||
else:
|
||||
print(f"❌ {test_name} FAILED")
|
||||
self.test_results.append((test_name, "FAILED", "Test returned False"))
|
||||
self.failed_tests.append(test_name)
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ {test_name} FAILED with exception: {e}")
|
||||
traceback.print_exc()
|
||||
self.test_results.append((test_name, "FAILED", str(e)))
|
||||
self.failed_tests.append(test_name)
|
||||
return False
|
||||
|
||||
def test_imports(self):
|
||||
"""Test all critical imports"""
|
||||
print("Testing critical imports...")
|
||||
|
||||
# Core dependencies
|
||||
try:
|
||||
import torch
|
||||
print(f"✓ PyTorch {torch.__version__}")
|
||||
print(f" CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f" CUDA version: {torch.version.cuda}")
|
||||
print(f" GPU count: {torch.cuda.device_count()}")
|
||||
except ImportError as e:
|
||||
print(f"✗ PyTorch: {e}")
|
||||
return False
|
||||
|
||||
# Test Ascend NPU support
|
||||
try:
|
||||
import torch_npu
|
||||
if torch.npu.is_available(): # type: ignore[attr-defined]
|
||||
print(f"✓ Ascend NPU available: {torch.npu.device_count()} devices") # type: ignore[attr-defined]
|
||||
else:
|
||||
print("! Ascend NPU library installed but no devices available")
|
||||
except ImportError:
|
||||
print("- Ascend NPU library not installed (optional)")
|
||||
|
||||
# Transformers and related
|
||||
try:
|
||||
import transformers
|
||||
print(f"✓ Transformers {transformers.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"✗ Transformers: {e}")
|
||||
return False
|
||||
|
||||
# Scientific computing
|
||||
try:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy
|
||||
import sklearn
|
||||
print("✓ Scientific computing libraries")
|
||||
except ImportError as e:
|
||||
print(f"✗ Scientific libraries: {e}")
|
||||
return False
|
||||
|
||||
# BGE specific
|
||||
try:
|
||||
import faiss
|
||||
print(f"✓ FAISS {faiss.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"✗ FAISS: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
print("✓ BM25")
|
||||
except ImportError as e:
|
||||
print(f"✗ BM25: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
import jieba
|
||||
print("✓ Jieba (Chinese tokenization)")
|
||||
except ImportError as e:
|
||||
print(f"✗ Jieba: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
import toml
|
||||
print("✓ TOML")
|
||||
except ImportError as e:
|
||||
print(f"✗ TOML: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_project_imports(self):
|
||||
"""Test project-specific imports"""
|
||||
print("Testing project imports...")
|
||||
|
||||
# Configuration
|
||||
try:
|
||||
from utils.config_loader import get_config, load_config_for_training
|
||||
from config.base_config import BaseConfig
|
||||
from config.m3_config import BGEM3Config
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
print("✓ Configuration classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Configuration imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Models
|
||||
try:
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
|
||||
print("✓ Model classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Model imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Data processing
|
||||
try:
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
print("✓ Data processing classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Data imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Training
|
||||
try:
|
||||
from training.trainer import BaseTrainer
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from training.joint_trainer import JointTrainer
|
||||
print("✓ Training classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Training imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Evaluation
|
||||
try:
|
||||
from evaluation.evaluator import RetrievalEvaluator
|
||||
from evaluation.metrics import compute_retrieval_metrics
|
||||
print("✓ Evaluation classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Evaluation imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Utilities
|
||||
try:
|
||||
from utils.logging import setup_logging
|
||||
from utils.checkpoint import CheckpointManager
|
||||
from utils.distributed import setup_distributed, is_main_process
|
||||
print("✓ Utility classes")
|
||||
except ImportError as e:
|
||||
print(f"✗ Utility imports: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_configuration_system(self):
|
||||
"""Test the centralized configuration system"""
|
||||
print("Testing configuration system...")
|
||||
|
||||
try:
|
||||
from utils.config_loader import get_config, load_config_for_training
|
||||
|
||||
# Test singleton behavior
|
||||
config1 = get_config()
|
||||
config2 = get_config()
|
||||
assert config1 is config2, "Config should be singleton"
|
||||
print("✓ Singleton pattern working")
|
||||
|
||||
# Test configuration loading
|
||||
test_value = config1.get('training.default_batch_size', None) # type: ignore[arg-type]
|
||||
print(f"✓ Config loading: batch_size = {test_value}")
|
||||
|
||||
# Test section loading
|
||||
training_section = config1.get_section('training')
|
||||
assert isinstance(training_section, dict), "Should return dict"
|
||||
print("✓ Section loading working")
|
||||
|
||||
# Test training configuration loading
|
||||
training_config = load_config_for_training()
|
||||
assert 'training' in training_config
|
||||
assert 'model_paths' in training_config
|
||||
print("✓ Training config loading working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Configuration system error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_configuration_integration(self):
|
||||
"""Test configuration integration in classes"""
|
||||
print("Testing configuration integration...")
|
||||
|
||||
try:
|
||||
from config.base_config import BaseConfig
|
||||
from config.m3_config import BGEM3Config
|
||||
|
||||
# Test BaseConfig with config.toml integration
|
||||
base_config = BaseConfig()
|
||||
print(f"✓ BaseConfig created with batch_size: {base_config.per_device_train_batch_size}")
|
||||
|
||||
# Test M3Config
|
||||
m3_config = BGEM3Config()
|
||||
print(f"✓ M3Config created with learning_rate: {m3_config.learning_rate}")
|
||||
|
||||
# Test device detection
|
||||
print(f"✓ Device detected: {base_config.device}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Configuration integration error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_data_processing(self):
|
||||
"""Test data processing functionality"""
|
||||
print("Testing data processing...")
|
||||
|
||||
try:
|
||||
from data.preprocessing import DataPreprocessor
|
||||
|
||||
# Create test data
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
test_data = [
|
||||
{"query": "test query 1", "pos": ["positive passage 1"], "neg": ["negative passage 1"]},
|
||||
{"query": "test query 2", "pos": ["positive passage 2"], "neg": ["negative passage 2"]}
|
||||
]
|
||||
for item in test_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
temp_file = f.name
|
||||
|
||||
# Test preprocessing
|
||||
preprocessor = DataPreprocessor()
|
||||
data = preprocessor._load_jsonl(temp_file)
|
||||
assert len(data) == 2
|
||||
print("✓ JSONL loading working")
|
||||
|
||||
# Test validation
|
||||
cleaned = preprocessor._validate_and_clean(data)
|
||||
assert len(cleaned) == 2
|
||||
print("✓ Data validation working")
|
||||
|
||||
# Cleanup
|
||||
os.unlink(temp_file)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Data processing error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_data_augmentation(self):
|
||||
"""Test data augmentation functionality"""
|
||||
print("Testing data augmentation...")
|
||||
|
||||
try:
|
||||
from data.augmentation import QueryAugmenter, PassageAugmenter, HardNegativeAugmenter
|
||||
|
||||
# Test QueryAugmenter
|
||||
augmenter = QueryAugmenter(augmentation_methods=['paraphrase', 'synonym'])
|
||||
test_data = [
|
||||
{"query": "什么是机器学习?", "pos": ["机器学习是AI的分支"], "neg": ["深度学习不同"]}
|
||||
]
|
||||
|
||||
augmented = augmenter.augment_queries(test_data, num_augmentations=1, keep_original=True)
|
||||
assert len(augmented) >= len(test_data)
|
||||
print("✓ Query augmentation working")
|
||||
|
||||
# Test PassageAugmenter
|
||||
passage_aug = PassageAugmenter()
|
||||
passage_augmented = passage_aug.augment_passages(test_data, augmentation_prob=1.0)
|
||||
print("✓ Passage augmentation working")
|
||||
|
||||
# Test HardNegativeAugmenter
|
||||
hn_aug = HardNegativeAugmenter()
|
||||
hn_augmented = hn_aug.create_hard_negatives(test_data, num_hard_negatives=2)
|
||||
print("✓ Hard negative augmentation working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Data augmentation error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_model_functionality(self):
|
||||
"""Test basic model functionality"""
|
||||
print("Testing model functionality...")
|
||||
|
||||
try:
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import InfoNCELoss, ListwiseCrossEntropyLoss
|
||||
|
||||
# Test loss functions
|
||||
infonce = InfoNCELoss()
|
||||
query_emb = torch.randn(2, 128)
|
||||
pos_emb = torch.randn(2, 128)
|
||||
neg_emb = torch.randn(2, 5, 128)
|
||||
|
||||
loss = infonce(query_emb, pos_emb, neg_emb)
|
||||
assert loss.item() > 0
|
||||
print("✓ InfoNCE loss working")
|
||||
|
||||
# Test listwise loss
|
||||
listwise = ListwiseCrossEntropyLoss()
|
||||
scores = torch.randn(2, 5)
|
||||
labels = torch.zeros(2, 5)
|
||||
labels[:, 0] = 1 # First is positive
|
||||
|
||||
loss = listwise(scores, labels, [5, 5])
|
||||
assert loss.item() > 0
|
||||
print("✓ Listwise loss working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Model functionality error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_hard_negative_mining(self):
|
||||
"""Test hard negative mining"""
|
||||
print("Testing hard negative mining...")
|
||||
|
||||
try:
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
|
||||
# Create miner
|
||||
miner = ANCEHardNegativeMiner(
|
||||
embedding_dim=128,
|
||||
num_hard_negatives=5,
|
||||
use_gpu=False # Use CPU for testing
|
||||
)
|
||||
|
||||
# Add some dummy embeddings
|
||||
embeddings = np.random.rand(10, 128).astype(np.float32)
|
||||
passages = [f"passage_{i}" for i in range(10)]
|
||||
|
||||
miner._add_to_index(embeddings, passages)
|
||||
print("✓ Index creation working")
|
||||
|
||||
# Test mining
|
||||
queries = ["query1", "query2"]
|
||||
query_embeddings = np.random.rand(2, 128).astype(np.float32)
|
||||
positive_passages = [["passage_0"], ["passage_5"]]
|
||||
|
||||
hard_negs = miner.mine_hard_negatives(
|
||||
queries, query_embeddings, positive_passages
|
||||
)
|
||||
|
||||
assert len(hard_negs) == 2
|
||||
print("✓ Hard negative mining working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Hard negative mining error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_evaluation_metrics(self):
|
||||
"""Test evaluation metrics"""
|
||||
print("Testing evaluation metrics...")
|
||||
|
||||
try:
|
||||
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
|
||||
|
||||
# Create dummy data
|
||||
query_embs = np.random.rand(5, 128)
|
||||
passage_embs = np.random.rand(20, 128)
|
||||
labels = np.zeros((5, 20))
|
||||
labels[0, 0] = 1 # First query, first passage is relevant
|
||||
labels[1, 5] = 1 # Second query, sixth passage is relevant
|
||||
|
||||
metrics = compute_retrieval_metrics(query_embs, passage_embs, labels)
|
||||
assert 'recall@1' in metrics
|
||||
assert 'precision@1' in metrics
|
||||
print("✓ Retrieval metrics working")
|
||||
|
||||
# Test reranker metrics
|
||||
scores = np.random.rand(10)
|
||||
labels = np.random.randint(0, 2, 10)
|
||||
reranker_metrics = compute_reranker_metrics(scores, labels)
|
||||
assert 'accuracy' in reranker_metrics or 'mrr' in reranker_metrics
|
||||
print("✓ Reranker metrics working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Evaluation metrics error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_data_formats(self):
|
||||
"""Test multiple data format support"""
|
||||
print("Testing data format support...")
|
||||
|
||||
try:
|
||||
from data.preprocessing import DataPreprocessor
|
||||
|
||||
preprocessor = DataPreprocessor()
|
||||
|
||||
# Test JSONL format
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write('{"query": "test", "pos": ["pos1"], "neg": ["neg1"]}\n')
|
||||
f.write('{"query": "test2", "pos": ["pos2"], "neg": ["neg2"]}\n')
|
||||
temp_file = f.name
|
||||
|
||||
data = preprocessor._load_jsonl(temp_file)
|
||||
assert len(data) == 2
|
||||
os.unlink(temp_file)
|
||||
print("✓ JSONL format working")
|
||||
|
||||
# Test CSV format
|
||||
import pandas as pd
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
|
||||
df = pd.DataFrame({
|
||||
'query': ['test1', 'test2'],
|
||||
'positive': ['pos1', 'pos2'],
|
||||
'negative': ['neg1', 'neg2']
|
||||
})
|
||||
df.to_csv(f.name, index=False)
|
||||
temp_file = f.name
|
||||
|
||||
data = preprocessor._load_csv(temp_file)
|
||||
assert len(data) >= 1 # At least one valid record
|
||||
os.unlink(temp_file)
|
||||
print("✓ CSV format working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Data format support error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_device_compatibility(self):
|
||||
"""Test device compatibility (CUDA/Ascend/CPU)"""
|
||||
print("Testing device compatibility...")
|
||||
|
||||
try:
|
||||
from config.base_config import BaseConfig
|
||||
|
||||
# Test device detection
|
||||
config = BaseConfig()
|
||||
print(f"✓ Device detected: {config.device}")
|
||||
print(f"✓ Number of devices: {config.n_gpu}")
|
||||
|
||||
# Test tensor operations
|
||||
if config.device == "cuda":
|
||||
x = torch.randn(10, 10).cuda()
|
||||
y = torch.matmul(x, x.t())
|
||||
assert y.device.type == "cuda"
|
||||
print("✓ CUDA tensor operations working")
|
||||
elif config.device == "npu":
|
||||
print("✓ Ascend NPU detected")
|
||||
# Note: Would need actual NPU to test operations
|
||||
else:
|
||||
x = torch.randn(10, 10)
|
||||
y = torch.matmul(x, x.t())
|
||||
print("✓ CPU tensor operations working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Device compatibility error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_checkpoint_functionality(self):
|
||||
"""Test checkpointing functionality"""
|
||||
print("Testing checkpoint functionality...")
|
||||
|
||||
try:
|
||||
from utils.checkpoint import CheckpointManager
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
manager = CheckpointManager(temp_dir)
|
||||
|
||||
# Test saving
|
||||
state_dict = {
|
||||
'model_state_dict': torch.randn(10, 10),
|
||||
'optimizer_state_dict': {'lr': 0.001},
|
||||
'step': 100
|
||||
}
|
||||
|
||||
checkpoint_path = manager.save(state_dict, step=100)
|
||||
assert os.path.exists(checkpoint_path)
|
||||
print("✓ Checkpoint saving working")
|
||||
|
||||
# Test loading
|
||||
loaded_state = manager.load_checkpoint(checkpoint_path)
|
||||
assert 'model_state_dict' in loaded_state
|
||||
print("✓ Checkpoint loading working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Checkpoint functionality error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_logging_system(self):
|
||||
"""Test logging system"""
|
||||
print("Testing logging system...")
|
||||
|
||||
try:
|
||||
from utils.logging import setup_logging, get_logger, MetricsLogger
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Test setup
|
||||
setup_logging(temp_dir, log_to_file=True)
|
||||
logger = get_logger("test")
|
||||
logger.info("Test log message")
|
||||
print("✓ Basic logging working")
|
||||
|
||||
# Test metrics logger
|
||||
metrics_logger = MetricsLogger(temp_dir)
|
||||
metrics_logger.log(step=1, metrics={'loss': 0.5}, phase='train')
|
||||
print("✓ Metrics logging working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Logging system error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_training_script_config_integration(self):
|
||||
"""Test training script configuration integration"""
|
||||
print("Testing training script configuration integration...")
|
||||
|
||||
try:
|
||||
# Import the main training script functions
|
||||
from scripts.train_m3 import create_config_from_args
|
||||
|
||||
# Create mock arguments
|
||||
class MockArgs:
|
||||
def __init__(self):
|
||||
self.config_path = None
|
||||
self.model_name_or_path = None
|
||||
self.cache_dir = None
|
||||
self.train_data = "dummy"
|
||||
self.eval_data = None
|
||||
self.max_seq_length = None
|
||||
self.query_max_length = None
|
||||
self.passage_max_length = None
|
||||
self.max_length = 8192
|
||||
self.output_dir = "/tmp/test"
|
||||
self.num_train_epochs = None
|
||||
self.per_device_train_batch_size = None
|
||||
self.per_device_eval_batch_size = None
|
||||
self.gradient_accumulation_steps = None
|
||||
self.learning_rate = None
|
||||
self.warmup_ratio = None
|
||||
self.weight_decay = None
|
||||
self.train_group_size = None
|
||||
self.temperature = None
|
||||
self.use_hard_negatives = False
|
||||
self.num_hard_negatives = None
|
||||
self.use_self_distill = False
|
||||
self.use_query_instruction = False
|
||||
self.query_instruction = None
|
||||
self.use_dense = True
|
||||
self.use_sparse = False
|
||||
self.use_colbert = False
|
||||
self.unified_finetuning = False
|
||||
self.augment_queries = False
|
||||
self.augmentation_methods = None
|
||||
self.create_hard_negatives = False
|
||||
self.fp16 = False
|
||||
self.bf16 = False
|
||||
self.gradient_checkpointing = False
|
||||
self.dataloader_num_workers = None
|
||||
self.save_steps = None
|
||||
self.eval_steps = None
|
||||
self.logging_steps = None
|
||||
self.save_total_limit = None
|
||||
self.resume_from_checkpoint = None
|
||||
self.local_rank = -1
|
||||
self.seed = None
|
||||
self.overwrite_output_dir = True
|
||||
|
||||
args = MockArgs()
|
||||
config = create_config_from_args(args)
|
||||
|
||||
assert config.learning_rate > 0
|
||||
assert config.per_device_train_batch_size > 0
|
||||
print("✓ Training script config integration working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Training script config integration error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all tests and generate report"""
|
||||
print("\n" + "="*80)
|
||||
print("BGE FINE-TUNING COMPREHENSIVE TEST SUITE")
|
||||
print("="*80)
|
||||
|
||||
tests = [
|
||||
("Core Dependencies Import", self.test_imports),
|
||||
("Project Components Import", self.test_project_imports),
|
||||
("Configuration System", self.test_configuration_system),
|
||||
("Configuration Integration", self.test_configuration_integration),
|
||||
("Data Processing", self.test_data_processing),
|
||||
("Data Augmentation", self.test_data_augmentation),
|
||||
("Model Functionality", self.test_model_functionality),
|
||||
("Hard Negative Mining", self.test_hard_negative_mining),
|
||||
("Evaluation Metrics", self.test_evaluation_metrics),
|
||||
("Data Format Support", self.test_data_formats),
|
||||
("Device Compatibility", self.test_device_compatibility),
|
||||
("Checkpoint Functionality", self.test_checkpoint_functionality),
|
||||
("Logging System", self.test_logging_system),
|
||||
("Training Script Integration", self.test_training_script_config_integration),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_name, test_func in tests:
|
||||
if self.run_test(test_name, test_func):
|
||||
passed += 1
|
||||
|
||||
# Generate report
|
||||
print("\n" + "="*80)
|
||||
print("TEST RESULTS SUMMARY")
|
||||
print("="*80)
|
||||
print(f"Tests passed: {passed}/{total}")
|
||||
print(f"Success rate: {passed/total*100:.1f}%")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 ALL TESTS PASSED! Project is ready for use.")
|
||||
print("\nNext steps:")
|
||||
print("1. Configure your API keys in config.toml if needed")
|
||||
print("2. Prepare your training data in JSONL format")
|
||||
print("3. Run training with: python scripts/train_m3.py --train_data your_data.jsonl")
|
||||
print("4. Evaluate with: python scripts/evaluate.py --retriever_model output/model")
|
||||
else:
|
||||
print(f"\n❌ {total - passed} tests failed. Please review the errors above.")
|
||||
print("\nFailed tests:")
|
||||
for test_name in self.failed_tests:
|
||||
print(f" - {test_name}")
|
||||
|
||||
print("\nDetailed test results:")
|
||||
for test_name, status, error in self.test_results:
|
||||
status_emoji = "✅" if status == "PASSED" else "❌"
|
||||
print(f" {status_emoji} {test_name}: {status}")
|
||||
if error:
|
||||
print(f" Error: {error}")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
def main():
|
||||
"""Run comprehensive test suite"""
|
||||
test_suite = ComprehensiveTestSuite()
|
||||
success = test_suite.run_all_tests()
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
299
tests/test_training_pipeline.py
Normal file
299
tests/test_training_pipeline.py
Normal file
@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python
|
||||
"""Unified test script to validate both BGE-M3 and BGE-Reranker training pipelines"""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import tempfile
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from config.m3_config import BGEM3Config
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset, collate_embedding_batch, collate_reranker_batch
|
||||
from training.m3_trainer import BGEM3Trainer
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def create_embedding_test_data():
|
||||
"""Create test data for embedding model (triplets format)"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": ["Machine learning is a subset of artificial intelligence that enables systems to learn from data."],
|
||||
"neg": ["The weather today is sunny and warm."]
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"pos": ["Deep learning uses neural networks with multiple layers to learn patterns."],
|
||||
"neg": ["Pizza is a popular Italian dish."]
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def create_reranker_test_data():
|
||||
"""Create test data for reranker model (pairs format)"""
|
||||
data = [
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
},
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "The weather today is sunny and warm.",
|
||||
"label": 0,
|
||||
"score": 0.15,
|
||||
"qid": "q1",
|
||||
"pid": "p2"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Deep learning uses neural networks with multiple layers to learn patterns.",
|
||||
"label": 1,
|
||||
"score": 0.92,
|
||||
"qid": "q2",
|
||||
"pid": "p3"
|
||||
},
|
||||
{
|
||||
"query": "How does deep learning work?",
|
||||
"passage": "Pizza is a popular Italian dish.",
|
||||
"label": 0,
|
||||
"score": 0.12,
|
||||
"qid": "q2",
|
||||
"pid": "p4"
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f.name
|
||||
|
||||
|
||||
def test_bge_m3():
|
||||
"""Test BGE-M3 embedding model"""
|
||||
print("\n🎯 Testing BGE-M3 Embedding Model...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_embedding_test_data()
|
||||
print(f"✅ Created M3 test data: {train_file}")
|
||||
|
||||
# Initialize config
|
||||
config = BGEM3Config(
|
||||
model_name_or_path="./models/bge-m3",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-m3",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=1e-5,
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
use_self_distill=False, # Disable for testing
|
||||
use_hard_negatives=False # Disable for testing
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading M3 model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGEM3Model.from_pretrained(config.model_name_or_path)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ M3 Dataset size: {len(dataset)}")
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_embedding_batch
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGEM3Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test one batch
|
||||
for batch in dataloader:
|
||||
print(f"M3 Batch keys: {list(batch.keys())}")
|
||||
print(f"Query shape: {batch['query_input_ids'].shape}")
|
||||
print(f"Passage shape: {batch['passage_input_ids'].shape}")
|
||||
print(f"Labels shape: {batch['labels'].shape}")
|
||||
|
||||
# Move batch to device
|
||||
batch = trainer._prepare_batch(batch)
|
||||
|
||||
# Compute loss
|
||||
with torch.no_grad():
|
||||
loss = trainer.compute_loss(batch)
|
||||
|
||||
print(f"✅ M3 Loss computed successfully: {loss.item():.4f}")
|
||||
break
|
||||
|
||||
print("✅ BGE-M3 test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BGE-M3 test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
def test_bge_reranker():
|
||||
"""Test BGE-Reranker model"""
|
||||
print("\n🎯 Testing BGE-Reranker Model...")
|
||||
|
||||
# Create test data
|
||||
train_file = create_reranker_test_data()
|
||||
print(f"✅ Created Reranker test data: {train_file}")
|
||||
|
||||
# Initialize config
|
||||
config = BGERerankerConfig(
|
||||
model_name_or_path="./models/bge-reranker-base",
|
||||
train_data_path=train_file,
|
||||
output_dir="output/test-reranker",
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=1e-5,
|
||||
logging_steps=1,
|
||||
save_steps=10,
|
||||
fp16=False,
|
||||
use_amp=False,
|
||||
dataloader_num_workers=0,
|
||||
overwrite_output_dir=True,
|
||||
loss_type="binary_cross_entropy"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
print("📦 Loading Reranker model and tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=config.model_name_or_path,
|
||||
num_labels=1,
|
||||
use_fp16=config.fp16,
|
||||
gradient_checkpointing=config.gradient_checkpointing
|
||||
)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
print(f"🖥️ Using device: {device}")
|
||||
|
||||
# Create dataset
|
||||
dataset = BGERerankerDataset(
|
||||
data_path=train_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=config.query_max_length,
|
||||
passage_max_length=config.passage_max_length,
|
||||
is_train=True
|
||||
)
|
||||
print(f"✅ Reranker Dataset size: {len(dataset)}")
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_reranker_batch
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BGERerankerTrainer(
|
||||
config=config,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Test one batch
|
||||
for batch in dataloader:
|
||||
print(f"Reranker Batch keys: {list(batch.keys())}")
|
||||
print(f"Input IDs shape: {batch['input_ids'].shape}")
|
||||
print(f"Attention mask shape: {batch['attention_mask'].shape}")
|
||||
print(f"Labels shape: {batch['labels'].shape}")
|
||||
print(f"Labels values: {batch['labels']}")
|
||||
|
||||
# Move batch to device
|
||||
batch = trainer._prepare_batch(batch)
|
||||
|
||||
# Compute loss
|
||||
with torch.no_grad():
|
||||
loss = trainer.compute_loss(batch)
|
||||
|
||||
print(f"✅ Reranker Loss computed successfully: {loss.item():.4f}")
|
||||
break
|
||||
|
||||
print("✅ BGE-Reranker test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BGE-Reranker test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
Path(train_file).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("🚀 Starting comprehensive BGE training pipeline tests...")
|
||||
|
||||
m3_success = test_bge_m3()
|
||||
reranker_success = test_bge_reranker()
|
||||
|
||||
print("\n📊 Test Results Summary:")
|
||||
print(f" BGE-M3: {'✅ PASS' if m3_success else '❌ FAIL'}")
|
||||
print(f" BGE-Reranker: {'✅ PASS' if reranker_success else '❌ FAIL'}")
|
||||
|
||||
if m3_success and reranker_success:
|
||||
print("\n🎉 All tests passed! Both models are working correctly.")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the error messages above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
40
training/__init__.py
Normal file
40
training/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""
|
||||
Training modules for BGE models
|
||||
"""
|
||||
|
||||
def get_trainer(trainer_type: str):
|
||||
"""Lazy loading of trainers to avoid importing all at once"""
|
||||
if trainer_type == 'base':
|
||||
from .trainer import BaseTrainer
|
||||
return BaseTrainer
|
||||
elif trainer_type == 'm3':
|
||||
from .m3_trainer import BGEM3Trainer
|
||||
return BGEM3Trainer
|
||||
elif trainer_type == 'reranker':
|
||||
from .reranker_trainer import BGERerankerTrainer
|
||||
return BGERerankerTrainer
|
||||
elif trainer_type == 'joint':
|
||||
from .joint_trainer import JointTrainer
|
||||
return JointTrainer
|
||||
else:
|
||||
raise ValueError(f"Unknown trainer type: {trainer_type}")
|
||||
|
||||
# For backward compatibility, import the commonly used ones
|
||||
from .trainer import BaseTrainer
|
||||
from .m3_trainer import BGEM3Trainer
|
||||
|
||||
# Only import reranker and joint trainers when explicitly requested
|
||||
def __getattr__(name):
|
||||
if name == 'BGERerankerTrainer':
|
||||
from .reranker_trainer import BGERerankerTrainer
|
||||
return BGERerankerTrainer
|
||||
elif name == 'JointTrainer':
|
||||
from .joint_trainer import JointTrainer
|
||||
return JointTrainer
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
__all__ = [
|
||||
'BaseTrainer',
|
||||
'BGEM3Trainer',
|
||||
'get_trainer'
|
||||
]
|
||||
419
training/joint_trainer.py
Normal file
419
training/joint_trainer.py
Normal file
@ -0,0 +1,419 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from .trainer import BaseTrainer
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import (
|
||||
InfoNCELoss, ListwiseCrossEntropyLoss,
|
||||
DynamicListwiseDistillationLoss, MultiTaskLoss
|
||||
)
|
||||
from data.dataset import JointDataset
|
||||
from config.base_config import BaseConfig
|
||||
from utils.distributed import is_main_process, get_world_size, reduce_dict
|
||||
from utils.logging import get_logger
|
||||
|
||||
# Import autocast and GradScaler with fallback
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
# Scheduler import
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JointTrainer(BaseTrainer):
|
||||
"""
|
||||
Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach
|
||||
|
||||
Notes:
|
||||
- All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseConfig,
|
||||
retriever: BGEM3Model,
|
||||
reranker: BGERerankerModel,
|
||||
retriever_optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
reranker_optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
# Create default optimizer if none provided
|
||||
if retriever_optimizer is None:
|
||||
retriever_optimizer = AdamW(
|
||||
retriever.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# We'll use the retriever as the main model for the base class
|
||||
super().__init__(config, retriever, retriever_optimizer, device=device)
|
||||
|
||||
self.retriever = self.model # Alias for clarity
|
||||
self.reranker = reranker
|
||||
self.reranker.to(self.device)
|
||||
|
||||
# Setup reranker optimizer if not provided
|
||||
if reranker_optimizer is None:
|
||||
self.reranker_optimizer = AdamW(
|
||||
self.reranker.parameters(),
|
||||
lr=config.learning_rate * 2, # Higher LR for reranker
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
else:
|
||||
self.reranker_optimizer = reranker_optimizer
|
||||
|
||||
# Initialize losses
|
||||
self.retriever_loss_fn = InfoNCELoss(
|
||||
temperature=config.temperature if hasattr(config, 'temperature') else 0.02,
|
||||
use_in_batch_negatives=True # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
self.reranker_loss_fn = ListwiseCrossEntropyLoss(
|
||||
temperature=1.0,
|
||||
label_smoothing=0.0 # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
self.distillation_loss_fn = DynamicListwiseDistillationLoss(
|
||||
temperature=1.0,
|
||||
alpha=0.5 # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Joint training specific parameters
|
||||
self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100)
|
||||
self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3)
|
||||
self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True)
|
||||
|
||||
# Tracking metrics
|
||||
self.retriever_metrics = []
|
||||
self.reranker_metrics = []
|
||||
|
||||
def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor:
|
||||
"""
|
||||
Compute joint loss for both retriever and reranker
|
||||
|
||||
Args:
|
||||
batch: Tuple of (retriever_batch, reranker_batch)
|
||||
"""
|
||||
retriever_batch, reranker_batch = batch
|
||||
|
||||
# 1. Compute retriever loss
|
||||
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
||||
|
||||
# 2. Compute reranker loss
|
||||
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
||||
|
||||
# 3. Compute distillation losses if enabled
|
||||
if self.use_dynamic_distillation:
|
||||
# Get scores from both models
|
||||
with torch.no_grad():
|
||||
retriever_scores = self._get_retriever_scores(retriever_batch)
|
||||
reranker_scores = self._get_reranker_scores(reranker_batch)
|
||||
|
||||
# Compute mutual distillation
|
||||
retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn(
|
||||
retriever_scores,
|
||||
reranker_scores,
|
||||
labels=reranker_batch.get('labels')
|
||||
)
|
||||
|
||||
# Add distillation to main losses
|
||||
retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss
|
||||
reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss
|
||||
|
||||
# Combine losses (we'll handle separate backward passes)
|
||||
self.current_retriever_loss = retriever_loss
|
||||
self.current_reranker_loss = reranker_loss
|
||||
|
||||
# Return combined loss for logging
|
||||
return retriever_loss + reranker_loss
|
||||
|
||||
def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor:
|
||||
"""Compute retriever (bi-encoder) loss"""
|
||||
# Get embeddings
|
||||
query_outputs = self.retriever(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
passage_outputs = self.retriever(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
# Compute contrastive loss
|
||||
loss = self.retriever_loss_fn(
|
||||
query_outputs['dense'],
|
||||
passage_outputs['dense'],
|
||||
labels=batch.get('labels')
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor:
|
||||
"""Compute reranker (cross-encoder) loss"""
|
||||
outputs = self.reranker(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
|
||||
logits = outputs['logits']
|
||||
labels = batch['labels']
|
||||
|
||||
# Reshape for listwise loss if needed
|
||||
if logits.dim() == 1:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
|
||||
else:
|
||||
# Listwise ranking
|
||||
loss = self.reranker_loss_fn(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
def _get_retriever_scores(self, batch: Dict) -> torch.Tensor:
|
||||
"""Get retriever similarity scores"""
|
||||
with torch.no_grad():
|
||||
query_outputs = self.retriever(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
passage_outputs = self.retriever(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
# Compute similarities
|
||||
scores = torch.matmul(
|
||||
query_outputs['dense'],
|
||||
passage_outputs['dense'].transpose(-1, -2)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _get_reranker_scores(self, batch: Dict) -> torch.Tensor:
|
||||
"""Get reranker scores"""
|
||||
outputs = self.reranker(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
return outputs['logits']
|
||||
|
||||
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
|
||||
"""Train both models for one epoch"""
|
||||
self.retriever.train()
|
||||
self.reranker.train()
|
||||
|
||||
epoch_metrics = {
|
||||
'retriever_loss': 0.0,
|
||||
'reranker_loss': 0.0,
|
||||
'total_loss': 0.0
|
||||
}
|
||||
num_batches = 0
|
||||
|
||||
# Set epoch for distributed sampler
|
||||
if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'):
|
||||
dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
|
||||
|
||||
progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Joint Training Epoch {epoch + 1}",
|
||||
disable=not is_main_process()
|
||||
)
|
||||
|
||||
for step, batch in enumerate(progress_bar):
|
||||
# Prepare batch
|
||||
retriever_batch, reranker_batch = batch
|
||||
retriever_batch = self._prepare_batch(retriever_batch)
|
||||
reranker_batch = self._prepare_batch(reranker_batch)
|
||||
batch = (retriever_batch, reranker_batch)
|
||||
|
||||
# Forward pass
|
||||
total_loss = self.compute_loss(batch)
|
||||
|
||||
# Scale losses for gradient accumulation
|
||||
retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps
|
||||
reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps
|
||||
|
||||
# Backward passes (separate for each model)
|
||||
if self.scaler:
|
||||
# Retriever backward
|
||||
self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined]
|
||||
# Reranker backward
|
||||
self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined]
|
||||
else:
|
||||
retriever_loss.backward(retain_graph=True)
|
||||
reranker_loss.backward()
|
||||
|
||||
# Update weights
|
||||
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
||||
# Gradient clipping
|
||||
if self.scaler:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
self.scaler.unscale_(self.reranker_optimizer)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
||||
self.retriever.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
||||
self.reranker.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
|
||||
# Optimizer steps
|
||||
if self.scaler:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.step(self.reranker_optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.reranker_optimizer.step()
|
||||
|
||||
# Scheduler steps
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
# Note: Add reranker scheduler if needed
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
self.reranker_optimizer.zero_grad()
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
# Update hard negatives periodically
|
||||
if self.global_step % self.update_retriever_interval == 0:
|
||||
self._update_hard_negatives()
|
||||
|
||||
# Logging
|
||||
if self.global_step % self.config.logging_steps == 0:
|
||||
step_metrics = {
|
||||
'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps,
|
||||
'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps,
|
||||
'total_loss': (
|
||||
retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps,
|
||||
'retriever_lr': self.optimizer.param_groups[0]['lr'],
|
||||
'reranker_lr': self.reranker_optimizer.param_groups[0]['lr']
|
||||
}
|
||||
|
||||
if self.is_distributed:
|
||||
step_metrics = reduce_dict(step_metrics)
|
||||
|
||||
progress_bar.set_postfix({
|
||||
'ret_loss': f"{step_metrics['retriever_loss']:.4f}",
|
||||
'rank_loss': f"{step_metrics['reranker_loss']:.4f}"
|
||||
})
|
||||
|
||||
if is_main_process():
|
||||
for key, value in step_metrics.items():
|
||||
self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step)
|
||||
# Flush logs periodically
|
||||
if self.global_step % (self.config.logging_steps * 10) == 0:
|
||||
self.tb_logger.flush()
|
||||
|
||||
# Accumulate metrics
|
||||
epoch_metrics['retriever_loss'] += retriever_loss.item()
|
||||
epoch_metrics['reranker_loss'] += reranker_loss.item()
|
||||
epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item())
|
||||
num_batches += 1
|
||||
|
||||
# Average metrics
|
||||
for key in epoch_metrics:
|
||||
epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps
|
||||
|
||||
if self.is_distributed:
|
||||
epoch_metrics = reduce_dict(epoch_metrics)
|
||||
|
||||
return epoch_metrics
|
||||
|
||||
def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]:
|
||||
"""Evaluate both models on a batch"""
|
||||
retriever_batch, reranker_batch = batch
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Evaluate retriever
|
||||
with torch.no_grad():
|
||||
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
||||
metrics['retriever_loss'] = retriever_loss.item()
|
||||
|
||||
# Compute retrieval metrics
|
||||
retriever_scores = self._get_retriever_scores(retriever_batch)
|
||||
labels = retriever_batch['labels']
|
||||
|
||||
# Simple accuracy: is the positive passage ranked first?
|
||||
top_indices = retriever_scores.argmax(dim=-1)
|
||||
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
||||
metrics['retriever_accuracy'] = correct.item()
|
||||
|
||||
# Evaluate reranker
|
||||
with torch.no_grad():
|
||||
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
||||
metrics['reranker_loss'] = reranker_loss.item()
|
||||
|
||||
# Compute reranking metrics
|
||||
reranker_scores = self._get_reranker_scores(reranker_batch)
|
||||
labels = reranker_batch['labels']
|
||||
|
||||
if reranker_scores.dim() == 1:
|
||||
# Binary accuracy
|
||||
predictions = (reranker_scores > 0).float()
|
||||
correct = (predictions == labels).float().mean()
|
||||
else:
|
||||
# Ranking accuracy
|
||||
top_indices = reranker_scores.argmax(dim=-1)
|
||||
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
||||
|
||||
metrics['reranker_accuracy'] = correct.item()
|
||||
|
||||
metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss']
|
||||
|
||||
return metrics
|
||||
|
||||
def _update_hard_negatives(self):
|
||||
"""Update hard negatives using current models (simplified version)"""
|
||||
# This would implement the dynamic hard negative mining
|
||||
# For now, just log that we would update
|
||||
if is_main_process():
|
||||
logger.info(f"Would update hard negatives at step {self.global_step}")
|
||||
|
||||
def save_models(self, output_dir: str):
|
||||
"""Save both models"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
# Save retriever
|
||||
retriever_path = os.path.join(output_dir, "retriever")
|
||||
os.makedirs(retriever_path, exist_ok=True)
|
||||
|
||||
if hasattr(self.retriever, 'module'):
|
||||
self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
||||
|
||||
# Save reranker
|
||||
reranker_path = os.path.join(output_dir, "reranker")
|
||||
os.makedirs(reranker_path, exist_ok=True)
|
||||
|
||||
if hasattr(self.reranker, 'module'):
|
||||
self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
||||
|
||||
logger.info(f"Saved models to {output_dir}")
|
||||
894
training/m3_trainer.py
Normal file
894
training/m3_trainer.py
Normal file
@ -0,0 +1,894 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
import numpy as np
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.distributed as dist
|
||||
|
||||
from .trainer import BaseTrainer
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.losses import InfoNCELoss, M3KnowledgeDistillationLoss, MultiTaskLoss
|
||||
from data.dataset import BGEM3Dataset
|
||||
|
||||
# Optional hard negative mining
|
||||
try:
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
ANCE_AVAILABLE = True
|
||||
except ImportError:
|
||||
ANCE_AVAILABLE = False
|
||||
ANCEHardNegativeMiner = None
|
||||
|
||||
from utils.checkpoint import CheckpointManager
|
||||
from utils.logging import get_logger
|
||||
from evaluation.metrics import compute_retrieval_metrics
|
||||
from config.m3_config import BGEM3Config
|
||||
|
||||
# Import scheduler with fallback
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BGEM3Trainer(BaseTrainer):
|
||||
"""
|
||||
Trainer for BGE-M3 embedding model with all SOTA techniques
|
||||
|
||||
Notes:
|
||||
- All parameters (config, model, tokenizer, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BGEM3Config,
|
||||
model: Optional[BGEM3Model] = None,
|
||||
tokenizer: Optional[AutoTokenizer] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
# Store config first
|
||||
self.config = config
|
||||
|
||||
# Initialize tokenizer first
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
config.model_name_or_path,
|
||||
cache_dir=config.cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Initialize model
|
||||
if model is None:
|
||||
self.model = BGEM3Model(
|
||||
model_name_or_path=config.model_name_or_path,
|
||||
use_dense=config.use_dense,
|
||||
use_sparse=config.use_sparse,
|
||||
use_colbert=config.use_colbert,
|
||||
pooling_method=config.sentence_pooling_method,
|
||||
normalize_embeddings=config.normalize_embeddings,
|
||||
colbert_dim=config.colbert_dim,
|
||||
sparse_top_k=config.sparse_top_k,
|
||||
cache_dir=config.cache_dir
|
||||
)
|
||||
else:
|
||||
self.model = model
|
||||
|
||||
# Initialize optimizer if not provided
|
||||
if optimizer is None:
|
||||
optimizer = self._create_optimizer()
|
||||
|
||||
# Initialize scheduler if not provided and get_linear_schedule_with_warmup is available
|
||||
if scheduler is None and get_linear_schedule_with_warmup is not None:
|
||||
# We'll create it later when we know the total training steps
|
||||
pass
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(config, self.model, optimizer, scheduler, device)
|
||||
|
||||
# Initialize losses
|
||||
self.infonce_loss = InfoNCELoss(
|
||||
temperature=config.temperature,
|
||||
reduction='mean'
|
||||
)
|
||||
|
||||
self.kd_loss = M3KnowledgeDistillationLoss(
|
||||
temperature=config.self_distill_temperature,
|
||||
alpha=config.self_distill_alpha,
|
||||
distill_types=['dense', 'sparse', 'colbert']
|
||||
)
|
||||
|
||||
self.multi_task_loss = MultiTaskLoss(
|
||||
loss_weights={
|
||||
'dense': config.dense_weight,
|
||||
'sparse': config.sparse_weight,
|
||||
'colbert': config.colbert_weight,
|
||||
'distillation': config.distillation_weight
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize hard negative miner
|
||||
if config.use_hard_negatives and ANCE_AVAILABLE:
|
||||
self.hard_negative_miner = ANCEHardNegativeMiner(
|
||||
embedding_dim=self.model.config.hidden_size, # type: ignore[attr-defined]
|
||||
num_hard_negatives=config.num_hard_negatives,
|
||||
negative_range=config.ance_negative_range,
|
||||
margin=config.ance_margin,
|
||||
index_refresh_interval=config.hard_negative_mining_steps,
|
||||
use_gpu=str(self.device) == "cuda"
|
||||
)
|
||||
self.hard_negative_miner.start_async_updates()
|
||||
else:
|
||||
if config.use_hard_negatives and not ANCE_AVAILABLE:
|
||||
logger.warning("Hard negative mining requested but ANCEHardNegativeMiner not available")
|
||||
self.hard_negative_miner = None
|
||||
|
||||
logger.info(f"Initialized BGE-M3 trainer with config: {config.to_dict()}")
|
||||
|
||||
def _create_optimizer(self) -> torch.optim.Optimizer:
|
||||
"""Create optimizer with parameter groups"""
|
||||
# Separate parameters for different learning rates if needed
|
||||
param_groups = [
|
||||
{
|
||||
'params': [p for n, p in self.model.named_parameters()
|
||||
if 'colbert_linear' not in n and 'sparse_linear' not in n],
|
||||
'lr': self.config.learning_rate
|
||||
}
|
||||
]
|
||||
|
||||
# Different LR for task-specific layers
|
||||
if self.config.use_colbert and hasattr(self.model, 'colbert_linear'):
|
||||
param_groups.append({
|
||||
'params': self.model.colbert_linear.parameters(),
|
||||
'lr': self.config.learning_rate * 2 # Higher LR for new layers
|
||||
})
|
||||
|
||||
if self.config.use_sparse and hasattr(self.model, 'sparse_linear'):
|
||||
param_groups.append({
|
||||
'params': self.model.sparse_linear.parameters(),
|
||||
'lr': self.config.learning_rate * 2
|
||||
})
|
||||
|
||||
optimizer = AdamW(
|
||||
param_groups,
|
||||
eps=self.config.adam_epsilon,
|
||||
weight_decay=self.config.weight_decay,
|
||||
betas=(self.config.adam_beta1, self.config.adam_beta2)
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
def create_scheduler(self, num_training_steps: int):
|
||||
"""Create learning rate scheduler"""
|
||||
if get_linear_schedule_with_warmup is None:
|
||||
logger.warning("transformers not available, using constant learning rate")
|
||||
return None
|
||||
|
||||
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
|
||||
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
return self.scheduler
|
||||
|
||||
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute M3 loss with multiple objectives"""
|
||||
|
||||
# Update hard negative mining if hook is available
|
||||
if hasattr(self, '_mining_hook'):
|
||||
self._mining_hook()
|
||||
|
||||
# Detect batch format and handle accordingly
|
||||
if 'pos_input_ids' in batch and 'neg_input_ids' in batch:
|
||||
# This is triplet format from FastM3Dataset
|
||||
return self._compute_triplet_loss(batch)
|
||||
elif 'passage_input_ids' in batch:
|
||||
# This is standard embedding format
|
||||
return self._compute_embedding_loss(batch)
|
||||
else:
|
||||
raise ValueError(f"Unknown batch format. Keys: {list(batch.keys())}")
|
||||
|
||||
def _compute_triplet_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute loss for triplet format data (FastM3Dataset)"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get positive embeddings
|
||||
pos_outputs = self.model.forward(
|
||||
input_ids=batch['pos_input_ids'],
|
||||
attention_mask=batch['pos_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get negative embeddings
|
||||
neg_outputs = self.model.forward(
|
||||
input_ids=batch['neg_input_ids'],
|
||||
attention_mask=batch['neg_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
losses = {}
|
||||
|
||||
# Dense loss (InfoNCE with explicit pos/neg)
|
||||
if 'dense' in query_outputs:
|
||||
query_dense = query_outputs['dense'].float() # type: ignore[index]
|
||||
pos_dense = pos_outputs['dense'].float() # type: ignore[index]
|
||||
neg_dense = neg_outputs['dense'].float() # type: ignore[index]
|
||||
|
||||
# For triplet format, reshape negative embeddings to expected 3D format
|
||||
# InfoNCE expects negative_embeddings: [batch_size, num_negatives, embedding_dim]
|
||||
# But we have single negatives: [batch_size, embedding_dim]
|
||||
# Reshape to [batch_size, 1, embedding_dim]
|
||||
neg_dense_3d = neg_dense.unsqueeze(1) # [batch_size, 1, embedding_dim]
|
||||
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
pos_dense,
|
||||
neg_dense_3d # Now properly shaped 3D tensor
|
||||
)
|
||||
losses['dense'] = dense_loss
|
||||
|
||||
# Sparse loss
|
||||
if self.config.use_sparse and 'sparse' in query_outputs:
|
||||
from models.losses import SparseLoss
|
||||
sparse_loss_fn = SparseLoss(
|
||||
flops_loss_weight=1e-3,
|
||||
sparse_reg_weight=1e-4
|
||||
)
|
||||
|
||||
# Combine pos and neg for sparse loss
|
||||
combined_sparse = torch.cat([pos_outputs['sparse'], neg_outputs['sparse']], dim=0) # type: ignore[index]
|
||||
query_sparse_repeated = query_outputs['sparse'].repeat(2, 1) # type: ignore[index]
|
||||
labels = torch.cat([torch.ones(len(pos_outputs['sparse'])), # type: ignore[index]
|
||||
torch.zeros(len(neg_outputs['sparse']))], dim=0).to(query_sparse_repeated.device) # type: ignore[index]
|
||||
|
||||
sparse_loss, _ = sparse_loss_fn(
|
||||
query_sparse_repeated,
|
||||
combined_sparse,
|
||||
labels
|
||||
)
|
||||
losses['sparse'] = sparse_loss
|
||||
|
||||
# ColBERT loss
|
||||
if self.config.use_colbert and 'colbert' in query_outputs:
|
||||
from models.losses import ColBERTLoss
|
||||
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
|
||||
|
||||
# Combine pos and neg for ColBERT loss
|
||||
combined_colbert = torch.cat([pos_outputs['colbert'], neg_outputs['colbert']], dim=0) # type: ignore[index]
|
||||
query_colbert_repeated = query_outputs['colbert'].repeat(2, 1, 1) # type: ignore[index]
|
||||
labels = torch.cat([torch.ones(len(pos_outputs['colbert'])), # type: ignore[index]
|
||||
torch.zeros(len(neg_outputs['colbert']))], dim=0).to(query_colbert_repeated.device) # type: ignore[index]
|
||||
|
||||
colbert_loss = colbert_loss_fn(
|
||||
query_colbert_repeated,
|
||||
combined_colbert,
|
||||
labels
|
||||
)
|
||||
losses['colbert'] = colbert_loss
|
||||
|
||||
# Combine losses
|
||||
total_loss, _ = self.multi_task_loss(losses)
|
||||
return total_loss
|
||||
|
||||
def _compute_embedding_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute loss for standard embedding format data"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get passage embeddings - handle both single and multiple passages
|
||||
if batch['passage_input_ids'].dim() == 3:
|
||||
# Multiple passages per query [batch, num_passages, seq_len]
|
||||
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
|
||||
passage_input_ids = batch['passage_input_ids'].view(batch_size * num_passages, seq_len)
|
||||
passage_attention_mask = batch['passage_attention_mask'].view(batch_size * num_passages, seq_len)
|
||||
multiple_passages = True
|
||||
else:
|
||||
# Single passage per query [batch, seq_len]
|
||||
passage_input_ids = batch['passage_input_ids']
|
||||
passage_attention_mask = batch['passage_attention_mask']
|
||||
batch_size = passage_input_ids.shape[0]
|
||||
num_passages = 1
|
||||
multiple_passages = False
|
||||
|
||||
passage_outputs = self.model.forward(
|
||||
input_ids=passage_input_ids,
|
||||
attention_mask=passage_attention_mask,
|
||||
return_dense=True,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute losses for different components
|
||||
losses = {}
|
||||
|
||||
# Dense loss (InfoNCE)
|
||||
if 'dense' in query_outputs and 'dense' in passage_outputs:
|
||||
# Get embeddings and ensure they're on the same device
|
||||
query_dense = query_outputs['dense'].float() # type: ignore[index] # [batch_size, embedding_dim]
|
||||
passage_dense = passage_outputs['dense'].float() # type: ignore[index] # [batch_size * num_passages, embedding_dim]
|
||||
|
||||
# Handle multiple passages per query differently for contrastive learning
|
||||
if multiple_passages:
|
||||
# Reshape passage embeddings back to [batch_size, num_passages, embedding_dim]
|
||||
passage_dense = passage_dense.view(batch_size, num_passages, -1)
|
||||
|
||||
# For contrastive learning, we need to identify positive passages
|
||||
# Use the labels to find the first positive passage for each query
|
||||
labels = batch.get('labels') # [batch_size, num_passages]
|
||||
|
||||
if labels is not None:
|
||||
# Find the first positive passage for each query
|
||||
positive_indices = []
|
||||
negative_passages = []
|
||||
|
||||
for i in range(batch_size):
|
||||
query_labels = labels[i] # [num_passages]
|
||||
positive_mask = query_labels > 0
|
||||
|
||||
if positive_mask.any():
|
||||
# Get first positive passage
|
||||
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
||||
positive_indices.append(pos_idx)
|
||||
|
||||
# Collect negative passages
|
||||
neg_mask = ~positive_mask
|
||||
if neg_mask.any():
|
||||
neg_embeddings = passage_dense[i][neg_mask] # [num_negatives, embedding_dim]
|
||||
negative_passages.append(neg_embeddings)
|
||||
else:
|
||||
negative_passages.append(None)
|
||||
else:
|
||||
# No positive passage, use first passage as positive (fallback)
|
||||
positive_indices.append(0)
|
||||
negative_passages.append(passage_dense[i][1:])
|
||||
|
||||
# Extract positive embeddings
|
||||
positive_embeddings = torch.stack([
|
||||
passage_dense[i, positive_indices[i]] for i in range(batch_size)
|
||||
]) # [batch_size, embedding_dim]
|
||||
|
||||
# Stack negative embeddings (if available)
|
||||
max_neg = max(neg.shape[0] if neg is not None else 0 for neg in negative_passages)
|
||||
if max_neg > 0:
|
||||
# Pad negative embeddings to same size
|
||||
padded_negatives = []
|
||||
for neg in negative_passages:
|
||||
if neg is not None and neg.shape[0] > 0:
|
||||
if neg.shape[0] < max_neg:
|
||||
# Pad with zeros
|
||||
padding = torch.zeros(max_neg - neg.shape[0], neg.shape[1], device=neg.device)
|
||||
neg = torch.cat([neg, padding], dim=0)
|
||||
padded_negatives.append(neg)
|
||||
else:
|
||||
# No negatives, create zero tensor
|
||||
padded_negatives.append(torch.zeros(max_neg, passage_dense.shape[-1], device=passage_dense.device))
|
||||
|
||||
negative_embeddings = torch.stack(padded_negatives) # [batch_size, max_neg, embedding_dim]
|
||||
else:
|
||||
negative_embeddings = None
|
||||
|
||||
# Compute InfoNCE loss with explicit positives and negatives
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
positive_embeddings,
|
||||
negative_embeddings
|
||||
)
|
||||
else:
|
||||
# No labels provided, use first passage as positive, rest as negatives
|
||||
positive_embeddings = passage_dense[:, 0, :] # [batch_size, embedding_dim]
|
||||
if num_passages > 1:
|
||||
negative_embeddings = passage_dense[:, 1:, :] # [batch_size, num_passages-1, embedding_dim]
|
||||
else:
|
||||
negative_embeddings = None
|
||||
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
positive_embeddings,
|
||||
negative_embeddings
|
||||
)
|
||||
else:
|
||||
# Single passage per query - use standard in-batch negatives
|
||||
# InfoNCE will create contrastive pairs using in-batch negatives
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
passage_dense,
|
||||
None # No explicit negatives, will use in-batch
|
||||
)
|
||||
|
||||
losses['dense'] = dense_loss
|
||||
|
||||
# Sparse loss (if using sparse)
|
||||
if self.config.use_sparse and 'sparse' in query_outputs and 'sparse' in passage_outputs:
|
||||
from models.losses import SparseLoss
|
||||
sparse_loss_fn = SparseLoss(
|
||||
flops_loss_weight=1e-3,
|
||||
sparse_reg_weight=1e-4
|
||||
)
|
||||
|
||||
query_sparse = query_outputs['sparse'] # type: ignore[index] # [batch_size, query_seq_len]
|
||||
passage_sparse = passage_outputs['sparse'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len]
|
||||
|
||||
# For sparse loss, we need to handle the contrastive learning differently
|
||||
# Since sparse embeddings have different sequence lengths, we'll use the same approach as dense
|
||||
if multiple_passages:
|
||||
# Use only the first positive passage for each query
|
||||
labels = batch.get('labels')
|
||||
if labels is not None:
|
||||
positive_indices = []
|
||||
for i in range(batch_size):
|
||||
query_labels = labels[i]
|
||||
positive_mask = query_labels > 0
|
||||
if positive_mask.any():
|
||||
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
||||
else:
|
||||
pos_idx = 0
|
||||
positive_indices.append(pos_idx + i * num_passages)
|
||||
|
||||
# Extract positive sparse embeddings
|
||||
positive_sparse = passage_sparse[positive_indices] # [batch_size, doc_seq_len]
|
||||
else:
|
||||
# Use first passage as positive
|
||||
positive_sparse = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
|
||||
|
||||
# Create contrastive labels for sparse loss
|
||||
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
|
||||
|
||||
sparse_loss, _ = sparse_loss_fn(
|
||||
query_sparse,
|
||||
positive_sparse,
|
||||
contrastive_labels
|
||||
)
|
||||
else:
|
||||
# Single passage per query
|
||||
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
|
||||
sparse_loss, _ = sparse_loss_fn(
|
||||
query_sparse,
|
||||
passage_sparse,
|
||||
contrastive_labels
|
||||
)
|
||||
|
||||
losses['sparse'] = sparse_loss
|
||||
|
||||
# ColBERT loss (if using ColBERT)
|
||||
if self.config.use_colbert and 'colbert' in query_outputs and 'colbert' in passage_outputs:
|
||||
from models.losses import ColBERTLoss
|
||||
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
|
||||
|
||||
query_colbert = query_outputs['colbert'] # type: ignore[index] # [batch_size, query_seq_len, dim]
|
||||
passage_colbert = passage_outputs['colbert'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len, dim]
|
||||
|
||||
# For ColBERT, we need pairwise interactions
|
||||
if multiple_passages:
|
||||
# Use only positive passages
|
||||
labels = batch.get('labels')
|
||||
if labels is not None:
|
||||
positive_indices = []
|
||||
for i in range(batch_size):
|
||||
query_labels = labels[i]
|
||||
positive_mask = query_labels > 0
|
||||
if positive_mask.any():
|
||||
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
||||
else:
|
||||
pos_idx = 0
|
||||
positive_indices.append(pos_idx + i * num_passages)
|
||||
|
||||
positive_colbert = passage_colbert[positive_indices]
|
||||
else:
|
||||
positive_colbert = passage_colbert[::num_passages]
|
||||
|
||||
# Binary labels for ColBERT (1 for positive pairs)
|
||||
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
|
||||
else:
|
||||
positive_colbert = passage_colbert
|
||||
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
|
||||
|
||||
colbert_loss = colbert_loss_fn(
|
||||
query_colbert,
|
||||
positive_colbert,
|
||||
colbert_labels
|
||||
)
|
||||
losses['colbert'] = colbert_loss
|
||||
|
||||
# Knowledge distillation loss (only if explicitly enabled)
|
||||
if (getattr(self.config, 'use_self_distill', False) and
|
||||
self.global_step >= getattr(self.config, 'self_distill_start_step', 0)):
|
||||
try:
|
||||
# Compute teacher scores (weighted combination)
|
||||
teacher_scores = self._compute_teacher_scores(query_outputs, passage_outputs) # type: ignore[arg-type]
|
||||
student_scores = self.model.compute_similarity(
|
||||
query_outputs, passage_outputs, mode='dense'
|
||||
)
|
||||
|
||||
kd_loss, _ = self.kd_loss(
|
||||
{'dense': student_scores},
|
||||
{'dense': teacher_scores}
|
||||
)
|
||||
losses['kd'] = kd_loss
|
||||
except Exception as e:
|
||||
logger.warning(f"Knowledge distillation failed: {e}")
|
||||
|
||||
# Combine losses
|
||||
total_loss, _ = self.multi_task_loss(losses)
|
||||
|
||||
return total_loss
|
||||
|
||||
def _compute_teacher_scores(
|
||||
self,
|
||||
query_outputs: Dict[str, torch.Tensor],
|
||||
passage_outputs: Dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""Compute teacher scores from multiple signals"""
|
||||
scores = []
|
||||
weights = []
|
||||
|
||||
# Get original batch size for reshaping
|
||||
query_batch_size = query_outputs['dense'].shape[0] if 'dense' in query_outputs else 1
|
||||
passage_batch_size = passage_outputs['dense'].shape[0] if 'dense' in passage_outputs else 1
|
||||
|
||||
# Check if we have multiple passages per query
|
||||
multiple_passages = passage_batch_size > query_batch_size
|
||||
|
||||
if multiple_passages:
|
||||
# For multiple passages, we need to compute scores only for positive pairs
|
||||
# Use only the first passage for each query as a simplified teacher score
|
||||
num_passages = passage_batch_size // query_batch_size
|
||||
|
||||
# Dense scores (use only first passage per query)
|
||||
if 'dense' in query_outputs and 'dense' in passage_outputs:
|
||||
query_dense = query_outputs['dense'] # [batch_size, dim]
|
||||
passage_dense = passage_outputs['dense'] # [batch_size * num_passages, dim]
|
||||
|
||||
# Extract first passage for each query
|
||||
positive_passages = passage_dense[::num_passages] # [batch_size, dim]
|
||||
|
||||
try:
|
||||
dense_scores = self.model.compute_similarity(
|
||||
{'dense': query_dense},
|
||||
{'dense': positive_passages},
|
||||
mode='dense'
|
||||
)
|
||||
scores.append(dense_scores)
|
||||
weights.append(self.config.dense_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Dense similarity computation failed: {e}")
|
||||
|
||||
# Sparse scores (use only first passage per query)
|
||||
if 'sparse' in query_outputs and 'sparse' in passage_outputs:
|
||||
query_sparse = query_outputs['sparse'] # [batch_size, query_seq_len]
|
||||
passage_sparse = passage_outputs['sparse'] # [batch_size * num_passages, doc_seq_len]
|
||||
|
||||
# Extract first passage for each query
|
||||
positive_passages = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
|
||||
|
||||
try:
|
||||
sparse_scores = self.model.compute_similarity(
|
||||
{'sparse': query_sparse},
|
||||
{'sparse': positive_passages},
|
||||
mode='sparse'
|
||||
)
|
||||
scores.append(sparse_scores)
|
||||
weights.append(self.config.sparse_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Sparse similarity computation failed: {e}")
|
||||
|
||||
# ColBERT scores (use only first passage per query)
|
||||
if 'colbert' in query_outputs and 'colbert' in passage_outputs:
|
||||
query_colbert = query_outputs['colbert'] # [batch_size, query_seq_len, dim]
|
||||
passage_colbert = passage_outputs['colbert'] # [batch_size * num_passages, doc_seq_len, dim]
|
||||
|
||||
# Extract first passage for each query
|
||||
positive_passages = passage_colbert[::num_passages] # [batch_size, doc_seq_len, dim]
|
||||
|
||||
try:
|
||||
colbert_scores = self.model.compute_similarity(
|
||||
{'colbert': query_colbert},
|
||||
{'colbert': positive_passages},
|
||||
mode='colbert'
|
||||
)
|
||||
scores.append(colbert_scores)
|
||||
weights.append(self.config.colbert_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"ColBERT similarity computation failed: {e}")
|
||||
else:
|
||||
# Single passage per query - original logic
|
||||
# Dense scores
|
||||
if 'dense' in query_outputs:
|
||||
try:
|
||||
dense_scores = self.model.compute_similarity(
|
||||
{'dense': query_outputs['dense']},
|
||||
{'dense': passage_outputs['dense']},
|
||||
mode='dense'
|
||||
)
|
||||
scores.append(dense_scores)
|
||||
weights.append(self.config.dense_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Dense similarity computation failed: {e}")
|
||||
|
||||
# Sparse scores
|
||||
if 'sparse' in query_outputs:
|
||||
try:
|
||||
sparse_scores = self.model.compute_similarity(
|
||||
{'sparse': query_outputs['sparse']},
|
||||
{'sparse': passage_outputs['sparse']},
|
||||
mode='sparse'
|
||||
)
|
||||
scores.append(sparse_scores)
|
||||
weights.append(self.config.sparse_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Sparse similarity computation failed: {e}")
|
||||
|
||||
# ColBERT scores
|
||||
if 'colbert' in query_outputs:
|
||||
try:
|
||||
colbert_scores = self.model.compute_similarity(
|
||||
{'colbert': query_outputs['colbert']},
|
||||
{'colbert': passage_outputs['colbert']},
|
||||
mode='colbert'
|
||||
)
|
||||
scores.append(colbert_scores)
|
||||
weights.append(self.config.colbert_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"ColBERT similarity computation failed: {e}")
|
||||
|
||||
# Weighted combination
|
||||
if scores:
|
||||
weights = torch.tensor(weights, device=scores[0].device)
|
||||
weights = weights / weights.sum()
|
||||
teacher_scores = torch.stack([w * s for w, s in zip(weights, scores)]).sum(dim=0)
|
||||
return teacher_scores
|
||||
|
||||
# Fallback to zero scores if no valid similarity computation
|
||||
fallback_size = query_batch_size if not multiple_passages else query_batch_size
|
||||
return torch.zeros(fallback_size, device=query_outputs[list(query_outputs.keys())[0]].device)
|
||||
|
||||
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Evaluate a single batch"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get passage embeddings
|
||||
passage_outputs = self.model.forward(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute loss
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Compute similarity scores
|
||||
scores = self.model.compute_similarity(
|
||||
query_outputs, passage_outputs, mode='dense'
|
||||
)
|
||||
|
||||
# Basic metrics
|
||||
labels = batch.get('labels', torch.zeros(scores.shape[0]))
|
||||
if labels.dim() > 1:
|
||||
labels = labels.view(-1)
|
||||
|
||||
predictions = (scores > 0.5).float()
|
||||
accuracy = (predictions == labels.float()).float().mean()
|
||||
|
||||
return {
|
||||
'loss': loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'mean_score': scores.mean().item()
|
||||
}
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_dataset: BGEM3Dataset,
|
||||
eval_dataset: Optional[BGEM3Dataset] = None,
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Main training loop for M3Trainer with robust error handling, config-driven parameters, and progress tracking.
|
||||
"""
|
||||
try:
|
||||
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
||||
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
||||
|
||||
# Calculate total training steps
|
||||
num_update_steps_per_epoch = max(1, len(train_dataloader) // self.config.gradient_accumulation_steps) # type: ignore[arg-type]
|
||||
total_training_steps = num_update_steps_per_epoch * self.config.num_train_epochs
|
||||
|
||||
# Create scheduler with total steps
|
||||
self.create_scheduler(total_training_steps)
|
||||
|
||||
if resume_from_checkpoint:
|
||||
self.load_checkpoint(resume_from_checkpoint)
|
||||
|
||||
logger.info(f"Starting training for {self.config.num_train_epochs} epochs")
|
||||
logger.info(f"Total training steps: {total_training_steps}")
|
||||
|
||||
# Use parent's train method
|
||||
super().train(train_dataloader, eval_dataloader, self.config.num_train_epochs) # type: ignore[arg-type]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"M3 training failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup
|
||||
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
|
||||
self.hard_negative_miner.stop_async_updates()
|
||||
|
||||
def _create_dataloader(
|
||||
self,
|
||||
dataset: Optional[BGEM3Dataset],
|
||||
is_train: bool = True
|
||||
) -> Optional[DataLoader]:
|
||||
"""Create data loader"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
# Import both collate functions
|
||||
from data.dataset import collate_embedding_batch
|
||||
|
||||
# Choose appropriate collate function based on dataset type
|
||||
collate_fn = collate_embedding_batch
|
||||
num_workers = self.config.dataloader_num_workers
|
||||
|
||||
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
|
||||
|
||||
# Setup distributed sampler if needed
|
||||
sampler = None
|
||||
if self.is_distributed:
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
sampler = DistributedSampler(dataset, shuffle=is_train)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=is_train and sampler is None,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
pin_memory=self.config.dataloader_pin_memory,
|
||||
drop_last=is_train and len(dataset) > batch_size, # Only drop last if we have multiple batches
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
|
||||
"""Train one epoch - use parent's method with hard negative mining integration"""
|
||||
# Add hard negative mining updates before calling parent's method
|
||||
if self.hard_negative_miner:
|
||||
# Set up a hook to update hard negatives during training
|
||||
def mining_hook():
|
||||
if self.hard_negative_miner:
|
||||
self.hard_negative_miner.update_index_if_needed(self.global_step, self.model)
|
||||
|
||||
# Store the hook for use in compute_loss
|
||||
self._mining_hook = mining_hook
|
||||
|
||||
# Use parent's robust training loop
|
||||
return super().train_epoch(dataloader, epoch)
|
||||
|
||||
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
||||
"""Evaluate the model"""
|
||||
self.model.eval()
|
||||
|
||||
all_query_embeddings = []
|
||||
all_passage_embeddings = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, desc="Evaluating") as tbar:
|
||||
for batch in tbar:
|
||||
try:
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
passage_outputs = self.model.forward(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
all_query_embeddings.append(query_outputs['dense'].cpu()) # type: ignore[index]
|
||||
all_passage_embeddings.append(passage_outputs['dense'].cpu()) # type: ignore[index]
|
||||
all_labels.append(batch['labels'].cpu())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in M3 evaluation step: {e}")
|
||||
continue
|
||||
|
||||
if not all_query_embeddings:
|
||||
logger.warning("No valid evaluation batches processed")
|
||||
return {'loss': float('inf'), 'accuracy': 0.0}
|
||||
|
||||
# Concatenate all embeddings
|
||||
all_query_embeddings = torch.cat(all_query_embeddings, dim=0)
|
||||
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0)
|
||||
all_labels = torch.cat(all_labels, dim=0)
|
||||
|
||||
# Compute metrics
|
||||
try:
|
||||
metrics = compute_retrieval_metrics(
|
||||
all_query_embeddings.numpy(),
|
||||
all_passage_embeddings.numpy(),
|
||||
all_labels.numpy(),
|
||||
k_values=[1, 3, 5, 10]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error computing retrieval metrics: {e}")
|
||||
metrics = {'loss': float('inf'), 'accuracy': 0.0}
|
||||
|
||||
return metrics
|
||||
|
||||
def save_model(self, save_path: str):
|
||||
"""Save the trained model"""
|
||||
try:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
model_to_save.save_pretrained(save_path) # type: ignore[attr-defined]
|
||||
|
||||
# Save tokenizer
|
||||
self.tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
|
||||
|
||||
# Save config
|
||||
self.config.save(os.path.join(save_path, 'training_config.json'))
|
||||
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when trainer is destroyed"""
|
||||
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
|
||||
self.hard_negative_miner.stop_async_updates()
|
||||
489
training/reranker_trainer.py
Normal file
489
training/reranker_trainer.py
Normal file
@ -0,0 +1,489 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from typing import Dict, Optional, List, Tuple, Any
|
||||
import logging
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
from .trainer import BaseTrainer
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import ListwiseCrossEntropyLoss, PairwiseMarginLoss
|
||||
from data.dataset import BGERerankerDataset
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
from evaluation.metrics import compute_reranker_metrics
|
||||
from utils.distributed import is_main_process
|
||||
|
||||
# AMP imports
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
|
||||
# Scheduler import
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGERerankerTrainer(BaseTrainer):
|
||||
"""
|
||||
Trainer for BGE-Reranker cross-encoder model
|
||||
|
||||
Notes:
|
||||
- All parameters (config, model, tokenizer, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BGERerankerConfig,
|
||||
model: Optional[BGERerankerModel] = None,
|
||||
tokenizer=None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
# Initialize model if not provided
|
||||
if model is None:
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=config.model_name_or_path,
|
||||
num_labels=config.num_labels,
|
||||
cache_dir=config.cache_dir,
|
||||
use_fp16=config.fp16,
|
||||
gradient_checkpointing=config.gradient_checkpointing
|
||||
)
|
||||
|
||||
# Initialize optimizer if not provided
|
||||
if optimizer is None:
|
||||
optimizer = self._create_optimizer(model, config)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(config, model, optimizer, scheduler, device)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.config = config
|
||||
|
||||
# Initialize losses based on config
|
||||
if config.loss_type == "listwise_ce":
|
||||
self.loss_fn = ListwiseCrossEntropyLoss(
|
||||
temperature=getattr(config, 'temperature', 1.0)
|
||||
)
|
||||
elif config.loss_type == "pairwise_margin":
|
||||
self.loss_fn = PairwiseMarginLoss(margin=config.margin)
|
||||
elif config.loss_type == "combined":
|
||||
self.listwise_loss = ListwiseCrossEntropyLoss(
|
||||
temperature=getattr(config, 'temperature', 1.0)
|
||||
)
|
||||
self.pairwise_loss = PairwiseMarginLoss(margin=config.margin)
|
||||
else:
|
||||
self.loss_fn = nn.BCEWithLogitsLoss()
|
||||
|
||||
# Knowledge distillation setup
|
||||
self.use_kd = config.use_knowledge_distillation
|
||||
if self.use_kd:
|
||||
self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
|
||||
self.kd_temperature = config.distillation_temperature
|
||||
self.kd_loss_weight = config.distillation_alpha
|
||||
|
||||
# Hard negative configuration
|
||||
self.hard_negative_ratio = config.hard_negative_ratio
|
||||
self.use_denoising = config.use_denoising
|
||||
self.denoise_threshold = getattr(config, 'denoise_confidence_threshold', 0.5)
|
||||
|
||||
def _create_optimizer(self, model: nn.Module, config: BGERerankerConfig) -> torch.optim.Optimizer:
|
||||
"""Create optimizer with proper parameter groups"""
|
||||
# Separate parameters
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
'params': [p for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)],
|
||||
'weight_decay': config.weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [p for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)],
|
||||
'weight_decay': 0.0,
|
||||
}
|
||||
]
|
||||
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=config.learning_rate,
|
||||
eps=config.adam_epsilon,
|
||||
betas=(config.adam_beta1, config.adam_beta2)
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
def create_scheduler(self, num_training_steps: int) -> Optional[object]:
|
||||
"""Create learning rate scheduler"""
|
||||
if get_linear_schedule_with_warmup is None:
|
||||
logger.warning("transformers not available, using constant learning rate")
|
||||
return None
|
||||
|
||||
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
|
||||
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
return self.scheduler
|
||||
|
||||
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute reranker loss"""
|
||||
# Get model outputs
|
||||
outputs = self.model.forward(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
logits = outputs['logits']
|
||||
else:
|
||||
logits = outputs.scores
|
||||
|
||||
labels = batch['labels']
|
||||
|
||||
# Handle padding labels
|
||||
if labels.dim() > 1:
|
||||
# Remove padding (-1 labels)
|
||||
valid_mask = labels != -1
|
||||
if valid_mask.any():
|
||||
logits = logits[valid_mask]
|
||||
labels = labels[valid_mask]
|
||||
else:
|
||||
# If all labels are padding, return zero loss
|
||||
return torch.tensor(0.0, device=logits.device, requires_grad=True)
|
||||
|
||||
# Handle different loss types
|
||||
if self.config.loss_type == "listwise_ce":
|
||||
# Listwise loss expects grouped format
|
||||
try:
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
if groups:
|
||||
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
|
||||
else:
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
except Exception as e:
|
||||
# Fallback to BCE with specific error logging
|
||||
logger.warning(f"Listwise CE loss failed: {e}. Falling back to BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
elif self.config.loss_type == "pairwise_margin":
|
||||
# Extract positive and negative scores
|
||||
pos_mask = labels == 1
|
||||
neg_mask = labels == 0
|
||||
|
||||
if pos_mask.any() and neg_mask.any():
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
try:
|
||||
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
|
||||
except Exception as e:
|
||||
# Fallback to BCE with specific error logging
|
||||
logger.warning(f"Pairwise margin loss failed: {e}. Falling back to BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
else:
|
||||
# Fallback to BCE if no proper pairs
|
||||
logger.warning("No positive-negative pairs found for pairwise loss. Using BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
elif self.config.loss_type == "combined":
|
||||
# Combine listwise and pairwise losses
|
||||
try:
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
if groups:
|
||||
listwise_loss = self.listwise_loss(logits.squeeze(-1), labels.float(), groups)
|
||||
pairwise_loss = self.pairwise_loss(logits.squeeze(-1), labels.float(), groups)
|
||||
loss = self.config.listwise_weight * listwise_loss + self.config.pairwise_weight * pairwise_loss
|
||||
else:
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
except Exception as e:
|
||||
# Fallback to BCE with specific error logging
|
||||
logger.warning(f"Combined loss computation failed: {e}. Falling back to BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
else:
|
||||
# Default binary cross-entropy
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
# Add knowledge distillation loss if enabled
|
||||
if self.use_kd and 'teacher_scores' in batch:
|
||||
teacher_scores = batch['teacher_scores']
|
||||
student_probs = F.log_softmax(logits / self.kd_temperature, dim=-1)
|
||||
teacher_probs = F.softmax(teacher_scores / self.kd_temperature, dim=-1)
|
||||
|
||||
kd_loss = self.kd_loss_fn(student_probs, teacher_probs) * (self.kd_temperature ** 2)
|
||||
loss = (1 - self.kd_loss_weight) * loss + self.kd_loss_weight * kd_loss
|
||||
|
||||
# Apply denoising if enabled
|
||||
if self.use_denoising and 'teacher_scores' in batch:
|
||||
# Denoise by confidence thresholding
|
||||
confidence_scores = torch.sigmoid(batch['teacher_scores'])
|
||||
confidence_mask = confidence_scores > self.denoise_threshold
|
||||
if confidence_mask.any():
|
||||
loss = loss * confidence_mask.float().mean()
|
||||
|
||||
return loss
|
||||
|
||||
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Evaluate a single batch"""
|
||||
with torch.no_grad():
|
||||
# Get model outputs
|
||||
outputs = self.model.forward(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
logits = outputs['logits']
|
||||
else:
|
||||
logits = outputs.scores
|
||||
|
||||
labels = batch['labels']
|
||||
|
||||
# Handle padding labels
|
||||
if labels.dim() > 1:
|
||||
valid_mask = labels != -1
|
||||
if valid_mask.any():
|
||||
logits = logits[valid_mask]
|
||||
labels = labels[valid_mask]
|
||||
else:
|
||||
return {'loss': 0.0, 'accuracy': 0.0}
|
||||
|
||||
# Compute loss
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Compute metrics based on group size
|
||||
if hasattr(self.config, 'train_group_size') and self.config.train_group_size > 1:
|
||||
# Listwise evaluation
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
if groups:
|
||||
try:
|
||||
metrics = compute_reranker_metrics(
|
||||
logits.squeeze(-1).cpu().numpy(),
|
||||
labels.cpu().numpy(),
|
||||
groups=groups,
|
||||
k_values=[1, 3, 5, 10]
|
||||
)
|
||||
metrics['loss'] = loss.item()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error computing reranker metrics: {e}")
|
||||
metrics = {'loss': loss.item(), 'accuracy': 0.0}
|
||||
else:
|
||||
metrics = {'loss': loss.item(), 'accuracy': 0.0}
|
||||
else:
|
||||
# Binary evaluation
|
||||
predictions = (torch.sigmoid(logits.squeeze(-1)) > 0.5).float()
|
||||
accuracy = (predictions == labels.float()).float().mean()
|
||||
|
||||
metrics = {
|
||||
'accuracy': accuracy.item(),
|
||||
'loss': loss.item()
|
||||
}
|
||||
|
||||
# Add AUC if we have both positive and negative examples
|
||||
if len(torch.unique(labels)) > 1:
|
||||
try:
|
||||
from sklearn.metrics import roc_auc_score
|
||||
auc = roc_auc_score(labels.cpu().numpy(), torch.sigmoid(logits.squeeze(-1)).cpu().numpy())
|
||||
metrics['auc'] = float(auc)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error computing AUC: {e}")
|
||||
|
||||
return metrics
|
||||
|
||||
def train_with_hard_negative_mining(
|
||||
self,
|
||||
train_dataset: BGERerankerDataset,
|
||||
eval_dataset: Optional[BGERerankerDataset] = None,
|
||||
retriever_model: Optional[nn.Module] = None
|
||||
):
|
||||
"""
|
||||
Train with periodic hard negative mining from retriever
|
||||
"""
|
||||
if retriever_model is None:
|
||||
logger.warning("No retriever model provided for hard negative mining. Using standard training.")
|
||||
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
||||
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
||||
|
||||
# Only proceed if we have a valid training dataloader
|
||||
if train_dataloader is not None:
|
||||
return super().train(train_dataloader, eval_dataloader)
|
||||
else:
|
||||
logger.error("No valid training dataloader created")
|
||||
return
|
||||
|
||||
# Training loop with hard negative updates
|
||||
for epoch in range(self.config.num_train_epochs):
|
||||
# Mine hard negatives at the beginning of each epoch
|
||||
if epoch > 0:
|
||||
logger.info(f"Mining hard negatives for epoch {epoch + 1}")
|
||||
self._mine_hard_negatives(train_dataset, retriever_model)
|
||||
|
||||
# Create data loader with new negatives
|
||||
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
||||
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
||||
|
||||
# Only proceed if we have a valid training dataloader
|
||||
if train_dataloader is not None:
|
||||
# Train one epoch
|
||||
super().train(train_dataloader, eval_dataloader, num_epochs=1)
|
||||
else:
|
||||
logger.error(f"No valid training dataloader for epoch {epoch + 1}")
|
||||
continue
|
||||
|
||||
# Update current epoch
|
||||
self.current_epoch = epoch
|
||||
|
||||
def _create_dataloader(self, dataset, is_train: bool = True) -> Optional[DataLoader]:
|
||||
"""Create data loader"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
from data.dataset import collate_reranker_batch
|
||||
|
||||
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
|
||||
|
||||
# Setup distributed sampler if needed
|
||||
sampler = None
|
||||
if self.is_distributed:
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
sampler = DistributedSampler(dataset, shuffle=is_train)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=is_train and sampler is None,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.dataloader_num_workers,
|
||||
pin_memory=self.config.dataloader_pin_memory,
|
||||
drop_last=is_train,
|
||||
collate_fn=collate_reranker_batch
|
||||
)
|
||||
|
||||
def _mine_hard_negatives(
|
||||
self,
|
||||
dataset: BGERerankerDataset,
|
||||
retriever_model: Any
|
||||
):
|
||||
"""Mine hard negatives using the retriever model and update the dataset's negatives."""
|
||||
logger.info("Starting hard negative mining using retriever model...")
|
||||
num_hard_negatives = getattr(self.config, 'num_hard_negatives', 10)
|
||||
retriever_batch_size = getattr(self.config, 'retriever_batch_size', 32)
|
||||
mined_count = 0
|
||||
|
||||
# Collect all unique candidate passages from the dataset
|
||||
all_passages = set()
|
||||
for item in dataset.data:
|
||||
all_passages.update(item.get('pos', []))
|
||||
all_passages.update(item.get('neg', []))
|
||||
all_passages = list(all_passages)
|
||||
|
||||
if not all_passages:
|
||||
logger.warning("No candidate passages found in dataset for hard negative mining.")
|
||||
return
|
||||
|
||||
# Encode all candidate passages
|
||||
try:
|
||||
passage_embeddings = retriever_model.encode(
|
||||
all_passages,
|
||||
batch_size=retriever_batch_size,
|
||||
convert_to_numpy=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding passages for hard negative mining: {e}")
|
||||
return
|
||||
|
||||
# For each query, mine hard negatives
|
||||
for idx in range(len(dataset)):
|
||||
item = dataset.data[idx]
|
||||
query = item['query']
|
||||
positives = set(item.get('pos', []))
|
||||
existing_negs = set(item.get('neg', []))
|
||||
|
||||
try:
|
||||
query_emb = retriever_model.encode([query], batch_size=1, convert_to_numpy=True)[0]
|
||||
|
||||
# Compute similarity
|
||||
scores = np.dot(passage_embeddings, query_emb)
|
||||
|
||||
# Rank passages by similarity
|
||||
ranked_indices = np.argsort(scores)[::-1]
|
||||
|
||||
new_negs = []
|
||||
for i in ranked_indices:
|
||||
passage = all_passages[int(i)] # Convert to int for indexing
|
||||
if passage not in positives and passage not in existing_negs:
|
||||
new_negs.append(passage)
|
||||
if len(new_negs) >= num_hard_negatives:
|
||||
break
|
||||
|
||||
if new_negs:
|
||||
item['neg'] = list(existing_negs) + new_negs
|
||||
mined_count += 1
|
||||
logger.debug(f"Query idx {idx}: Added {len(new_negs)} hard negatives.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error mining hard negatives for query idx {idx}: {e}")
|
||||
|
||||
logger.info(f"Hard negative mining complete. Updated negatives for {mined_count} queries.")
|
||||
|
||||
def save_model(self, save_path: str):
|
||||
"""Save the trained model"""
|
||||
try:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
if hasattr(model_to_save, 'save_pretrained'):
|
||||
model_to_save.save_pretrained(save_path) # type: ignore
|
||||
else:
|
||||
# Fallback to torch.save
|
||||
torch.save(model_to_save.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) # type: ignore
|
||||
|
||||
# Save tokenizer if available
|
||||
if self.tokenizer and hasattr(self.tokenizer, 'save_pretrained'):
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
# Save config
|
||||
if hasattr(self.config, 'save'):
|
||||
self.config.save(os.path.join(save_path, 'training_config.json'))
|
||||
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
raise
|
||||
|
||||
def export_for_inference(self, output_path: str):
|
||||
"""Export model for efficient inference"""
|
||||
if is_main_process():
|
||||
logger.info(f"Exporting model for inference to {output_path}")
|
||||
|
||||
# Save model in inference mode
|
||||
self.model.eval()
|
||||
|
||||
# Save model state
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'config': self.config.to_dict(),
|
||||
'model_type': 'bge_reranker'
|
||||
}, output_path)
|
||||
|
||||
# Also save in HuggingFace format if possible
|
||||
try:
|
||||
hf_path = output_path.replace('.pt', '_hf')
|
||||
self.save_model(hf_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save in HuggingFace format: {e}")
|
||||
|
||||
logger.info(f"Model exported successfully")
|
||||
534
training/trainer.py
Normal file
534
training/trainer.py
Normal file
@ -0,0 +1,534 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typing import Dict, Optional, List, Tuple, Any, Union, cast
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from abc import ABC, abstractmethod
|
||||
import time
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
|
||||
from torch.nn.utils.clip_grad import clip_grad_norm_
|
||||
from collections import defaultdict
|
||||
|
||||
from utils.checkpoint import CheckpointManager
|
||||
from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger
|
||||
from utils.distributed import (
|
||||
setup_distributed, cleanup_distributed, is_main_process,
|
||||
reduce_dict, synchronize, get_rank, get_world_size
|
||||
)
|
||||
from config.base_config import BaseConfig
|
||||
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTrainer(ABC):
|
||||
"""
|
||||
Abstract base trainer class for BGE models
|
||||
|
||||
Notes:
|
||||
- All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseConfig,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
|
||||
# Setup device
|
||||
if device is None:
|
||||
self.device = torch.device(config.device)
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Move model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Setup distributed training if needed
|
||||
self.is_distributed = config.local_rank != -1
|
||||
if self.is_distributed:
|
||||
setup_distributed()
|
||||
self.model = self._setup_distributed_model()
|
||||
|
||||
# Setup mixed precision training
|
||||
self.use_amp = config.fp16 or config.bf16
|
||||
if self.use_amp and config.bf16:
|
||||
# Use bfloat16 for better numerical stability on RTX5090/A100
|
||||
self.amp_dtype = torch.bfloat16
|
||||
else:
|
||||
self.amp_dtype = torch.float16
|
||||
self.scaler = GradScaler() if self.use_amp and not config.bf16 else None
|
||||
|
||||
# Memory optimization settings
|
||||
self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100)
|
||||
self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
|
||||
|
||||
# Setup logging
|
||||
self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None
|
||||
self.tb_logger = TensorBoardLogger(
|
||||
os.path.join(config.logging_dir, 'tensorboard'),
|
||||
enabled=is_main_process()
|
||||
)
|
||||
self.progress_logger = None # Will be initialized in train()
|
||||
|
||||
# Setup checkpointing
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
output_dir=config.output_dir,
|
||||
save_total_limit=config.save_total_limit
|
||||
) if is_main_process() else None
|
||||
|
||||
# Training state
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
|
||||
self.training_history = []
|
||||
|
||||
def _setup_distributed_model(self) -> nn.Module:
|
||||
"""Setup model for distributed training"""
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
model = DDP(
|
||||
self.model,
|
||||
device_ids=[self.config.local_rank],
|
||||
output_device=self.config.local_rank,
|
||||
find_unused_parameters=True
|
||||
)
|
||||
return model
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute loss for a batch - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Evaluate a single batch - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: Optional[DataLoader] = None,
|
||||
num_epochs: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Main training loop with robust error handling, config-driven parameters, and progress tracking.
|
||||
"""
|
||||
try:
|
||||
num_epochs = num_epochs or self.config.num_train_epochs
|
||||
total_steps = len(train_dataloader) * num_epochs
|
||||
|
||||
self.progress_logger = ProgressLogger(
|
||||
total_steps=total_steps,
|
||||
log_interval=self.config.logging_steps
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined]
|
||||
logger.info(f" Num epochs = {num_epochs}")
|
||||
logger.info(f" Batch size = {self.config.per_device_train_batch_size}")
|
||||
logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {total_steps}")
|
||||
logger.info(f" Device = {self.device}")
|
||||
if self.is_distributed:
|
||||
logger.info(f" Distributed training with {get_world_size()} GPUs")
|
||||
|
||||
for epoch in range(self.current_epoch, num_epochs):
|
||||
self.current_epoch = epoch
|
||||
epoch_start_time = time.time()
|
||||
|
||||
train_metrics = self.train_epoch(train_dataloader, epoch)
|
||||
|
||||
# Add metrics to training history
|
||||
epoch_metrics = {
|
||||
'epoch': epoch + 1,
|
||||
'train_metrics': train_metrics,
|
||||
'global_step': self.global_step,
|
||||
'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0,
|
||||
}
|
||||
|
||||
if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0:
|
||||
eval_metrics = self.evaluate(eval_dataloader)
|
||||
epoch_metrics['eval_metrics'] = eval_metrics
|
||||
current_metric = eval_metrics.get(self.config.metric_for_best_model, 0)
|
||||
is_best = self._is_better_metric(current_metric)
|
||||
if is_best:
|
||||
self.best_metric = current_metric
|
||||
if is_main_process() and self.checkpoint_manager:
|
||||
self._save_checkpoint(is_best=is_best, metrics=eval_metrics)
|
||||
if is_main_process():
|
||||
logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}")
|
||||
|
||||
# Add to training history
|
||||
self.training_history.append(epoch_metrics)
|
||||
|
||||
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
if is_main_process():
|
||||
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds")
|
||||
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
|
||||
|
||||
if is_main_process():
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
self._save_training_summary()
|
||||
# Close TensorBoard logger
|
||||
self.tb_logger.close()
|
||||
|
||||
if self.is_distributed:
|
||||
cleanup_distributed()
|
||||
|
||||
return self.training_history # type: ignore[return-value]
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
def train_epoch(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
epoch: int
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Train for one epoch with comprehensive error handling and recovery
|
||||
|
||||
Args:
|
||||
dataloader: Training data loader
|
||||
epoch: Current epoch number
|
||||
|
||||
Returns:
|
||||
Dictionary of training metrics for the epoch
|
||||
|
||||
Raises:
|
||||
RuntimeError: If training fails and cannot recover
|
||||
"""
|
||||
self.model.train()
|
||||
epoch_loss = 0.0
|
||||
epoch_metrics = defaultdict(float)
|
||||
num_batches = len(dataloader)
|
||||
|
||||
# Track failed batches for recovery
|
||||
failed_batches = []
|
||||
max_failures = 3
|
||||
|
||||
# Training loop with error recovery
|
||||
progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Epoch {epoch + 1}",
|
||||
disable=not is_main_process()
|
||||
)
|
||||
|
||||
for step, batch in enumerate(progress_bar):
|
||||
try:
|
||||
# Move batch to device with error handling
|
||||
try:
|
||||
batch = self._prepare_batch(batch)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to prepare batch {step}: {e}")
|
||||
failed_batches.append((step, 'prepare', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Compute loss with error recovery
|
||||
loss: torch.Tensor
|
||||
try:
|
||||
if self.use_amp:
|
||||
with autocast(enabled=True, dtype=self.amp_dtype):
|
||||
loss = self.compute_loss(batch)
|
||||
loss = cast(torch.Tensor, loss)
|
||||
else:
|
||||
loss = self.compute_loss(batch)
|
||||
loss = cast(torch.Tensor, loss)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compute loss for batch {step}: {e}")
|
||||
failed_batches.append((step, 'forward', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Scale loss for gradient accumulation
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
|
||||
# Backward pass with error handling
|
||||
try:
|
||||
if self.scaler:
|
||||
self.scaler.scale(loss).backward() # type: ignore[attr-defined]
|
||||
else:
|
||||
loss.backward() # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
logger.warning(f"Backward pass failed for batch {step}: {e}")
|
||||
# Clear gradients and continue
|
||||
self.optimizer.zero_grad()
|
||||
failed_batches.append((step, 'backward', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Update weights with error handling
|
||||
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
||||
try:
|
||||
# Gradient clipping
|
||||
if self.scaler:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
|
||||
clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
if self.scaler:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
# Scheduler step
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
|
||||
# Update global step
|
||||
self.global_step += 1
|
||||
|
||||
# Clear cache periodically for memory efficiency
|
||||
if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0:
|
||||
self._clear_device_cache()
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Optimizer update failed at step {step}: {e}")
|
||||
# Reset optimizer state
|
||||
self.optimizer.zero_grad()
|
||||
if self.scaler:
|
||||
self.scaler.update()
|
||||
failed_batches.append((step, 'optimizer', str(e)))
|
||||
continue
|
||||
|
||||
# Logging with error handling
|
||||
if self.global_step % self.config.logging_steps == 0:
|
||||
try:
|
||||
current_loss = loss.item() * self.config.gradient_accumulation_steps
|
||||
epoch_loss += current_loss
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({
|
||||
'loss': f'{current_loss:.4f}',
|
||||
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A',
|
||||
'failures': len(failed_batches)
|
||||
})
|
||||
|
||||
# Log to TensorBoard
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_scalar('train/loss', current_loss, self.global_step)
|
||||
self.tb_logger.log_scalar('train/learning_rate',
|
||||
self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'],
|
||||
self.global_step)
|
||||
except Exception as e:
|
||||
logger.debug(f"Logging failed: {e}")
|
||||
|
||||
# Checkpoint saving with error handling
|
||||
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
|
||||
try:
|
||||
if self.checkpoint_manager:
|
||||
state_dict = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
||||
'global_step': self.global_step,
|
||||
'current_epoch': epoch,
|
||||
'config': vars(self.config)
|
||||
}
|
||||
self.checkpoint_manager.save(state_dict, self.global_step)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}")
|
||||
# Continue training even if checkpoint fails
|
||||
|
||||
except Exception as e:
|
||||
# Catch-all for any unexpected errors
|
||||
logger.error(f"Unexpected error in training step {step}: {e}")
|
||||
failed_batches.append((step, 'unknown', str(e)))
|
||||
if len(failed_batches) > max_failures * 2:
|
||||
raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Log epoch summary
|
||||
if failed_batches:
|
||||
logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches")
|
||||
for step, stage, error in failed_batches[:5]: # Show first 5 failures
|
||||
logger.debug(f" Batch {step} failed at {stage}: {error}")
|
||||
|
||||
# Compute epoch metrics
|
||||
epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1)
|
||||
epoch_metrics['failed_batches'] = len(failed_batches)
|
||||
epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
return epoch_metrics
|
||||
|
||||
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
||||
"""Evaluate the model"""
|
||||
self.model.eval()
|
||||
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
all_metrics = {}
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()):
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.evaluate_step(batch)
|
||||
|
||||
# Accumulate metrics
|
||||
for key, value in metrics.items():
|
||||
if key not in all_metrics:
|
||||
all_metrics[key] = 0
|
||||
all_metrics[key] += value
|
||||
|
||||
num_batches += 1
|
||||
|
||||
# Average metrics
|
||||
for key in all_metrics:
|
||||
all_metrics[key] /= num_batches
|
||||
|
||||
# Gather metrics across devices
|
||||
if self.is_distributed:
|
||||
all_metrics = reduce_dict(all_metrics)
|
||||
|
||||
return all_metrics
|
||||
|
||||
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Move tensor values in batch dict to the target device, leave non-tensors untouched."""
|
||||
moved = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
moved[k] = v.to(self.device, non_blocking=True)
|
||||
else:
|
||||
moved[k] = v # keep lists / strings / ints as-is
|
||||
return moved
|
||||
|
||||
def _clear_device_cache(self):
|
||||
"""Clear device memory cache with support for CUDA and Ascend NPU"""
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# Try Ascend support if available
|
||||
try:
|
||||
from utils.ascend_support import clear_ascend_cache
|
||||
clear_ascend_cache()
|
||||
except ImportError:
|
||||
# Fallback to manual NPU check
|
||||
try:
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and hasattr(npu, 'is_available') and npu.is_available():
|
||||
npu.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to clear device cache: {e}")
|
||||
|
||||
def _is_better_metric(self, current: float) -> bool:
|
||||
"""Check if current metric is better than best"""
|
||||
if self.config.greater_is_better:
|
||||
return current > self.best_metric
|
||||
else:
|
||||
return current < self.best_metric
|
||||
|
||||
def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None):
|
||||
"""Save model checkpoint"""
|
||||
if not is_main_process() or not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
checkpoint = {
|
||||
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined]
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
||||
'config': self.config.to_dict(),
|
||||
'global_step': self.global_step,
|
||||
'current_epoch': self.current_epoch,
|
||||
'best_metric': self.best_metric
|
||||
}
|
||||
|
||||
if self.scaler:
|
||||
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
||||
|
||||
self.checkpoint_manager.save(
|
||||
checkpoint,
|
||||
step=self.global_step,
|
||||
is_best=is_best,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
def _save_training_summary(self):
|
||||
"""Save training summary"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
summary = {
|
||||
'config': self.config.to_dict(),
|
||||
'final_metrics': self.training_history[-1] if self.training_history else {},
|
||||
'best_metric': self.best_metric,
|
||||
'total_steps': self.global_step,
|
||||
'total_epochs': self.current_epoch,
|
||||
'training_time': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
summary_path = os.path.join(self.config.output_dir, 'training_summary.json')
|
||||
with open(summary_path, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load checkpoint"""
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
if self.is_distributed:
|
||||
self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
||||
|
||||
# Load optimizer state
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Load scheduler state
|
||||
if self.scheduler and checkpoint.get('scheduler_state_dict'):
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
# Load scaler state
|
||||
if self.scaler and checkpoint.get('scaler_state_dict'):
|
||||
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||
|
||||
# Load training state
|
||||
self.global_step = checkpoint.get('global_step', 0)
|
||||
self.current_epoch = checkpoint.get('current_epoch', 0)
|
||||
self.best_metric = checkpoint.get('best_metric', float('-inf'))
|
||||
|
||||
logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")
|
||||
127
utils/__init__.py
Normal file
127
utils/__init__.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Utility modules for logging, checkpointing, and distributed training
|
||||
"""
|
||||
from .logging import (
|
||||
setup_logging,
|
||||
get_logger,
|
||||
MetricsLogger,
|
||||
TensorBoardLogger,
|
||||
ProgressLogger,
|
||||
log_model_info,
|
||||
log_training_summary
|
||||
)
|
||||
from .checkpoint import CheckpointManager, save_model_for_inference
|
||||
from .distributed import (
|
||||
setup_distributed,
|
||||
cleanup_distributed,
|
||||
is_main_process,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
get_local_rank,
|
||||
synchronize,
|
||||
all_reduce,
|
||||
all_gather,
|
||||
broadcast,
|
||||
reduce_dict,
|
||||
setup_model_for_distributed,
|
||||
create_distributed_dataloader,
|
||||
set_random_seed,
|
||||
print_distributed_info,
|
||||
DistributedMetricAggregator,
|
||||
save_on_master,
|
||||
log_on_master
|
||||
)
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def optional_import(module_name, feature_desc=None, warn_only=True):
|
||||
"""Try to import an optional module. Log a warning if not found. Returns the module or None."""
|
||||
import importlib
|
||||
try:
|
||||
return importlib.import_module(module_name)
|
||||
except ImportError:
|
||||
msg = f"Optional dependency '{module_name}' not found."
|
||||
if feature_desc:
|
||||
msg += f" {feature_desc} will be unavailable."
|
||||
if warn_only:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.error(msg)
|
||||
return None
|
||||
|
||||
# Import guard for optional dependencies
|
||||
|
||||
def _check_optional_dependencies():
|
||||
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
|
||||
# Check for torch_npu if running on Ascend
|
||||
torch = optional_import('torch', warn_only=True)
|
||||
if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'):
|
||||
optional_import('torch_npu', feature_desc='Ascend NPU support')
|
||||
|
||||
# Check for wandb if Weights & Biases logging is enabled
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
use_wandb = False
|
||||
try:
|
||||
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if use_wandb:
|
||||
optional_import('wandb', feature_desc='Weights & Biases logging')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for tensorboard if tensorboard logging is likely to be used
|
||||
tensorboard_dir = None
|
||||
try:
|
||||
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if tensorboard_dir:
|
||||
optional_import('tensorboard', feature_desc='TensorBoard logging')
|
||||
|
||||
# Check for pytest if running tests
|
||||
import sys
|
||||
if any('pytest' in arg for arg in sys.argv):
|
||||
optional_import('pytest', feature_desc='Some tests may not run')
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Logging
|
||||
'setup_logging',
|
||||
'get_logger',
|
||||
'MetricsLogger',
|
||||
'TensorBoardLogger',
|
||||
'ProgressLogger',
|
||||
'log_model_info',
|
||||
'log_training_summary',
|
||||
|
||||
# Checkpointing
|
||||
'CheckpointManager',
|
||||
'save_model_for_inference',
|
||||
|
||||
# Distributed training
|
||||
'setup_distributed',
|
||||
'cleanup_distributed',
|
||||
'is_main_process',
|
||||
'get_rank',
|
||||
'get_world_size',
|
||||
'get_local_rank',
|
||||
'synchronize',
|
||||
'all_reduce',
|
||||
'all_gather',
|
||||
'broadcast',
|
||||
'reduce_dict',
|
||||
'setup_model_for_distributed',
|
||||
'create_distributed_dataloader',
|
||||
'set_random_seed',
|
||||
'print_distributed_info',
|
||||
'DistributedMetricAggregator',
|
||||
'save_on_master',
|
||||
'log_on_master'
|
||||
]
|
||||
271
utils/ascend_support.py
Normal file
271
utils/ascend_support.py
Normal file
@ -0,0 +1,271 @@
|
||||
"""
|
||||
Huawei Ascend NPU support utilities for BGE fine-tuning
|
||||
|
||||
This module provides compatibility layer for running BGE models on Ascend NPUs.
|
||||
Requires torch_npu to be installed: pip install torch_npu
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag to track Ascend availability
|
||||
_ASCEND_AVAILABLE = False
|
||||
_ASCEND_CHECK_DONE = False
|
||||
|
||||
|
||||
def check_ascend_availability() -> bool:
|
||||
"""
|
||||
Check if Ascend NPU is available
|
||||
|
||||
Returns:
|
||||
bool: True if Ascend NPU is available and properly configured
|
||||
"""
|
||||
global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE
|
||||
|
||||
if _ASCEND_CHECK_DONE:
|
||||
return _ASCEND_AVAILABLE
|
||||
|
||||
_ASCEND_CHECK_DONE = True
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Check if NPU is available
|
||||
if hasattr(torch, 'npu') and torch.npu.is_available():
|
||||
_ASCEND_AVAILABLE = True
|
||||
npu_count = torch.npu.device_count()
|
||||
logger.info(f"Ascend NPU available with {npu_count} devices")
|
||||
|
||||
# Log NPU information
|
||||
for i in range(npu_count):
|
||||
props = torch.npu.get_device_properties(i)
|
||||
logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB")
|
||||
else:
|
||||
logger.info("Ascend NPU not available")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("torch_npu not installed. Ascend NPU support disabled.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Ascend availability: {e}")
|
||||
|
||||
return _ASCEND_AVAILABLE
|
||||
|
||||
|
||||
def get_ascend_backend() -> str:
|
||||
"""
|
||||
Get the appropriate backend for Ascend NPU
|
||||
|
||||
Returns:
|
||||
str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU
|
||||
"""
|
||||
if check_ascend_availability():
|
||||
return 'hccl'
|
||||
elif torch.cuda.is_available():
|
||||
return 'nccl'
|
||||
else:
|
||||
return 'gloo'
|
||||
|
||||
|
||||
def setup_ascend_device(local_rank: int = 0) -> torch.device:
|
||||
"""
|
||||
Setup Ascend NPU device
|
||||
|
||||
Args:
|
||||
local_rank: Local rank for distributed training
|
||||
|
||||
Returns:
|
||||
torch.device: Configured device
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
raise RuntimeError("Ascend NPU not available")
|
||||
|
||||
try:
|
||||
# Set NPU device
|
||||
torch.npu.set_device(local_rank)
|
||||
device = torch.device(f'npu:{local_rank}')
|
||||
logger.info(f"Using Ascend NPU device: {device}")
|
||||
return device
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup Ascend device: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module:
|
||||
"""
|
||||
Convert model for Ascend NPU compatibility
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
|
||||
Returns:
|
||||
Model configured for Ascend
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return model
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Apply Ascend-specific optimizations
|
||||
model = torch_npu.contrib.transfer_to_npu(model)
|
||||
|
||||
# Enable Ascend graph mode if available
|
||||
if hasattr(torch_npu, 'enable_graph_mode'):
|
||||
torch_npu.enable_graph_mode()
|
||||
logger.info("Enabled Ascend graph mode")
|
||||
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply Ascend optimizations: {e}")
|
||||
return model
|
||||
|
||||
|
||||
def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Optimize configuration for Ascend NPU
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
|
||||
Returns:
|
||||
Optimized configuration
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return config
|
||||
|
||||
# Create a copy to avoid modifying original
|
||||
config = config.copy()
|
||||
|
||||
# Ascend-specific optimizations
|
||||
ascend_config = config.get('ascend', {})
|
||||
|
||||
# Set backend to hccl for distributed training
|
||||
if ascend_config.get('backend') == 'hccl':
|
||||
config['distributed'] = config.get('distributed', {})
|
||||
config['distributed']['backend'] = 'hccl'
|
||||
|
||||
# Enable mixed precision if specified
|
||||
if ascend_config.get('mixed_precision', False):
|
||||
config['optimization'] = config.get('optimization', {})
|
||||
config['optimization']['fp16'] = True
|
||||
logger.info("Enabled mixed precision for Ascend")
|
||||
|
||||
# Set visible devices
|
||||
if 'visible_devices' in ascend_config:
|
||||
os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices']
|
||||
logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]:
|
||||
"""
|
||||
Get Ascend NPU memory information
|
||||
|
||||
Args:
|
||||
device_id: NPU device ID
|
||||
|
||||
Returns:
|
||||
Tuple of (used_memory_gb, total_memory_gb)
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return 0.0, 0.0
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Get memory info
|
||||
used = torch_npu.memory_allocated(device_id) / 1024**3
|
||||
total = torch_npu.get_device_properties(device_id).total_memory / 1024**3
|
||||
|
||||
return used, total
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Ascend memory info: {e}")
|
||||
return 0.0, 0.0
|
||||
|
||||
|
||||
def clear_ascend_cache():
|
||||
"""Clear Ascend NPU memory cache"""
|
||||
if not check_ascend_availability():
|
||||
return
|
||||
|
||||
try:
|
||||
torch.npu.empty_cache()
|
||||
logger.debug("Cleared Ascend NPU cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear Ascend cache: {e}")
|
||||
|
||||
|
||||
def is_ascend_available() -> bool:
|
||||
"""
|
||||
Simple check for Ascend availability
|
||||
|
||||
Returns:
|
||||
bool: True if Ascend NPU is available
|
||||
"""
|
||||
return check_ascend_availability()
|
||||
|
||||
|
||||
class AscendDataLoader:
|
||||
"""
|
||||
Wrapper for DataLoader with Ascend-specific optimizations
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, device):
|
||||
self.dataloader = dataloader
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
for batch in self.dataloader:
|
||||
# Move batch to NPU device
|
||||
if isinstance(batch, dict):
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for v in batch]
|
||||
else:
|
||||
batch = batch.to(self.device)
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
|
||||
# Compatibility functions
|
||||
def npu_synchronize():
|
||||
"""Synchronize NPU device (compatibility wrapper)"""
|
||||
if check_ascend_availability():
|
||||
try:
|
||||
torch.npu.synchronize()
|
||||
except Exception as e:
|
||||
logger.warning(f"NPU synchronize failed: {e}")
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def get_device_name(device: torch.device) -> str:
|
||||
"""
|
||||
Get device name with Ascend support
|
||||
|
||||
Args:
|
||||
device: PyTorch device
|
||||
|
||||
Returns:
|
||||
str: Device name
|
||||
"""
|
||||
if device.type == 'npu':
|
||||
if check_ascend_availability():
|
||||
try:
|
||||
props = torch.npu.get_device_properties(device.index or 0)
|
||||
return f"Ascend {props.name}"
|
||||
except:
|
||||
return "Ascend NPU"
|
||||
return "NPU (unavailable)"
|
||||
elif device.type == 'cuda':
|
||||
return torch.cuda.get_device_name(device.index or 0)
|
||||
else:
|
||||
return str(device)
|
||||
465
utils/checkpoint.py
Normal file
465
utils/checkpoint.py
Normal file
@ -0,0 +1,465 @@
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import glob
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manages model checkpoints with automatic cleanup and best model tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
save_total_limit: int = 3,
|
||||
best_model_dir: str = "best_model",
|
||||
save_optimizer: bool = True,
|
||||
save_scheduler: bool = True
|
||||
):
|
||||
self.output_dir = output_dir
|
||||
self.save_total_limit = save_total_limit
|
||||
self.best_model_dir = os.path.join(output_dir, best_model_dir)
|
||||
self.save_optimizer = save_optimizer
|
||||
self.save_scheduler = save_scheduler
|
||||
|
||||
# Create directories
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(self.best_model_dir, exist_ok=True)
|
||||
|
||||
# Track checkpoints
|
||||
self.checkpoints = []
|
||||
self._load_existing_checkpoints()
|
||||
|
||||
def _load_existing_checkpoints(self):
|
||||
"""Load list of existing checkpoints"""
|
||||
checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
|
||||
|
||||
# Sort by step number
|
||||
self.checkpoints = []
|
||||
for cp_dir in checkpoint_dirs:
|
||||
try:
|
||||
step = int(cp_dir.split("-")[-1])
|
||||
self.checkpoints.append((step, cp_dir))
|
||||
except:
|
||||
pass
|
||||
|
||||
self.checkpoints.sort(key=lambda x: x[0])
|
||||
|
||||
def save(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
step: int,
|
||||
is_best: bool = False,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Save checkpoint with atomic writes and error handling
|
||||
|
||||
Args:
|
||||
state_dict: State dictionary containing model, optimizer, etc.
|
||||
step: Current training step
|
||||
is_best: Whether this is the best model so far
|
||||
metrics: Optional metrics to save with checkpoint
|
||||
|
||||
Returns:
|
||||
Path to saved checkpoint
|
||||
"""
|
||||
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
||||
temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp")
|
||||
|
||||
try:
|
||||
# Create temporary directory for atomic write
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Save individual components
|
||||
checkpoint_files = {}
|
||||
|
||||
# Save model with error handling
|
||||
if 'model_state_dict' in state_dict:
|
||||
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
|
||||
self._safe_torch_save(state_dict['model_state_dict'], model_path)
|
||||
checkpoint_files['model'] = 'pytorch_model.bin'
|
||||
|
||||
# Save optimizer
|
||||
if self.save_optimizer and 'optimizer_state_dict' in state_dict:
|
||||
optimizer_path = os.path.join(temp_dir, 'optimizer.pt')
|
||||
self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path)
|
||||
checkpoint_files['optimizer'] = 'optimizer.pt'
|
||||
|
||||
# Save scheduler
|
||||
if self.save_scheduler and 'scheduler_state_dict' in state_dict:
|
||||
scheduler_path = os.path.join(temp_dir, 'scheduler.pt')
|
||||
self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path)
|
||||
checkpoint_files['scheduler'] = 'scheduler.pt'
|
||||
|
||||
# Save training state
|
||||
training_state = {
|
||||
'step': step,
|
||||
'checkpoint_files': checkpoint_files,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Add additional state info
|
||||
for key in ['global_step', 'current_epoch', 'best_metric']:
|
||||
if key in state_dict:
|
||||
training_state[key] = state_dict[key]
|
||||
|
||||
if metrics:
|
||||
training_state['metrics'] = metrics
|
||||
|
||||
# Save config if provided
|
||||
if 'config' in state_dict:
|
||||
config_path = os.path.join(temp_dir, 'config.json')
|
||||
self._safe_json_save(state_dict['config'], config_path)
|
||||
|
||||
# Save training state
|
||||
state_path = os.path.join(temp_dir, 'training_state.json')
|
||||
self._safe_json_save(training_state, state_path)
|
||||
|
||||
# Atomic rename from temp to final directory
|
||||
if os.path.exists(checkpoint_dir):
|
||||
# Backup existing checkpoint
|
||||
backup_dir = f"{checkpoint_dir}.backup"
|
||||
if os.path.exists(backup_dir):
|
||||
shutil.rmtree(backup_dir)
|
||||
shutil.move(checkpoint_dir, backup_dir)
|
||||
|
||||
# Move temp directory to final location
|
||||
shutil.move(temp_dir, checkpoint_dir)
|
||||
|
||||
# Remove backup if successful
|
||||
if os.path.exists(f"{checkpoint_dir}.backup"):
|
||||
shutil.rmtree(f"{checkpoint_dir}.backup")
|
||||
|
||||
# Add to checkpoint list
|
||||
self.checkpoints.append((step, checkpoint_dir))
|
||||
|
||||
# Save best model
|
||||
if is_best:
|
||||
self._save_best_model(checkpoint_dir, step, metrics)
|
||||
|
||||
# Cleanup old checkpoints
|
||||
self._cleanup_checkpoints()
|
||||
|
||||
logger.info(f"Saved checkpoint to {checkpoint_dir}")
|
||||
return checkpoint_dir
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
# Cleanup temp directory
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
# Restore backup if exists
|
||||
backup_dir = f"{checkpoint_dir}.backup"
|
||||
if os.path.exists(backup_dir):
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
shutil.move(backup_dir, checkpoint_dir)
|
||||
raise RuntimeError(f"Checkpoint save failed: {e}")
|
||||
|
||||
def _save_best_model(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
step: int,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
):
|
||||
"""Save the best model"""
|
||||
# Copy checkpoint to best model directory
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
src = os.path.join(checkpoint_dir, filename)
|
||||
dst = os.path.join(self.best_model_dir, filename)
|
||||
if os.path.isfile(src):
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Save best model info
|
||||
best_info = {
|
||||
'step': step,
|
||||
'source_checkpoint': checkpoint_dir,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
if metrics:
|
||||
best_info['metrics'] = metrics
|
||||
|
||||
best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json')
|
||||
with open(best_info_path, 'w') as f:
|
||||
json.dump(best_info, f, indent=2)
|
||||
|
||||
logger.info(f"Saved best model from step {step}")
|
||||
|
||||
def _cleanup_checkpoints(self):
|
||||
"""Remove old checkpoints to maintain save_total_limit"""
|
||||
if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit:
|
||||
# Remove oldest checkpoints
|
||||
checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
|
||||
|
||||
for step, checkpoint_dir in checkpoints_to_remove:
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
logger.info(f"Removed old checkpoint: {checkpoint_dir}")
|
||||
|
||||
# Update checkpoint list
|
||||
self.checkpoints = self.checkpoints[-self.save_total_limit:]
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
load_best: bool = False,
|
||||
map_location: str = 'cpu',
|
||||
strict: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load checkpoint with robust error handling
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to specific checkpoint
|
||||
load_best: Whether to load best model
|
||||
map_location: Device to map tensors to
|
||||
strict: Whether to strictly enforce state dict matching
|
||||
|
||||
Returns:
|
||||
Loaded state dictionary
|
||||
|
||||
Raises:
|
||||
RuntimeError: If checkpoint loading fails
|
||||
"""
|
||||
try:
|
||||
if load_best:
|
||||
checkpoint_path = self.best_model_dir
|
||||
elif checkpoint_path is None:
|
||||
# Load latest checkpoint
|
||||
if self.checkpoints:
|
||||
checkpoint_path = self.checkpoints[-1][1]
|
||||
else:
|
||||
raise ValueError("No checkpoints found")
|
||||
|
||||
# Ensure checkpoint_path is a string
|
||||
if checkpoint_path is None:
|
||||
raise ValueError("Checkpoint path is None")
|
||||
|
||||
checkpoint_path = str(checkpoint_path) # Ensure it's a string
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
|
||||
# Load training state with error handling
|
||||
state_dict = {}
|
||||
state_path = os.path.join(checkpoint_path, 'training_state.json')
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r') as f:
|
||||
training_state = json.load(f)
|
||||
state_dict.update(training_state)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load training state: {e}")
|
||||
training_state = {}
|
||||
else:
|
||||
training_state = {}
|
||||
|
||||
# Load model state
|
||||
model_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
|
||||
if os.path.exists(model_path):
|
||||
try:
|
||||
state_dict['model_state_dict'] = torch.load(
|
||||
model_path,
|
||||
map_location=map_location,
|
||||
weights_only=True # Security: prevent arbitrary code execution
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading without weights_only for backward compatibility
|
||||
try:
|
||||
state_dict['model_state_dict'] = torch.load(
|
||||
model_path,
|
||||
map_location=map_location
|
||||
)
|
||||
logger.warning("Loaded checkpoint using legacy mode (weights_only=False)")
|
||||
except Exception as legacy_e:
|
||||
raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}")
|
||||
else:
|
||||
logger.warning(f"Model state not found at {model_path}")
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
|
||||
if os.path.exists(optimizer_path):
|
||||
try:
|
||||
state_dict['optimizer_state_dict'] = torch.load(
|
||||
optimizer_path,
|
||||
map_location=map_location,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading
|
||||
try:
|
||||
state_dict['optimizer_state_dict'] = torch.load(
|
||||
optimizer_path,
|
||||
map_location=map_location
|
||||
)
|
||||
except Exception as legacy_e:
|
||||
logger.warning(f"Failed to load optimizer state: {e}")
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
|
||||
if os.path.exists(scheduler_path):
|
||||
try:
|
||||
state_dict['scheduler_state_dict'] = torch.load(
|
||||
scheduler_path,
|
||||
map_location=map_location,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading
|
||||
try:
|
||||
state_dict['scheduler_state_dict'] = torch.load(
|
||||
scheduler_path,
|
||||
map_location=map_location
|
||||
)
|
||||
except Exception as legacy_e:
|
||||
logger.warning(f"Failed to load scheduler state: {e}")
|
||||
|
||||
# Load config if available
|
||||
config_path = os.path.join(checkpoint_path, 'config.json')
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
state_dict['config'] = json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config: {e}")
|
||||
|
||||
logger.info(f"Successfully loaded checkpoint from {checkpoint_path}")
|
||||
return state_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
raise RuntimeError(f"Checkpoint loading failed: {e}")
|
||||
|
||||
def get_latest_checkpoint(self) -> Optional[str]:
|
||||
"""Get path to latest checkpoint"""
|
||||
if self.checkpoints:
|
||||
return self.checkpoints[-1][1]
|
||||
return None
|
||||
|
||||
def get_best_checkpoint(self) -> str:
|
||||
"""Get path to best model checkpoint"""
|
||||
return self.best_model_dir
|
||||
|
||||
def checkpoint_exists(self, step: int) -> bool:
|
||||
"""Check if checkpoint exists for given step"""
|
||||
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
||||
return os.path.exists(checkpoint_dir)
|
||||
|
||||
def _safe_torch_save(self, obj: Any, path: str):
|
||||
"""Save torch object with error handling"""
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
torch.save(obj, temp_path)
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||
|
||||
def _safe_json_save(self, obj: Any, path: str):
|
||||
"""Save JSON with atomic write"""
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
with open(temp_path, 'w') as f:
|
||||
json.dump(obj, f, indent=2)
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||
|
||||
|
||||
def save_model_for_inference(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
output_dir: str,
|
||||
model_name: str = "model",
|
||||
save_config: bool = True,
|
||||
additional_files: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""
|
||||
Save model in a format ready for inference
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
tokenizer: Tokenizer to save
|
||||
output_dir: Output directory
|
||||
model_name: Name for the model
|
||||
save_config: Whether to save model config
|
||||
additional_files: Additional files to include
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model_path = os.path.join(output_dir, f"{model_name}.pt")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
logger.info(f"Saved model to {model_path}")
|
||||
|
||||
# Save tokenizer
|
||||
if hasattr(tokenizer, 'save_pretrained'):
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
logger.info(f"Saved tokenizer to {output_dir}")
|
||||
|
||||
# Save config if available
|
||||
if save_config and hasattr(model, 'config'):
|
||||
config_path = os.path.join(output_dir, 'config.json')
|
||||
model.config.save_pretrained(output_dir) # type: ignore[attr-defined]
|
||||
|
||||
# Save additional files
|
||||
if additional_files:
|
||||
for filename, content in additional_files.items():
|
||||
filepath = os.path.join(output_dir, filename)
|
||||
with open(filepath, 'w') as f:
|
||||
if isinstance(content, dict):
|
||||
json.dump(content, f, indent=2)
|
||||
else:
|
||||
f.write(str(content))
|
||||
|
||||
# Create README
|
||||
readme_content = f"""# {model_name} Model
|
||||
|
||||
This directory contains a fine-tuned model ready for inference.
|
||||
|
||||
## Files:
|
||||
- `{model_name}.pt`: Model weights
|
||||
- `tokenizer_config.json`: Tokenizer configuration
|
||||
- `special_tokens_map.json`: Special tokens mapping
|
||||
- `vocab.txt` or `tokenizer.json`: Vocabulary file
|
||||
|
||||
## Loading the model:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('{output_dir}')
|
||||
|
||||
# Load model
|
||||
model = YourModelClass() # Initialize your model class
|
||||
model.load_state_dict(torch.load('{model_name}.pt'))
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Model Information:
|
||||
- Model type: {model.__class__.__name__}
|
||||
- Saved on: {datetime.now().isoformat()}
|
||||
"""
|
||||
|
||||
readme_path = os.path.join(output_dir, 'README.md')
|
||||
with open(readme_path, 'w') as f:
|
||||
f.write(readme_content)
|
||||
|
||||
logger.info(f"Model saved for inference in {output_dir}")
|
||||
518
utils/config_loader.py
Normal file
518
utils/config_loader.py
Normal file
@ -0,0 +1,518 @@
|
||||
"""
|
||||
Centralized configuration loader for BGE fine-tuning project
|
||||
"""
|
||||
import os
|
||||
import toml
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_model_path(
|
||||
model_key: str,
|
||||
config_instance: 'ConfigLoader',
|
||||
model_name_or_path: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Resolve model path with automatic local/remote fallback
|
||||
|
||||
Args:
|
||||
model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker')
|
||||
config_instance: ConfigLoader instance
|
||||
model_name_or_path: Optional override from CLI
|
||||
cache_dir: Optional cache directory
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_model_path, resolved_cache_dir)
|
||||
"""
|
||||
# Default remote model names for each model type
|
||||
remote_models = {
|
||||
'bge_m3': 'BAAI/bge-m3',
|
||||
'bge_reranker': 'BAAI/bge-reranker-base',
|
||||
}
|
||||
|
||||
# Get model source (huggingface or modelscope)
|
||||
model_source = config_instance.get('model_paths', 'source', 'huggingface')
|
||||
|
||||
# If ModelScope is used, adjust model names
|
||||
if model_source == 'modelscope':
|
||||
remote_models = {
|
||||
'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base',
|
||||
'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base',
|
||||
}
|
||||
|
||||
# 1. Check CLI override first (only if explicitly provided)
|
||||
if model_name_or_path is not None:
|
||||
resolved_model_path = model_name_or_path
|
||||
logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}")
|
||||
else:
|
||||
# 2. Check local path from config
|
||||
local_path_raw = config_instance.get('model_paths', model_key, None)
|
||||
|
||||
# Handle list/string conversion
|
||||
if isinstance(local_path_raw, list):
|
||||
local_path = local_path_raw[0] if local_path_raw else None
|
||||
else:
|
||||
local_path = local_path_raw
|
||||
|
||||
if local_path and isinstance(local_path, str) and os.path.exists(local_path):
|
||||
# Check if it's a valid model directory
|
||||
if _is_valid_model_directory(local_path):
|
||||
resolved_model_path = local_path
|
||||
logger.info(f"[LOCAL] Using local model: {resolved_model_path}")
|
||||
else:
|
||||
logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote")
|
||||
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
||||
else:
|
||||
# 3. Fall back to remote model
|
||||
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
||||
logger.info(f"[REMOTE] Using remote model: {resolved_model_path}")
|
||||
|
||||
# If local path was configured but doesn't exist, we'll download to that location
|
||||
if local_path and isinstance(local_path, str):
|
||||
logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
|
||||
# Resolve cache directory
|
||||
if cache_dir is not None:
|
||||
resolved_cache_dir = cache_dir
|
||||
else:
|
||||
cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data')
|
||||
if isinstance(cache_dir_raw, list):
|
||||
resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data'
|
||||
else:
|
||||
resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data'
|
||||
|
||||
# Ensure cache directory exists
|
||||
if isinstance(resolved_cache_dir, str):
|
||||
os.makedirs(resolved_cache_dir, exist_ok=True)
|
||||
|
||||
return resolved_model_path, resolved_cache_dir
|
||||
|
||||
|
||||
def _is_valid_model_directory(path: str) -> bool:
|
||||
"""Check if a directory contains a valid model"""
|
||||
required_files = ['config.json']
|
||||
optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json']
|
||||
model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json']
|
||||
|
||||
path_obj = Path(path)
|
||||
if not path_obj.exists() or not path_obj.is_dir():
|
||||
return False
|
||||
|
||||
# Check for required files
|
||||
for file in required_files:
|
||||
if not (path_obj / file).exists():
|
||||
return False
|
||||
|
||||
# Check for at least one model file
|
||||
has_model_file = any((path_obj / file).exists() for file in model_files)
|
||||
if not has_model_file:
|
||||
return False
|
||||
|
||||
# Check for at least one tokenizer file (for embedding models)
|
||||
has_tokenizer_file = any((path_obj / file).exists() for file in optional_files)
|
||||
if not has_tokenizer_file:
|
||||
logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_model_if_needed(
|
||||
model_name_or_path: str,
|
||||
local_path: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
model_source: str = 'huggingface'
|
||||
) -> str:
|
||||
"""
|
||||
Download model if it's a remote identifier and local path is specified
|
||||
|
||||
Args:
|
||||
model_name_or_path: Model identifier or path
|
||||
local_path: Target local path for download
|
||||
cache_dir: Cache directory
|
||||
model_source: 'huggingface' or 'modelscope'
|
||||
|
||||
Returns:
|
||||
Final model path (local if downloaded, original if already local)
|
||||
"""
|
||||
# If it's already a local path, return as is
|
||||
if os.path.exists(model_name_or_path):
|
||||
return model_name_or_path
|
||||
|
||||
# If no local path specified, return original (will be handled by transformers)
|
||||
if not local_path:
|
||||
return model_name_or_path
|
||||
|
||||
# Check if local path already exists and is valid
|
||||
if os.path.exists(local_path) and _is_valid_model_directory(local_path):
|
||||
logger.info(f"Model already exists locally: {local_path}")
|
||||
return local_path
|
||||
|
||||
# Download the model
|
||||
try:
|
||||
logger.info(f"Downloading model {model_name_or_path} to {local_path}")
|
||||
|
||||
if model_source == 'huggingface':
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Download model and tokenizer
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Save to local path
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
model.save_pretrained(local_path)
|
||||
tokenizer.save_pretrained(local_path)
|
||||
|
||||
logger.info(f"Model downloaded successfully to {local_path}")
|
||||
return local_path
|
||||
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
# Download using ModelScope
|
||||
downloaded_path = snapshot_download(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
local_dir=local_path
|
||||
)
|
||||
|
||||
logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}")
|
||||
return downloaded_path
|
||||
|
||||
except ImportError:
|
||||
logger.warning("ModelScope not available, falling back to HuggingFace")
|
||||
return download_model_if_needed(
|
||||
model_name_or_path, local_path, cache_dir, 'huggingface'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name_or_path}: {e}")
|
||||
logger.info("Returning original model path - transformers will handle caching")
|
||||
|
||||
# Fallback return
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""Centralized configuration management for BGE fine-tuning
|
||||
|
||||
Parameter Precedence (highest to lowest):
|
||||
1. Command-line arguments (CLI)
|
||||
2. Model-specific config ([m3], [reranker])
|
||||
3. [hardware] section in config.toml
|
||||
4. [training] section in config.toml
|
||||
5. Hardcoded defaults in code
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_config = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ConfigLoader, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._config is None:
|
||||
self._load_config()
|
||||
|
||||
def _find_config_file(self) -> str:
|
||||
"""Find config.toml file in project hierarchy"""
|
||||
# Check environment variable FIRST (highest priority)
|
||||
if 'BGE_CONFIG_PATH' in os.environ:
|
||||
config_path = Path(os.environ['BGE_CONFIG_PATH'])
|
||||
if config_path.exists():
|
||||
return str(config_path)
|
||||
else:
|
||||
logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default")
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
|
||||
# Search up the directory tree for config.toml
|
||||
for _ in range(5): # Limit search depth
|
||||
config_path = current_dir / "config.toml"
|
||||
if config_path.exists():
|
||||
return str(config_path)
|
||||
current_dir = current_dir.parent
|
||||
|
||||
# Default location
|
||||
return os.path.join(os.getcwd(), "config.toml")
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from TOML file"""
|
||||
config_path = self._find_config_file()
|
||||
|
||||
try:
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
loaded = toml.load(f)
|
||||
if loaded is None:
|
||||
loaded = {}
|
||||
self._config = dict(loaded)
|
||||
logger.info(f"Loaded configuration from {config_path}")
|
||||
else:
|
||||
logger.warning(f"Config file not found at {config_path}, using defaults")
|
||||
self._config = self._get_default_config()
|
||||
self._save_default_config(config_path)
|
||||
logger.warning("Default config created. Please check your configuration file.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config: {e}")
|
||||
self._config = self._get_default_config()
|
||||
|
||||
# Validation for unset/placeholder values
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate config for unset/placeholder values"""
|
||||
api_key = self.get('api', 'api_key', None)
|
||||
if not api_key or api_key == 'your-api-key-here':
|
||||
logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.")
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
'api': {
|
||||
'base_url': 'https://api.deepseek.com/v1',
|
||||
'api_key': 'your-api-key-here',
|
||||
'model': 'deepseek-chat',
|
||||
'max_retries': 3,
|
||||
'timeout': 30,
|
||||
'temperature': 0.7
|
||||
},
|
||||
'training': {
|
||||
'default_batch_size': 4,
|
||||
'default_learning_rate': 1e-5,
|
||||
'default_num_epochs': 3,
|
||||
'default_warmup_ratio': 0.1,
|
||||
'default_weight_decay': 0.01,
|
||||
'gradient_accumulation_steps': 4
|
||||
},
|
||||
'model_paths': {
|
||||
'bge_m3': './models/bge-m3',
|
||||
'bge_reranker': './models/bge-reranker-base',
|
||||
'cache_dir': './cache/data'
|
||||
},
|
||||
'data': {
|
||||
'max_query_length': 64,
|
||||
'max_passage_length': 512,
|
||||
'max_seq_length': 512,
|
||||
'validation_split': 0.1,
|
||||
'random_seed': 42
|
||||
},
|
||||
'augmentation': {
|
||||
'enable_query_augmentation': False,
|
||||
'enable_passage_augmentation': False,
|
||||
'augmentation_methods': ['paraphrase', 'back_translation'],
|
||||
'num_augmentations_per_sample': 2,
|
||||
'augmentation_temperature': 0.8
|
||||
},
|
||||
'hard_negative_mining': {
|
||||
'enable_mining': True,
|
||||
'num_hard_negatives': 15,
|
||||
'negative_range_min': 10,
|
||||
'negative_range_max': 100,
|
||||
'margin': 0.1,
|
||||
'index_refresh_interval': 1000,
|
||||
'index_type': 'IVF',
|
||||
'nlist': 100
|
||||
},
|
||||
'evaluation': {
|
||||
'eval_batch_size': 32,
|
||||
'k_values': [1, 3, 5, 10, 20, 50, 100],
|
||||
'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'],
|
||||
'save_predictions': True,
|
||||
'predictions_dir': './output/predictions'
|
||||
},
|
||||
'logging': {
|
||||
'log_level': 'INFO',
|
||||
'log_to_file': True,
|
||||
'log_dir': './logs',
|
||||
'tensorboard_dir': './logs/tensorboard',
|
||||
'wandb_project': 'bge-finetuning',
|
||||
'wandb_entity': '',
|
||||
'use_wandb': False
|
||||
},
|
||||
'distributed': {
|
||||
'backend': 'nccl',
|
||||
'init_method': 'env://',
|
||||
'find_unused_parameters': True
|
||||
},
|
||||
'optimization': {
|
||||
'gradient_checkpointing': True,
|
||||
'fp16': True,
|
||||
'bf16': False,
|
||||
'fp16_opt_level': 'O1',
|
||||
'max_grad_norm': 1.0,
|
||||
'adam_epsilon': 1e-8,
|
||||
'adam_beta1': 0.9,
|
||||
'adam_beta2': 0.999
|
||||
},
|
||||
'hardware': {
|
||||
'cuda_visible_devices': '0,1,2,3',
|
||||
'dataloader_num_workers': 4,
|
||||
'dataloader_pin_memory': True,
|
||||
'per_device_train_batch_size': 4,
|
||||
'per_device_eval_batch_size': 8
|
||||
},
|
||||
'checkpoint': {
|
||||
'save_strategy': 'steps',
|
||||
'save_steps': 1000,
|
||||
'save_total_limit': 3,
|
||||
'load_best_model_at_end': True,
|
||||
'metric_for_best_model': 'eval_recall@10',
|
||||
'greater_is_better': True,
|
||||
'resume_from_checkpoint': ''
|
||||
},
|
||||
'export': {
|
||||
'export_format': ['pytorch', 'onnx'],
|
||||
'optimize_for_inference': True,
|
||||
'quantize': False,
|
||||
'quantization_bits': 8
|
||||
}
|
||||
}
|
||||
|
||||
def _save_default_config(self, config_path: str):
|
||||
"""Save default configuration to file"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
config_to_save = self._config if isinstance(self._config, dict) else {}
|
||||
toml.dump(config_to_save, f)
|
||||
logger.info(f"Created default config at {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving default config: {e}")
|
||||
|
||||
def get(self, section: str, key: str, default=None):
|
||||
"""Get a config value, logging its source (CLI, config, default)."""
|
||||
# CLI overrides are handled outside this loader; here we check config and default
|
||||
value = default
|
||||
source = 'default'
|
||||
if isinstance(self._config, dict) and section in self._config and key in self._config[section]:
|
||||
value = self._config[section][key]
|
||||
source = 'config.toml'
|
||||
# Robust list/tuple conversion
|
||||
if isinstance(value, str) and (',' in value or ';' in value):
|
||||
try:
|
||||
value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()]
|
||||
source += ' (parsed as list)'
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse list for {section}.{key}: {e}")
|
||||
log_func = logging.info if source != 'default' else logging.warning
|
||||
log_func(f"Config param {section}.{key} = {value} (source: {source})")
|
||||
return value
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""Get entire configuration section"""
|
||||
if isinstance(self._config, dict):
|
||||
return self._config.get(section, {})
|
||||
return {}
|
||||
|
||||
def update(self, key: str, value: Any):
|
||||
"""Update configuration value. Logs a warning if overriding an existing value."""
|
||||
if self._config is None:
|
||||
self._config = {}
|
||||
|
||||
keys = key.split('.')
|
||||
config = self._config
|
||||
|
||||
# Navigate to the parent of the final key
|
||||
for k in keys[:-1]:
|
||||
if k not in config:
|
||||
config[k] = {}
|
||||
config = config[k]
|
||||
|
||||
# Log if overriding
|
||||
old_value = config.get(keys[-1], None)
|
||||
if old_value is not None and old_value != value:
|
||||
logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)")
|
||||
config[keys[-1]] = value
|
||||
|
||||
def reload(self):
|
||||
"""Reload configuration from file"""
|
||||
self._config = None
|
||||
self._load_config()
|
||||
|
||||
|
||||
# Global instance
|
||||
config = ConfigLoader()
|
||||
|
||||
|
||||
def get_config() -> ConfigLoader:
|
||||
"""Get global configuration instance"""
|
||||
return config
|
||||
|
||||
|
||||
def load_config_for_training(
|
||||
config_path: Optional[str] = None,
|
||||
override_params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load configuration specifically for training scripts
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file
|
||||
override_params: Parameters to override from command line
|
||||
|
||||
Returns:
|
||||
Dictionary with all configuration parameters
|
||||
"""
|
||||
global config
|
||||
|
||||
# Reload with specific path if provided
|
||||
if config_path:
|
||||
old_path = os.environ.get('BGE_CONFIG_PATH')
|
||||
os.environ['BGE_CONFIG_PATH'] = config_path
|
||||
config.reload()
|
||||
if old_path:
|
||||
os.environ['BGE_CONFIG_PATH'] = old_path
|
||||
else:
|
||||
del os.environ['BGE_CONFIG_PATH']
|
||||
|
||||
# Get all configuration
|
||||
training_config = {
|
||||
'api': config.get_section('api'),
|
||||
'training': config.get_section('training'),
|
||||
'model_paths': config.get_section('model_paths'),
|
||||
'data': config.get_section('data'),
|
||||
'augmentation': config.get_section('augmentation'),
|
||||
'hard_negative_mining': config.get_section('hard_negative_mining'),
|
||||
'evaluation': config.get_section('evaluation'),
|
||||
'logging': config.get_section('logging'),
|
||||
'distributed': config.get_section('distributed'),
|
||||
'optimization': config.get_section('optimization'),
|
||||
'hardware': config.get_section('hardware'),
|
||||
'checkpoint': config.get_section('checkpoint'),
|
||||
'export': config.get_section('export'),
|
||||
'm3': config.get_section('m3'),
|
||||
'reranker': config.get_section('reranker'),
|
||||
}
|
||||
|
||||
# Apply overrides if provided
|
||||
if override_params:
|
||||
for key, value in override_params.items():
|
||||
config.update(key, value)
|
||||
# Also update the returned dict
|
||||
keys = key.split('.')
|
||||
current = training_config
|
||||
for k in keys[:-1]:
|
||||
if k not in current:
|
||||
current[k] = {}
|
||||
current = current[k]
|
||||
current[keys[-1]] = value
|
||||
|
||||
return training_config
|
||||
425
utils/distributed.py
Normal file
425
utils/distributed.py
Normal file
@ -0,0 +1,425 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
import logging
|
||||
from typing import Optional, Any, List
|
||||
import random
|
||||
import numpy as np
|
||||
from .config_loader import get_config
|
||||
|
||||
# Import Ascend support utilities
|
||||
try:
|
||||
from .ascend_support import (
|
||||
check_ascend_availability,
|
||||
get_ascend_backend,
|
||||
setup_ascend_device,
|
||||
clear_ascend_cache
|
||||
)
|
||||
ASCEND_SUPPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
ASCEND_SUPPORT_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_distributed(config=None):
|
||||
"""Setup distributed training environment with proper error handling."""
|
||||
try:
|
||||
# Check if we're in a distributed environment
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
if world_size == 1:
|
||||
logger.info("Single GPU/CPU training mode")
|
||||
return
|
||||
|
||||
# Determine backend based on hardware availability
|
||||
# First check for Ascend NPU support
|
||||
if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability():
|
||||
backend = 'hccl'
|
||||
try:
|
||||
setup_ascend_device(local_rank)
|
||||
logger.info("Using Ascend NPU with HCCL backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup Ascend device: {e}, falling back")
|
||||
backend = 'gloo'
|
||||
elif torch.cuda.is_available():
|
||||
backend = 'nccl'
|
||||
# Set CUDA device
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
backend = 'gloo'
|
||||
|
||||
# Override backend from config if provided
|
||||
if config is not None:
|
||||
backend = config.get('distributed.backend', backend)
|
||||
|
||||
# Initialize process group
|
||||
init_method = os.environ.get('MASTER_ADDR', None)
|
||||
if init_method:
|
||||
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}"
|
||||
else:
|
||||
init_method = 'env://'
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method=init_method,
|
||||
world_size=world_size,
|
||||
rank=rank
|
||||
)
|
||||
|
||||
logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize distributed training: {e}")
|
||||
logger.info("Falling back to single GPU/CPU training")
|
||||
# Set environment variables to indicate single GPU mode
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
os.environ['RANK'] = '0'
|
||||
os.environ['LOCAL_RANK'] = '0'
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
"""Cleanup distributed training environment with error handling."""
|
||||
try:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
logger.info("Distributed process group destroyed.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during distributed cleanup: {e}")
|
||||
# Continue execution even if cleanup fails
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""Check if this is the main process"""
|
||||
if not dist.is_initialized():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get current process rank"""
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get world size"""
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Get local rank (GPU index on current node)"""
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""Synchronize all processes"""
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
|
||||
"""
|
||||
All-reduce a tensor across all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to reduce
|
||||
op: Reduction operation
|
||||
|
||||
Returns:
|
||||
Reduced tensor
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
dist.all_reduce(tensor, op=op)
|
||||
return tensor
|
||||
|
||||
|
||||
def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
All-gather tensors from all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to gather
|
||||
|
||||
Returns:
|
||||
List of tensors from all processes
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return [tensor]
|
||||
|
||||
world_size = get_world_size()
|
||||
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensors, tensor)
|
||||
return tensors
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Broadcast tensor from source to all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to broadcast
|
||||
src: Source rank
|
||||
|
||||
Returns:
|
||||
Broadcasted tensor
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
dist.broadcast(tensor, src=src)
|
||||
return tensor
|
||||
|
||||
|
||||
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
|
||||
"""
|
||||
Reduce a dictionary of values across all processes
|
||||
|
||||
Args:
|
||||
input_dict: Dictionary to reduce
|
||||
average: Whether to average the values
|
||||
|
||||
Returns:
|
||||
Reduced dictionary
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return input_dict
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
|
||||
if average:
|
||||
values /= world_size
|
||||
|
||||
reduced_dict = {k: v.item() for k, v in zip(names, values)}
|
||||
|
||||
return reduced_dict
|
||||
|
||||
|
||||
def setup_model_for_distributed(
|
||||
model: torch.nn.Module,
|
||||
device_ids: Optional[List[int]] = None,
|
||||
output_device: Optional[int] = None,
|
||||
find_unused_parameters: bool = True,
|
||||
broadcast_buffers: bool = True
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Wrap model for distributed training
|
||||
|
||||
Args:
|
||||
model: Model to wrap
|
||||
device_ids: GPU indices to use
|
||||
output_device: Device where outputs are gathered
|
||||
find_unused_parameters: Whether to find unused parameters
|
||||
broadcast_buffers: Whether to broadcast buffers
|
||||
|
||||
Returns:
|
||||
Distributed model
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return model
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = [get_local_rank()]
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=device_ids,
|
||||
output_device=output_device,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
broadcast_buffers=broadcast_buffers
|
||||
)
|
||||
|
||||
logger.info(f"Wrapped model for distributed training on devices {device_ids}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_distributed_dataloader(
|
||||
dataset,
|
||||
batch_size: int,
|
||||
shuffle: bool = True,
|
||||
num_workers: int = 0,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = False,
|
||||
seed: int = 42
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Create a distributed dataloader
|
||||
|
||||
Args:
|
||||
dataset: Dataset to load
|
||||
batch_size: Batch size per GPU
|
||||
shuffle: Whether to shuffle
|
||||
num_workers: Number of data loading workers
|
||||
pin_memory: Whether to pin memory
|
||||
drop_last: Whether to drop last incomplete batch
|
||||
seed: Random seed for shuffling
|
||||
|
||||
Returns:
|
||||
Distributed dataloader
|
||||
"""
|
||||
if dist.is_initialized():
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=get_world_size(),
|
||||
rank=get_rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last
|
||||
)
|
||||
shuffle = False # Sampler handles shuffling
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def set_random_seed(seed: int, deterministic: bool = False):
|
||||
"""
|
||||
Set random seed for reproducibility
|
||||
|
||||
Args:
|
||||
seed: Random seed
|
||||
deterministic: Whether to use deterministic algorithms
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if deterministic and torch.cuda.is_available():
|
||||
# Configure cuDNN determinism settings
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
|
||||
logger.info(f"Set random seed to {seed} (deterministic={deterministic})")
|
||||
|
||||
|
||||
def print_distributed_info():
|
||||
"""Print distributed training information"""
|
||||
if not dist.is_initialized():
|
||||
logger.info("Distributed training not initialized")
|
||||
return
|
||||
|
||||
logger.info(f"Distributed Info:")
|
||||
logger.info(f" Backend: {dist.get_backend()}")
|
||||
logger.info(f" World size: {get_world_size()}")
|
||||
logger.info(f" Rank: {get_rank()}")
|
||||
logger.info(f" Local rank: {get_local_rank()}")
|
||||
logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}")
|
||||
logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}")
|
||||
|
||||
|
||||
class DistributedMetricAggregator:
|
||||
"""Helper class to aggregate metrics across distributed processes"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = {}
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""Update metrics"""
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.detach()
|
||||
else:
|
||||
v = torch.tensor(v, dtype=torch.float32)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
v = v.cuda()
|
||||
|
||||
self.metrics[k] = v
|
||||
|
||||
def aggregate(self, average: bool = True) -> dict:
|
||||
"""Aggregate metrics across all processes"""
|
||||
if not dist.is_initialized():
|
||||
return {k: v.item() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in self.metrics.items()}
|
||||
|
||||
aggregated = {}
|
||||
world_size = get_world_size()
|
||||
|
||||
for k, v in self.metrics.items():
|
||||
# All-reduce
|
||||
all_reduce(v)
|
||||
|
||||
# Average if requested
|
||||
if average:
|
||||
v = v / world_size
|
||||
|
||||
aggregated[k] = v.item()
|
||||
|
||||
return aggregated
|
||||
|
||||
def reset(self):
|
||||
"""Reset metrics"""
|
||||
self.metrics = {}
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
"""Decorator to only save on master process"""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_main_process():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
# Return None on non-master processes
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def log_on_master(logger: logging.Logger):
|
||||
"""Get a logger that only logs on master process"""
|
||||
|
||||
class MasterOnlyLogger:
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
|
||||
def __getattr__(self, name):
|
||||
if is_main_process():
|
||||
return getattr(self.logger, name)
|
||||
else:
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
return MasterOnlyLogger(logger)
|
||||
344
utils/logging.py
Normal file
344
utils/logging.py
Normal file
@ -0,0 +1,344 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_progress_bar(
|
||||
iterable,
|
||||
desc: str = "",
|
||||
total: Optional[int] = None,
|
||||
disable: bool = False,
|
||||
position: int = 0,
|
||||
leave: bool = True,
|
||||
ncols: Optional[int] = None,
|
||||
unit: str = "it",
|
||||
dynamic_ncols: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a tqdm progress bar with consistent settings across the project.
|
||||
|
||||
Args:
|
||||
iterable: Iterable to wrap with progress bar
|
||||
desc: Description for the progress bar
|
||||
total: Total number of iterations (if None, will try to get from iterable)
|
||||
disable: Whether to disable the progress bar
|
||||
position: Position of the progress bar (for multiple bars)
|
||||
leave: Whether to leave the progress bar after completion
|
||||
ncols: Width of the progress bar
|
||||
unit: Unit of iteration (e.g., "batch", "step")
|
||||
dynamic_ncols: Automatically adjust width based on terminal size
|
||||
|
||||
Returns:
|
||||
tqdm progress bar object
|
||||
"""
|
||||
# Check if we're in distributed training
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
# Disable progress bar for non-main processes
|
||||
disable = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Create progress bar with consistent settings
|
||||
return tqdm(
|
||||
iterable,
|
||||
desc=desc,
|
||||
total=total,
|
||||
disable=disable,
|
||||
position=position,
|
||||
leave=leave,
|
||||
ncols=ncols,
|
||||
unit=unit,
|
||||
dynamic_ncols=dynamic_ncols,
|
||||
# Use custom write method to avoid conflicts with logging
|
||||
file=sys.stdout,
|
||||
# Ensure smooth updates
|
||||
smoothing=0.1,
|
||||
# Format with consistent style
|
||||
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
||||
)
|
||||
|
||||
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
"""Custom logging handler that plays nicely with tqdm progress bars"""
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
# Use tqdm.write to avoid breaking progress bars
|
||||
tqdm.write(msg, file=sys.stdout)
|
||||
self.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
def setup_logging(
|
||||
output_dir: Optional[str] = None,
|
||||
log_to_file: bool = False,
|
||||
log_level: int = logging.INFO,
|
||||
log_to_console: bool = True,
|
||||
use_tqdm_handler: bool = True,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Setup logging with optional file and console handlers.
|
||||
|
||||
Args:
|
||||
output_dir: Directory for log files
|
||||
log_to_file: Whether to log to file
|
||||
log_level: Logging level
|
||||
log_to_console: Whether to log to console
|
||||
use_tqdm_handler: Use tqdm-compatible handler for console logging
|
||||
rank: Process rank for distributed training
|
||||
world_size: World size for distributed training
|
||||
"""
|
||||
handlers: list[logging.Handler] = []
|
||||
log_path = None
|
||||
|
||||
# Only log to console from main process in distributed training
|
||||
if log_to_console and (world_size == 1 or rank == 0):
|
||||
if use_tqdm_handler:
|
||||
# Use tqdm-compatible handler
|
||||
console_handler = TqdmLoggingHandler()
|
||||
else:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(log_level)
|
||||
handlers.append(console_handler)
|
||||
|
||||
if log_to_file:
|
||||
if output_dir is None:
|
||||
output_dir = "./logs"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Add rank to filename for distributed training
|
||||
if world_size > 1:
|
||||
log_filename = f"train_rank{rank}.log"
|
||||
else:
|
||||
log_filename = "train.log"
|
||||
|
||||
log_path = os.path.join(output_dir, log_filename)
|
||||
file_handler = logging.FileHandler(log_path)
|
||||
file_handler.setLevel(log_level)
|
||||
handlers.append(file_handler)
|
||||
|
||||
# Create formatter
|
||||
if world_size > 1:
|
||||
# Include rank in log format for distributed training
|
||||
formatter = logging.Formatter(
|
||||
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# Apply formatter to all handlers
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
# Remove existing handlers to avoid duplicates
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Add new handlers
|
||||
for handler in handlers:
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
# Log initialization message
|
||||
if rank == 0:
|
||||
logging.info(
|
||||
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
|
||||
f"Log file: {log_path if log_to_file else 'none'} | "
|
||||
f"World size: {world_size}"
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with the specified name"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
class MetricsLogger:
|
||||
"""Logger specifically for training metrics"""
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
|
||||
"""Log metrics to file"""
|
||||
entry = {
|
||||
'step': step,
|
||||
'phase': phase,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
with open(self.metrics_file, 'a') as f:
|
||||
f.write(json.dumps(entry) + '\n')
|
||||
|
||||
def log_config(self, config: Dict[str, Any]):
|
||||
"""Log configuration"""
|
||||
config_file = os.path.join(self.output_dir, 'config.json')
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
def log_summary(self, summary: Dict[str, Any]):
|
||||
"""Log training summary"""
|
||||
summary_file = os.path.join(self.output_dir, 'summary.json')
|
||||
with open(summary_file, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
|
||||
class TensorBoardLogger:
|
||||
"""TensorBoard logging wrapper"""
|
||||
|
||||
def __init__(self, log_dir: str, enabled: bool = True):
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
try:
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
except ImportError:
|
||||
get_logger(__name__).warning(
|
||||
"TensorBoard not available. Install with: pip install tensorboard"
|
||||
)
|
||||
self.enabled = False
|
||||
|
||||
def log_scalar(self, tag: str, value: float, step: int):
|
||||
"""Log a scalar value"""
|
||||
if self.enabled:
|
||||
self.writer.add_scalar(tag, value, step)
|
||||
|
||||
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
|
||||
"""Log multiple scalars"""
|
||||
if self.enabled:
|
||||
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
|
||||
|
||||
def log_histogram(self, tag: str, values: np.ndarray, step: int):
|
||||
"""Log a histogram"""
|
||||
if self.enabled:
|
||||
self.writer.add_histogram(tag, values, step)
|
||||
|
||||
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
|
||||
"""Log embeddings for visualization"""
|
||||
if self.enabled:
|
||||
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer"""
|
||||
if self.enabled:
|
||||
self.writer.flush()
|
||||
|
||||
def close(self):
|
||||
"""Close the writer"""
|
||||
if self.enabled:
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class ProgressLogger:
|
||||
"""Helper for logging training progress"""
|
||||
|
||||
def __init__(self, total_steps: int, log_interval: int = 100):
|
||||
self.total_steps = total_steps
|
||||
self.log_interval = log_interval
|
||||
self.logger = get_logger(__name__)
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
|
||||
"""Log progress for a training step"""
|
||||
if step % self.log_interval == 0:
|
||||
elapsed = (datetime.now() - self.start_time).total_seconds()
|
||||
steps_per_sec = step / elapsed if elapsed > 0 else 0
|
||||
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
|
||||
|
||||
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
|
||||
|
||||
if metrics:
|
||||
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
||||
msg += f" | {metric_str}"
|
||||
|
||||
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
|
||||
|
||||
self.logger.info(msg)
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
"""Format seconds to human readable time"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
seconds = int(seconds % 60)
|
||||
|
||||
if hours > 0:
|
||||
return f"{hours}h {minutes}m {seconds}s"
|
||||
elif minutes > 0:
|
||||
return f"{minutes}m {seconds}s"
|
||||
else:
|
||||
return f"{seconds}s"
|
||||
|
||||
|
||||
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
|
||||
"""Log model information"""
|
||||
if logger is None:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Count parameters
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
logger.info(f"Model: {model.__class__.__name__}")
|
||||
logger.info(f"Total parameters: {total_params:,}")
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
|
||||
|
||||
# Log parameter shapes
|
||||
logger.debug("Parameter shapes:")
|
||||
for name, param in model.named_parameters():
|
||||
logger.debug(f" {name}: {list(param.shape)}")
|
||||
|
||||
|
||||
def log_training_summary(
|
||||
config: Dict[str, Any],
|
||||
metrics: Dict[str, float],
|
||||
elapsed_time: float,
|
||||
output_dir: str
|
||||
):
|
||||
"""Log comprehensive training summary"""
|
||||
logger = get_logger(__name__)
|
||||
|
||||
summary = {
|
||||
'config': config,
|
||||
'final_metrics': metrics,
|
||||
'training_time': elapsed_time,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Save to file
|
||||
summary_file = os.path.join(output_dir, 'training_summary.json')
|
||||
with open(summary_file, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
# Log to console
|
||||
logger.info("=" * 50)
|
||||
logger.info("TRAINING SUMMARY")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Training time: {elapsed_time:.2f} seconds")
|
||||
logger.info("Final metrics:")
|
||||
for metric, value in metrics.items():
|
||||
logger.info(f" {metric}: {value:.4f}")
|
||||
logger.info(f"Summary saved to: {summary_file}")
|
||||
logger.info("=" * 50)
|
||||
Loading…
x
Reference in New Issue
Block a user