init commit
This commit is contained in:
40
training/__init__.py
Normal file
40
training/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Training modules for BGE models
|
||||
"""
|
||||
|
||||
def get_trainer(trainer_type: str):
|
||||
"""Lazy loading of trainers to avoid importing all at once"""
|
||||
if trainer_type == 'base':
|
||||
from .trainer import BaseTrainer
|
||||
return BaseTrainer
|
||||
elif trainer_type == 'm3':
|
||||
from .m3_trainer import BGEM3Trainer
|
||||
return BGEM3Trainer
|
||||
elif trainer_type == 'reranker':
|
||||
from .reranker_trainer import BGERerankerTrainer
|
||||
return BGERerankerTrainer
|
||||
elif trainer_type == 'joint':
|
||||
from .joint_trainer import JointTrainer
|
||||
return JointTrainer
|
||||
else:
|
||||
raise ValueError(f"Unknown trainer type: {trainer_type}")
|
||||
|
||||
# For backward compatibility, import the commonly used ones
|
||||
from .trainer import BaseTrainer
|
||||
from .m3_trainer import BGEM3Trainer
|
||||
|
||||
# Only import reranker and joint trainers when explicitly requested
|
||||
def __getattr__(name):
|
||||
if name == 'BGERerankerTrainer':
|
||||
from .reranker_trainer import BGERerankerTrainer
|
||||
return BGERerankerTrainer
|
||||
elif name == 'JointTrainer':
|
||||
from .joint_trainer import JointTrainer
|
||||
return JointTrainer
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
__all__ = [
|
||||
'BaseTrainer',
|
||||
'BGEM3Trainer',
|
||||
'get_trainer'
|
||||
]
|
||||
Reference in New Issue
Block a user