Skip to content

Commit

Permalink
Merge pull request #337 from ego-thales/api-mode-none
Browse files Browse the repository at this point in the history
* api: summary  keeps current mode (#331)

* fix: forgot to rm 'Mode' from error msg

* dev: revert Mode Enum removal -> now 'same' in place of None

* readme: Update summary doc ('same' mode)
  • Loading branch information
TylerYep authored Dec 11, 2024
2 parents 1281f91 + 48ee0a7 commit 9ba0628
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def summary(
depth: int = 3,
device: Optional[torch.device] = None,
dtypes: Optional[List[torch.dtype]] = None,
mode: str | None = None,
mode: str = "same",
row_settings: Optional[Iterable[str]] = None,
verbose: int = 1,
**kwargs: Any,
Expand Down Expand Up @@ -198,9 +198,10 @@ Args:
Default: None
mode (str)
Either "train" or "eval", which determines whether we call
model.train() or model.eval() before calling summary().
Default: "eval".
Either "train", "eval" or "same", which determines whether we call
model.train() or model.eval() before calling summary(). In any case,
original model mode is restored at the end.
Default: "same".
row_settings (Iterable[str]):
Specify which features to show in a row. Currently supported: (
Expand Down
4 changes: 2 additions & 2 deletions tests/torchinfo_xl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_eval_order_doesnt_matter() -> None:
model2 = torchvision.models.resnet18(
weights=torchvision.models.ResNet18_Weights.DEFAULT
)
summary(model2, input_size=input_size)
summary(model2, input_size=input_size, mode="eval")
model2.eval()
with torch.inference_mode():
output2 = model2(input_tensor)
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_tmva_net_column_totals() -> None:
def test_google() -> None:
google_net = torchvision.models.googlenet(init_weights=False)

summary(google_net, (1, 3, 112, 112), depth=7)
summary(google_net, (1, 3, 112, 112), depth=7, mode="eval")

# Check googlenet in training mode since InceptionAux layers are used in
# forward-prop in train mode but not in eval mode.
Expand Down
1 change: 1 addition & 0 deletions torchinfo/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Mode(str, Enum):

TRAIN = "train"
EVAL = "eval"
SAME = "same"


@unique
Expand Down
16 changes: 7 additions & 9 deletions torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def summary(
depth: int = 3,
device: torch.device | str | None = None,
dtypes: list[torch.dtype] | None = None,
mode: str | None = None,
mode: str = "same",
row_settings: Iterable[str] | None = None,
verbose: int | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -156,9 +156,10 @@ class name as the key. If the forward pass is an expensive operation,
Default: None
mode (str)
Either "train" or "eval", which determines whether we call
model.train() or model.eval() before calling summary().
Default: "eval".
Either "train", "eval" or "same", which determines whether we call
model.train() or model.eval() before calling summary(). In any case,
original model mode is restored at the end.
Default: "same".
row_settings (Iterable[str]):
Specify which features to show in a row. Currently supported: (
Expand Down Expand Up @@ -198,10 +199,7 @@ class name as the key. If the forward pass is an expensive operation,
else:
rows = {RowSettings(name) for name in row_settings}

if mode is None:
model_mode = Mode.EVAL
else:
model_mode = Mode(mode)
model_mode = Mode(mode)

if verbose is None:
verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1
Expand Down Expand Up @@ -286,7 +284,7 @@ def forward_pass(
model.train()
elif mode == Mode.EVAL:
model.eval()
else:
elif mode != Mode.SAME:
raise RuntimeError(
f"Specified model mode ({list(Mode)}) not recognized: {mode}"
)
Expand Down

0 comments on commit 9ba0628

Please sign in to comment.