Skip to content

Commit

Permalink
added support for checkpoint recovery tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghh04 committed Feb 20, 2025
1 parent bf2fdc3 commit 517d678
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 41 deletions.
37 changes: 35 additions & 2 deletions dlio_benchmark/checkpointing/base_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def get_tensor(self, length, datatype="int8"):
def save_state(self, suffix, state, fsync=False):
pass

@abstractmethod
def load_state(self, suffix, state):
pass

def get_name(self, suffix):
return os.path.join(self.args.checkpoint_folder, f"{suffix}.{self.ext}")

Expand Down Expand Up @@ -248,7 +252,7 @@ def get_layer_index(self):
return start_layer, end_layer

@abstractmethod
def checkpoint(self, epoch, step_number):
def save_checkpoint(self, epoch, step_number):
my_rank = DLIOMPI.get_instance().rank()
start_layer, end_layer = self.get_layer_index()
# create a specifc folder for each step
Expand All @@ -262,7 +266,7 @@ def checkpoint(self, epoch, step_number):
self.save_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_optim_states", state=self.optimization_state, fsync = self.args.checkpoint_fsync)

if self.layer_state:
if self.args.zero_stage < 3:
if self.args.zero_stage < 3 and self.args.zero_stage > 0:
# if pp is turned on, we assume that the model is sharded across the pipeline stages
if self.data_parallelism_rank == 0 and self.args.num_layers > 0:
# in this case, model is saved layer by layer
Expand All @@ -276,6 +280,35 @@ def checkpoint(self, epoch, step_number):
assert(self.args.pipeline_parallelism == 1)
self.save_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_model_states", state=self.layer_state, fsync = self.args.checkpoint_fsync)

@abstractmethod
def load_checkpoint(self, epoch, step_number):
my_rank = (DLIOMPI.get_instance().rank() + self.args.checkpoint_load_rank_shift) % DLIOMPI.get_instance().size()
start_layer, end_layer = self.get_layer_index()
# create a specifc folder for each step
checkpoint_id = f"global_epoch{epoch}_step{step_number}"
self.checkpoint_storage.create_node(checkpoint_id, exist_ok=True)
if self.rank_to_checkpoint == my_rank:
if self.model_state:
self.load_state(suffix=f"{checkpoint_id}/model_states-{my_rank}", state=self.model_state, fsync = self.args.checkpoint_fsync)

if self.optimization_state:
self.load_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_optim_states", state=self.optimization_state)

if self.layer_state:
if self.args.zero_stage < 3 and self.args.zero_stage > 0:
# if pp is turned on, we assume that the model is sharded across the pipeline stages
if self.data_parallelism_rank == 0 and self.args.num_layers > 0:
# in this case, model is saved layer by layer
if self.args.pipeline_parallelism > 1:
for layer_index in range(start_layer, end_layer + 1):
self.load_state(suffix=f"{checkpoint_id}/layer_{layer_index}-model_{self.model_parallelism_rank}_model_states", state=self.layer_state[str(layer_index)])
else:
self.load_state(suffix=f"{checkpoint_id}/model_{self.model_parallelism_rank}_model_states", state=self.layer_state, fsync = self.args.checkpoint_fsync)
else:
# in this case, model is sharded across the data parallel ranks
assert(self.args.pipeline_parallelism == 1)
self.load_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_model_states", state=self.layer_state)

@abstractmethod
def finalize(self):
pass
13 changes: 11 additions & 2 deletions dlio_benchmark/checkpointing/pytorch_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,17 @@ def save_state(self, suffix, state, fsync = False):
os.fsync(f.fileno())

@dlp.log
def checkpoint(self, epoch, step_number):
super().checkpoint(epoch, step_number)
def load_state(self, suffix, state):
name = self.get_name(suffix)
state = torch.load(name)

@dlp.log
def save_checkpoint(self, epoch, step_number):
super().save_checkpoint(epoch, step_number)

@dlp.log
def load_checkpoint(self, epoch, step_number):
super().load_checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
Expand Down
13 changes: 11 additions & 2 deletions dlio_benchmark/checkpointing/tf_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,17 @@ def save_state(self, suffix, state, fsync = False):
checkpoint.save(name)

