Skip to content

Commit

Permalink
Merge pull request #308 from singh96aman/swae_loss_function
Browse files Browse the repository at this point in the history
Issue #307 and #309 - Preserving Data Dimensionality and Adding SWAE Loss Function
  • Loading branch information
exook authored Oct 13, 2023
2 parents 8d14be8 + e31743c commit 7b283f6
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 32 deletions.
34 changes: 30 additions & 4 deletions baler/baler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,17 @@ def perform_training(output_path, config, verbose: bool):
Raises:
NameError: Baler currently only supports 1D (e.g. HEP) or 2D (e.g. CFD) data as inputs.
"""
train_set_norm, test_set_norm, normalization_features = helper.process(
(
train_set_norm,
test_set_norm,
normalization_features,
original_shape,
) = helper.process(
config.input_path,
config.custom_norm,
config.test_size,
config.apply_normalization,
config.convert_to_blocks,
config.convert_to_blocks if hasattr(config, "convert_to_blocks") else None,
)

if verbose:
Expand All @@ -111,8 +116,12 @@ def perform_training(output_path, config, verbose: bool):
)
config.number_of_columns = number_of_columns
elif config.data_dimension == 2:
number_of_rows = train_set_norm.shape[1]
number_of_columns = train_set_norm.shape[2]
if config.model_type == "dense":
number_of_rows = train_set_norm.shape[1]
number_of_columns = train_set_norm.shape[2]
else:
number_of_rows = original_shape[1]
number_of_columns = original_shape[2]
config.latent_space_size = ceil(
(number_of_rows * number_of_columns) / config.compression_ratio
)
Expand Down Expand Up @@ -329,6 +338,23 @@ def perform_decompression(output_path, config, verbose: bool):
if verbose:
print(f"Model used: {model_name}")

if config.convert_to_blocks:
data_before = np.load(config.input_path)["data"]
print(
"Converting Blocked Data into Standard Format. Old Shape - ",
decompressed.shape,
"Target Shape - ",
data_before.shape,
)
if config.model_type == "dense":
decompressed = decompressed.reshape(
data_before.shape[0], data_before.shape[1], data_before.shape[2]
)
else:
decompressed = decompressed.reshape(
data_before.shape[0], 1, data_before.shape[1], data_before.shape[2]
)

if config.apply_normalization:
print("Un-normalizing...")
normalization_features = np.load(
Expand Down
17 changes: 10 additions & 7 deletions baler/modules/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def process(input_path, custom_norm, test_size, apply_normalization, convert_to_
"""
loaded = np.load(input_path)
data = loaded["data"]
original_shape = data.shape

if convert_to_blocks:
data = data_processing.convert_to_blocks_util(convert_to_blocks, data)
Expand All @@ -295,11 +296,7 @@ def process(input_path, custom_norm, test_size, apply_normalization, convert_to_
data, test_size=test_size, random_state=1
)

return (
train_set,
test_set,
normalization_features,
)
return (train_set, test_set, normalization_features, original_shape)


def renormalize(data, true_min_list, feature_range_list):
Expand Down Expand Up @@ -456,6 +453,7 @@ def compress(model_path, config):
# Loads the data and applies normalization if config.apply_normalization = True
loaded = np.load(config.input_path)
data_before = loaded["data"]
original_shape = data_before.shape

if config.convert_to_blocks:
data_before = data_processing.convert_to_blocks_util(
Expand All @@ -478,8 +476,12 @@ def compress(model_path, config):
)
config.number_of_columns = number_of_columns
elif config.data_dimension == 2:
number_of_rows = data.shape[1]
config.number_of_columns = data.shape[2]
if config.model_type == "dense":
number_of_rows = data.shape[1]
config.number_of_columns = data.shape[2]
else:
number_of_rows = original_shape[1]
config.number_of_columns = original_shape[2]
config.latent_space_size = ceil(
(number_of_rows * config.number_of_columns) / config.compression_ratio
)
Expand Down Expand Up @@ -695,6 +697,7 @@ def decompress(
decompressed = decompressed.reshape(
(len(decompressed), number_of_columns, number_of_columns)
)

return decompressed, names, normalization_features


Expand Down
10 changes: 0 additions & 10 deletions baler/modules/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,6 @@ def plot_2D(project_path, config):
"data"
]

if config.convert_to_blocks:
if config.model_type == "dense":
data_decompressed = data_decompressed.reshape(
data.shape[0], data.shape[1], data.shape[2]
)
else:
data_decompressed = data_decompressed.reshape(
data.shape[0], 1, data.shape[1], data.shape[2]
)

