Skip to content

Commit 8b3ae72

Browse files
richardtmlTylerYep
andauthored
Add params and MACs units specifiers (#188)
* Add params and MACs units specifiers * Use enums * Add test case Co-authored-by: Tyler Yep <[email protected]>
1 parent 0c1ccff commit 8b3ae72

15 files changed

+119
-24
lines changed

tests/test_output/dict_parameters_1.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ DictParameter [10, 1] --
66
Total params: 0
77
Trainable params: 0
88
Non-trainable params: 0
9-
Total mult-adds (M): 0.00
9+
Total mult-adds (M): 0
1010
==========================================================================================
1111
Input size (MB): 0.00
1212
Forward/backward pass size (MB): 0.00

tests/test_output/dict_parameters_2.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ DictParameter [10, 1] --
66
Total params: 0
77
Trainable params: 0
88
Non-trainable params: 0
9-
Total mult-adds (M): 0.00
9+
Total mult-adds (M): 0
1010
==========================================================================================
1111
Input size (MB): 0.00
1212
Forward/backward pass size (MB): 0.00

tests/test_output/dict_parameters_3.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ DictParameter [10, 1] --
66
Total params: 0
77
Trainable params: 0
88
Non-trainable params: 0
9-
Total mult-adds (M): 0.00
9+
Total mult-adds (M): 0
1010
==========================================================================================
1111
Input size (MB): 0.00
1212
Forward/backward pass size (MB): 0.00
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
==========================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==========================================================================================
4+
SingleInputNet [16, 10] --
5+
├─Conv2d: 1-1 [16, 10, 24, 24] 260
6+
├─Conv2d: 1-2 [16, 20, 8, 8] 5,020
7+
├─Dropout2d: 1-3 [16, 20, 8, 8] --
8+
├─Linear: 1-4 [16, 50] 16,050
9+
├─Linear: 1-5 [16, 10] 510
10+
==========================================================================================
11+
Total params: 21,840
12+
Trainable params: 21,840
13+
Non-trainable params: 0
14+
Total mult-adds: 7,801,600
15+
==========================================================================================
16+
Input size (MB): 0.05
17+
Forward/backward pass size (MB): 0.91
18+
Params size (MB): 0.09
19+
Estimated Total Size (MB): 1.05
20+
==========================================================================================
21+
==========================================================================================
22+
Layer (type:depth-idx) Output Shape Param #
23+
==========================================================================================
24+
SingleInputNet [16, 10] --
25+
├─Conv2d: 1-1 [16, 10, 24, 24] 260
26+
├─Conv2d: 1-2 [16, 20, 8, 8] 5,020
27+
├─Dropout2d: 1-3 [16, 20, 8, 8] --
28+
├─Linear: 1-4 [16, 50] 16,050
29+
├─Linear: 1-5 [16, 10] 510
30+
==========================================================================================
31+
Total params (T): 0.00
32+
Trainable params (T): 0.00
33+
Non-trainable params (T): 0
34+
Total mult-adds (T): 0.00
35+
==========================================================================================
36+
Input size (MB): 0.05
37+
Forward/backward pass size (MB): 0.91
38+
Params size (MB): 0.09
39+
Estimated Total Size (MB): 1.05
40+
==========================================================================================

tests/test_output/jit.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ LinearModel -- --
3333
Total params: 33,153
3434
Trainable params: 33,153
3535
Non-trainable params: 0
36-
Total mult-adds (M): 0.00
36+
Total mult-adds (M): 0
3737
==========================================================================================
3838
Input size (MB): 0.03
3939
Forward/backward pass size (MB): 0.00

tests/test_output/namedtuple.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ NamedTuple [2, 1, 28, 28] --
66
Total params: 0
77
Trainable params: 0
88
Non-trainable params: 0
9-
Total mult-adds (M): 0.00
9+
Total mult-adds (M): 0
1010
==========================================================================================
1111
Input size (MB): 0.01
1212
Forward/backward pass size (MB): 0.00

tests/test_output/parameter_list.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ ParameterListModel -- [100, 100]
66
Total params: 30,000
77
Trainable params: 30,000
88
Non-trainable params: 0
9-
Total mult-adds (M): 0.00
9+
Total mult-adds (M): 0
1010
================================================================================================================================================================
1111
Input size (MB): 0.04
1212
Forward/backward pass size (MB): 0.00

tests/test_output/parameters_with_other_layers.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ParameterFCNet [3, 64] 8,256
2828
Total params: 8,256
2929
Trainable params: 8,256
3030
Non-trainable params: 0
31-
Total mult-adds (M): 0.00
31+
Total mult-adds (M): 0
3232
==========================================================================================
3333
Input size (MB): 0.00
3434
Forward/backward pass size (MB): 0.00

tests/test_output/partial_jit.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ PartialJITModel -- --
1111
Total params: 21,840
1212
Trainable params: 21,840
1313
Non-trainable params: 0
14-
Total mult-adds (M): 0.00
14+
Total mult-adds (M): 0
1515
==========================================================================================
1616
Input size (MB): 0.01
1717
Forward/backward pass size (MB): 0.00

tests/test_output/uninitialized_tensor.out

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ UninitializedParameterModel [2, 2] 128
1515
Total params: 128
1616
Trainable params: 128
1717
Non-trainable params: 0
18-
Total mult-adds (M): 0.00
18+
Total mult-adds (M): 0
1919
==========================================================================================
2020
Input size (MB): 0.00
2121
Forward/backward pass size (MB): 0.00

tests/torchinfo_test.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
UninitializedParameterModel,
4242
)
4343
from torchinfo import ColumnSettings, summary
44-
from torchinfo.enums import Verbosity
44+
from torchinfo.enums import Units, Verbosity
4545