@dlp.log
def checkpoint(self, epoch, step_number):
super().checkpoint(epoch, step_number)
def save_state(self, suffix, state):
name = self.get_name(suffix)
state = tf.train.load_checkpoint(name)

@dlp.log
def save_checkpoint(self, epoch, step_number):
super().save_checkpoint(epoch, step_number)

@dlp.log
def load_checkpoint(self, epoch, step_number):
super().load_checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
Expand Down
1 change: 1 addition & 0 deletions dlio_benchmark/configs/workload/llama_1t.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ checkpoint:
checkpoint_folder: checkpoints/llama_405b
time_between_checkpoints: 5
num_checkpoints: 10
recovery_after_steps: 2
1 change: 1 addition & 0 deletions dlio_benchmark/configs/workload/llama_405b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ checkpoint:
checkpoint_folder: checkpoints/llama_405b
time_between_checkpoints: 5
num_checkpoints: 10
recovery_after_steps: 2
1 change: 1 addition & 0 deletions dlio_benchmark/configs/workload/llama_70b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ checkpoint:
checkpoint_folder: checkpoints/llama_70b
time_between_checkpoints: 5
num_checkpoints: 10
recovery_after_steps: 2
1 change: 1 addition & 0 deletions dlio_benchmark/configs/workload/llama_70b_zero3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ checkpoint:
checkpoint_folder: checkpoints/llama_70b
time_between_checkpoints: 5
num_checkpoints: 10
recovery_after_steps: 2
1 change: 1 addition & 0 deletions dlio_benchmark/configs/workload/llama_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ checkpoint:
checkpoint_folder: checkpoints/llama_7b
time_between_checkpoints: 5
num_checkpoints: 10
recovery_after_steps: 2
1 change: 1 addition & 0 deletions dlio_benchmark/configs/workload/llama_7b_zero3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ checkpoint:
checkpoint_folder: checkpoints/llama_7b_zero3
time_between_checkpoints: 5
num_checkpoints: 10
recovery_after_steps: 2
3 changes: 2 additions & 1 deletion dlio_benchmark/configs/workload/llama_8b_zero3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ workflow:
checkpoint: True

dataset:
data_folder: data/llama_7b/
data_folder: data/llama_8b/
format: mmap_indexed_binary
num_files_train: 1
num_samples_per_file: 1048576
Expand All @@ -45,3 +45,4 @@ checkpoint:
checkpoint_folder: checkpoints/llama_8b_zero3
time_between_checkpoints: 5
num_checkpoints: 10
recovery_after_steps: 2
24 changes: 14 additions & 10 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,16 @@ def _checkpoint(self):
for i in range(self.args.num_checkpoints):
self.stats.start_block(epoch, block)
# We still make sure that the checkpoint is done after allreduce; therefore, allreduce here is required.
self.comm.barrier()
self.stats.start_ckpt(epoch, block, overall_step)
self.checkpointing_mechanism.checkpoint(epoch, overall_step)
self.framework.compute(None, epoch, block_step, self.args.time_between_checkpoints)
self.stats.end_ckpt(epoch, block)
self.comm.barrier()
self.stats.start_save_ckpt(epoch, block, overall_step)
self.checkpointing_mechanism.save_checkpoint(epoch, overall_step)
self.stats.end_save_ckpt(epoch, block)
if self.args.checkpoint_recovery_after_steps > 0 and (i + 1) % self.args.checkpoint_recovery_after_steps==0:
self.comm.barrier()
self.stats.start_load_ckpt(epoch, block, overall_step)
self.checkpointing_mechanism.load_checkpoint(epoch, overall_step)
self.stats.end_load_ckpt(epoch, block)
block = block+1
overall_step = overall_step + 1
@dlp.log
Expand Down Expand Up @@ -294,9 +299,9 @@ def _train(self, epoch):
if self.do_checkpoint and (
self.steps_between_checkpoints >= 0) and overall_step == self.next_checkpoint_step:
self.stats.end_block(epoch, block, block_step)
self.stats.start_ckpt(epoch, block, overall_step)
self.checkpointing_mechanism.checkpoint(epoch, overall_step)
self.stats.end_ckpt(epoch, block)
self.stats.start_save_ckpt(epoch, block, overall_step)
self.checkpointing_mechanism.save_checkpoint(epoch, overall_step)
self.stats.end_save_ckpt(epoch, block)
block += 1
# Reset the number of steps after every checkpoint to mark the start of a new block
block_step = 1
Expand All @@ -308,10 +313,9 @@ def _train(self, epoch):
t0 = time()
if self.do_checkpoint and (self.steps_between_checkpoints < 0) and (epoch == self.next_checkpoint_epoch):
self.stats.end_block(epoch, block, block_step)
self.stats.start_ckpt(epoch, block, overall_step)
self.stats.start_save_ckpt(epoch, block, overall_step)
self.checkpointing_mechanism.checkpoint(epoch, overall_step)
self.comm.barrier()
self.stats.end_ckpt(epoch, block)
self.stats.end_save_ckpt(epoch, block)
self.next_checkpoint_epoch += self.epochs_between_checkpoints
return overall_step

