init commit
This commit is contained in:
91
scripts/test_converted_dataset.py
Normal file
91
scripts/test_converted_dataset.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify the converted dataset can be loaded by BGE training system
|
||||
"""
|
||||
|
||||
from data.dataset import BGEDataset
|
||||
from transformers import AutoTokenizer
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_dataset_loading(dataset_path: str):
|
||||
"""Test that the dataset can be loaded properly"""
|
||||
|
||||
# Check if dataset file exists
|
||||
if not os.path.exists(dataset_path):
|
||||
logger.error(f"Dataset file not found: {dataset_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load tokenizer
|
||||
logger.info("Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('models/bge-m3')
|
||||
|
||||
# Test loading the converted dataset
|
||||
logger.info("Loading dataset...")
|
||||
dataset = BGEDataset(
|
||||
data_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=64,
|
||||
passage_max_length=512,
|
||||
format_type='triplets'
|
||||
)
|
||||
|
||||
logger.info(f"✅ Dataset loaded successfully!")
|
||||
logger.info(f"📊 Total samples: {len(dataset)}")
|
||||
logger.info(f"🔍 Detected format: {dataset.detected_format}")
|
||||
|
||||
# Test getting a sample
|
||||
logger.info("Testing sample retrieval...")
|
||||
sample = dataset[0]
|
||||
|
||||
logger.info(f"🎯 Sample keys: {list(sample.keys())}")
|
||||
logger.info(f"🔤 Query shape: {sample['query_input_ids'].shape}")
|
||||
logger.info(f"📄 Passages shape: {sample['passage_input_ids'].shape}")
|
||||
logger.info(f"🏷️ Labels shape: {sample['labels'].shape}")
|
||||
logger.info(f"📊 Scores shape: {sample['scores'].shape}")
|
||||
|
||||
# Show sample content
|
||||
logger.info(f"🔍 Query text: {sample['query_text'][:100]}...")
|
||||
logger.info(f"📝 First passage: {sample['passage_texts'][0][:100]}...")
|
||||
logger.info(f"🏷️ Labels: {sample['labels']}")
|
||||
logger.info(f"📊 Scores: {sample['scores']}")
|
||||
|
||||
# Test a few more samples
|
||||
logger.info("Testing multiple samples...")
|
||||
for i in [1, 10, 100]:
|
||||
if i < len(dataset):
|
||||
sample_i = dataset[i]
|
||||
logger.info(f"Sample {i}: Query length={len(sample_i['query_text'])}, Passages={len([p for p in sample_i['passage_texts'] if p])}")
|
||||
|
||||
logger.info("✅ All tests passed! Dataset is ready for BGE-M3 training.")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Dataset loading failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
dataset_path = "data/datasets/三国演义/m3_train_converted.jsonl"
|
||||
|
||||
logger.info("🧪 Testing converted BGE-M3 dataset...")
|
||||
logger.info(f"📁 Dataset path: {dataset_path}")
|
||||
|
||||
success = test_dataset_loading(dataset_path)
|
||||
|
||||
if success:
|
||||
logger.info("🎉 Dataset is ready for training!")
|
||||
logger.info("💡 You can now use this dataset with:")
|
||||
logger.info(" python scripts/train_m3.py --data_path data/datasets/三国演义/m3_train_converted.jsonl")
|
||||
else:
|
||||
logger.error("💥 Dataset testing failed!")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user