Skip to content

Commit

Permalink
Remove extra vertical lines from the layer tree view
Browse files Browse the repository at this point in the history
  • Loading branch information
kalekundert committed Oct 24, 2023
1 parent 0f8e78a commit 6048021
Show file tree
Hide file tree
Showing 70 changed files with 1,983 additions and 1,773 deletions.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ profile = black
[flake8]
max-line-length = 88
extend-ignore = E203,F401

[tool:pytest]
python_files = *_test.py
10 changes: 9 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from torchinfo.torchinfo import clear_cached_forward_pass


def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line(
"markers", "no_verify_capsys: skip the verify_capsys fixture"
)


def pytest_addoption(parser: pytest.Parser) -> None:
"""This allows us to check for these params in sys.argv."""
parser.addoption("--overwrite", action="store_true", default=False)
Expand All @@ -24,6 +30,8 @@ def verify_capsys(
clear_cached_forward_pass()
if "--no-output" in sys.argv:
return
if "no_verify_capsys" in request.keywords:
return

test_name = request.node.name.replace("test_", "")
if sys.version_info < (3, 8) and test_name == "tmva_net_column_totals":
Expand Down Expand Up @@ -73,7 +81,7 @@ def get_column_value_for_row(line: str, offset: int) -> int:
if (
not col_value
or col_value in ("--", "(recursive)")
or col_value.startswith(("└─", "├─"))
or col_value.startswith(("└─", "├─", "'--", "|--"))
):
return 0
return int(col_value.replace(",", "").replace("(", "").replace(")", ""))
Expand Down
42 changes: 42 additions & 0 deletions tests/formatting_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# pylint: skip-file

from __future__ import annotations

import pytest

from torchinfo.formatting import make_layer_tree


class MockLayer:
def __init__(self, depth: int) -> None:
self.depth = depth

def __repr__(self) -> str:
return f"L({self.depth})"


L = MockLayer


@pytest.mark.no_verify_capsys
@pytest.mark.parametrize(
"layers, expected",
[
([], []),
([a := L(0)], [a]),
([a := L(0), b := L(0)], [a, b]),
([a := L(0), b := L(1)], [a, [b]]),
([a := L(0), b := L(0), c := L(0)], [a, b, c]),
([a := L(0), b := L(1), c := L(0)], [a, [b], c]),
([a := L(0), b := L(1), c := L(1)], [a, [b, c]]),
([a := L(0), b := L(1), c := L(2)], [a, [b, [c]]]),
(
# If this ever happens, there's probably a bug elsewhere, but
# we still want to format things as best as possible.
[a := L(1), b := L(0)],
[[a], b],
),
],
)
def test_make_layer_tree(layers, expected): # type: ignore [no-untyped-def]
assert make_layer_tree(layers) == expected
138 changes: 69 additions & 69 deletions tests/test_output/ascii_only.out
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,75 @@
Layer (type) Output Shape Param #
==========================================================================================
ResNet [1, 1000] --
+ Conv2d [1, 64, 32, 32] 9,408
+ BatchNorm2d [1, 64, 32, 32] 128
+ ReLU [1, 64, 32, 32] --
+ MaxPool2d [1, 64, 16, 16] --
+ Sequential [1, 64, 16, 16] --
| + BasicBlock [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
| + BasicBlock [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
| | + Conv2d [1, 64, 16, 16] 36,864
| | + BatchNorm2d [1, 64, 16, 16] 128
| | + ReLU [1, 64, 16, 16] --
+ Sequential [1, 128, 8, 8] --
| + BasicBlock [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 73,728
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + ReLU [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 147,456
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + Sequential [1, 128, 8, 8] 8,448
| | + ReLU [1, 128, 8, 8] --
| + BasicBlock [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 147,456
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + ReLU [1, 128, 8, 8] --
| | + Conv2d [1, 128, 8, 8] 147,456
| | + BatchNorm2d [1, 128, 8, 8] 256
| | + ReLU [1, 128, 8, 8] --
+ Sequential [1, 256, 4, 4] --
| + BasicBlock [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 294,912
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + ReLU [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 589,824
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + Sequential [1, 256, 4, 4] 33,280
| | + ReLU [1, 256, 4, 4] --
| + BasicBlock [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 589,824
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + ReLU [1, 256, 4, 4] --
| | + Conv2d [1, 256, 4, 4] 589,824
| | + BatchNorm2d [1, 256, 4, 4] 512
| | + ReLU [1, 256, 4, 4] --
+ Sequential [1, 512, 2, 2] --
| + BasicBlock [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 1,179,648
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + ReLU [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 2,359,296
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + Sequential [1, 512, 2, 2] 132,096
| | + ReLU [1, 512, 2, 2] --
| + BasicBlock [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 2,359,296
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + ReLU [1, 512, 2, 2] --
| | + Conv2d [1, 512, 2, 2] 2,359,296
| | + BatchNorm2d [1, 512, 2, 2] 1,024
| | + ReLU [1, 512, 2, 2] --
+ AdaptiveAvgPool2d [1, 512, 1, 1] --
+ Linear [1, 1000] 513,000
|--Conv2d [1, 64, 32, 32] 9,408
|--BatchNorm2d [1, 64, 32, 32] 128
|--ReLU [1, 64, 32, 32] --
|--MaxPool2d [1, 64, 16, 16] --
|--Sequential [1, 64, 16, 16] --
| |--BasicBlock [1, 64, 16, 16] --
| | |--Conv2d [1, 64, 16, 16] 36,864
| | |--BatchNorm2d [1, 64, 16, 16] 128
| | |--ReLU [1, 64, 16, 16] --
| | |--Conv2d [1, 64, 16, 16] 36,864
| | |--BatchNorm2d [1, 64, 16, 16] 128
| | '--ReLU [1, 64, 16, 16] --
| '--BasicBlock [1, 64, 16, 16] --
| |--Conv2d [1, 64, 16, 16] 36,864
| |--BatchNorm2d [1, 64, 16, 16] 128
| |--ReLU [1, 64, 16, 16] --
| |--Conv2d [1, 64, 16, 16] 36,864
| |--BatchNorm2d [1, 64, 16, 16] 128
| '--ReLU [1, 64, 16, 16] --
|--Sequential [1, 128, 8, 8] --
| |--BasicBlock [1, 128, 8, 8] --
| | |--Conv2d [1, 128, 8, 8] 73,728
| | |--BatchNorm2d [1, 128, 8, 8] 256
| | |--ReLU [1, 128, 8, 8] --
| | |--Conv2d [1, 128, 8, 8] 147,456
| | |--BatchNorm2d [1, 128, 8, 8] 256
| | |--Sequential [1, 128, 8, 8] 8,448
| | '--ReLU [1, 128, 8, 8] --
| '--BasicBlock [1, 128, 8, 8] --
| |--Conv2d [1, 128, 8, 8] 147,456
| |--BatchNorm2d [1, 128, 8, 8] 256
| |--ReLU [1, 128, 8, 8] --
| |--Conv2d [1, 128, 8, 8] 147,456
| |--BatchNorm2d [1, 128, 8, 8] 256
| '--ReLU [1, 128, 8, 8] --
|--Sequential [1, 256, 4, 4] --
| |--BasicBlock [1, 256, 4, 4] --
| | |--Conv2d [1, 256, 4, 4] 294,912
| | |--BatchNorm2d [1, 256, 4, 4] 512
| | |--ReLU [1, 256, 4, 4] --
| | |--Conv2d [1, 256, 4, 4] 589,824
| | |--BatchNorm2d [1, 256, 4, 4] 512
| | |--Sequential [1, 256, 4, 4] 33,280
| | '--ReLU [1, 256, 4, 4] --
| '--BasicBlock [1, 256, 4, 4] --
| |--Conv2d [1, 256, 4, 4] 589,824
| |--BatchNorm2d [1, 256, 4, 4] 512
| |--ReLU [1, 256, 4, 4] --
| |--Conv2d [1, 256, 4, 4] 589,824
| |--BatchNorm2d [1, 256, 4, 4] 512
| '--ReLU [1, 256, 4, 4] --
|--Sequential [1, 512, 2, 2] --
| |--BasicBlock [1, 512, 2, 2] --
| | |--Conv2d [1, 512, 2, 2] 1,179,648
| | |--BatchNorm2d [1, 512, 2, 2] 1,024
| | |--ReLU [1, 512, 2, 2] --
| | |--Conv2d [1, 512, 2, 2] 2,359,296
| | |--BatchNorm2d [1, 512, 2, 2] 1,024
| | |--Sequential [1, 512, 2, 2] 132,096
| | '--ReLU [1, 512, 2, 2] --
| '--BasicBlock [1, 512, 2, 2] --
| |--Conv2d [1, 512, 2, 2] 2,359,296
| |--BatchNorm2d [1, 512, 2, 2] 1,024
| |--ReLU [1, 512, 2, 2] --
| |--Conv2d [1, 512, 2, 2] 2,359,296
| |--BatchNorm2d [1, 512, 2, 2] 1,024
| '--ReLU [1, 512, 2, 2] --
|--AdaptiveAvgPool2d [1, 512, 1, 1] --
'--Linear [1, 1000] 513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Expand Down
8 changes: 4 additions & 4 deletions tests/test_output/autoencoder.out
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ Layer (type:depth-idx) Output Shape Param #
===================================================================================================================
AutoEncoder [1, 3, 64, 64] -- --
├─Sequential: 1-1 [1, 16, 64, 64] -- --
─Conv2d: 2-1 [1, 16, 64, 64] 448 [3, 3]
─Conv2d: 2-1 [1, 16, 64, 64] 448 [3, 3]
│ └─ReLU: 2-2 [1, 16, 64, 64] -- --
├─MaxPool2d: 1-2 [1, 16, 32, 32] -- 2
├─MaxUnpool2d: 1-3 [1, 16, 64, 64] -- [2, 2]
─Sequential: 1-4 [1, 3, 64, 64] -- --
─Conv2d: 2-3 [1, 3, 64, 64] 435 [3, 3]
└─ReLU: 2-4 [1, 3, 64, 64] -- --
─Sequential: 1-4 [1, 3, 64, 64] -- --
─Conv2d: 2-3 [1, 3, 64, 64] 435 [3, 3]
└─ReLU: 2-4 [1, 3, 64, 64] -- --
===================================================================================================================
Total params: 883
Trainable params: 883
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/basic_summary.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SingleInputNet --
├─Conv2d: 1-2 5,020
├─Dropout2d: 1-3 --
├─Linear: 1-4 16,050
─Linear: 1-5 510
─Linear: 1-5 510
=================================================================
Total params: 21,840
Trainable params: 21,840
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/batch_size_optimization.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SingleInputNet [1, 10] --
├─Conv2d: 1-2 [1, 20, 8, 8] 5,020
├─Dropout2d: 1-3 [1, 20, 8, 8] --
├─Linear: 1-4 [1, 50] 16,050
─Linear: 1-5 [1, 10] 510
─Linear: 1-5 [1, 10] 510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Expand Down
38 changes: 19 additions & 19 deletions tests/test_output/bert.out
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@ Layer (type:depth-idx) Output Shape Par
====================================================================================================
BertModel [2, 768] --
├─BertEmbeddings: 1-1 [2, 512, 768] --
─Embedding: 2-1 [2, 512, 768] 23,440,896
─Embedding: 2-2 [2, 512, 768] 1,536
─Embedding: 2-3 [1, 512, 768] 393,216
─LayerNorm: 2-4 [2, 512, 768] 1,536
─Embedding: 2-1 [2, 512, 768] 23,440,896
─Embedding: 2-2 [2, 512, 768] 1,536
─Embedding: 2-3 [1, 512, 768] 393,216
─LayerNorm: 2-4 [2, 512, 768] 1,536
│ └─Dropout: 2-5 [2, 512, 768] --
├─BertEncoder: 1-2 [2, 512, 768] --
│ └─ModuleList: 2-6 -- --
─BertLayer: 3-1 [2, 512, 768] 7,087,872
─BertLayer: 3-2 [2, 512, 768] 7,087,872
─BertLayer: 3-3 [2, 512, 768] 7,087,872
─BertLayer: 3-4 [2, 512, 768] 7,087,872
─BertLayer: 3-5 [2, 512, 768] 7,087,872
─BertLayer: 3-6 [2, 512, 768] 7,087,872
─BertLayer: 3-7 [2, 512, 768] 7,087,872
─BertLayer: 3-8 [2, 512, 768] 7,087,872
─BertLayer: 3-9 [2, 512, 768] 7,087,872
─BertLayer: 3-10 [2, 512, 768] 7,087,872
─BertLayer: 3-11 [2, 512, 768] 7,087,872
└─BertLayer: 3-12 [2, 512, 768] 7,087,872
─BertPooler: 1-3 [2, 768] --
─Linear: 2-7 [2, 768] 590,592
└─Tanh: 2-8 [2, 768] --
─BertLayer: 3-1 [2, 512, 768] 7,087,872
─BertLayer: 3-2 [2, 512, 768] 7,087,872
─BertLayer: 3-3 [2, 512, 768] 7,087,872
─BertLayer: 3-4 [2, 512, 768] 7,087,872
─BertLayer: 3-5 [2, 512, 768] 7,087,872
─BertLayer: 3-6 [2, 512, 768] 7,087,872
─BertLayer: 3-7 [2, 512, 768] 7,087,872
─BertLayer: 3-8 [2, 512, 768] 7,087,872
─BertLayer: 3-9 [2, 512, 768] 7,087,872
─BertLayer: 3-10 [2, 512, 768] 7,087,872
─BertLayer: 3-11 [2, 512, 768] 7,087,872
└─BertLayer: 3-12 [2, 512, 768] 7,087,872
─BertPooler: 1-3 [2, 768] --
─Linear: 2-7 [2, 768] 590,592
└─Tanh: 2-8 [2, 768] --
====================================================================================================
Total params: 109,482,240
Trainable params: 109,482,240
Expand Down
Loading

0 comments on commit 6048021

Please sign in to comment.