106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
"""
|
|
BGE Fine-tuning Package
|
|
A comprehensive implementation for fine-tuning BAAI BGE models
|
|
"""
|
|
|
|
__version__ = "1.0.0"
|
|
__author__ = "ldy"
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Mandatory dependency check
|
|
def _check_dependencies():
|
|
"""Ensure core dependencies are available."""
|
|
try:
|
|
_ = torch.__version__
|
|
except Exception as exc: # pragma: no cover - should not happen in tests
|
|
raise ImportError("PyTorch is required for bge_finetune") from exc
|
|
try:
|
|
import transformers # noqa: F401
|
|
except Exception as exc: # pragma: no cover - should not happen in tests
|
|
raise ImportError("transformers is required for bge_finetune") from exc
|
|
|
|
# Import guard for optional dependencies
|
|
|
|
def _check_optional_dependencies():
|
|
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
|
|
# Check for torch_npu if running on Ascend
|
|
try:
|
|
npu_available = hasattr(torch, 'npu') and torch.npu.is_available()
|
|
except Exception:
|
|
npu_available = False
|
|
if npu_available or os.environ.get('DEVICE', '').lower() == 'npu':
|
|
try:
|
|
import torch_npu
|
|
except ImportError:
|
|
logger.warning("Optional dependency 'torch_npu' not found. Ascend NPU support will be unavailable.")
|
|
|
|
# Check for wandb if Weights & Biases logging is enabled
|
|
try:
|
|
from utils.config_loader import get_config
|
|
config = get_config()
|
|
use_wandb = False
|
|
try:
|
|
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
|
|
except Exception:
|
|
pass
|
|
if use_wandb:
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
logger.warning("Optional dependency 'wandb' not found. Weights & Biases logging will be unavailable.")
|
|
except Exception:
|
|
pass
|
|
|
|
# Check for tensorboard if tensorboard logging is likely to be used
|
|
tensorboard_dir = None
|
|
try:
|
|
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
|
|
except Exception:
|
|
pass
|
|
if tensorboard_dir:
|
|
try:
|
|
import tensorboard
|
|
except ImportError:
|
|
logger.warning("Optional dependency 'tensorboard' not found. TensorBoard logging will be unavailable.")
|
|
|
|
# Check for pytest if running tests
|
|
if any('pytest' in arg for arg in sys.argv):
|
|
try:
|
|
import pytest
|
|
except ImportError:
|
|
logger.warning("Optional dependency 'pytest' not found. Some tests may not run.")
|
|
|
|
|
|
# Check dependencies on import
|
|
_check_dependencies()
|
|
_check_optional_dependencies()
|
|
|
|
# Make key components available at package level
|
|
try:
|
|
from .config import BaseConfig, BGEM3Config, BGERerankerConfig
|
|
from .models import BGEM3Model, BGERerankerModel
|
|
from .utils import setup_logging, get_logger
|
|
|
|
__all__ = [
|
|
'BaseConfig',
|
|
'BGEM3Config',
|
|
'BGERerankerConfig',
|
|
'BGEM3Model',
|
|
'BGERerankerModel',
|
|
'setup_logging',
|
|
'get_logger'
|
|
]
|
|
except ImportError as e:
|
|
# Allow partial imports during development
|
|
import warnings
|
|
|
|
warnings.warn(f"Some components not available: {e}")
|
|
__all__ = []
|