Skip to content

Commit

Permalink
FX: optimizer load state dict error
Browse files Browse the repository at this point in the history
  • Loading branch information
a710128 committed Jun 14, 2022
1 parent 87fb21b commit 6a09f41
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 3 deletions.
65 changes: 62 additions & 3 deletions bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from .. import nccl
import inspect

from copy import deepcopy
from itertools import chain
from collections import defaultdict

class AdamOptimizer(torch.optim.Optimizer):
"""
Adam optimizer
Expand Down Expand Up @@ -97,15 +101,15 @@ def step(self, closure=None):
state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device

if p.dtype == torch.half:
state['param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device
state['param_fp32'].copy_(p)
state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device
state['_param_fp32'].copy_(p)

# update the steps for each param group update
state['step'] += 1

if p.dtype == torch.half:
C.f_adam(
state["param_fp32"], # fp32
state["_param_fp32"], # fp32
p, # fp16
p.grad, # fp16
state['exp_avg'], # fp16: m
Expand Down Expand Up @@ -144,3 +148,58 @@ def loss_scale(self, loss : torch.Tensor) -> torch.Tensor:
Backward with loss scale.
"""
return loss * (self.scale / config['world_size'])

def load_state_dict(self, state_dict: dict) -> None:
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")

# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]

if param.dtype == torch.half and "_param_fp32" not in v:
v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device)
v["_param_fp32"].copy_(param)

for name, dtype in [("exp_avg", param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
if name in v:
v[name] = v[name].to(param.device).to(dtype)

state[param] = v
else:
state[k] = v

# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
107 changes: 107 additions & 0 deletions bmtrain/optim/adam_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import torch.optim._functional as F
import inspect

from copy import deepcopy
from itertools import chain
from collections import defaultdict

class AdamOffloadOptimizer(torch.optim.Optimizer):
"""
Adam optimizer
Expand Down Expand Up @@ -172,3 +176,106 @@ def loss_scale(self, loss : torch.Tensor) -> torch.Tensor:
Backward with loss scale.
"""
return loss * (self.scale / config['world_size'])

def load_state_dict(self, state_dict: dict) -> None:
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")

# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]

if "_param_fp32" not in v:
v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu")
v["_param_fp32"].copy_(param)

for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
if name in v:
v[name] = v[name].to("cpu").to(dtype)

state[param] = v
if param.dtype == torch.half:
# initialize placeholders
state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host
state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host
else:
state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory()

# initialize placeholders
state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host
else:
state[k] = v

# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

def state_dict(self) -> dict:
r"""Returns the state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""
# Save order indices instead of Tensors
param_mappings = {}
start_index = 0

def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
return packed

def cut_states(state):
return {
"step": state["step"],
"exp_avg": state["exp_avg"],
"exp_avg_sq": state["exp_avg_sq"],
"_param_fp32": state["_param_fp32"],
}
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v)
for k, v in self.state.items()}
return {
'state': packed_state,
'param_groups': param_groups,
}

0 comments on commit 6a09f41

Please sign in to comment.