Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 10, 2022
1 parent fc54130 commit 6861ce1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
7 changes: 1 addition & 6 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def round_up(x, d):
return (x + d - 1) // d * d

# the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter
class OpCheckpointBlock(torch.autograd.Function):
@staticmethod
def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len_args, *args):
Expand Down Expand Up @@ -423,7 +424,6 @@ def __init__(self, inner_module : torch.nn.Module):
self._param_info[-1]["end"] = (to_offset_end - to_offset_st,)
param.data[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
# self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end])
del contiguous_param
else:
param.data = torch.tensor([], dtype=param.dtype, device=param.device)
Expand Down Expand Up @@ -506,7 +506,6 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
d_device = self._storage_params[kw_name].device
torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
# self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end])
del contiguous_param
elif strict:
missing_keys.append(key)
Expand Down Expand Up @@ -563,10 +562,8 @@ def init_parameters(self):
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
d_device = self._storage_params[kw_name].device
# param.data=torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))
param.data[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:]
# self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end])
del tmp_tensor

def _named_members(self, get_members_fn, prefix='', recurse=True):
Expand Down Expand Up @@ -838,5 +835,3 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock:
def forward(self, hidden_state, *args):
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
return OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args)
# def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
# return super().named_modules(memo, prefix, remove_duplicate)
4 changes: 2 additions & 2 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def init_distributed(
):
"""Initialize distributed training.
This function will initialize the distributed training, set the random seed and global configurations.
It must be called before any other distributed functions.siz
It must be called before any other distributed functions.
Args:
seed (int): The random seed.
Expand Down Expand Up @@ -46,7 +46,7 @@ def init_distributed(
world_size = int(os.environ["WORLD_SIZE"])
local_size = int(os.environ["LOCAL_WORLD_SIZE"])
master = os.environ["MASTER_ADDR"] + ":" + os.environ["MASTER_PORT"]
timeout = datetime.timedelta(seconds=1800)
timeout = datetime.timedelta(seconds=1800)
rendezvous_iterator = dist.rendezvous(
init_method, rank, world_size, timeout=timeout
)
Expand Down

0 comments on commit 6861ce1

Please sign in to comment.