Expand Down
6 changes: 6 additions & 0 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class ConfigArguments:
optimizer_datatype: str = "fp32"
checkpoint_fsync: bool = False
checkpoint_only: bool = False
checkpoint_load_rank_shift: int = 0
checkpoint_recovery_after_steps: int = -1
time_between_checkpoints: float = -1
num_checkpoints: int = -1
model_size: int = 10240
Expand Down Expand Up @@ -610,6 +612,10 @@ def LoadConfig(args, config):
args.time_between_checkpoints = config['checkpoint']['time_between_checkpoints']
if 'num_checkpoints' in config['checkpoint']:
args.num_checkpoints = config['checkpoint']['num_checkpoints']
if 'load_rank_shift' in config['checkpoint']:
args.checkpoint_load_rank_shift = config['checkpoint']['load_rank_shift']
if 'recovery_after_steps' in config['checkpoint']:
args.checkpoint_recovery_after_steps = config['checkpoint']['recovery_after_steps']

if 'model' in config:
if 'name' in config['model']:
Expand Down
71 changes: 51 additions & 20 deletions dlio_benchmark/utils/statscounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,27 @@ def start_run(self):
def end_run(self):
self.end_run_timestamp = time()
if self.args.do_checkpoint and self.my_rank == 0:
duration = []
io = []
duration_save = []
io_save = []
duration_load = []
io_load = []
for e in self.per_epoch_stats:
for t in self.per_epoch_stats[e]:
if t.find("ckpt")!=-1:
duration.append(float(self.per_epoch_stats[e][t]['duration']))
io.append(self.per_epoch_stats[e][t]['throughput'])
self.summary['metric']['checkpoint_io_mean_GB_per_second'] = np.mean(io)
self.summary['metric']['checkpoint_io_stdev_GB_per_second'] = np.std(io)
self.summary['metric']['checkpoint_duration_mean_seconds'] = np.mean(duration)
self.summary['metric']['checkpoint_duration_stdev_seconds'] = np.std(duration)
if t.find("save_ckpt")!=-1:
duration_save.append(float(self.per_epoch_stats[e][t]['duration']))
io_save.append(self.per_epoch_stats[e][t]['throughput'])
elif t.find("load_ckpt")!=-1:
duration_load.append(float(self.per_epoch_stats[e][t]['duration']))
io_load.append(self.per_epoch_stats[e][t]['throughput'])
self.summary['metric']['save_checkpoint_io_mean_GB_per_second'] = np.mean(io_save)
self.summary['metric']['save_checkpoint_io_stdev_GB_per_second'] = np.std(io_save)
self.summary['metric']['save_checkpoint_duration_mean_seconds'] = np.mean(duration_save)
self.summary['metric']['save_checkpoint_duration_stdev_seconds'] = np.std(duration_save)
if len(io_load) > 0:
self.summary['metric']['load_checkpoint_io_mean_GB_per_second'] = np.mean(io_load)
self.summary['metric']['load_checkpoint_io_stdev_GB_per_second'] = np.std(io_load)
self.summary['metric']['load_checkpoint_duration_mean_seconds'] = np.mean(duration_load)
self.summary['metric']['load_checkpoint_duration_stdev_seconds'] = np.std(duration_load)
self.summary['metric']['checkpoint_size_GB'] = self.checkpoint_size
if not self.args.generate_only:
total_elapsed_time = self.end_run_timestamp - self.start_run_timestamp
Expand Down Expand Up @@ -183,15 +193,19 @@ def end_run(self):
self.summary['metric']['eval_io_stdev_MB_per_second'] = np.std(eval_throughput)*self.record_size/1024./1024.
if self.my_rank==0:
logging.info(f"{utcnow()} Saved outputs in {self.output_folder}")
metric="Averaged metric over all epochs\n[METRIC] ==========================================================\n"
metric="Averaged metric over all steps/epochs\n[METRIC] ==========================================================\n"
metric = metric + f"[METRIC] Number of Simulated Accelerators: {self.comm_size} \n"
if self.args.do_train:
metric = metric + f"[METRIC] Training Accelerator Utilization [AU] (%): {np.mean(train_au):.4f} ({np.std(train_au):.4f})\n"
metric = metric + f"[METRIC] Training Throughput (samples/second): {np.mean(train_throughput):.4f} ({np.std(train_throughput):.4f})\n"
metric = metric + f"[METRIC] Training I/O Throughput (MB/second): {np.mean(train_throughput)*self.record_size/1024/1024:.4f} ({np.std(train_throughput)*self.record_size/1024/1024:.4f})\n"
metric = metric + f"[METRIC] train_au_meet_expectation: {self.summary['metric']['train_au_meet_expectation']}\n"
if self.args.do_checkpoint:
metric = metric + f"[METRIC] Checkpoint I/O Throughput (GB/second): {self.summary['metric']['checkpoint_io_mean_GB_per_second']:.4f} ({self.summary['metric']['checkpoint_io_stdev_GB_per_second']:.4f})\n"
metric = metric + f"[METRIC] Checkpoint save duration (seconds): {self.summary['metric']['save_checkpoint_duration_mean_seconds']:.4f} ({self.summary['metric']['save_checkpoint_duration_stdev_seconds']:.4f})\n"
metric = metric + f"[METRIC] Checkpoint save I/O Throughput (GB/second): {self.summary['metric']['save_checkpoint_io_mean_GB_per_second']:.4f} ({self.summary['metric']['save_checkpoint_io_stdev_GB_per_second']:.4f})\n"
if 'load_checkpoint_io_mean_GB_per_second' in self.summary['metric']:
metric = metric + f"[METRIC] Checkpoint load duration (seconds): {self.summary['metric']['load_checkpoint_duration_mean_seconds']:.4f} ({self.summary['metric']['load_checkpoint_duration_stdev_seconds']:.4f})\n"
metric = metric + f"[METRIC] Checkpoint load I/O Throughput (GB/second): {self.summary['metric']['load_checkpoint_io_mean_GB_per_second']:.4f} ({self.summary['metric']['load_checkpoint_io_stdev_GB_per_second']:.4f})\n"