4646

4747
def test_basic_summary() -> None:
@@ -175,6 +175,18 @@ def test_row_settings() -> None:
175175
summary(model, input_size=(16, 1, 28, 28), row_settings=("var_names",))
176176

177177

178+
def test_formatting_options() -> None:
179+
model = SingleInputNet()
180+
181+
results = summary(model, input_size=(16, 1, 28, 28), verbose=0)
182+
results.formatting.macs_units = Units.NONE
183+
print(results)
184+
185+
results.formatting.params_units = Units.TERABYTES
186+
results.formatting.macs_units = Units.TERABYTES
187+
print(results)
188+
189+
178190
def test_jit() -> None:
179191
model = LinearModel()
180192
model_jit = torch.jit.script(model)

torchinfo/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .enums import ColumnSettings, Mode, RowSettings, Verbosity
1+
from .enums import ColumnSettings, Mode, RowSettings, Units, Verbosity
22
from .model_statistics import ModelStatistics
33
from .torchinfo import summary
44

@@ -8,6 +8,7 @@
88
"Mode",
99
"ModelStatistics",
1010
"RowSettings",
11+
"Units",
1112
"Verbosity",
1213
)
1314
__version__ = "1.7.1"

torchinfo/enums.py

+11
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ class ColumnSettings(str, Enum):
3333
TRAINABLE = "trainable"
3434

3535

36+
@unique
37+
class Units(str, Enum):
38+
"""Enum containing all available bytes units."""
39+
40+
AUTO = "auto"
41+
MEGABYTES = "M"
42+
GIGABYTES = "G"
43+
TERABYTES = "T"
44+
NONE = ""
45+
46+
3647
@unique
3748
class Verbosity(IntEnum):
3849
"""Contains verbosity levels."""

torchinfo/formatting.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44
from typing import Any
55

6-
from .enums import ColumnSettings, RowSettings, Verbosity
6+
from .enums import ColumnSettings, RowSettings, Units, Verbosity
77
from .layer_info import LayerInfo
88

99
HEADER_TITLES = {
@@ -14,6 +14,12 @@
1414
ColumnSettings.MULT_ADDS: "Mult-Adds",
1515
ColumnSettings.TRAINABLE: "Trainable",
1616
}
17+
CONVERSION_FACTORS = {
18+
Units.TERABYTES: 1e12,
19+
Units.GIGABYTES: 1e9,
20+
Units.MEGABYTES: 1e6,
21+
Units.NONE: 1,
22+
}
1723

1824

