2
2
3
3
from typing import Any
4
4
5
- from .formatting import FormattingOptions
5
+ from .enums import Units
6
+ from .formatting import CONVERSION_FACTORS , FormattingOptions
6
7
from .layer_info import LayerInfo
7
8
8
9
@@ -46,24 +47,35 @@ def __init__(
46
47
def __repr__ (self ) -> str :
47
48
"""Print results of the summary."""
48
49
divider = "=" * self .formatting .get_total_width ()
50
+ total_params = ModelStatistics .format_output_num (
51
+ self .total_params , self .formatting .params_units
52
+ )
53
+ trainable_params = ModelStatistics .format_output_num (
54
+ self .trainable_params , self .formatting .params_units
55
+ )
56
+ non_trainable_params = ModelStatistics .format_output_num (
57
+ self .total_params - self .trainable_params , self .formatting .params_units
58
+ )
49
59
summary_str = (
50
60
f"{ divider } \n "
51
61
f"{ self .formatting .header_row ()} { divider } \n "
52
62
f"{ self .formatting .layers_to_str (self .summary_list )} { divider } \n "
53
- f"Total params: { self . total_params :, } \n "
54
- f"Trainable params: { self . trainable_params :, } \n "
55
- f"Non-trainable params: { self . total_params - self . trainable_params :, } \n "
63
+ f"Total params{ total_params } \n "
64
+ f"Trainable params{ trainable_params } \n "
65
+ f"Non-trainable params{ non_trainable_params } \n "
56
66
)
57
67
if self .input_size :
58
- unit , macs = self .to_readable (self .total_mult_adds )
68
+ macs = ModelStatistics .format_output_num (
69
+ self .total_mult_adds , self .formatting .macs_units
70
+ )
59
71
input_size = self .to_megabytes (self .total_input )
60
72
output_bytes = self .to_megabytes (self .total_output_bytes )
61
73
param_bytes = self .to_megabytes (self .total_param_bytes )
62
74
total_bytes = self .to_megabytes (
63
75
self .total_input + self .total_output_bytes + self .total_param_bytes
64
76
)
65
77
summary_str += (
66
- f"Total mult-adds ( { unit } ): { macs :0.2f } \n { divider } \n "
78
+ f"Total mult-adds{ macs } \n { divider } \n "
67
79
f"Input size (MB): { input_size :0.2f} \n "
68
80
f"Forward/backward pass size (MB): { output_bytes :0.2f} \n "
69
81
f"Params size (MB): { param_bytes :0.2f} \n "
@@ -83,10 +95,21 @@ def to_megabytes(num: int) -> float:
83
95
return num / 1e6
84
96
85
97
@staticmethod
86
- def to_readable (num : int ) -> tuple [str , float ]:
98
+ def to_readable (num : int , units : Units = Units . AUTO ) -> tuple [Units , float ]:
87
99
"""Converts a number to millions, billions, or trillions."""
88
- if num >= 1e12 :
89
- return "T" , num / 1e12
90
- if num >= 1e9 :
91
- return "G" , num / 1e9
92
- return "M" , num / 1e6
100
+ if units == Units .AUTO :
101
+ if num >= 1e12 :
102
+ return Units .TERABYTES , num / 1e12
103
+ if num >= 1e9 :
104
+ return Units .GIGABYTES , num / 1e9
105
+ return Units .MEGABYTES , num / 1e6
106
+ return units , num / CONVERSION_FACTORS [units ]
107
+
108
+ @staticmethod
109
+ def format_output_num (num : int , units : Units ) -> str :
110
+ units_used , converted_num = ModelStatistics .to_readable (num , units )
111
+ if converted_num .is_integer ():
112
+ converted_num = int (converted_num )
113
+ units_display = "" if units_used == Units .NONE else f" ({ units_used } )"
114
+ fmt = "d" if isinstance (converted_num , int ) else ".2f"
115
+ return f"{ units_display } : { converted_num :,{fmt }} "
0 commit comments