Skip to content

Commit 564b82d

Browse files
committed
v1.7.2
1 parent 01fa0ce commit 564b82d

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ repos:
1919
- id: isort
2020

2121
- repo: https://github.com/psf/black
22-
rev: 22.12.0
22+
rev: 23.1.0
2323
hooks:
2424
- id: black
2525
args: [-C]
@@ -36,7 +36,7 @@ repos:
3636
]
3737

3838
- repo: https://github.com/PyCQA/pylint
39-
rev: v2.16.0b1
39+
rev: v2.16.1
4040
hooks:
4141
- id: pylint
4242
args: ["--disable=import-error"]

tests/fixtures/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ def __init__(self) -> None:
376376
self.linear = nn.Linear(3, 1)
377377

378378
def forward(self, input_list: dict[str, torch.Tensor]) -> dict[str, IntWithGetitem]:
379-
x = input_list["foo"] if input_list else torch.ones(3)
379+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
380+
x = input_list["foo"] if input_list else torch.ones(3).to(device)
380381
x = self.linear(x)
381382
return {"foo": IntWithGetitem(x)}
382383

torchinfo/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
"Units",
1212
"Verbosity",
1313
)
14-
__version__ = "1.7.1"
14+
__version__ = "1.7.2"

torchinfo/model_statistics.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727

2828
# TODO: Figure out why the below functions using max() are ever 0
2929
# (they should always be non-negative), and remove the call to max().
30+
# Investigation: https://github.com/TylerYep/torchinfo/pull/195
3031
for layer_info in summary_list:
3132
if layer_info.is_leaf_layer:
3233
self.total_mult_adds += layer_info.macs

0 commit comments

Comments
 (0)