Compare commits

..

2 Commits

Author SHA1 Message Date
ldy
bace84f9e9 Revised saving 2025-07-23 15:17:14 +08:00
ldy
8a7a011cb1 Added spitted dataset 2025-07-23 15:06:30 +08:00
9 changed files with 57639 additions and 2 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
{
"input_dir": "data/datasets/三国演义",
"output_dir": "data/datasets/三国演义/splits",
"split_ratios": {
"train": 0.8,
"validation": 0.1,
"test": 0.1
},
"seed": 42,
"datasets": {
"bge_m3": {
"total_samples": 31370,
"train_samples": 25096,
"val_samples": 3137,
"test_samples": 3137
},
"reranker": {
"total_samples": 26238,
"train_samples": 20990,
"val_samples": 2623,
"test_samples": 2625
}
}
}

View File

@ -525,7 +525,12 @@ class ComprehensiveTestSuite:
# Test saving
state_dict = {
'model_state_dict': torch.randn(10, 10),
'model_state_dict': {
'layer1.weight': torch.randn(10, 5),
'layer1.bias': torch.randn(10),
'layer2.weight': torch.randn(3, 10),
'layer2.bias': torch.randn(3)
},
'optimizer_state_dict': {'lr': 0.001},
'step': 100
}

View File

@ -86,7 +86,7 @@ class CheckpointManager:
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
# Add extra safety check for model state dict
model_state = state_dict['model_state_dict']
if not model_state:
if model_state is None or (isinstance(model_state, dict) and len(model_state) == 0):
raise RuntimeError("Model state dict is empty")
self._safe_torch_save(model_state, model_path)