if data.shape[0] > 1:
num_tiles = data.shape[0]
else:
Expand Down
42 changes: 31 additions & 11 deletions baler/modules/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@
from ..modules import diagnostics
from ..modules import helper
from ..modules import utils
from torch.nn import functional as F


def fit(
model, train_dl, model_children, regular_param, optimizer, RHO, l1, n_dimensions
config,
model,
train_dl,
model_children,
regular_param,
optimizer,
latent_dim,
RHO,
l1,
n_dimensions,
):
"""This function trains the model on the train set. It computes the losses and does the backwards propagation, and updates the optimizer as well.
Args:
Expand Down Expand Up @@ -59,14 +69,23 @@ def fit(
# Compute the predicted outputs from the input data
reconstructions = model(inputs)

# Compute how far off the prediction is
loss, mse_loss, l1_loss = utils.mse_sum_loss_l1(
model_children=model_children,
true_data=inputs,
reconstructed_data=reconstructions,
reg_param=regular_param,
validate=True,
)
if (
hasattr(config, "custom_loss_function")
and config.custom_loss_function == "loss_function_swae"
):
z = model.encode(inputs)
loss, mse_loss, l1_loss = utils.loss_function_swae(
inputs, z, reconstructions, latent_dim
)
else:
# Compute how far off the prediction is
loss, mse_loss, l1_loss = utils.mse_sum_loss_l1(
model_children=model_children,
true_data=inputs,
reconstructed_data=reconstructions,
reg_param=regular_param,
validate=True,
)

# Compute the loss-gradient with
loss.backward()
Expand Down Expand Up @@ -267,16 +286,17 @@ def train(model, variables, train_data, test_data, project_path, config):
print(f"Epoch {epoch + 1} of {epochs}")

train_epoch_loss, mse_loss_fit, regularizer_loss_fit, trained_model = fit(
config=config,
model=model,
train_dl=train_dl,
model_children=model_children,
regular_param=reg_param,
optimizer=optimizer,
latent_dim=latent_space_size,
RHO=rho,
regular_param=reg_param,
l1=l1,
n_dimensions=config.data_dimension,
)

train_loss.append(train_epoch_loss)

if test_size:
Expand Down
69 changes: 69 additions & 0 deletions baler/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,80 @@
from scipy.stats import wasserstein_distance
from torch.nn import functional
from tqdm import tqdm
from torch.nn import functional as F
from torch import distributions as dist

factor = 0.5
min_lr = 1e-6


def loss_function_swae(
inputs,
z,
reconstructions,
latent_dim,
reg_weight=100,
wasserstein_deg=2.0,
num_projections=2000,
projection_dist="normal",
):
batch_size = inputs.shape[0]
bias_corr = batch_size * (batch_size - 1)
reg_weight = reg_weight / bias_corr

mse_sum = nn.MSELoss(reduction="sum")
mse_loss = mse_sum(reconstructions, inputs)
number_of_columns = inputs.shape[1]
mse_sum_loss = mse_loss / number_of_columns
# recons_loss_l1 = F.l1_loss(reconstructions, inputs)
recons_loss_l1 = 0
swd_loss = compute_swd(
z, wasserstein_deg, reg_weight, latent_dim, num_projections, projection_dist
)
loss = mse_sum_loss + recons_loss_l1 + swd_loss
SWD = swd_loss
return loss, mse_sum_loss, SWD


def compute_swd(z, p, reg_weight, latent_dim, num_projections, proj_dist):
prior_z = torch.randn_like(z) # [N x D]
device = z.device

proj_matrix = (
get_random_projections(proj_dist, latent_dim, num_samples=num_projections)
.transpose(0, 1)
.to(device)
)

latent_projections = z.matmul(proj_matrix) # [N x S]
prior_projections = prior_z.matmul(proj_matrix) # [N x S]

# The Wasserstein distance is computed by sorting the two projections
# across the batches and computing their element-wise distance
w_dist = (
torch.sort(latent_projections.t(), dim=1)[0]
- torch.sort(prior_projections.t(), dim=1)[0]
)
w_dist = w_dist.pow(p)
return reg_weight * w_dist.mean()


def get_random_projections(proj_dist, latent_dim, num_samples):
if proj_dist == "normal":
rand_samples = torch.randn(num_samples, latent_dim)
elif proj_dist == "cauchy":
rand_samples = (
dist.Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
.sample((num_samples, latent_dim))
.squeeze()
)
else:
raise ValueError("Unknown projection distribution.")

rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1, 1)
return rand_proj # [S x D]


def mse_loss_emd_l1(model_children, true_data, reconstructed_data, reg_param, validate):
"""
Computes a sparse loss function consisting of three terms: the Earth Mover's Distance (EMD) loss between the
Expand Down

0 comments on commit 7b283f6

Please sign in to comment.