Skip to content

Commit

Permalink
KL to batchmean
Browse files Browse the repository at this point in the history
  • Loading branch information
um1 committed Dec 28, 2023
1 parent 2955b31 commit e580e97
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
outputs2 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]

mean_pred = sm(outputs1 + outputs2)
kl_loss = nn.KLDivLoss(size_average=False)
kl_loss = nn.KLDivLoss(reduction='batchmean')
reg= (kl_loss(log_sm(outputs2) , mean_pred) + kl_loss(log_sm(outputs1) , mean_pred))/2
loss += 0.01*reg
del inputs1, inputs2
Expand Down

0 comments on commit e580e97

Please sign in to comment.