Skip to content

Commit

Permalink
stabilise tab_ddpm internal functions (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis authored Jan 8, 2025
1 parent 21f8e30 commit 5877192
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- https://github.com/ehoogeboom/multinomial_diffusion
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""

# stdlib
import math
from typing import Any, Optional, Tuple
Expand Down Expand Up @@ -457,10 +458,10 @@ def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor:
log_alpha_t = perm_and_expand(self.log_alpha, t, log_x_t.shape)
log_1_min_alpha_t = perm_and_expand(self.log_1_min_alpha, t, log_x_t.shape)

# alpha_t * E[xt] + (1 - alpha_t) 1 / K
# Clamp before log_add_exp to prevent numerical issues
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - torch.log(self.num_classes_expanded),
log_1_min_alpha_t - torch.log(self.num_classes_expanded + 1e-10),
)

return log_probs
Expand All @@ -475,7 +476,7 @@ def q_pred(self, log_x_start: Tensor, t: Tensor) -> Tensor:

log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded),
log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded + 1e-10),
)

return log_probs
Expand Down Expand Up @@ -541,9 +542,11 @@ def log_sample_categorical(self, logits: Tensor) -> Tensor:
full_sample = []
for i in range(len(self.num_classes)):
one_class_logits = logits[:, self.slices_for_classes[i]]
uniform = torch.rand_like(one_class_logits)
# Clamp logits to prevent overflow in Gumbel noise
one_class_logits_clamped = torch.clamp(one_class_logits, max=50)
uniform = torch.rand_like(one_class_logits_clamped)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample = (gumbel_noise + one_class_logits).argmax(dim=1)
sample = (gumbel_noise + one_class_logits_clamped).argmax(dim=1)
full_sample.append(sample.unsqueeze(1))
full_sample = torch.cat(full_sample, dim=1)
log_sample = index_to_log_onehot(full_sample, self.num_classes)
Expand Down
11 changes: 9 additions & 2 deletions src/synthcity/plugins/core/models/tabular_ddpm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,16 @@ def index_to_log_onehot(x: Tensor, num_classes: np.ndarray) -> Tensor:


@torch.jit.script
def log_sub_exp(a: Tensor, b: Tensor) -> Tensor:
def log_sub_exp(a: Tensor, b: Tensor, epsilon: float = 1e-10) -> Tensor:
m = torch.maximum(a, b)
return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m
# Compute the exponentials safely
exp_diff = torch.exp(a - m) - torch.exp(b - m)
# Ensure that exp_diff is greater than epsilon
exp_diff_clamped = torch.clamp(exp_diff, min=epsilon)
# Where a <= b, set the result to -inf or another appropriate value
valid = a > b
log_result = torch.log(exp_diff_clamped) + m
return torch.where(valid, log_result, torch.full_like(log_result, -float("inf")))


@torch.jit.script
Expand Down

0 comments on commit 5877192

Please sign in to comment.