diff --git a/tests/utils/test_distributed.py b/tests/utils/test_distributed.py index cb444267cd..61b30cfbdb 100644 --- a/tests/utils/test_distributed.py +++ b/tests/utils/test_distributed.py @@ -187,12 +187,18 @@ def test_revert_sync_batchnorm(self) -> None: self.assertNotIsInstance(batch_norm, torch.nn.SyncBatchNorm) self.assertTrue( torch.equal( - batch_norm.running_mean, none_throws(original_batchnorm.running_mean) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + batch_norm.running_mean, + none_throws(original_batchnorm.running_mean), ) ) self.assertTrue( torch.equal( - batch_norm.running_var, none_throws(original_batchnorm.running_var) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + batch_norm.running_var, + none_throws(original_batchnorm.running_var), ) ) diff --git a/torchtnt/framework/_loop_utils.py b/torchtnt/framework/_loop_utils.py index 2316da9185..ca1b2444bd 100644 --- a/torchtnt/framework/_loop_utils.py +++ b/torchtnt/framework/_loop_utils.py @@ -94,16 +94,23 @@ def _set_module_training_mode( is_ddp = isinstance(module, DistributedDataParallel) if _EXPORT_UTILS_AVAIL and model_is_exported( - module.module if is_ddp else module + # pyre-fixme[6]: For 1st argument expected `Module` but got + # `Union[Module, Tensor]`. + module.module + if is_ddp + else module ): move_fn = ( torch.ao.quantization.move_exported_model_to_train if mode else torch.ao.quantization.move_exported_model_to_eval ) + # pyre-fixme[6]: For 1st argument expected `GraphModule` but got + # `Union[Module, Tensor]`. move_fn(module.module if is_ddp else module) module.training = mode if is_ddp: + # pyre-fixme[16]: `Tensor` has no attribute `training`. module.module.training = mode else: module.train(mode) @@ -122,16 +129,23 @@ def _reset_module_training_mode( is_ddp = isinstance(module, DistributedDataParallel) if _EXPORT_UTILS_AVAIL and model_is_exported( - module.module if is_ddp else module + # pyre-fixme[6]: For 1st argument expected `Module` but got + # `Union[Module, Tensor]`. + module.module + if is_ddp + else module ): move_fn = ( torch.ao.quantization.move_exported_model_to_train if prior_modes[name] else torch.ao.quantization.move_exported_model_to_eval ) + # pyre-fixme[6]: For 1st argument expected `GraphModule` but got + # `Union[Module, Tensor]`. move_fn(module.module if is_ddp else module) module.training = prior_modes[name] if is_ddp: + # pyre-fixme[16]: `Tensor` has no attribute `training`. module.module.training = prior_modes[name] else: module.train(prior_modes[name]) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 225e14e3c8..622a21e49d 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -638,6 +638,7 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: # https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync # https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync maybe_no_sync = ( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. module.no_sync() if not should_update_weights and (isinstance(module, DDP) or _is_fsdp_module(module)) diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index 37bc1edafc..08c05ac158 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -436,6 +436,7 @@ def revert_sync_batchnorm( module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked if hasattr(module, "qconfig"): + # pyre-fixme[16]: `_BatchNormXd` has no attribute `qconfig`. module_output.qconfig = module.qconfig for name, child in module.named_children(): module_output.add_module(name, revert_sync_batchnorm(child, device))