91 lines
3.3 KiB
Python
91 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify the converted dataset can be loaded by BGE training system
|
|
"""
|
|
|
|
from data.dataset import BGEDataset
|
|
from transformers import AutoTokenizer
|
|
import logging
|
|
import sys
|
|
import os
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def test_dataset_loading(dataset_path: str):
|
|
"""Test that the dataset can be loaded properly"""
|
|
|
|
# Check if dataset file exists
|
|
if not os.path.exists(dataset_path):
|
|
logger.error(f"Dataset file not found: {dataset_path}")
|
|
return False
|
|
|
|
try:
|
|
# Load tokenizer
|
|
logger.info("Loading tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained('models/bge-m3')
|
|
|
|
# Test loading the converted dataset
|
|
logger.info("Loading dataset...")
|
|
dataset = BGEDataset(
|
|
data_path=dataset_path,
|
|
tokenizer=tokenizer,
|
|
query_max_length=64,
|
|
passage_max_length=512,
|
|
format_type='triplets'
|
|
)
|
|
|
|
logger.info(f"✅ Dataset loaded successfully!")
|
|
logger.info(f"📊 Total samples: {len(dataset)}")
|
|
logger.info(f"🔍 Detected format: {dataset.detected_format}")
|
|
|
|
# Test getting a sample
|
|
logger.info("Testing sample retrieval...")
|
|
sample = dataset[0]
|
|
|
|
logger.info(f"🎯 Sample keys: {list(sample.keys())}")
|
|
logger.info(f"🔤 Query shape: {sample['query_input_ids'].shape}")
|
|
logger.info(f"📄 Passages shape: {sample['passage_input_ids'].shape}")
|
|
logger.info(f"🏷️ Labels shape: {sample['labels'].shape}")
|
|
logger.info(f"📊 Scores shape: {sample['scores'].shape}")
|
|
|
|
# Show sample content
|
|
logger.info(f"🔍 Query text: {sample['query_text'][:100]}...")
|
|
logger.info(f"📝 First passage: {sample['passage_texts'][0][:100]}...")
|
|
logger.info(f"🏷️ Labels: {sample['labels']}")
|
|
logger.info(f"📊 Scores: {sample['scores']}")
|
|
|
|
# Test a few more samples
|
|
logger.info("Testing multiple samples...")
|
|
for i in [1, 10, 100]:
|
|
if i < len(dataset):
|
|
sample_i = dataset[i]
|
|
logger.info(f"Sample {i}: Query length={len(sample_i['query_text'])}, Passages={len([p for p in sample_i['passage_texts'] if p])}")
|
|
|
|
logger.info("✅ All tests passed! Dataset is ready for BGE-M3 training.")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Dataset loading failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def main():
|
|
dataset_path = "data/datasets/三国演义/m3_train_converted.jsonl"
|
|
|
|
logger.info("🧪 Testing converted BGE-M3 dataset...")
|
|
logger.info(f"📁 Dataset path: {dataset_path}")
|
|
|
|
success = test_dataset_loading(dataset_path)
|
|
|
|
if success:
|
|
logger.info("🎉 Dataset is ready for training!")
|
|
logger.info("💡 You can now use this dataset with:")
|
|
logger.info(" python scripts/train_m3.py --data_path data/datasets/三国演义/m3_train_converted.jsonl")
|
|
else:
|
|
logger.error("💥 Dataset testing failed!")
|
|
sys.exit(1)
|
|
|
|
if __name__ == '__main__':
|
|
main() |