From 6a09f419835ccc27dce93fa8c69545645f642bd3 Mon Sep 17 00:00:00 2001 From: zgy Date: Tue, 14 Jun 2022 04:57:31 +0000 Subject: [PATCH] FX: optimizer load state dict error --- bmtrain/optim/adam.py | 65 ++++++++++++++++++++- bmtrain/optim/adam_offload.py | 107 ++++++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 3 deletions(-) diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index 12bbf9e6..03830fdd 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -5,6 +5,10 @@ from .. import nccl import inspect +from copy import deepcopy +from itertools import chain +from collections import defaultdict + class AdamOptimizer(torch.optim.Optimizer): """ Adam optimizer @@ -97,15 +101,15 @@ def step(self, closure=None): state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device if p.dtype == torch.half: - state['param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device - state['param_fp32'].copy_(p) + state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device + state['_param_fp32'].copy_(p) # update the steps for each param group update state['step'] += 1 if p.dtype == torch.half: C.f_adam( - state["param_fp32"], # fp32 + state["_param_fp32"], # fp32 p, # fp16 p.grad, # fp16 state['exp_avg'], # fp16: m @@ -144,3 +148,58 @@ def loss_scale(self, loss : torch.Tensor) -> torch.Tensor: Backward with loss scale. """ return loss * (self.scale / config['world_size']) + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict['param_groups'] + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " + "parameter groups") + param_lens = (len(g['params']) for g in groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Update the state + id_map = {old_id: p for old_id, p in + zip(chain.from_iterable((g['params'] for g in saved_groups)), + chain.from_iterable((g['params'] for g in groups)))} + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict['state'].items(): + if k in id_map: + param = id_map[k] + + if param.dtype == torch.half and "_param_fp32" not in v: + v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device) + v["_param_fp32"].copy_(param) + + for name, dtype in [("exp_avg", param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + if name in v: + v[name] = v[name].to(param.device).to(dtype) + + state[param] = v + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({'state': state, 'param_groups': param_groups}) \ No newline at end of file diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index 45fa726f..9d19c042 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -6,6 +6,10 @@ import torch.optim._functional as F import inspect +from copy import deepcopy +from itertools import chain +from collections import defaultdict + class AdamOffloadOptimizer(torch.optim.Optimizer): """ Adam optimizer @@ -172,3 +176,106 @@ def loss_scale(self, loss : torch.Tensor) -> torch.Tensor: Backward with loss scale. """ return loss * (self.scale / config['world_size']) + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict['param_groups'] + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " + "parameter groups") + param_lens = (len(g['params']) for g in groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Update the state + id_map = {old_id: p for old_id, p in + zip(chain.from_iterable((g['params'] for g in saved_groups)), + chain.from_iterable((g['params'] for g in groups)))} + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict['state'].items(): + if k in id_map: + param = id_map[k] + + if "_param_fp32" not in v: + v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") + v["_param_fp32"].copy_(param) + + for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + if name in v: + v[name] = v[name].to("cpu").to(dtype) + + state[param] = v + if param.dtype == torch.half: + # initialize placeholders + state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host + state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host + else: + state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory() + + # initialize placeholders + state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({'state': state, 'param_groups': param_groups}) + + def state_dict(self) -> dict: + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + * param_groups - a list containing all parameter groups where each + parameter group is a dict + """ + # Save order indices instead of Tensors + param_mappings = {} + start_index = 0 + + def pack_group(group): + nonlocal start_index + packed = {k: v for k, v in group.items() if k != 'params'} + param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index) + if id(p) not in param_mappings}) + packed['params'] = [param_mappings[id(p)] for p in group['params']] + start_index += len(packed['params']) + return packed + + def cut_states(state): + return { + "step": state["step"], + "exp_avg": state["exp_avg"], + "exp_avg_sq": state["exp_avg_sq"], + "_param_fp32": state["_param_fp32"], + } + param_groups = [pack_group(g) for g in self.param_groups] + # Remap state to use order indices as keys + packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v) + for k, v in self.state.items()} + return { + 'state': packed_state, + 'param_groups': param_groups, + } \ No newline at end of file