bge_finetune/__init__.py
2025-07-22 16:55:25 +08:00

106 lines
3.2 KiB
Python

"""
BGE Fine-tuning Package
A comprehensive implementation for fine-tuning BAAI BGE models
"""
__version__ = "1.0.0"
__author__ = "ldy"
import logging
import os
import sys
import torch
logger = logging.getLogger(__name__)
# Mandatory dependency check
def _check_dependencies():
"""Ensure core dependencies are available."""
try:
_ = torch.__version__
except Exception as exc: # pragma: no cover - should not happen in tests
raise ImportError("PyTorch is required for bge_finetune") from exc
try:
import transformers # noqa: F401
except Exception as exc: # pragma: no cover - should not happen in tests
raise ImportError("transformers is required for bge_finetune") from exc
# Import guard for optional dependencies
def _check_optional_dependencies():
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
# Check for torch_npu if running on Ascend
try:
npu_available = hasattr(torch, 'npu') and torch.npu.is_available()
except Exception:
npu_available = False
if npu_available or os.environ.get('DEVICE', '').lower() == 'npu':
try:
import torch_npu
except ImportError:
logger.warning("Optional dependency 'torch_npu' not found. Ascend NPU support will be unavailable.")
# Check for wandb if Weights & Biases logging is enabled
try:
from utils.config_loader import get_config
config = get_config()
use_wandb = False
try:
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
except Exception:
pass
if use_wandb:
try:
import wandb
except ImportError:
logger.warning("Optional dependency 'wandb' not found. Weights & Biases logging will be unavailable.")
except Exception:
pass
# Check for tensorboard if tensorboard logging is likely to be used
tensorboard_dir = None
try:
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
except Exception:
pass
if tensorboard_dir:
try:
import tensorboard
except ImportError:
logger.warning("Optional dependency 'tensorboard' not found. TensorBoard logging will be unavailable.")
# Check for pytest if running tests
if any('pytest' in arg for arg in sys.argv):
try:
import pytest
except ImportError:
logger.warning("Optional dependency 'pytest' not found. Some tests may not run.")
# Check dependencies on import
_check_dependencies()
_check_optional_dependencies()
# Make key components available at package level
try:
from .config import BaseConfig, BGEM3Config, BGERerankerConfig
from .models import BGEM3Model, BGERerankerModel
from .utils import setup_logging, get_logger
__all__ = [
'BaseConfig',
'BGEM3Config',
'BGERerankerConfig',
'BGEM3Model',
'BGERerankerModel',
'setup_logging',
'get_logger'
]
except ImportError as e:
# Allow partial imports during development
import warnings
warnings.warn(f"Some components not available: {e}")
__all__ = []