10
10
11
11
12
12
class BaseMultiTaskSegModel (nn .ModuleDict ):
13
+ def forward_encoder (self , x : torch .Tensor ) -> List [torch .Tensor ]:
14
+ """Forward the model encoder."""
15
+ self ._check_input_shape (x )
16
+ feats = self .encoder (x )
17
+
18
+ return feats
19
+
20
+ def forward_style (self , feat : torch .Tensor ) -> torch .Tensor :
21
+ """Forward the style domain adaptation layer.
22
+
23
+ NOTE: returns None if style channels are not given at model init.
24
+ """
25
+ style = None
26
+ if self .make_style is not None :
27
+ style = self .make_style (feat )
28
+
29
+ return style
30
+
13
31
def forward_dec_features (
14
32
self , feats : List [torch .Tensor ], style : torch .Tensor = None
15
- ) -> Dict [str , torch .Tensor ]:
16
- """Forward pass of the decoders in a multi-task seg model."""
33
+ ) -> Dict [str , List [torch .Tensor ]]:
34
+ """Forward pass of all the decoder features mappings in the model.
35
+
36
+ NOTE: returns all the features from diff decoder stages in a list.
37
+ """
17
38
res = {}
18
39
decoders = [k for k in self .keys () if "decoder" in k ]
19
40
20
41
for dec in decoders :
21
- x = self [dec ](* feats , style = style )
42
+ featlist = self [dec ](* feats , style = style )
22
43
branch = dec .split ("_" )[0 ]
23
- res [branch ] = x
44
+ res [branch ] = featlist
24
45
25
46
return res
26
47
@@ -30,10 +51,9 @@ def forward_heads(
30
51
"""Forward pass of the seg heads in a multi-task seg model."""
31
52
res = {}
32
53
heads = [k for k in self .keys () if "head" in k ]
33
-
34
54
for head in heads :
35
55
branch = head .split ("_" )[0 ]
36
- x = self [head ](dec_feats [branch ])
56
+ x = self [head ](dec_feats [branch ][ - 1 ]) # the last decoder stage feat map
37
57
res [branch ] = x
38
58
39
59
return res
0 commit comments