diff --git a/framework/prediction.py b/framework/prediction.py index ebfa30c..f069dd0 100644 --- a/framework/prediction.py +++ b/framework/prediction.py @@ -155,4 +155,4 @@ def init_weight(self): nn.init.uniform_(self.fc.bias, a=0, b=0.2) def forward(self, feature, *args, **kwargs): - return F.relu(self.fc(feature)[0]) + return F.relu(self.fc(feature))