import os import sys import tempfile import pytest # Add project root to Python path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.preprocessing import DataPreprocessor from data.augmentation import QueryAugmenter @pytest.fixture def tmp_jsonl_file(): # Create a malformed JSONL file fd, path = tempfile.mkstemp(suffix='.jsonl') with os.fdopen(fd, 'w', encoding='utf-8') as f: f.write('{"query": "valid", "pos": ["a"], "neg": ["b"]}\n') f.write('{invalid json}\n') f.write('') yield path os.remove(path) @pytest.fixture def tmp_csv_file(): fd, path = tempfile.mkstemp(suffix='.csv') with os.fdopen(fd, 'w', encoding='utf-8') as f: f.write('query,pos,neg\n') f.write('valid,"a","b"\n') f.write('malformed line\n') yield path os.remove(path) @pytest.fixture def tmp_tsv_file(): fd, path = tempfile.mkstemp(suffix='.tsv') with os.fdopen(fd, 'w', encoding='utf-8') as f: f.write('query\tpos\tneg\n') f.write('valid\ta\tb\n') f.write('malformed line\n') yield path os.remove(path) def test_jsonl_malformed_lines(tmp_jsonl_file): preprocessor = DataPreprocessor(tokenizer=None) # Should skip malformed line and not raise data = preprocessor._load_jsonl(tmp_jsonl_file) assert isinstance(data, list) assert len(data) == 1 assert data[0]['query'] == 'valid' def test_csv_malformed_lines(tmp_csv_file): preprocessor = DataPreprocessor(tokenizer=None) # Should skip malformed line and not raise data = preprocessor._load_csv(tmp_csv_file) assert isinstance(data, list) assert len(data) == 1 assert data[0]['query'] == 'valid' def test_tsv_malformed_lines(tmp_tsv_file): preprocessor = DataPreprocessor(tokenizer=None) # Should skip malformed line and not raise data = preprocessor._load_tsv(tmp_tsv_file) assert isinstance(data, list) assert len(data) == 1 assert data[0]['query'] == 'valid' def test_augmentation_methods_config(): # Should use config-driven augmentation methods augmenter = QueryAugmenter(augmentation_methods=None, config_path='config.toml') assert isinstance(augmenter.augmentation_methods, list) assert 'paraphrase' in augmenter.augmentation_methods def test_config_list_parsing(): """Test that config loader parses comma/semicolon-separated lists correctly.""" from utils.config_loader import ConfigLoader loader = ConfigLoader() loader._config = {'section': {'list_param': 'a, b; c'}} value = loader.get('section', 'list_param', default=[]) assert isinstance(value, list) assert set(value) == {'a', 'b', 'c'} def test_config_param_source_logging(caplog): """Test that config loader logs parameter source correctly.""" from utils.config_loader import ConfigLoader loader = ConfigLoader() loader._config = {'section': {'param': 42}} with caplog.at_level('INFO'): value = loader.get('section', 'param', default=0) assert 'source: config.toml' in caplog.text value = loader.get('section', 'missing', default=99) assert 'source: default' in caplog.text def test_config_fallback_default(): """Test that config loader falls back to default if key is missing.""" from utils.config_loader import ConfigLoader loader = ConfigLoader() loader._config = {'section': {}} value = loader.get('section', 'not_found', default='fallback') assert value == 'fallback'