Revised init commit

This commit is contained in:
ldy
2025-07-22 17:49:16 +08:00
parent 36003b83e2
commit ae0f405a9b
4 changed files with 219 additions and 18 deletions

View File

@@ -184,7 +184,11 @@ class BaseTrainer(ABC):
# Add to training history
self.training_history.append(epoch_metrics)
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
# Save checkpoint at end of epoch (if save_steps allows and we've done training)
if (is_main_process() and
self.config.save_steps > 0 and
self.global_step > 0 and
(epoch + 1) % self.config.save_steps == 0):
self._save_checkpoint(metrics=train_metrics)
epoch_time = time.time() - epoch_start_time
@@ -193,7 +197,9 @@ class BaseTrainer(ABC):
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
if is_main_process():
self._save_checkpoint(metrics=train_metrics)
# Only save final checkpoint if we've actually trained (global_step > 0)
if self.global_step > 0:
self._save_checkpoint(metrics=train_metrics)
self._save_training_summary()
# Close TensorBoard logger
self.tb_logger.close()
@@ -351,8 +357,11 @@ class BaseTrainer(ABC):
except Exception as e:
logger.debug(f"Logging failed: {e}")
# Checkpoint saving with error handling
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
# Checkpoint saving with error handling - avoid saving at step 0
if (self.config.save_steps > 0 and
self.global_step > 0 and # Don't save at step 0
self.global_step % self.config.save_steps == 0 and
is_main_process()):
try:
if self.checkpoint_manager:
state_dict = {