Skip to content

Commit

Permalink
Update low_level_optim.py
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 20, 2024
1 parent ff04170 commit e324c8a
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,6 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
if pinned_state_dicts and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu")
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
Expand Down Expand Up @@ -873,10 +871,6 @@ def state_dict_shard(

for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
working_param, pin_memory=True, device="cpu"
)
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
Expand Down

0 comments on commit e324c8a

Please sign in to comment.