bge_finetune/tests/test_data_edge_cases.py
2025-07-22 16:55:25 +08:00

100 lines
3.4 KiB
Python

import os
import sys
import tempfile
import pytest
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.preprocessing import DataPreprocessor
from data.augmentation import QueryAugmenter
@pytest.fixture
def tmp_jsonl_file():
# Create a malformed JSONL file
fd, path = tempfile.mkstemp(suffix='.jsonl')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('{"query": "valid", "pos": ["a"], "neg": ["b"]}\n')
f.write('{invalid json}\n')
f.write('')
yield path
os.remove(path)
@pytest.fixture
def tmp_csv_file():
fd, path = tempfile.mkstemp(suffix='.csv')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('query,pos,neg\n')
f.write('valid,"a","b"\n')
f.write('malformed line\n')
yield path
os.remove(path)
@pytest.fixture
def tmp_tsv_file():
fd, path = tempfile.mkstemp(suffix='.tsv')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write('query\tpos\tneg\n')
f.write('valid\ta\tb\n')
f.write('malformed line\n')
yield path
os.remove(path)
def test_jsonl_malformed_lines(tmp_jsonl_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_jsonl(tmp_jsonl_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_csv_malformed_lines(tmp_csv_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_csv(tmp_csv_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_tsv_malformed_lines(tmp_tsv_file):
preprocessor = DataPreprocessor(tokenizer=None)
# Should skip malformed line and not raise
data = preprocessor._load_tsv(tmp_tsv_file)
assert isinstance(data, list)
assert len(data) == 1
assert data[0]['query'] == 'valid'
def test_augmentation_methods_config():
# Should use config-driven augmentation methods
augmenter = QueryAugmenter(augmentation_methods=None, config_path='config.toml')
assert isinstance(augmenter.augmentation_methods, list)
assert 'paraphrase' in augmenter.augmentation_methods
def test_config_list_parsing():
"""Test that config loader parses comma/semicolon-separated lists correctly."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {'list_param': 'a, b; c'}}
value = loader.get('section', 'list_param', default=[])
assert isinstance(value, list)
assert set(value) == {'a', 'b', 'c'}
def test_config_param_source_logging(caplog):
"""Test that config loader logs parameter source correctly."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {'param': 42}}
with caplog.at_level('INFO'):
value = loader.get('section', 'param', default=0)
assert 'source: config.toml' in caplog.text
value = loader.get('section', 'missing', default=99)
assert 'source: default' in caplog.text
def test_config_fallback_default():
"""Test that config loader falls back to default if key is missing."""
from utils.config_loader import ConfigLoader
loader = ConfigLoader()
loader._config = {'section': {}}
value = loader.get('section', 'not_found', default='fallback')
assert value == 'fallback'