""" Setup script for BGE fine-tuning environment """ import os import sys import subprocess import argparse import shutil import urllib.request import json from pathlib import Path import logging from utils.config_loader import get_config logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def run_command(command, check=True): """Run a shell command""" logger.info(f"Running: {command}") result = subprocess.run(command, shell=True, capture_output=True, text=True) if check and result.returncode != 0: logger.error(f"Command failed: {command}") logger.error(f"Error: {result.stderr}") sys.exit(1) return result def check_python_version(): """Check Python version""" if sys.version_info < (3, 8): logger.error("Python 3.8 or higher is required") sys.exit(1) logger.info(f"Python version: {sys.version}") def check_gpu(): """Check GPU availability""" try: result = run_command("nvidia-smi", check=False) if result.returncode == 0: logger.info("NVIDIA GPU detected") return "cuda" else: logger.info("No NVIDIA GPU detected") return "cpu" except: logger.info("nvidia-smi not available") return "cpu" def install_pytorch(device_type="cuda"): """Install PyTorch""" logger.info("Installing PyTorch...") if device_type == "cuda": # Install CUDA version for RTX 5090 pip_command = "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128" else: # Install CPU version pip_command = "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu" run_command(pip_command) # Verify installation try: import torch logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"CUDA version: {torch.version.cuda}") logger.info(f"GPU count: {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}") except ImportError: logger.error("Failed to import PyTorch") sys.exit(1) def install_requirements(): """Install requirements""" logger.info("Installing requirements...") requirements_file = Path(__file__).parent.parent / "requirements.txt" if requirements_file.exists(): run_command(f"pip install -r {requirements_file}") else: logger.warning("requirements.txt not found, installing core packages...") # Core packages packages = [ "transformers>=4.36.0", "datasets>=2.14.0", "tokenizers>=0.15.0", "numpy>=1.24.0", "scipy>=1.10.0", "scikit-learn>=1.3.0", "pandas>=2.0.0", "FlagEmbedding>=1.2.0", "faiss-gpu>=1.7.4", "rank-bm25>=0.2.2", "accelerate>=0.25.0", "tensorboard>=2.14.0", "tqdm>=4.66.0", "pyyaml>=6.0", "jsonlines>=3.1.0", "huggingface-hub>=0.19.0", "jieba>=0.42.1", "toml>=0.10.2", "requests>=2.28.0" ] for package in packages: run_command(f"pip install {package}") def setup_directories(): """Setup project directories""" logger.info("Setting up directories...") base_dir = Path(__file__).parent.parent directories = [ "models", "models/bge-m3", "models/bge-reranker-base", "data/raw", "data/processed", "output", "logs", "cache/data" ] for directory in directories: dir_path = base_dir / directory dir_path.mkdir(parents=True, exist_ok=True) logger.info(f"Created directory: {dir_path}") def download_models(download_models_flag=False): """Download base models from HuggingFace or ModelScope based on config.toml""" if not download_models_flag: logger.info("Skipping model download. Use --download-models to download base models.") return logger.info("Downloading base models...") base_dir = Path(__file__).parent.parent config = get_config() model_source = config.get('model_paths.source', 'huggingface') if model_source == 'huggingface': try: from huggingface_hub import snapshot_download # Download BGE-M3 logger.info("Downloading BGE-M3 from HuggingFace...") snapshot_download( repo_id="BAAI/bge-m3", local_dir=base_dir / "models/bge-m3", local_dir_use_symlinks=False ) # Download BGE-Reranker logger.info("Downloading BGE-Reranker from HuggingFace...") snapshot_download( repo_id="BAAI/bge-reranker-base", local_dir=base_dir / "models/bge-reranker-base", local_dir_use_symlinks=False ) logger.info("Models downloaded successfully from HuggingFace!") except ImportError: logger.error("huggingface_hub not available. Install it first: pip install huggingface_hub") except Exception as e: logger.error(f"Failed to download models from HuggingFace: {e}") logger.info("You can download models manually or use the Hugging Face CLI:") logger.info(" huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3") logger.info(" huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base") elif model_source == 'modelscope': try: from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download # Download BGE-M3 logger.info("Downloading BGE-M3 from ModelScope...") ms_snapshot_download( model_id="damo/nlp_corom_sentence-embedding_chinese-base", cache_dir=str(base_dir / "models/bge-m3") ) # Download BGE-Reranker logger.info("Downloading BGE-Reranker from ModelScope...") ms_snapshot_download( model_id="damo/nlp_corom_cross-encoder_chinese-base", cache_dir=str(base_dir / "models/bge-reranker-base") ) logger.info("Models downloaded successfully from ModelScope!") except ImportError: logger.error("modelscope not available. Install it first: pip install modelscope") except Exception as e: logger.error(f"Failed to download models from ModelScope: {e}") logger.info("You can download models manually from ModelScope website or CLI.") else: logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.") def create_sample_config(): """Create sample configuration files""" logger.info("Creating sample configuration...") base_dir = Path(__file__).parent.parent config_file = base_dir / "config.toml" if not config_file.exists(): # Create default config (this would match our config.toml artifact) sample_config = """# Configuration file for BGE fine-tuning project [api] # OpenAI-compatible API configuration (e.g., DeepSeek, OpenAI, etc.) base_url = "https://api.deepseek.com/v1" api_key = "your-api-key-here" # Replace with your actual API key model = "deepseek-chat" timeout = 30 [translation] # Translation settings enable_api = true target_languages = ["en", "zh", "ja", "fr"] max_retries = 3 fallback_to_simulation = true [training] # Training defaults default_batch_size = 4 default_learning_rate = 1e-5 default_epochs = 3 default_warmup_ratio = 0.1 [model_paths] # Default model paths bge_m3 = "./models/bge-m3" bge_reranker = "./models/bge-reranker-base" [data] # Data processing settings max_query_length = 512 max_passage_length = 8192 min_passage_length = 10 validation_split = 0.1 """ with open(config_file, 'w') as f: f.write(sample_config) logger.info(f"Created sample config: {config_file}") logger.info("Please edit config.toml to add your API keys if needed.") def create_sample_data(): """Create sample data files""" logger.info("Creating sample data...") base_dir = Path(__file__).parent.parent sample_file = base_dir / "data/raw/sample_data.jsonl" if not sample_file.exists(): sample_data = [ { "query": "什么是机器学习?", "pos": ["机器学习是人工智能的一个分支,它使用算法让计算机能够从数据中学习和做出决策。"], "neg": ["深度学习是神经网络的一种形式。", "自然语言处理是计算机科学的一个领域。"] }, { "query": "How does neural network work?", "pos": [ "Neural networks are computing systems inspired by biological neural networks that constitute the brain."], "neg": ["Machine learning algorithms can be supervised or unsupervised.", "Data preprocessing is an important step in ML."] }, { "query": "Python编程语言的特点", "pos": ["Python是一种高级编程语言,具有简洁的语法和强大的功能,广泛用于数据科学和AI开发。"], "neg": ["JavaScript是一种脚本语言。", "Java是一种面向对象的编程语言。"] } ] with open(sample_file, 'w', encoding='utf-8') as f: for item in sample_data: f.write(json.dumps(item, ensure_ascii=False) + '\n') logger.info(f"Created sample data: {sample_file}") def verify_installation(): """Verify installation""" logger.info("Verifying installation...") try: # Test imports import torch import transformers import numpy as np import pandas as pd logger.info("✓ Core packages imported successfully") # Test GPU if torch.cuda.is_available(): logger.info("✓ CUDA is available") device = torch.device("cuda") x = torch.randn(10, 10).to(device) logger.info("✓ GPU tensor operations work") else: logger.info("✓ CPU-only setup") # Test model loading (if models exist) base_dir = Path(__file__).parent.parent if (base_dir / "models/bge-m3").exists(): try: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models/bge-m3")) logger.info("✓ BGE-M3 model files are accessible") except Exception as e: logger.warning(f"BGE-M3 model check failed: {e}") logger.info("Installation verification completed!") except ImportError as e: logger.error(f"Import error: {e}") logger.error("Installation may be incomplete") sys.exit(1) def print_usage_info(): """Print usage information""" print("\n" + "=" * 60) print("BGE FINE-TUNING SETUP COMPLETED!") print("=" * 60) print("\nQuick Start:") print("1. Edit config.toml to add your API keys (optional)") print("2. Prepare your training data in JSONL format") print("3. Run training:") print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl") print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl") print("4. Evaluate models:") print(" python scripts/evaluate.py --retriever_model output/model --eval_data data/raw/test.jsonl") print("\nSample data created in: data/raw/sample_data.jsonl") print("Configuration file: config.toml") print("\nFor more information, see the README.md file.") print("=" * 60) def main(): parser = argparse.ArgumentParser(description="Setup BGE fine-tuning environment") parser.add_argument("--download-models", action="store_true", help="Download base models from Hugging Face") parser.add_argument("--cpu-only", action="store_true", help="Install CPU-only version of PyTorch") parser.add_argument("--skip-pytorch", action="store_true", help="Skip PyTorch installation") parser.add_argument("--skip-requirements", action="store_true", help="Skip requirements installation") args = parser.parse_args() logger.info("Starting BGE fine-tuning environment setup...") # Check Python version check_python_version() # Detect device type if args.cpu_only: device_type = "cpu" else: device_type = check_gpu() # Install PyTorch if not args.skip_pytorch: install_pytorch(device_type) # Install requirements if not args.skip_requirements: install_requirements() # Setup directories setup_directories() # Download models download_models(args.download_models) # Create sample files create_sample_config() create_sample_data() # Verify installation verify_installation() # Print usage info print_usage_info() if __name__ == "__main__": main()