Skip to content

Commit 3e45145

Browse files
committed
Fix hardcoded input dim in DiffusionModelEncoder
Signed-off-by: IamTingTing <[email protected]>
1 parent e499362 commit 3e45145

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

monai/networks/nets/diffusion_model_unet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2005,7 +2005,7 @@ def __init__(
20052005

20062006
self.down_blocks.append(down_block)
20072007

2008-
self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
2008+
self.out = None
20092009

20102010
def forward(
20112011
self,
@@ -2048,6 +2048,12 @@ def forward(
20482048
h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
20492049

20502050
h = h.reshape(h.shape[0], -1)
2051+
2052+
# 5. out
2053+
if self.out is None:
2054+
self.out = nn.Sequential(
2055+
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
2056+
)
20512057
output: torch.Tensor = self.out(h)
20522058

20532059
return output

0 commit comments

Comments
 (0)