Skip to content

Commit

Permalink
Fix a mypy issue related to torch optimizer step (#3018)
Browse files Browse the repository at this point in the history
* Fix a mypy issue

* Fix a flake issue in test_lr_finder
  • Loading branch information
sadra-barikbin authored Aug 8, 2023
1 parent 1506066 commit 2f7246c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ignite/distributed/auto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, Iterator, List, Optional, Union
from typing import Any, Iterator, List, Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -351,5 +351,5 @@ def __init__(self, optimizer: Optimizer) -> None:
super(self.__class__, self).__init__(optimizer.param_groups) # type: ignore[call-arg]
self.wrapped_optimizer = optimizer

def step(self, closure: Optional[Callable] = None) -> None:
def step(self, closure: Any = None) -> Any:
xm.optimizer_step(self.wrapped_optimizer, barrier=True)
4 changes: 2 additions & 2 deletions tests/ignite/handlers/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,13 @@ def test_engine_output_type(lr_finder, dummy_engine, optimizer):
lr_finder._history = {"lr": [], "loss": []}
lr_finder._log_lr_and_loss(dummy_engine, output_transform=lambda x: x, smooth_f=0, diverge_th=1)
loss = lr_finder._history["loss"][-1]
assert type(loss) == float
assert type(loss) is float

dummy_engine.state.output = torch.tensor([10.0], dtype=torch.float32)
lr_finder._history = {"lr": [], "loss": []}
lr_finder._log_lr_and_loss(dummy_engine, output_transform=lambda x: x, smooth_f=0, diverge_th=1)
loss = lr_finder._history["loss"][-1]
assert type(loss) == float
assert type(loss) is float


def test_lr_suggestion_unexpected_curve(lr_finder, to_save, dummy_engine, dataloader):
Expand Down

0 comments on commit 2f7246c

Please sign in to comment.