Revised init commit
This commit is contained in:
@@ -184,7 +184,11 @@ class BaseTrainer(ABC):
|
||||
# Add to training history
|
||||
self.training_history.append(epoch_metrics)
|
||||
|
||||
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
|
||||
# Save checkpoint at end of epoch (if save_steps allows and we've done training)
|
||||
if (is_main_process() and
|
||||
self.config.save_steps > 0 and
|
||||
self.global_step > 0 and
|
||||
(epoch + 1) % self.config.save_steps == 0):
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
@@ -193,7 +197,9 @@ class BaseTrainer(ABC):
|
||||
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
|
||||
|
||||
if is_main_process():
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
# Only save final checkpoint if we've actually trained (global_step > 0)
|
||||
if self.global_step > 0:
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
self._save_training_summary()
|
||||
# Close TensorBoard logger
|
||||
self.tb_logger.close()
|
||||
@@ -351,8 +357,11 @@ class BaseTrainer(ABC):
|
||||
except Exception as e:
|
||||
logger.debug(f"Logging failed: {e}")
|
||||
|
||||
# Checkpoint saving with error handling
|
||||
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
|
||||
# Checkpoint saving with error handling - avoid saving at step 0
|
||||
if (self.config.save_steps > 0 and
|
||||
self.global_step > 0 and # Don't save at step 0
|
||||
self.global_step % self.config.save_steps == 0 and
|
||||
is_main_process()):
|
||||
try:
|
||||
if self.checkpoint_manager:
|
||||
state_dict = {
|
||||
|
||||
Reference in New Issue
Block a user