diff --git a/net.py b/net.py index 09ec3b2b..5a3c4073 100644 --- a/net.py +++ b/net.py @@ -910,7 +910,8 @@ def forward(self, x): x = self.map_blocks[i](x) # We select just one output. For compatibility with older models. # All other outputs are ignored - return x[:, 0, 0] + # It is the same as if the last layer had one output. + return x[:, 0, x.shape[2] // 2] @MAPPINGS.register("MappingDNoStyle")