if self.args.do_eval:
metric = metric + f"[METRIC] Eval Accelerator Utilization [AU] (%): {np.mean(eval_au):.4f} ({np.std(eval_au):.4f})\n"
Expand Down Expand Up @@ -312,21 +326,38 @@ def end_block(self, epoch, block, steps_taken):
logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Accelerator Utilization [AU] (%): {self.output[epoch]['au'][f'block{block}']:.4f}")
logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Throughput (samples/second): {self.output[epoch]['throughput'][f'block{block}']*self.comm_size:.4f}")

def start_ckpt(self, epoch, block, steps_taken):
def start_save_ckpt(self, epoch, block, steps_taken):
ts = utcnow()
if self.my_rank == 0:
logging.info(f"{ts} Starting checkpoint {block} after total step {steps_taken} for epoch {epoch}")
self.per_epoch_stats[epoch][f'ckpt{block}'] = {
logging.info(f"{ts} Starting saving checkpoint {block} after total step {steps_taken} for epoch {epoch}")
self.per_epoch_stats[epoch][f'save_ckpt{block}'] = {
'start': ts
}
def end_ckpt(self, epoch, block):
def end_save_ckpt(self, epoch, block):
ts = utcnow()
duration = pd.to_datetime(ts) - pd.to_datetime(self.per_epoch_stats[epoch][f'ckpt{block}']['start'])
self.per_epoch_stats[epoch][f'ckpt{block}']['end'] = ts
self.per_epoch_stats[epoch][f'ckpt{block}']['duration'] = float(duration.total_seconds())
self.per_epoch_stats[epoch][f'ckpt{block}']['throughput'] = self.checkpoint_size / float(duration.total_seconds())
duration = pd.to_datetime(ts) - pd.to_datetime(self.per_epoch_stats[epoch][f'save_ckpt{block}']['start'])
self.per_epoch_stats[epoch][f'save_ckpt{block}']['end'] = ts
self.per_epoch_stats[epoch][f'save_ckpt{block}']['duration'] = float(duration.total_seconds())
self.per_epoch_stats[epoch][f'save_ckpt{block}']['throughput'] = self.checkpoint_size / float(duration.total_seconds())
if self.my_rank == 0:
logging.info(f"{ts} Finished checkpoint {block} for epoch {epoch} in {duration.total_seconds():.4f} s; Throughput: {self.per_epoch_stats[epoch][f'ckpt{block}']['throughput']:.4f} GB/s")
logging.info(f"{ts} Finished saving checkpoint {block} for epoch {epoch} in {duration.total_seconds():.4f} s; Throughput: {self.per_epoch_stats[epoch][f'save_ckpt{block}']['throughput']:.4f} GB/s")


def start_load_ckpt(self, epoch, block, steps_taken):
ts = utcnow()
if self.my_rank == 0:
logging.info(f"{ts} Starting loading checkpoint {block} after total step {steps_taken} for epoch {epoch}")
self.per_epoch_stats[epoch][f'load_ckpt{block}'] = {
'start': ts
}
def end_load_ckpt(self, epoch, block):
ts = utcnow()
duration = pd.to_datetime(ts) - pd.to_datetime(self.per_epoch_stats[epoch][f'save_ckpt{block}']['start'])
self.per_epoch_stats[epoch][f'load_ckpt{block}']['end'] = ts
self.per_epoch_stats[epoch][f'load_ckpt{block}']['duration'] = float(duration.total_seconds())
self.per_epoch_stats[epoch][f'load_ckpt{block}']['throughput'] = self.checkpoint_size / float(duration.total_seconds())
if self.my_rank == 0:
logging.info(f"{ts} Finished loading checkpoint {block} for epoch {epoch} in {duration.total_seconds():.4f} s; Throughput: {self.per_epoch_stats[epoch][f'load_ckpt{block}']['throughput']:.4f} GB/s")


def batch_loaded(self, epoch, step, block, t0):
Expand Down
3 changes: 3 additions & 0 deletions docs/source/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ checkpoint
* - num_checkpoints
- -1
- How many checkpoints to write; this parameter is used only when workflow.train=False
* - recovery_after_steps:
- -1
- How many checkpoints to write before doing read for recovery. -1 means never doing recovery.

.. note::

Expand Down
6 changes: 2 additions & 4 deletions tests/dlio_benchmark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,8 @@ def test_checkpoint_epoch(framework, model_size, optimizers, num_layers, layer_p
nranks = comm.size
num_model_files = 1
num_optimizer_files = 1
if zero_stage != 3:
num_layer_files = num_layers
else:
num_layer_files = 1
# We are setting num_layer_files to be one because pipeline parallelism is not used.
num_layer_files = 1
files_per_checkpoint = (num_model_files + num_optimizer_files + num_layer_files) * nranks
if framework == "tensorflow":
file_per_ckp = 2
Expand Down

0 comments on commit 517d678

Please sign in to comment.