Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly render the "layers" tree view #281

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ pytest
pytest-cov
pre-commit
transformers
compressai
types-tqdm
types-setuptools
compressai
10 changes: 9 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from torchinfo.torchinfo import clear_cached_forward_pass


def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line(
"markers", "no_verify_capsys: skip the verify_capsys fixture"
)


def pytest_addoption(parser: pytest.Parser) -> None:
"""This allows us to check for these params in sys.argv."""
parser.addoption("--overwrite", action="store_true", default=False)
Expand All @@ -25,6 +31,8 @@ def verify_capsys(
clear_cached_forward_pass()
if "--no-output" in sys.argv:
return
if "no_verify_capsys" in request.keywords:
return

test_name = request.node.name.replace("test_", "")
if sys.version_info < (3, 8) and test_name == "tmva_net_column_totals":
Expand Down Expand Up @@ -74,7 +82,7 @@ def get_column_value_for_row(line: str, offset: int) -> int:
if (
not col_value
or col_value in ("--", "(recursive)")
or col_value.startswith(("└─", "├─"))
or col_value.startswith(("└─", "├─", "'--", "|--"))
):
return 0
return int(col_value.replace(",", "").replace("(", "").replace(")", ""))
Expand Down
42 changes: 42 additions & 0 deletions tests/formatting_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# pylint: skip-file

from __future__ import annotations

import pytest

from torchinfo.formatting import make_layer_tree


class MockLayer:
def __init__(self, depth: int) -> None:
self.depth = depth

def __repr__(self) -> str:
return f"L({self.depth})"


L = MockLayer


@pytest.mark.no_verify_capsys
@pytest.mark.parametrize(
"layers, expected",
[
([], []),
([a := L(0)], [a]),
([a := L(0), b := L(0)], [a, b]),
([a := L(0), b := L(1)], [a, [b]]),
([a := L(0), b := L(0), c := L(0)], [a, b, c]),
([a := L(0), b := L(1), c := L(0)], [a, [b], c]),
([a := L(0), b := L(1), c := L(1)], [a, [b, c]]),
([a := L(0), b := L(1), c := L(2)], [a, [b, [c]]]),
(
# If this ever happens, there's probably a bug elsewhere, but
# we still want to format things as best as possible.
[a := L(1), b := L(0)],
[[a], b],
),
],
)
def test_make_layer_tree(layers, expected): # type: ignore [no-untyped-def]
assert make_layer_tree(layers) == expected
138 changes: 69 additions & 69 deletions tests/test_output/ascii_only.out
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,75 @@
Layer (type) Output Shape Param #
==========================================================================================
ResNet [1, 1000] --
+ Conv2d [1, 64, 32, 32] 9,408
+ BatchNorm2d [1, 64, 32, 32] 128
+ ReLU [1, 64, 32, 32] --
+ MaxPool2d [1, 64, 16, 16] --
+ Sequential [1, 64, 16, 16] --
| + BasicBlock [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
| + BasicBlock [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
+ Sequential [1, 128, 8, 8] --
| + BasicBlock [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 73,728
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + ReLU [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 147,456
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + Sequential [1, 128, 8, 8] 8,448
| | + ReLU [1, 128, 8, 8] --
| + BasicBlock [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 147,456
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + ReLU [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 147,456
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + ReLU [1, 128, 8, 8] --
+ Sequential [1, 256, 4, 4] --
| + BasicBlock [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 294,912
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + ReLU [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 589,824
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + Sequential [1, 256, 4, 4] 33,280
| | + ReLU [1, 256, 4, 4] --
| + BasicBlock [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 589,824
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + ReLU [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 589,824
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + ReLU [1, 256, 4, 4] --
+ Sequential [1, 512, 2, 2] --
| + BasicBlock [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 1,179,648
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + ReLU [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 2,359,296
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + Sequential [1, 512, 2, 2] 132,096
| | + ReLU [1, 512, 2, 2] --
| + BasicBlock [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 2,359,296
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + ReLU [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 2,359,296
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + ReLU [1, 512, 2, 2] --
+ AdaptiveAvgPool2d [1, 512, 1, 1] --
+ Linear [1, 1000] 513,000
|--Conv2d [1, 64, 32, 32] 9,408
|--BatchNorm2d [1, 64, 32, 32] 128
|--ReLU [1, 64, 32, 32] --
|--MaxPool2d [1, 64, 16, 16] --
|--Sequential [1, 64, 16, 16] --
| |--BasicBlock [1, 64, 16, 16] --
| | |--Conv2d [1, 64, 16, 16] 36,864
| | |--BatchNorm2d [1, 64, 16, 16] 128
| | |--ReLU [1, 64, 16, 16] --
| | |--Conv2d [1, 64, 16, 16] 36,864
| | |--BatchNorm2d [1, 64, 16, 16] 128
| | '--ReLU [1, 64, 16, 16] --
| '--BasicBlock [1, 64, 16, 16] --
| |--Conv2d [1, 64, 16, 16] 36,864
| |--BatchNorm2d [1, 64, 16, 16] 128
| |--ReLU [1, 64, 16, 16] --
| |--Conv2d [1, 64, 16, 16] 36,864
| |--BatchNorm2d [1, 64, 16, 16] 128
| '--ReLU [1, 64, 16, 16] --
|--Sequential [1, 128, 8, 8] --
| |--BasicBlock [1, 128, 8, 8] --
| | |--Conv2d [1, 128, 8, 8] 73,728
| | |--BatchNorm2d [1, 128, 8, 8] 256
| | |--ReLU [1, 128, 8, 8] --
| | |--Conv2d [1, 128, 8, 8] 147,456
| | |--BatchNorm2d [1, 128, 8, 8] 256
| | |--Sequential [1, 128, 8, 8] 8,448
| | '--ReLU [1, 128, 8, 8] --
| '--BasicBlock [1, 128, 8, 8] --
| |--Conv2d [1, 128, 8, 8] 147,456
| |--BatchNorm2d [1, 128, 8, 8] 256
| |--ReLU [1, 128, 8, 8] --
| |--Conv2d [1, 128, 8, 8] 147,456
| |--BatchNorm2d [1, 128, 8, 8] 256
| '--ReLU [1, 128, 8, 8] --
|--Sequential [1, 256, 4, 4] --
| |--BasicBlock [1, 256, 4, 4] --
| | |--Conv2d [1, 256, 4, 4] 294,912
| | |--BatchNorm2d [1, 256, 4, 4] 512
| | |--ReLU [1, 256, 4, 4] --
| | |--Conv2d [1, 256, 4, 4] 589,824
| | |--BatchNorm2d [1, 256, 4, 4] 512
| | |--Sequential [1, 256, 4, 4] 33,280
| | '--ReLU [1, 256, 4, 4] --
| '--BasicBlock [1, 256, 4, 4] --
| |--Conv2d [1, 256, 4, 4] 589,824
| |--BatchNorm2d [1, 256, 4, 4] 512
| |--ReLU [1, 256, 4, 4] --
| |--Conv2d [1, 256, 4, 4] 589,824
| |--BatchNorm2d [1, 256, 4, 4] 512
| '--ReLU [1, 256, 4, 4] --
|--Sequential [1, 512, 2, 2] --
| |--BasicBlock [1, 512, 2, 2] --
| | |--Conv2d [1, 512, 2, 2] 1,179,648
| | |--BatchNorm2d [1, 512, 2, 2] 1,024
| | |--ReLU [1, 512, 2, 2] --
| | |--Conv2d [1, 512, 2, 2] 2,359,296
| | |--BatchNorm2d [1, 512, 2, 2] 1,024
| | |--Sequential [1, 512, 2, 2] 132,096
| | '--ReLU [1, 512, 2, 2] --
| '--BasicBlock [1, 512, 2, 2] --
| |--Conv2d [1, 512, 2, 2] 2,359,296
| |--BatchNorm2d [1, 512, 2, 2] 1,024
| |--ReLU [1, 512, 2, 2] --
| |--Conv2d [1, 512, 2, 2] 2,359,296
| |--BatchNorm2d [1, 512, 2, 2] 1,024
| '--ReLU [1, 512, 2, 2] --
|--AdaptiveAvgPool2d [1, 512, 1, 1] --
'--Linear [1, 1000] 513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Expand Down
8 changes: 4 additions & 4 deletions tests/test_output/autoencoder.out
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ Layer (type:depth-idx) Output Shape Param #
===================================================================================================================
AutoEncoder [1, 3, 64, 64] -- --
├─Sequential: 1-1 [1, 16, 64, 64] -- --
─Conv2d: 2-1 [1, 16, 64, 64] 448 [3, 3]
─Conv2d: 2-1 [1, 16, 64, 64] 448 [3, 3]
│ └─ReLU: 2-2 [1, 16, 64, 64] -- --
├─MaxPool2d: 1-2 [1, 16, 32, 32] -- 2
├─MaxUnpool2d: 1-3 [1, 16, 64, 64] -- [2, 2]
─Sequential: 1-4 [1, 3, 64, 64] -- --
─Conv2d: 2-3 [1, 3, 64, 64] 435 [3, 3]
└─ReLU: 2-4 [1, 3, 64, 64] -- --
─Sequential: 1-4 [1, 3, 64, 64] -- --
─Conv2d: 2-3 [1, 3, 64, 64] 435 [3, 3]
└─ReLU: 2-4 [1, 3, 64, 64] -- --
===================================================================================================================
Total params: 883
Trainable params: 883
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/basic_summary.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SingleInputNet --
├─Conv2d: 1-2 5,020
├─Dropout2d: 1-3 --
├─Linear: 1-4 16,050
─Linear: 1-5 510
─Linear: 1-5 510
=================================================================
Total params: 21,840
Trainable params: 21,840
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/batch_size_optimization.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SingleInputNet [1, 10] --
├─Conv2d: 1-2 [1, 20, 8, 8] 5,020
├─Dropout2d: 1-3 [1, 20, 8, 8] --
├─Linear: 1-4 [1, 50] 16,050
─Linear: 1-5 [1, 10] 510
─Linear: 1-5 [1, 10] 510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Expand Down
38 changes: 19 additions & 19 deletions tests/test_output/bert.out
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@ Layer (type:depth-idx) Output Shape Par
====================================================================================================
BertModel [2, 768] --
├─BertEmbeddings: 1-1 [2, 512, 768] --
─Embedding: 2-1 [2, 512, 768] 23,440,896
─Embedding: 2-2 [2, 512, 768] 1,536
─Embedding: 2-3 [1, 512, 768] 393,216
─LayerNorm: 2-4 [2, 512, 768] 1,536
─Embedding: 2-1 [2, 512, 768] 23,440,896
─Embedding: 2-2 [2, 512, 768] 1,536
─Embedding: 2-3 [1, 512, 768] 393,216
─LayerNorm: 2-4 [2, 512, 768] 1,536
│ └─Dropout: 2-5 [2, 512, 768] --
├─BertEncoder: 1-2 [2, 512, 768] --
│ └─ModuleList: 2-6 -- --
─BertLayer: 3-1 [2, 512, 768] 7,087,872
─BertLayer: 3-2 [2, 512, 768] 7,087,872
─BertLayer: 3-3 [2, 512, 768] 7,087,872
─BertLayer: 3-4 [2, 512, 768] 7,087,872
─BertLayer: 3-5 [2, 512, 768] 7,087,872
─BertLayer: 3-6 [2, 512, 768] 7,087,872
─BertLayer: 3-7 [2, 512, 768] 7,087,872
─BertLayer: 3-8 [2, 512, 768] 7,087,872
─BertLayer: 3-9 [2, 512, 768] 7,087,872
─BertLayer: 3-10 [2, 512, 768] 7,087,872
─BertLayer: 3-11 [2, 512, 768] 7,087,872
└─BertLayer: 3-12 [2, 512, 768] 7,087,872
─BertPooler: 1-3 [2, 768] --
─Linear: 2-7 [2, 768] 590,592
└─Tanh: 2-8 [2, 768] --
─BertLayer: 3-1 [2, 512, 768] 7,087,872
─BertLayer: 3-2 [2, 512, 768] 7,087,872
─BertLayer: 3-3 [2, 512, 768] 7,087,872
─BertLayer: 3-4 [2, 512, 768] 7,087,872
─BertLayer: 3-5 [2, 512, 768] 7,087,872
─BertLayer: 3-6 [2, 512, 768] 7,087,872
─BertLayer: 3-7 [2, 512, 768] 7,087,872
─BertLayer: 3-8 [2, 512, 768] 7,087,872
─BertLayer: 3-9 [2, 512, 768] 7,087,872
─BertLayer: 3-10 [2, 512, 768] 7,087,872
─BertLayer: 3-11 [2, 512, 768] 7,087,872
└─BertLayer: 3-12 [2, 512, 768] 7,087,872
─BertPooler: 1-3 [2, 768] --
─Linear: 2-7 [2, 768] 590,592
└─Tanh: 2-8 [2, 768] --
====================================================================================================
Total params: 109,482,240
Trainable params: 109,482,240
Expand Down
Loading
Loading