bge_finetune/scripts/test_converted_dataset.py
2025-07-22 16:55:25 +08:00

91 lines
3.3 KiB
Python

#!/usr/bin/env python3
"""
Test script to verify the converted dataset can be loaded by BGE training system
"""
from data.dataset import BGEDataset
from transformers import AutoTokenizer
import logging
import sys
import os
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_dataset_loading(dataset_path: str):
"""Test that the dataset can be loaded properly"""
# Check if dataset file exists
if not os.path.exists(dataset_path):
logger.error(f"Dataset file not found: {dataset_path}")
return False
try:
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('models/bge-m3')
# Test loading the converted dataset
logger.info("Loading dataset...")
dataset = BGEDataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=64,
passage_max_length=512,
format_type='triplets'
)
logger.info(f"✅ Dataset loaded successfully!")
logger.info(f"📊 Total samples: {len(dataset)}")
logger.info(f"🔍 Detected format: {dataset.detected_format}")
# Test getting a sample
logger.info("Testing sample retrieval...")
sample = dataset[0]
logger.info(f"🎯 Sample keys: {list(sample.keys())}")
logger.info(f"🔤 Query shape: {sample['query_input_ids'].shape}")
logger.info(f"📄 Passages shape: {sample['passage_input_ids'].shape}")
logger.info(f"🏷️ Labels shape: {sample['labels'].shape}")
logger.info(f"📊 Scores shape: {sample['scores'].shape}")
# Show sample content
logger.info(f"🔍 Query text: {sample['query_text'][:100]}...")
logger.info(f"📝 First passage: {sample['passage_texts'][0][:100]}...")
logger.info(f"🏷️ Labels: {sample['labels']}")
logger.info(f"📊 Scores: {sample['scores']}")
# Test a few more samples
logger.info("Testing multiple samples...")
for i in [1, 10, 100]:
if i < len(dataset):
sample_i = dataset[i]
logger.info(f"Sample {i}: Query length={len(sample_i['query_text'])}, Passages={len([p for p in sample_i['passage_texts'] if p])}")
logger.info("✅ All tests passed! Dataset is ready for BGE-M3 training.")
return True
except Exception as e:
logger.error(f"❌ Dataset loading failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
dataset_path = "data/datasets/三国演义/m3_train_converted.jsonl"
logger.info("🧪 Testing converted BGE-M3 dataset...")
logger.info(f"📁 Dataset path: {dataset_path}")
success = test_dataset_loading(dataset_path)
if success:
logger.info("🎉 Dataset is ready for training!")
logger.info("💡 You can now use this dataset with:")
logger.info(" python scripts/train_m3.py --data_path data/datasets/三国演义/m3_train_converted.jsonl")
else:
logger.error("💥 Dataset testing failed!")
sys.exit(1)
if __name__ == '__main__':
main()