diff --git a/ddpm/unet.py b/ddpm/unet.py index 57725e3..ef457ba 100644 --- a/ddpm/unet.py +++ b/ddpm/unet.py @@ -297,7 +297,7 @@ def __init__( time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), ResidualBlock(out_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), - AttentionBlock(out_channels) if use_attn else nn.Identity, + AttentionBlock(out_channels) if use_attn else nn.Identity(), Downsample(out_channels, use_reflection_pad=use_reflection_pad) if not is_last else nn.Identity(), ])) @@ -316,7 +316,7 @@ def __init__( time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), ResidualBlock(in_channels, in_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), - AttentionBlock(in_channels) if use_attn else nn.Identity, + AttentionBlock(in_channels) if use_attn else nn.Identity(), Upsample(in_channels, align_corners=align_corners, use_reflection_pad=use_reflection_pad), ]))