From 5877192b786ae0f20ef28386c4fabd2b389dd6af Mon Sep 17 00:00:00 2001 From: Rob <62107751+robsdavis@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:13:30 +0000 Subject: [PATCH] stabilise tab_ddpm internal functions (#317) --- .../tabular_ddpm/gaussian_multinomial_diffsuion.py | 13 ++++++++----- .../plugins/core/models/tabular_ddpm/utils.py | 11 +++++++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 6414a2af..f1f54ae2 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -4,6 +4,7 @@ - https://github.com/ehoogeboom/multinomial_diffusion - https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ + # stdlib import math from typing import Any, Optional, Tuple @@ -457,10 +458,10 @@ def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: log_alpha_t = perm_and_expand(self.log_alpha, t, log_x_t.shape) log_1_min_alpha_t = perm_and_expand(self.log_1_min_alpha, t, log_x_t.shape) - # alpha_t * E[xt] + (1 - alpha_t) 1 / K + # Clamp before log_add_exp to prevent numerical issues log_probs = log_add_exp( log_x_t + log_alpha_t, - log_1_min_alpha_t - torch.log(self.num_classes_expanded), + log_1_min_alpha_t - torch.log(self.num_classes_expanded + 1e-10), ) return log_probs @@ -475,7 +476,7 @@ def q_pred(self, log_x_start: Tensor, t: Tensor) -> Tensor: log_probs = log_add_exp( log_x_start + log_cumprod_alpha_t, - log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded), + log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded + 1e-10), ) return log_probs @@ -541,9 +542,11 @@ def log_sample_categorical(self, logits: Tensor) -> Tensor: full_sample = [] for i in range(len(self.num_classes)): one_class_logits = logits[:, self.slices_for_classes[i]] - uniform = torch.rand_like(one_class_logits) + # Clamp logits to prevent overflow in Gumbel noise + one_class_logits_clamped = torch.clamp(one_class_logits, max=50) + uniform = torch.rand_like(one_class_logits_clamped) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) - sample = (gumbel_noise + one_class_logits).argmax(dim=1) + sample = (gumbel_noise + one_class_logits_clamped).argmax(dim=1) full_sample.append(sample.unsqueeze(1)) full_sample = torch.cat(full_sample, dim=1) log_sample = index_to_log_onehot(full_sample, self.num_classes) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index d8fc5008..6040c054 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -149,9 +149,16 @@ def index_to_log_onehot(x: Tensor, num_classes: np.ndarray) -> Tensor: @torch.jit.script -def log_sub_exp(a: Tensor, b: Tensor) -> Tensor: +def log_sub_exp(a: Tensor, b: Tensor, epsilon: float = 1e-10) -> Tensor: m = torch.maximum(a, b) - return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m + # Compute the exponentials safely + exp_diff = torch.exp(a - m) - torch.exp(b - m) + # Ensure that exp_diff is greater than epsilon + exp_diff_clamped = torch.clamp(exp_diff, min=epsilon) + # Where a <= b, set the result to -inf or another appropriate value + valid = a > b + log_result = torch.log(exp_diff_clamped) + m + return torch.where(valid, log_result, torch.full_like(log_result, -float("inf"))) @torch.jit.script