100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
import os
|
|
import sys
|
|
import tempfile
|
|
import pytest
|
|
|
|
# Add project root to Python path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from data.preprocessing import DataPreprocessor
|
|
from data.augmentation import QueryAugmenter
|
|
|
|
@pytest.fixture
|
|
def tmp_jsonl_file():
|
|
# Create a malformed JSONL file
|
|
fd, path = tempfile.mkstemp(suffix='.jsonl')
|
|
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
|
f.write('{"query": "valid", "pos": ["a"], "neg": ["b"]}\n')
|
|
f.write('{invalid json}\n')
|
|
f.write('')
|
|
yield path
|
|
os.remove(path)
|
|
|
|
@pytest.fixture
|
|
def tmp_csv_file():
|
|
fd, path = tempfile.mkstemp(suffix='.csv')
|
|
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
|
f.write('query,pos,neg\n')
|
|
f.write('valid,"a","b"\n')
|
|
f.write('malformed line\n')
|
|
yield path
|
|
os.remove(path)
|
|
|
|
@pytest.fixture
|
|
def tmp_tsv_file():
|
|
fd, path = tempfile.mkstemp(suffix='.tsv')
|
|
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
|
f.write('query\tpos\tneg\n')
|
|
f.write('valid\ta\tb\n')
|
|
f.write('malformed line\n')
|
|
yield path
|
|
os.remove(path)
|
|
|
|
def test_jsonl_malformed_lines(tmp_jsonl_file):
|
|
preprocessor = DataPreprocessor(tokenizer=None)
|
|
# Should skip malformed line and not raise
|
|
data = preprocessor._load_jsonl(tmp_jsonl_file)
|
|
assert isinstance(data, list)
|
|
assert len(data) == 1
|
|
assert data[0]['query'] == 'valid'
|
|
|
|
def test_csv_malformed_lines(tmp_csv_file):
|
|
preprocessor = DataPreprocessor(tokenizer=None)
|
|
# Should skip malformed line and not raise
|
|
data = preprocessor._load_csv(tmp_csv_file)
|
|
assert isinstance(data, list)
|
|
assert len(data) == 1
|
|
assert data[0]['query'] == 'valid'
|
|
|
|
def test_tsv_malformed_lines(tmp_tsv_file):
|
|
preprocessor = DataPreprocessor(tokenizer=None)
|
|
# Should skip malformed line and not raise
|
|
data = preprocessor._load_tsv(tmp_tsv_file)
|
|
assert isinstance(data, list)
|
|
assert len(data) == 1
|
|
assert data[0]['query'] == 'valid'
|
|
|
|
def test_augmentation_methods_config():
|
|
# Should use config-driven augmentation methods
|
|
augmenter = QueryAugmenter(augmentation_methods=None, config_path='config.toml')
|
|
assert isinstance(augmenter.augmentation_methods, list)
|
|
assert 'paraphrase' in augmenter.augmentation_methods
|
|
|
|
def test_config_list_parsing():
|
|
"""Test that config loader parses comma/semicolon-separated lists correctly."""
|
|
from utils.config_loader import ConfigLoader
|
|
loader = ConfigLoader()
|
|
loader._config = {'section': {'list_param': 'a, b; c'}}
|
|
value = loader.get('section', 'list_param', default=[])
|
|
assert isinstance(value, list)
|
|
assert set(value) == {'a', 'b', 'c'}
|
|
|
|
def test_config_param_source_logging(caplog):
|
|
"""Test that config loader logs parameter source correctly."""
|
|
from utils.config_loader import ConfigLoader
|
|
loader = ConfigLoader()
|
|
loader._config = {'section': {'param': 42}}
|
|
with caplog.at_level('INFO'):
|
|
value = loader.get('section', 'param', default=0)
|
|
assert 'source: config.toml' in caplog.text
|
|
value = loader.get('section', 'missing', default=99)
|
|
assert 'source: default' in caplog.text
|
|
|
|
def test_config_fallback_default():
|
|
"""Test that config loader falls back to default if key is missing."""
|
|
from utils.config_loader import ConfigLoader
|
|
loader = ConfigLoader()
|
|
loader._config = {'section': {}}
|
|
value = loader.get('section', 'not_found', default='fallback')
|
|
assert value == 'fallback'
|