2025-07-23 14:54:46 +08:00

402 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Setup script for BGE fine-tuning environment
"""
import os
import sys
import subprocess
import argparse
import shutil
import urllib.request
import json
from pathlib import Path
import logging
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def run_command(command, check=True):
"""Run a shell command"""
logger.info(f"Running: {command}")
result = subprocess.run(command, shell=True, capture_output=True, text=True)
if check and result.returncode != 0:
logger.error(f"Command failed: {command}")
logger.error(f"Error: {result.stderr}")
sys.exit(1)
return result
def check_python_version():
"""Check Python version"""
if sys.version_info < (3, 8):
logger.error("Python 3.8 or higher is required")
sys.exit(1)
logger.info(f"Python version: {sys.version}")
def check_gpu():
"""Check GPU availability"""
try:
result = run_command("nvidia-smi", check=False)
if result.returncode == 0:
logger.info("NVIDIA GPU detected")
return "cuda"
else:
logger.info("No NVIDIA GPU detected")
return "cpu"
except:
logger.info("nvidia-smi not available")
return "cpu"
def install_pytorch(device_type="cuda"):
"""Install PyTorch"""
logger.info("Installing PyTorch...")
if device_type == "cuda":
# Install CUDA version for RTX 5090
pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
else:
# Install CPU version
pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
run_command(pip_command)
# Verify installation
try:
import torch
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
except ImportError:
logger.error("Failed to import PyTorch")
sys.exit(1)
def install_requirements():
"""Install requirements"""
logger.info("Installing requirements...")
requirements_file = Path(__file__).parent.parent / "requirements.txt"
if requirements_file.exists():
run_command(f"pip install -r {requirements_file}")
else:
logger.warning("requirements.txt not found, installing core packages...")
# Core packages
packages = [
"transformers>=4.36.0",
"datasets>=2.14.0",
"tokenizers>=0.15.0",
"numpy>=1.24.0",
"scipy>=1.10.0",
"scikit-learn>=1.3.0",
"pandas>=2.0.0",
"FlagEmbedding>=1.2.0",
"faiss-gpu>=1.7.4",
"rank-bm25>=0.2.2",
"accelerate>=0.25.0",
"tensorboard>=2.14.0",
"tqdm>=4.66.0",
"pyyaml>=6.0",
"jsonlines>=3.1.0",
"huggingface-hub>=0.19.0",
"jieba>=0.42.1",
"toml>=0.10.2",
"requests>=2.28.0"
]
for package in packages:
run_command(f"pip install {package}")
def setup_directories():
"""Setup project directories"""
logger.info("Setting up directories...")
base_dir = Path(__file__).parent.parent
directories = [
"models",
"models/bge-m3",
"models/bge-reranker-base",
"data/raw",
"data/processed",
"output",
"logs",
"cache/data"
]
for directory in directories:
dir_path = base_dir / directory
dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created directory: {dir_path}")
def download_models(download_models_flag=False):
"""Download base models from HuggingFace or ModelScope based on config.toml"""
if not download_models_flag:
logger.info("Skipping model download. Use --download-models to download base models.")
return
logger.info("Downloading base models...")
base_dir = Path(__file__).parent.parent
config = get_config()
model_source = config.get('model_paths.source', 'huggingface')
if model_source == 'huggingface':
try:
from huggingface_hub import snapshot_download
# Download BGE-M3
logger.info("Downloading BGE-M3 from HuggingFace...")
snapshot_download(
repo_id="BAAI/bge-m3",
local_dir=base_dir / "models/bge-m3",
local_dir_use_symlinks=False
)
# Download BGE-Reranker
logger.info("Downloading BGE-Reranker from HuggingFace...")
snapshot_download(
repo_id="BAAI/bge-reranker-base",
local_dir=base_dir / "models/bge-reranker-base",
local_dir_use_symlinks=False
)
logger.info("Models downloaded successfully from HuggingFace!")
except ImportError:
logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub")
except Exception as e:
logger.error(f"Failed to download models from HuggingFace: {e}")
logger.info("You can download models manually or use the Hugging Face CLI:")
logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3")
logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base")
elif model_source == 'modelscope':
try:
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
# Download BGE-M3
logger.info("Downloading BGE-M3 from ModelScope...")
ms_snapshot_download(
model_id="damo/nlp_corom_sentence-embedding_chinese-base",
cache_dir=str(base_dir / "models/bge-m3")
)
# Download BGE-Reranker
logger.info("Downloading BGE-Reranker from ModelScope...")
ms_snapshot_download(
model_id="damo/nlp_corom_cross-encoder_chinese-base",
cache_dir=str(base_dir / "models/bge-reranker-base")
)
logger.info("Models downloaded successfully from ModelScope!")
except ImportError:
logger.error("modelscope not available. Install it first: pip install modelscope")
except Exception as e:
logger.error(f"Failed to download models from ModelScope: {e}")
logger.info("You can download models manually from ModelScope website or CLI.")
else:
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
def create_sample_config():
"""Create sample configuration files"""
logger.info("Creating sample configuration...")
base_dir = Path(__file__).parent.parent
config_file = base_dir / "config.toml"
if not config_file.exists():
# Create default config (this would match our config.toml artifact)
sample_config = """# Configuration file for BGE fine-tuning project
[api]
# OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.)
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here" # Replace with your actual API key
model = "deepseek-chat"
timeout = 30
[translation]
# Translation settings
enable_api = true
target_languages = ["en", "zh", "ja", "fr"]
max_retries = 3
fallback_to_simulation = true
[training]
# Training defaults
default_batch_size = 4
default_learning_rate = 1e-5
default_epochs = 3
default_warmup_ratio = 0.1
[model_paths]
# Default model paths
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
[data]
# Data processing settings
max_query_length = 512
max_passage_length = 8192
min_passage_length = 10
validation_split = 0.1
"""
with open(config_file, 'w') as f:
f.write(sample_config)
logger.info(f"Created sample config: {config_file}")
logger.info("Please edit config.toml to add your API keys if needed.")
def create_sample_data():
"""Create sample data files"""
logger.info("Creating sample data...")
base_dir = Path(__file__).parent.parent
sample_file = base_dir / "data/raw/sample_data.jsonl"
if not sample_file.exists():
sample_data = [
{
"query": "什么是机器学习?",
"pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"],
"neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"]
},
{
"query": "How does neural network work?",
"pos": [
"Neural networks are computing systems inspired by biological neural networks that constitute the brain."],
"neg": ["Machine learning algorithms can be supervised or unsupervised.",
"Data preprocessing is an important step in ML."]
},
{
"query": "Python编程语言的特点",
"pos": ["Python是一种高级编程语言具有简洁的语法和强大的功能广泛用于数据科学和AI开发。"],
"neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"]
}
]
with open(sample_file, 'w', encoding='utf-8') as f:
for item in sample_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logger.info(f"Created sample data: {sample_file}")
def verify_installation():
"""Verify installation"""
logger.info("Verifying installation...")
try:
# Test imports
import torch
import transformers
import numpy as np
import pandas as pd
logger.info("✓ Core packages imported successfully")
# Test GPU
if torch.cuda.is_available():
logger.info("✓ CUDA is available")
device = torch.device("cuda")
x = torch.randn(10, 10).to(device)
logger.info("✓ GPU tensor operations work")
else:
logger.info("✓ CPU-only setup")
# Test model loading (if models exist)
base_dir = Path(__file__).parent.parent
if (base_dir / "models/bge-m3").exists():
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3"))
logger.info("✓ BGE-M3 model files are accessible")
except Exception as e:
logger.warning(f"BGE-M3 model check failed: {e}")
logger.info("Installation verification completed!")
except ImportError as e:
logger.error(f"Import error: {e}")
logger.error("Installation may be incomplete")
sys.exit(1)
def print_usage_info():
"""Print usage information"""
print("\n" + "=" * 60)
print("BGE FINE-TUNING SETUP COMPLETED!")
print("=" * 60)
print("\nQuick Start:")
print("1. Edit config.toml to add your API keys (optional)")
print("2. Prepare your training data in JSONL format")
print("3. Run training:")
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
print("4. Evaluate models:")
print(" python scripts/validate.py quick --retriever_model output/model")
print("\nSample data created in: data/raw/sample_data.jsonl")
print("Configuration file: config.toml")
print("\nFor more information, see the README.md file.")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment")
parser.add_argument("--download-models", action="store_true",
help="Download base models from Hugging Face")
parser.add_argument("--cpu-only", action="store_true",
help="Install CPU-only version of PyTorch")
parser.add_argument("--skip-pytorch", action="store_true",
help="Skip PyTorch installation")
parser.add_argument("--skip-requirements", action="store_true",
help="Skip requirements installation")
args = parser.parse_args()
logger.info("Starting BGE fine-tuning environment setup...")
# Check Python version
check_python_version()
# Detect device type
if args.cpu_only:
device_type = "cpu"
else:
device_type = check_gpu()
# Install PyTorch
if not args.skip_pytorch:
install_pytorch(device_type)
# Install requirements
if not args.skip_requirements:
install_requirements()
# Setup directories
setup_directories()
# Download models
download_models(args.download_models)
# Create sample files
create_sample_config()
create_sample_data()
# Verify installation
verify_installation()
# Print usage info
print_usage_info()
if __name__ == "__main__":
main()