41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
"""
|
|
Training modules for BGE models
|
|
"""
|
|
|
|
def get_trainer(trainer_type: str):
|
|
"""Lazy loading of trainers to avoid importing all at once"""
|
|
if trainer_type == 'base':
|
|
from .trainer import BaseTrainer
|
|
return BaseTrainer
|
|
elif trainer_type == 'm3':
|
|
from .m3_trainer import BGEM3Trainer
|
|
return BGEM3Trainer
|
|
elif trainer_type == 'reranker':
|
|
from .reranker_trainer import BGERerankerTrainer
|
|
return BGERerankerTrainer
|
|
elif trainer_type == 'joint':
|
|
from .joint_trainer import JointTrainer
|
|
return JointTrainer
|
|
else:
|
|
raise ValueError(f"Unknown trainer type: {trainer_type}")
|
|
|
|
# For backward compatibility, import the commonly used ones
|
|
from .trainer import BaseTrainer
|
|
from .m3_trainer import BGEM3Trainer
|
|
|
|
# Only import reranker and joint trainers when explicitly requested
|
|
def __getattr__(name):
|
|
if name == 'BGERerankerTrainer':
|
|
from .reranker_trainer import BGERerankerTrainer
|
|
return BGERerankerTrainer
|
|
elif name == 'JointTrainer':
|
|
from .joint_trainer import JointTrainer
|
|
return JointTrainer
|
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
|
|
__all__ = [
|
|
'BaseTrainer',
|
|
'BGEM3Trainer',
|
|
'get_trainer'
|
|
]
|