Skip to content

Commit

Permalink
fix(cat): correct temperature scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Nov 6, 2024
1 parent a3f93f5 commit 8cabee2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion simple_einet/layers/distributions/abstract_leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def dist_sample(distribution: dist.Distribution, ctx: SamplingContext = None) ->
elif type(distribution) == CustomNormal:
distribution = CustomNormal(mu=distribution.mu, sigma=distribution.sigma * np.sqrt(ctx.temperature_leaves))
elif type(distribution) == dist.Categorical:
distribution = dist.Categorical(logits=F.softmax(distribution.logits / ctx.temperature_leaves))
distribution = dist.Categorical(logits=distribution.logits / ctx.temperature_leaves)

samples = distribution.sample(sample_shape=(ctx.num_samples,)).float()

Expand Down

0 comments on commit 8cabee2

Please sign in to comment.