Skip to content

Commit

Permalink
fixed checkpoint mechanism bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghh04 committed Feb 20, 2025
1 parent f82fa90 commit bf2fdc3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 5 additions & 2 deletions dlio_benchmark/checkpointing/base_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,11 @@ def checkpoint(self, epoch, step_number):
# if pp is turned on, we assume that the model is sharded across the pipeline stages
if self.data_parallelism_rank == 0 and self.args.num_layers > 0:
# in this case, model is saved layer by layer
for layer_index in range(start_layer, end_layer + 1):
self.save_state(suffix=f"{checkpoint_id}/layer_{layer_index}-model_{self.model_parallelism_rank}_model_states", state=self.layer_state[str(layer_index)], fsync = self.args.checkpoint_fsync)
if self.args.pipeline_parallelism > 1:
for layer_index in range(start_layer, end_layer + 1):
self.save_state(suffix=f"{checkpoint_id}/layer_{layer_index}-model_{self.model_parallelism_rank}_model_states", state=self.layer_state[str(layer_index)], fsync = self.args.checkpoint_fsync)
else:
self.save_state(suffix=f"{checkpoint_id}/model_{self.model_parallelism_rank}_model_states", state=self.layer_state, fsync = self.args.checkpoint_fsync)
else:
# in this case, model is sharded across the data parallel ranks
assert(self.args.pipeline_parallelism == 1)
Expand Down
5 changes: 2 additions & 3 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,8 @@ def initialize(self):
file_list_eval = file_list_eval[:self.num_files_eval]
self.args.derive_configurations(file_list_train, file_list_eval)
self.args.validate()
if self.args.do_checkpoint:
self.checkpointing_mechanism = CheckpointingFactory().get_mechanism(self.args.checkpoint_mechanism)
self.stats.checkpoint_size = self.checkpointing_mechanism.checkpoint_size
self.checkpointing_mechanism = CheckpointingFactory().get_mechanism(self.args.checkpoint_mechanism)
self.stats.checkpoint_size = self.checkpointing_mechanism.checkpoint_size
self.comm.barrier()

@dlp.log
Expand Down

0 comments on commit bf2fdc3

Please sign in to comment.