402 lines
13 KiB
Python
402 lines
13 KiB
Python
"""
|
||
Setup script for BGE fine-tuning environment
|
||
"""
|
||
import os
|
||
import sys
|
||
import subprocess
|
||
import argparse
|
||
import shutil
|
||
import urllib.request
|
||
import json
|
||
from pathlib import Path
|
||
import logging
|
||
from utils.config_loader import get_config
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def run_command(command, check=True):
|
||
"""Run a shell command"""
|
||
logger.info(f"Running: {command}")
|
||
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||
|
||
if check and result.returncode != 0:
|
||
logger.error(f"Command failed: {command}")
|
||
logger.error(f"Error: {result.stderr}")
|
||
sys.exit(1)
|
||
|
||
return result
|
||
|
||
|
||
def check_python_version():
|
||
"""Check Python version"""
|
||
if sys.version_info < (3, 8):
|
||
logger.error("Python 3.8 or higher is required")
|
||
sys.exit(1)
|
||
|
||
logger.info(f"Python version: {sys.version}")
|
||
|
||
|
||
def check_gpu():
|
||
"""Check GPU availability"""
|
||
try:
|
||
result = run_command("nvidia-smi", check=False)
|
||
if result.returncode == 0:
|
||
logger.info("NVIDIA GPU detected")
|
||
return "cuda"
|
||
else:
|
||
logger.info("No NVIDIA GPU detected")
|
||
return "cpu"
|
||
except:
|
||
logger.info("nvidia-smi not available")
|
||
return "cpu"
|
||
|
||
|
||
def install_pytorch(device_type="cuda"):
|
||
"""Install PyTorch"""
|
||
logger.info("Installing PyTorch...")
|
||
|
||
if device_type == "cuda":
|
||
# Install CUDA version for RTX 5090
|
||
pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
|
||
else:
|
||
# Install CPU version
|
||
pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
|
||
|
||
run_command(pip_command)
|
||
|
||
# Verify installation
|
||
try:
|
||
import torch
|
||
logger.info(f"PyTorch version: {torch.__version__}")
|
||
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
||
if torch.cuda.is_available():
|
||
logger.info(f"CUDA version: {torch.version.cuda}")
|
||
logger.info(f"GPU count: {torch.cuda.device_count()}")
|
||
for i in range(torch.cuda.device_count()):
|
||
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
||
except ImportError:
|
||
logger.error("Failed to import PyTorch")
|
||
sys.exit(1)
|
||
|
||
|
||
def install_requirements():
|
||
"""Install requirements"""
|
||
logger.info("Installing requirements...")
|
||
|
||
requirements_file = Path(__file__).parent.parent / "requirements.txt"
|
||
|
||
if requirements_file.exists():
|
||
run_command(f"pip install -r {requirements_file}")
|
||
else:
|
||
logger.warning("requirements.txt not found, installing core packages...")
|
||
|
||
# Core packages
|
||
packages = [
|
||
"transformers>=4.36.0",
|
||
"datasets>=2.14.0",
|
||
"tokenizers>=0.15.0",
|
||
"numpy>=1.24.0",
|
||
"scipy>=1.10.0",
|
||
"scikit-learn>=1.3.0",
|
||
"pandas>=2.0.0",
|
||
"FlagEmbedding>=1.2.0",
|
||
"faiss-gpu>=1.7.4",
|
||
"rank-bm25>=0.2.2",
|
||
"accelerate>=0.25.0",
|
||
"tensorboard>=2.14.0",
|
||
"tqdm>=4.66.0",
|
||
"pyyaml>=6.0",
|
||
"jsonlines>=3.1.0",
|
||
"huggingface-hub>=0.19.0",
|
||
"jieba>=0.42.1",
|
||
"toml>=0.10.2",
|
||
"requests>=2.28.0"
|
||
]
|
||
|
||
for package in packages:
|
||
run_command(f"pip install {package}")
|
||
|
||
|
||
def setup_directories():
|
||
"""Setup project directories"""
|
||
logger.info("Setting up directories...")
|
||
|
||
base_dir = Path(__file__).parent.parent
|
||
directories = [
|
||
"models",
|
||
"models/bge-m3",
|
||
"models/bge-reranker-base",
|
||
"data/raw",
|
||
"data/processed",
|
||
"output",
|
||
"logs",
|
||
"cache/data"
|
||
]
|
||
|
||
for directory in directories:
|
||
dir_path = base_dir / directory
|
||
dir_path.mkdir(parents=True, exist_ok=True)
|
||
logger.info(f"Created directory: {dir_path}")
|
||
|
||
|
||
def download_models(download_models_flag=False):
|
||
"""Download base models from HuggingFace or ModelScope based on config.toml"""
|
||
if not download_models_flag:
|
||
logger.info("Skipping model download. Use --download-models to download base models.")
|
||
return
|
||
|
||
logger.info("Downloading base models...")
|
||
base_dir = Path(__file__).parent.parent
|
||
config = get_config()
|
||
model_source = config.get('model_paths.source', 'huggingface')
|
||
|
||
if model_source == 'huggingface':
|
||
try:
|
||
from huggingface_hub import snapshot_download
|
||
# Download BGE-M3
|
||
logger.info("Downloading BGE-M3 from HuggingFace...")
|
||
snapshot_download(
|
||
repo_id="BAAI/bge-m3",
|
||
local_dir=base_dir / "models/bge-m3",
|
||
local_dir_use_symlinks=False
|
||
)
|
||
# Download BGE-Reranker
|
||
logger.info("Downloading BGE-Reranker from HuggingFace...")
|
||
snapshot_download(
|
||
repo_id="BAAI/bge-reranker-base",
|
||
local_dir=base_dir / "models/bge-reranker-base",
|
||
local_dir_use_symlinks=False
|
||
)
|
||
logger.info("Models downloaded successfully from HuggingFace!")
|
||
except ImportError:
|
||
logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub")
|
||
except Exception as e:
|
||
logger.error(f"Failed to download models from HuggingFace: {e}")
|
||
logger.info("You can download models manually or use the Hugging Face CLI:")
|
||
logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3")
|
||
logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base")
|
||
elif model_source == 'modelscope':
|
||
try:
|
||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||
# Download BGE-M3
|
||
logger.info("Downloading BGE-M3 from ModelScope...")
|
||
ms_snapshot_download(
|
||
model_id="damo/nlp_corom_sentence-embedding_chinese-base",
|
||
cache_dir=str(base_dir / "models/bge-m3")
|
||
)
|
||
# Download BGE-Reranker
|
||
logger.info("Downloading BGE-Reranker from ModelScope...")
|
||
ms_snapshot_download(
|
||
model_id="damo/nlp_corom_cross-encoder_chinese-base",
|
||
cache_dir=str(base_dir / "models/bge-reranker-base")
|
||
)
|
||
logger.info("Models downloaded successfully from ModelScope!")
|
||
except ImportError:
|
||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||
except Exception as e:
|
||
logger.error(f"Failed to download models from ModelScope: {e}")
|
||
logger.info("You can download models manually from ModelScope website or CLI.")
|
||
else:
|
||
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||
|
||
|
||
def create_sample_config():
|
||
"""Create sample configuration files"""
|
||
logger.info("Creating sample configuration...")
|
||
|
||
base_dir = Path(__file__).parent.parent
|
||
config_file = base_dir / "config.toml"
|
||
|
||
if not config_file.exists():
|
||
# Create default config (this would match our config.toml artifact)
|
||
sample_config = """# Configuration file for BGE fine-tuning project
|
||
|
||
[api]
|
||
# OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.)
|
||
base_url = "https://api.deepseek.com/v1"
|
||
api_key = "your-api-key-here" # Replace with your actual API key
|
||
model = "deepseek-chat"
|
||
timeout = 30
|
||
|
||
[translation]
|
||
# Translation settings
|
||
enable_api = true
|
||
target_languages = ["en", "zh", "ja", "fr"]
|
||
max_retries = 3
|
||
fallback_to_simulation = true
|
||
|
||
[training]
|
||
# Training defaults
|
||
default_batch_size = 4
|
||
default_learning_rate = 1e-5
|
||
default_epochs = 3
|
||
default_warmup_ratio = 0.1
|
||
|
||
[model_paths]
|
||
# Default model paths
|
||
bge_m3 = "./models/bge-m3"
|
||
bge_reranker = "./models/bge-reranker-base"
|
||
|
||
[data]
|
||
# Data processing settings
|
||
max_query_length = 512
|
||
max_passage_length = 8192
|
||
min_passage_length = 10
|
||
validation_split = 0.1
|
||
"""
|
||
|
||
with open(config_file, 'w') as f:
|
||
f.write(sample_config)
|
||
|
||
logger.info(f"Created sample config: {config_file}")
|
||
logger.info("Please edit config.toml to add your API keys if needed.")
|
||
|
||
|
||
def create_sample_data():
|
||
"""Create sample data files"""
|
||
logger.info("Creating sample data...")
|
||
|
||
base_dir = Path(__file__).parent.parent
|
||
sample_file = base_dir / "data/raw/sample_data.jsonl"
|
||
|
||
if not sample_file.exists():
|
||
sample_data = [
|
||
{
|
||
"query": "什么是机器学习?",
|
||
"pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"],
|
||
"neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"]
|
||
},
|
||
{
|
||
"query": "How does neural network work?",
|
||
"pos": [
|
||
"Neural networks are computing systems inspired by biological neural networks that constitute the brain."],
|
||
"neg": ["Machine learning algorithms can be supervised or unsupervised.",
|
||
"Data preprocessing is an important step in ML."]
|
||
},
|
||
{
|
||
"query": "Python编程语言的特点",
|
||
"pos": ["Python是一种高级编程语言,具有简洁的语法和强大的功能,广泛用于数据科学和AI开发。"],
|
||
"neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"]
|
||
}
|
||
]
|
||
|
||
with open(sample_file, 'w', encoding='utf-8') as f:
|
||
for item in sample_data:
|
||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||
|
||
logger.info(f"Created sample data: {sample_file}")
|
||
|
||
|
||
def verify_installation():
|
||
"""Verify installation"""
|
||
logger.info("Verifying installation...")
|
||
|
||
try:
|
||
# Test imports
|
||
import torch
|
||
import transformers
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
logger.info("✓ Core packages imported successfully")
|
||
|
||
# Test GPU
|
||
if torch.cuda.is_available():
|
||
logger.info("✓ CUDA is available")
|
||
device = torch.device("cuda")
|
||
x = torch.randn(10, 10).to(device)
|
||
logger.info("✓ GPU tensor operations work")
|
||
else:
|
||
logger.info("✓ CPU-only setup")
|
||
|
||
# Test model loading (if models exist)
|
||
base_dir = Path(__file__).parent.parent
|
||
if (base_dir / "models/bge-m3").exists():
|
||
try:
|
||
from transformers import AutoTokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3"))
|
||
logger.info("✓ BGE-M3 model files are accessible")
|
||
except Exception as e:
|
||
logger.warning(f"BGE-M3 model check failed: {e}")
|
||
|
||
logger.info("Installation verification completed!")
|
||
|
||
except ImportError as e:
|
||
logger.error(f"Import error: {e}")
|
||
logger.error("Installation may be incomplete")
|
||
sys.exit(1)
|
||
|
||
|
||
def print_usage_info():
|
||
"""Print usage information"""
|
||
print("\n" + "=" * 60)
|
||
print("BGE FINE-TUNING SETUP COMPLETED!")
|
||
print("=" * 60)
|
||
print("\nQuick Start:")
|
||
print("1. Edit config.toml to add your API keys (optional)")
|
||
print("2. Prepare your training data in JSONL format")
|
||
print("3. Run training:")
|
||
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
|
||
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
|
||
print("4. Evaluate models:")
|
||
print(" python scripts/validate.py quick --retriever_model output/model")
|
||
print("\nSample data created in: data/raw/sample_data.jsonl")
|
||
print("Configuration file: config.toml")
|
||
print("\nFor more information, see the README.md file.")
|
||
print("=" * 60)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment")
|
||
parser.add_argument("--download-models", action="store_true",
|
||
help="Download base models from Hugging Face")
|
||
parser.add_argument("--cpu-only", action="store_true",
|
||
help="Install CPU-only version of PyTorch")
|
||
parser.add_argument("--skip-pytorch", action="store_true",
|
||
help="Skip PyTorch installation")
|
||
parser.add_argument("--skip-requirements", action="store_true",
|
||
help="Skip requirements installation")
|
||
|
||
args = parser.parse_args()
|
||
|
||
logger.info("Starting BGE fine-tuning environment setup...")
|
||
|
||
# Check Python version
|
||
check_python_version()
|
||
|
||
# Detect device type
|
||
if args.cpu_only:
|
||
device_type = "cpu"
|
||
else:
|
||
device_type = check_gpu()
|
||
|
||
# Install PyTorch
|
||
if not args.skip_pytorch:
|
||
install_pytorch(device_type)
|
||
|
||
# Install requirements
|
||
if not args.skip_requirements:
|
||
install_requirements()
|
||
|
||
# Setup directories
|
||
setup_directories()
|
||
|
||
# Download models
|
||
download_models(args.download_models)
|
||
|
||
# Create sample files
|
||
create_sample_config()
|
||
create_sample_data()
|
||
|
||
# Verify installation
|
||
verify_installation()
|
||
|
||
# Print usage info
|
||
print_usage_info()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|