Skip to content

Commit

Permalink
Merge pull request #8 from OpenBMB/FX_checkpointing_grad
Browse files Browse the repository at this point in the history
FX: always requires_grad in checkpointing block
  • Loading branch information
a710128 authored Mar 16, 2022
2 parents b5bc1ea + 32e49fb commit 4bde89a
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ def enter(self):
shape = param["shape"]
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape)
if requires_grad and kw_name in self._grad_buffer:
param["parameter"].requires_grad_(True)
param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape)
else:
param["parameter"].requires_grad_(False)


def __enter__(self):
Expand Down

0 comments on commit 4bde89a

Please sign in to comment.