diff --git a/stanza/models/lemma_classifier/train_model.py b/stanza/models/lemma_classifier/train_model.py index 2b313479bd..65c09c9ce1 100644 --- a/stanza/models/lemma_classifier/train_model.py +++ b/stanza/models/lemma_classifier/train_model.py @@ -95,7 +95,7 @@ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping): total_samples = sum(counts.values()) for class_idx in counts: weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes) - weights = torch.tensor(weights) + weights = torch.tensor(weights) logging.info(f"Using weights {weights} for weighted loss.") self.criterion = nn.BCEWithLogitsLoss(weight=weights)