init commit
This commit is contained in:
105
__init__.py
Normal file
105
__init__.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
BGE Fine-tuning Package
|
||||
A comprehensive implementation for fine-tuning BAAI BGE models
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "ldy"
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mandatory dependency check
|
||||
def _check_dependencies():
|
||||
"""Ensure core dependencies are available."""
|
||||
try:
|
||||
_ = torch.__version__
|
||||
except Exception as exc: # pragma: no cover - should not happen in tests
|
||||
raise ImportError("PyTorch is required for bge_finetune") from exc
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
except Exception as exc: # pragma: no cover - should not happen in tests
|
||||
raise ImportError("transformers is required for bge_finetune") from exc
|
||||
|
||||
# Import guard for optional dependencies
|
||||
|
||||
def _check_optional_dependencies():
|
||||
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
|
||||
# Check for torch_npu if running on Ascend
|
||||
try:
|
||||
npu_available = hasattr(torch, 'npu') and torch.npu.is_available()
|
||||
except Exception:
|
||||
npu_available = False
|
||||
if npu_available or os.environ.get('DEVICE', '').lower() == 'npu':
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'torch_npu' not found. Ascend NPU support will be unavailable.")
|
||||
|
||||
# Check for wandb if Weights & Biases logging is enabled
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
use_wandb = False
|
||||
try:
|
||||
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if use_wandb:
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'wandb' not found. Weights & Biases logging will be unavailable.")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for tensorboard if tensorboard logging is likely to be used
|
||||
tensorboard_dir = None
|
||||
try:
|
||||
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if tensorboard_dir:
|
||||
try:
|
||||
import tensorboard
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'tensorboard' not found. TensorBoard logging will be unavailable.")
|
||||
|
||||
# Check for pytest if running tests
|
||||
if any('pytest' in arg for arg in sys.argv):
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
logger.warning("Optional dependency 'pytest' not found. Some tests may not run.")
|
||||
|
||||
|
||||
# Check dependencies on import
|
||||
_check_dependencies()
|
||||
_check_optional_dependencies()
|
||||
|
||||
# Make key components available at package level
|
||||
try:
|
||||
from .config import BaseConfig, BGEM3Config, BGERerankerConfig
|
||||
from .models import BGEM3Model, BGERerankerModel
|
||||
from .utils import setup_logging, get_logger
|
||||
|
||||
__all__ = [
|
||||
'BaseConfig',
|
||||
'BGEM3Config',
|
||||
'BGERerankerConfig',
|
||||
'BGEM3Model',
|
||||
'BGERerankerModel',
|
||||
'setup_logging',
|
||||
'get_logger'
|
||||
]
|
||||
except ImportError as e:
|
||||
# Allow partial imports during development
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Some components not available: {e}")
|
||||
__all__ = []
|
||||
Reference in New Issue
Block a user