1925
class FormattingOptions:
@@ -32,6 +38,8 @@ def __init__(
3238
self.col_names = col_names
3339
self.col_width = col_width
3440
self.row_settings = row_settings
41+
self.params_units = Units.NONE
42+
self.macs_units = Units.AUTO
3543

3644
self.layer_name_width = 40
3745
self.ascii_only = RowSettings.ASCII_ONLY in self.row_settings

torchinfo/model_statistics.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from typing import Any
44

5-
from .formatting import FormattingOptions
5+
from .enums import Units
6+
from .formatting import CONVERSION_FACTORS, FormattingOptions
67
from .layer_info import LayerInfo
78

89

@@ -46,24 +47,35 @@ def __init__(
4647
def __repr__(self) -> str:
4748
"""Print results of the summary."""
4849
divider = "=" * self.formatting.get_total_width()
50+
total_params = ModelStatistics.format_output_num(
51+
self.total_params, self.formatting.params_units
52+
)
53+
trainable_params = ModelStatistics.format_output_num(
54+
self.trainable_params, self.formatting.params_units
55+
)
56+
non_trainable_params = ModelStatistics.format_output_num(
57+
self.total_params - self.trainable_params, self.formatting.params_units
58+
)
4959
summary_str = (
5060
f"{divider}\n"
5161
f"{self.formatting.header_row()}{divider}\n"
5262
f"{self.formatting.layers_to_str(self.summary_list)}{divider}\n"
53-
f"Total params: {self.total_params:,}\n"
54-
f"Trainable params: {self.trainable_params:,}\n"
55-
f"Non-trainable params: {self.total_params - self.trainable_params:,}\n"
63+
f"Total params{total_params}\n"
64+
f"Trainable params{trainable_params}\n"
65+
f"Non-trainable params{non_trainable_params}\n"
5666
)
5767
if self.input_size:
58-
unit, macs = self.to_readable(self.total_mult_adds)
68+
macs = ModelStatistics.format_output_num(
69+
self.total_mult_adds, self.formatting.macs_units
70+
)
5971
input_size = self.to_megabytes(self.total_input)
6072
output_bytes = self.to_megabytes(self.total_output_bytes)
6173
param_bytes = self.to_megabytes(self.total_param_bytes)
6274
total_bytes = self.to_megabytes(
6375
self.total_input + self.total_output_bytes + self.total_param_bytes
6476
)
6577
summary_str += (
66-
f"Total mult-adds ({unit}): {macs:0.2f}\n{divider}\n"
78+
f"Total mult-adds{macs}\n{divider}\n"
6779
f"Input size (MB): {input_size:0.2f}\n"
6880
f"Forward/backward pass size (MB): {output_bytes:0.2f}\n"
6981
f"Params size (MB): {param_bytes:0.2f}\n"
@@ -83,10 +95,21 @@ def to_megabytes(num: int) -> float:
8395
return num / 1e6
8496

8597
@staticmethod
86-
def to_readable(num: int) -> tuple[str, float]:
98+
def to_readable(num: int, units: Units = Units.AUTO) -> tuple[Units, float]:
8799
"""Converts a number to millions, billions, or trillions."""
88-
if num >= 1e12:
89-
return "T", num / 1e12
90-
if num >= 1e9:
91-
return "G", num / 1e9
92-
return "M", num / 1e6
100+
if units == Units.AUTO:
101+
if num >= 1e12:
102+
return Units.TERABYTES, num / 1e12
103+
if num >= 1e9:
104+
return Units.GIGABYTES, num / 1e9
105+
return Units.MEGABYTES, num / 1e6
106+
return units, num / CONVERSION_FACTORS[units]
107+
108+
@staticmethod
109+
def format_output_num(num: int, units: Units) -> str:
110+
units_used, converted_num = ModelStatistics.to_readable(num, units)
111+
if converted_num.is_integer():
112+
converted_num = int(converted_num)
113+
units_display = "" if units_used == Units.NONE else f" ({units_used})"
114+
fmt = "d" if isinstance(converted_num, int) else ".2f"
115+
return f"{units_display}: {converted_num:,{fmt}}"

0 commit comments

Comments
 (0)