diff --git a/nflib/flows.py b/nflib/flows.py index cd70a1d..5168053 100644 --- a/nflib/flows.py +++ b/nflib/flows.py @@ -143,7 +143,7 @@ def __init__(self, dim, parity, net_class=MLP, nh=24): self.layers = nn.ModuleDict() self.layers[str(0)] = LeafParam(2) for i in range(1, dim): - self.layers[str(i)] = MLP(i, 2, nh) + self.layers[str(i)] = net_class(i, 2, nh) self.order = list(range(dim)) if parity else list(range(dim))[::-1] def forward(self, x):