diff --git a/egg/zoo/signal_game/train.py b/egg/zoo/signal_game/train.py index e86888fbb..ffc4bd58f 100644 --- a/egg/zoo/signal_game/train.py +++ b/egg/zoo/signal_game/train.py @@ -67,7 +67,7 @@ def loss_nll( NLL loss - differentiable and can be used with both GS and Reinforce """ nll = F.nll_loss(receiver_output, labels, reduction="none") - acc = (labels == receiver_output.argmax(dim=1)).float().mean() + acc = (labels == receiver_output.argmax(dim=1)).float() return nll, {"acc": acc}