bge_finetune/training/__init__.py
2025-07-22 16:55:25 +08:00

41 lines
1.2 KiB
Python

"""
Training modules for BGE models
"""
def get_trainer(trainer_type: str):
"""Lazy loading of trainers to avoid importing all at once"""
if trainer_type == 'base':
from .trainer import BaseTrainer
return BaseTrainer
elif trainer_type == 'm3':
from .m3_trainer import BGEM3Trainer
return BGEM3Trainer
elif trainer_type == 'reranker':
from .reranker_trainer import BGERerankerTrainer
return BGERerankerTrainer
elif trainer_type == 'joint':
from .joint_trainer import JointTrainer
return JointTrainer
else:
raise ValueError(f"Unknown trainer type: {trainer_type}")
# For backward compatibility, import the commonly used ones
from .trainer import BaseTrainer
from .m3_trainer import BGEM3Trainer
# Only import reranker and joint trainers when explicitly requested
def __getattr__(name):
if name == 'BGERerankerTrainer':
from .reranker_trainer import BGERerankerTrainer
return BGERerankerTrainer
elif name == 'JointTrainer':
from .joint_trainer import JointTrainer
return JointTrainer
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
__all__ = [
'BaseTrainer',
'BGEM3Trainer',
'get_trainer'
]