diff --git a/torchinfo/model_statistics.py b/torchinfo/model_statistics.py index 8217e69..40b58c2 100644 --- a/torchinfo/model_statistics.py +++ b/torchinfo/model_statistics.py @@ -75,18 +75,18 @@ def __repr__(self) -> str: macs = ModelStatistics.format_output_num( self.total_mult_adds, self.formatting.macs_units ) - input_size = self.to_megabytes(self.total_input) - output_bytes = self.to_megabytes(self.total_output_bytes) - param_bytes = self.to_megabytes(self.total_param_bytes) - total_bytes = self.to_megabytes( + input_size = self.to_readable(self.total_input) + output_bytes = self.to_readable(self.total_output_bytes) + param_bytes = self.to_readable(self.total_param_bytes) + total_bytes = self.to_readable( self.total_input + self.total_output_bytes + self.total_param_bytes ) summary_str += ( f"Total mult-adds{macs}\n{divider}\n" - f"Input size (MB): {input_size:0.2f}\n" - f"Forward/backward pass size (MB): {output_bytes:0.2f}\n" - f"Params size (MB): {param_bytes:0.2f}\n" - f"Estimated Total Size (MB): {total_bytes:0.2f}\n" + f"Input size ({input_size[0]}B): {input_size[1]:0.2f}\n" + f"Forward/backward pass size ({output_bytes[0]}B): {output_bytes[1]:0.2f}\n" + f"Params size ({param_bytes[0]}B): {param_bytes[1]:0.2f}\n" + f"Estimated Total Size ({total_bytes[0]}B): {total_bytes[1]:0.2f}\n" ) summary_str += divider return summary_str