diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 8312dddd2fa6..31997177a262 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import re from typing import Dict import torch @@ -21,6 +22,7 @@ ARGS_KEY = 'args' CHECKPOINT_INFO_KEY = 'checkpoint_info' ITERATION_KEY = 'iteration' +LAYER_FILE_PREFIX_PATTERN = r'layer_(\d+)-model_.*' SEQUENTIAL_LAYERS = [ 'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight', @@ -32,7 +34,13 @@ class DeepSpeedCheckpoint(object): - def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): + def __init__(self, + dir, + tp_degree=None, + pp_degree=None, + dp_degree=None, + final_layer_norm_idx=FINAL_LAYER_NORM_INDEX): + self.final_layer_norm_idx = final_layer_norm_idx self.dir = dir pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0 @@ -73,7 +81,7 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): self.pp_to_transformer_map = self._build_pp_transformer_map() self.transformer_file_map = self._build_transformer_file_map() self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) - self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) + self.tp_to_final_norm_map = self._build_tp_other_layer_map(self.final_layer_norm_idx) self._build_global_state() def is_change_tp_degree(self): @@ -125,7 +133,7 @@ def get_embedding_layer_id(self): return self.layer_keys[EMBEDDING_LAYER_INDEX] def get_final_norm_layer_id(self): - return self.layer_keys[FINAL_LAYER_NORM_INDEX] + return self.layer_keys[self.final_layer_norm_idx] def get_iteration(self): if not ITERATION_KEY in self.global_state: @@ -214,7 +222,7 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list: def _build_pp_transformer_map(self): data_map = {} if self.pp_degree > 0: - transformer_layers = self.layer_keys[1:-1] + transformer_layers = self.layer_keys[1:self.final_layer_norm_idx] layers_per_pp = len(transformer_layers) // self.pp_degree data_map = { i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] @@ -229,7 +237,7 @@ def _dump_mapping(self, data_map, map_tag=None): print(f'{k} = {v}') def _build_transformer_file_map(self): - transformer_layer_keys = self.layer_keys[1:-1] + transformer_layer_keys = self.layer_keys[1:self.final_layer_norm_idx] file_map = {} # XXX: this is not guaranteed layers_per_pp = 1 @@ -238,7 +246,7 @@ def _build_transformer_file_map(self): #print(f"{transformer_layer_keys} {layers_per_pp}") for key_index, layer_key in enumerate(transformer_layer_keys): pp_index = key_index // layers_per_pp - layer_files = get_files_with_prefix(self.layer_files, layer_key) + layer_files = get_files_with_prefix(self.layer_files, layer_key + '-') layer_file_partitions = partition_data(layer_files, self.tp_degree) for tp_index in range(self.tp_degree): map_key = (tp_index, pp_index) @@ -263,11 +271,13 @@ def validate_files(self): def _get_layer_keys(self): key_set = set() - key_len = len(LAYER_FILE_PREFIX) + 2 for file_path in self.layer_files: _, fname = os.path.split(file_path) - key_set.add(fname[:key_len]) - return sorted(list(key_set)) + layer_id = re.search(LAYER_FILE_PREFIX_PATTERN, fname).group(1) + key_set.add(layer_id) + sorted_ids = sorted(list(key_set), key=int) + layer_keys = [LAYER_FILE_PREFIX + str(layer_id) for layer_id in sorted_ids] + return layer_keys def _merge_state_dicts(self, sd_list): merged_sd = {}