init commit
This commit is contained in:
401
scripts/setup.py
Normal file
401
scripts/setup.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Setup script for BGE fine-tuning environment
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
import shutil
|
||||
import urllib.request
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_command(command, check=True):
|
||||
"""Run a shell command"""
|
||||
logger.info(f"Running: {command}")
|
||||
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||||
|
||||
if check and result.returncode != 0:
|
||||
logger.error(f"Command failed: {command}")
|
||||
logger.error(f"Error: {result.stderr}")
|
||||
sys.exit(1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_python_version():
|
||||
"""Check Python version"""
|
||||
if sys.version_info < (3, 8):
|
||||
logger.error("Python 3.8 or higher is required")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Python version: {sys.version}")
|
||||
|
||||
|
||||
def check_gpu():
|
||||
"""Check GPU availability"""
|
||||
try:
|
||||
result = run_command("nvidia-smi", check=False)
|
||||
if result.returncode == 0:
|
||||
logger.info("NVIDIA GPU detected")
|
||||
return "cuda"
|
||||
else:
|
||||
logger.info("No NVIDIA GPU detected")
|
||||
return "cpu"
|
||||
except:
|
||||
logger.info("nvidia-smi not available")
|
||||
return "cpu"
|
||||
|
||||
|
||||
def install_pytorch(device_type="cuda"):
|
||||
"""Install PyTorch"""
|
||||
logger.info("Installing PyTorch...")
|
||||
|
||||
if device_type == "cuda":
|
||||
# Install CUDA version for RTX 5090
|
||||
pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
|
||||
else:
|
||||
# Install CPU version
|
||||
pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
||||
run_command(pip_command)
|
||||
|
||||
# Verify installation
|
||||
try:
|
||||
import torch
|
||||
logger.info(f"PyTorch version: {torch.__version__}")
|
||||
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
logger.info(f"CUDA version: {torch.version.cuda}")
|
||||
logger.info(f"GPU count: {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
||||
except ImportError:
|
||||
logger.error("Failed to import PyTorch")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_requirements():
|
||||
"""Install requirements"""
|
||||
logger.info("Installing requirements...")
|
||||
|
||||
requirements_file = Path(__file__).parent.parent / "requirements.txt"
|
||||
|
||||
if requirements_file.exists():
|
||||
run_command(f"pip install -r {requirements_file}")
|
||||
else:
|
||||
logger.warning("requirements.txt not found, installing core packages...")
|
||||
|
||||
# Core packages
|
||||
packages = [
|
||||
"transformers>=4.36.0",
|
||||
"datasets>=2.14.0",
|
||||
"tokenizers>=0.15.0",
|
||||
"numpy>=1.24.0",
|
||||
"scipy>=1.10.0",
|
||||
"scikit-learn>=1.3.0",
|
||||
"pandas>=2.0.0",
|
||||
"FlagEmbedding>=1.2.0",
|
||||
"faiss-gpu>=1.7.4",
|
||||
"rank-bm25>=0.2.2",
|
||||
"accelerate>=0.25.0",
|
||||
"tensorboard>=2.14.0",
|
||||
"tqdm>=4.66.0",
|
||||
"pyyaml>=6.0",
|
||||
"jsonlines>=3.1.0",
|
||||
"huggingface-hub>=0.19.0",
|
||||
"jieba>=0.42.1",
|
||||
"toml>=0.10.2",
|
||||
"requests>=2.28.0"
|
||||
]
|
||||
|
||||
for package in packages:
|
||||
run_command(f"pip install {package}")
|
||||
|
||||
|
||||
def setup_directories():
|
||||
"""Setup project directories"""
|
||||
logger.info("Setting up directories...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
directories = [
|
||||
"models",
|
||||
"models/bge-m3",
|
||||
"models/bge-reranker-base",
|
||||
"data/raw",
|
||||
"data/processed",
|
||||
"output",
|
||||
"logs",
|
||||
"cache/data"
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
dir_path = base_dir / directory
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created directory: {dir_path}")
|
||||
|
||||
|
||||
def download_models(download_models_flag=False):
|
||||
"""Download base models from HuggingFace or ModelScope based on config.toml"""
|
||||
if not download_models_flag:
|
||||
logger.info("Skipping model download. Use --download-models to download base models.")
|
||||
return
|
||||
|
||||
logger.info("Downloading base models...")
|
||||
base_dir = Path(__file__).parent.parent
|
||||
config = get_config()
|
||||
model_source = config.get('model_paths.source', 'huggingface')
|
||||
|
||||
if model_source == 'huggingface':
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
# Download BGE-M3
|
||||
logger.info("Downloading BGE-M3 from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-m3",
|
||||
local_dir=base_dir / "models/bge-m3",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
# Download BGE-Reranker
|
||||
logger.info("Downloading BGE-Reranker from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-reranker-base",
|
||||
local_dir=base_dir / "models/bge-reranker-base",
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
logger.info("Models downloaded successfully from HuggingFace!")
|
||||
except ImportError:
|
||||
logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download models from HuggingFace: {e}")
|
||||
logger.info("You can download models manually or use the Hugging Face CLI:")
|
||||
logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3")
|
||||
logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base")
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
# Download BGE-M3
|
||||
logger.info("Downloading BGE-M3 from ModelScope...")
|
||||
ms_snapshot_download(
|
||||
model_id="damo/nlp_corom_sentence-embedding_chinese-base",
|
||||
cache_dir=str(base_dir / "models/bge-m3")
|
||||
)
|
||||
# Download BGE-Reranker
|
||||
logger.info("Downloading BGE-Reranker from ModelScope...")
|
||||
ms_snapshot_download(
|
||||
model_id="damo/nlp_corom_cross-encoder_chinese-base",
|
||||
cache_dir=str(base_dir / "models/bge-reranker-base")
|
||||
)
|
||||
logger.info("Models downloaded successfully from ModelScope!")
|
||||
except ImportError:
|
||||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download models from ModelScope: {e}")
|
||||
logger.info("You can download models manually from ModelScope website or CLI.")
|
||||
else:
|
||||
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||||
|
||||
|
||||
def create_sample_config():
|
||||
"""Create sample configuration files"""
|
||||
logger.info("Creating sample configuration...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
config_file = base_dir / "config.toml"
|
||||
|
||||
if not config_file.exists():
|
||||
# Create default config (this would match our config.toml artifact)
|
||||
sample_config = """# Configuration file for BGE fine-tuning project
|
||||
|
||||
[api]
|
||||
# OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.)
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "your-api-key-here" # Replace with your actual API key
|
||||
model = "deepseek-chat"
|
||||
timeout = 30
|
||||
|
||||
[translation]
|
||||
# Translation settings
|
||||
enable_api = true
|
||||
target_languages = ["en", "zh", "ja", "fr"]
|
||||
max_retries = 3
|
||||
fallback_to_simulation = true
|
||||
|
||||
[training]
|
||||
# Training defaults
|
||||
default_batch_size = 4
|
||||
default_learning_rate = 1e-5
|
||||
default_epochs = 3
|
||||
default_warmup_ratio = 0.1
|
||||
|
||||
[model_paths]
|
||||
# Default model paths
|
||||
bge_m3 = "./models/bge-m3"
|
||||
bge_reranker = "./models/bge-reranker-base"
|
||||
|
||||
[data]
|
||||
# Data processing settings
|
||||
max_query_length = 512
|
||||
max_passage_length = 8192
|
||||
min_passage_length = 10
|
||||
validation_split = 0.1
|
||||
"""
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
f.write(sample_config)
|
||||
|
||||
logger.info(f"Created sample config: {config_file}")
|
||||
logger.info("Please edit config.toml to add your API keys if needed.")
|
||||
|
||||
|
||||
def create_sample_data():
|
||||
"""Create sample data files"""
|
||||
logger.info("Creating sample data...")
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
sample_file = base_dir / "data/raw/sample_data.jsonl"
|
||||
|
||||
if not sample_file.exists():
|
||||
sample_data = [
|
||||
{
|
||||
"query": "什么是机器学习?",
|
||||
"pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"],
|
||||
"neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"]
|
||||
},
|
||||
{
|
||||
"query": "How does neural network work?",
|
||||
"pos": [
|
||||
"Neural networks are computing systems inspired by biological neural networks that constitute the brain."],
|
||||
"neg": ["Machine learning algorithms can be supervised or unsupervised.",
|
||||
"Data preprocessing is an important step in ML."]
|
||||
},
|
||||
{
|
||||
"query": "Python编程语言的特点",
|
||||
"pos": ["Python是一种高级编程语言,具有简洁的语法和强大的功能,广泛用于数据科学和AI开发。"],
|
||||
"neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"]
|
||||
}
|
||||
]
|
||||
|
||||
with open(sample_file, 'w', encoding='utf-8') as f:
|
||||
for item in sample_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
logger.info(f"Created sample data: {sample_file}")
|
||||
|
||||
|
||||
def verify_installation():
|
||||
"""Verify installation"""
|
||||
logger.info("Verifying installation...")
|
||||
|
||||
try:
|
||||
# Test imports
|
||||
import torch
|
||||
import transformers
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger.info("✓ Core packages imported successfully")
|
||||
|
||||
# Test GPU
|
||||
if torch.cuda.is_available():
|
||||
logger.info("✓ CUDA is available")
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn(10, 10).to(device)
|
||||
logger.info("✓ GPU tensor operations work")
|
||||
else:
|
||||
logger.info("✓ CPU-only setup")
|
||||
|
||||
# Test model loading (if models exist)
|
||||
base_dir = Path(__file__).parent.parent
|
||||
if (base_dir / "models/bge-m3").exists():
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3"))
|
||||
logger.info("✓ BGE-M3 model files are accessible")
|
||||
except Exception as e:
|
||||
logger.warning(f"BGE-M3 model check failed: {e}")
|
||||
|
||||
logger.info("Installation verification completed!")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Import error: {e}")
|
||||
logger.error("Installation may be incomplete")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def print_usage_info():
|
||||
"""Print usage information"""
|
||||
print("\n" + "=" * 60)
|
||||
print("BGE FINE-TUNING SETUP COMPLETED!")
|
||||
print("=" * 60)
|
||||
print("\nQuick Start:")
|
||||
print("1. Edit config.toml to add your API keys (optional)")
|
||||
print("2. Prepare your training data in JSONL format")
|
||||
print("3. Run training:")
|
||||
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
|
||||
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
|
||||
print("4. Evaluate models:")
|
||||
print(" python scripts/evaluate.py --retriever_model output/model --eval_data data/raw/test.jsonl")
|
||||
print("\nSample data created in: data/raw/sample_data.jsonl")
|
||||
print("Configuration file: config.toml")
|
||||
print("\nFor more information, see the README.md file.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment")
|
||||
parser.add_argument("--download-models", action="store_true",
|
||||
help="Download base models from Hugging Face")
|
||||
parser.add_argument("--cpu-only", action="store_true",
|
||||
help="Install CPU-only version of PyTorch")
|
||||
parser.add_argument("--skip-pytorch", action="store_true",
|
||||
help="Skip PyTorch installation")
|
||||
parser.add_argument("--skip-requirements", action="store_true",
|
||||
help="Skip requirements installation")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Starting BGE fine-tuning environment setup...")
|
||||
|
||||
# Check Python version
|
||||
check_python_version()
|
||||
|
||||
# Detect device type
|
||||
if args.cpu_only:
|
||||
device_type = "cpu"
|
||||
else:
|
||||
device_type = check_gpu()
|
||||
|
||||
# Install PyTorch
|
||||
if not args.skip_pytorch:
|
||||
install_pytorch(device_type)
|
||||
|
||||
# Install requirements
|
||||
if not args.skip_requirements:
|
||||
install_requirements()
|
||||
|
||||
# Setup directories
|
||||
setup_directories()
|
||||
|
||||
# Download models
|
||||
download_models(args.download_models)
|
||||
|
||||
# Create sample files
|
||||
create_sample_config()
|
||||
create_sample_data()
|
||||
|
||||
# Verify installation
|
||||
verify_installation()
|
||||
|
||||
# Print usage info
|
||||
print_usage_info()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user