diff --git a/classy_vision/optim/adamw_mt.py b/classy_vision/optim/adamw_mt.py index 917e320f6..e511ae444 100644 --- a/classy_vision/optim/adamw_mt.py +++ b/classy_vision/optim/adamw_mt.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Tuple import torch.optim +from torch.optim import _multi_tensor from . import ClassyOptimizer, register_optimizer @@ -30,7 +31,7 @@ def __init__( self._amsgrad = amsgrad def prepare(self, param_groups) -> None: - self.optimizer = torch.optim._multi_tensor.AdamW( + self.optimizer = _multi_tensor.AdamW( param_groups, lr=self._lr, betas=self._betas,