Skip to content

Commit

Permalink
Merge pull request #26 from florencejt/merge/unlistingoutputs
Browse files Browse the repository at this point in the history
Merge/unlistingoutputs
  • Loading branch information
florencejt authored May 15, 2024
2 parents ea4db6c + 7c9bca7 commit b7df77f
Show file tree
Hide file tree
Showing 28 changed files with 654 additions and 621 deletions.
80 changes: 57 additions & 23 deletions fusilli/fusionmodels/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def __init__(self, model, metrics_list=None):
}

if self.model.prediction_task not in ["binary", "multiclass", "regression"]:
raise ValueError(f"Unsupported prediction_task: {self.model.prediction_task}")
raise ValueError(
f"Unsupported prediction_task: {self.model.prediction_task}"
)

self.metric_names_list = list(self.metrics.keys())

Expand Down Expand Up @@ -181,14 +183,20 @@ def set_metrics(self, metrics_list):
# If the list is None, use the default metrics
if metrics_list is None:
if self.model.prediction_task == "binary":
self.metrics = {"AUROC": self.MetricsCalculator.auroc,
"Accuracy": self.MetricsCalculator.accuracy}
self.metrics = {
"AUROC": self.MetricsCalculator.auroc,
"Accuracy": self.MetricsCalculator.accuracy,
}
elif self.model.prediction_task == "multiclass":
self.metrics = {"AUROC": self.MetricsCalculator.auroc,
"Accuracy": self.MetricsCalculator.accuracy}
self.metrics = {
"AUROC": self.MetricsCalculator.auroc,
"Accuracy": self.MetricsCalculator.accuracy,
}
elif self.model.prediction_task == "regression":
self.metrics = {"R2": self.MetricsCalculator.r2,
"MAE": self.MetricsCalculator.mae}
self.metrics = {
"R2": self.MetricsCalculator.r2,
"MAE": self.MetricsCalculator.mae,
}

# Error if list length is less than 2
else:
Expand All @@ -199,13 +207,23 @@ def set_metrics(self, metrics_list):

# Error if any of the metrics are not supported
for metric_string in metrics_list:
supported_metrics = [func for func in dir(self.MetricsCalculator) if
callable(getattr(self.MetricsCalculator, func)) and not func.startswith("__")]
if metric_string.lower() not in supported_metrics: # change this to be accurate
raise ValueError(f"Unsupported metric: {metric_string}. Please choose from: {supported_metrics}")
supported_metrics = [
func
for func in dir(self.MetricsCalculator)
if callable(getattr(self.MetricsCalculator, func))
and not func.startswith("__")
]
if (
metric_string.lower() not in supported_metrics
): # change this to be accurate
raise ValueError(
f"Unsupported metric: {metric_string}. Please choose from: {supported_metrics}"
)

# Set the new metrics
self.metrics[metric_string] = getattr(self.MetricsCalculator, metric_string.lower())
self.metrics[metric_string] = getattr(
self.MetricsCalculator, metric_string.lower()
)

def get_data_from_batch(self, batch):
"""
Expand Down Expand Up @@ -255,19 +273,28 @@ def get_model_outputs(self, x):
-------
logits : tensor
Logits.
reconstructions : tensor
Reconstructions (returned if the model has a custom loss function such as a subspace method)
reconstructions : tensor or None
Reconstructions (returned if the model has a custom loss function such as a subspace method). None if not provided.
Note
----
if you get an error here, check that the forward output in fusion model is [out,] or [out, reconstructions]
If you get an error here, check that the forward output in fusion model is [out,] or [out, reconstructions].
"""
model_outputs = self.model(x)

logits, *reconstructions = model_outputs
logits = logits.squeeze(dim=1)
# changing for shap implementation
if isinstance(x, tuple) and self.model.fusion_type != "graph":
x1, x2 = x
model_outputs = self.model(x1, x2)
else:
model_outputs = self.model(x)

return logits, reconstructions
if isinstance(model_outputs, list):
logits, *reconstructions = model_outputs
logits = logits.squeeze(dim=1)
return logits, reconstructions[0] if reconstructions else []
else:
logits = model_outputs.squeeze(dim=1)
return logits, []

def get_model_outputs_and_loss(self, x, y, train=True):
"""
Expand All @@ -293,7 +320,9 @@ def get_model_outputs_and_loss(self, x, y, train=True):
"""
logits, reconstructions = self.get_model_outputs(x)

end_output = self.output_activation_functions[self.model.prediction_task](logits)
end_output = self.output_activation_functions[self.model.prediction_task](
logits
)

# if we're doing graph-based fusion and train/test doesn't work the same as normal
if hasattr(self, "train_mask"):
Expand All @@ -309,8 +338,9 @@ def get_model_outputs_and_loss(self, x, y, train=True):
loss = self.loss_functions[self.model.prediction_task](logits, y)

if reconstructions != [] and self.model.custom_loss is not None:
# changing reconstructions[0] to just reconstructions after changing the model inputs from tuple to two tensors
added_loss = self.model.custom_loss(
reconstructions[0], x[-1]
reconstructions, x[-1]
) # x[-1] bc img is always last

loss += added_loss
Expand Down Expand Up @@ -349,7 +379,9 @@ def training_step(self, batch, batch_idx):
)

for metric_name, metric_func in self.metrics.items():
if (self.safe_squeeze(end_output).shape[0] == 1) or (self.safe_squeeze(logits).shape[0] == 1):
if (self.safe_squeeze(end_output).shape[0] == 1) or (
self.safe_squeeze(logits).shape[0] == 1
):
# if it's a single value, we can't calculate a metric
pass

Expand Down Expand Up @@ -435,7 +467,9 @@ def on_validation_epoch_end(self):
try:
self.train_reals = torch.cat(self.batch_train_reals, dim=-1)
self.train_preds = torch.cat(self.batch_train_preds, dim=-1)
except RuntimeError: # if we're doing graph-based fusion and train/test doesn't work the same as normal
except (
RuntimeError
): # if we're doing graph-based fusion and train/test doesn't work the same as normal
pass

for metric_name, metric_func in self.metrics.items():
Expand Down
45 changes: 23 additions & 22 deletions fusilli/fusionmodels/tabularfusion/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

class ActivationFusion(ParentFusionModel, nn.Module):
"""
Performs an element wise product of the feature maps of the two tabular modalities,
Performs an element wise product of the feature maps of the two tabular modalities,
tanh activation function and sigmoid activation function. Afterwards the the first tabular modality feature
map is concatenated with the fused feature map.
map is concatenated with the fused feature map.
Attributes
----------
Expand Down Expand Up @@ -55,7 +55,9 @@ def __init__(self, prediction_task, data_dims, multiclass_dimensions):
multiclass_dimensions : int
Number of classes in the multiclass classification problem.
"""
ParentFusionModel.__init__(self, prediction_task, data_dims, multiclass_dimensions)
ParentFusionModel.__init__(
self, prediction_task, data_dims, multiclass_dimensions
)

self.prediction_task = prediction_task

Expand Down Expand Up @@ -94,7 +96,8 @@ def calc_fused_layers(self):
mod2_output_dim = list(self.mod2_layers.values())[-1][0].out_features
if mod1_output_dim != mod2_output_dim:
raise UserWarning(
"The number of output features of mod1_layers and mod2_layers must be the same for Activation fusion. Please change the final layers in the modality layers to have the same number of output features as each other.")
"The number of output features of mod1_layers and mod2_layers must be the same for Activation fusion. Please change the final layers in the modality layers to have the same number of output features as each other."
)

self.get_fused_dim()
self.fused_layers, out_dim = check_model_validity.check_fused_layers(
Expand All @@ -104,50 +107,48 @@ def calc_fused_layers(self):
# setting final prediction layers with final out features of fused layers
self.set_final_pred_layers(out_dim)

def forward(self, x):
def forward(self, x1, x2):
"""
Forward pass of the model.
Parameters
----------
x : tuple
Tuple containing the input data.
x1 : torch.Tensor
Input tensor of the first modality.
x2 : torch.Tensor
Input tensor of the second modality.
Returns
-------
list
List containing the output of the model.
out : torch.Tensor
Fused prediction.
"""

# ~~ Checks ~~
check_model_validity.check_model_input(x)

x_tab1 = x[0]
x_tab2 = x[1]
check_model_validity.check_model_input(x1)
check_model_validity.check_model_input(x2)

for layer in self.mod1_layers.values():
x_tab1 = layer(x_tab1)
x1 = layer(x1)

for layer in self.mod2_layers.values():
x_tab2 = layer(x_tab2)
x2 = layer(x2)

x_tab1 = torch.squeeze(x_tab1, 1)
x_tab2 = torch.squeeze(x_tab2, 1)
x1 = torch.squeeze(x1, 1)
x2 = torch.squeeze(x2, 1)

out_fuse = torch.mul(x_tab1, x_tab2)
out_fuse = torch.mul(x1, x2)

out_fuse = torch.tanh(out_fuse)
out_fuse = torch.sigmoid(out_fuse)

out_fuse = torch.cat((out_fuse, x_tab1), dim=1)
out_fuse = torch.cat((out_fuse, x1), dim=1)

out_fuse = self.fused_layers(out_fuse)

out = self.final_prediction(out_fuse)

return [
out,
]
return out


"""
Expand Down
42 changes: 24 additions & 18 deletions fusilli/fusionmodels/tabularfusion/attention_and_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
class AttentionAndSelfActivation(ParentFusionModel, nn.Module):
"""
Applies an attention mechanism on the second tabular modality features and performs an element wise product of
the feature maps of the two tabular modalities,
the feature maps of the two tabular modalities,
tanh activation function and sigmoid activation function. Afterwards the the first tabular modality feature
map is concatenated with the fused feature map.
map is concatenated with the fused feature map.
Attributes
----------
Expand Down Expand Up @@ -59,7 +59,9 @@ def __init__(self, prediction_task, data_dims, multiclass_dimensions):
multiclass_dimensions : int
Number of classes in the multiclass classification task.
"""
ParentFusionModel.__init__(self, prediction_task, data_dims, multiclass_dimensions)
ParentFusionModel.__init__(
self, prediction_task, data_dims, multiclass_dimensions
)

self.prediction_task = prediction_task

Expand Down Expand Up @@ -100,7 +102,8 @@ def calc_fused_layers(self):
mod2_output_dim = list(self.mod2_layers.values())[-1][0].out_features
if mod1_output_dim != mod2_output_dim:
raise UserWarning(
"The number of output features of mod1_layers and mod2_layers must be the same for ActivationandSelfAttention. Please change the final layers in the modality layers to have the same number of output features as each other.")
"The number of output features of mod1_layers and mod2_layers must be the same for ActivationandSelfAttention. Please change the final layers in the modality layers to have the same number of output features as each other."
)

self.get_fused_dim()
self.fused_layers, out_dim = check_model_validity.check_fused_layers(
Expand All @@ -110,32 +113,36 @@ def calc_fused_layers(self):
# setting final prediction layers with final out features of fused layers
self.set_final_pred_layers(out_dim)

def forward(self, x):
def forward(self, x1, x2):
"""
Forward pass of the model.
Parameters
----------
x : tuple
Tuple containing the input data.
x1 : torch.Tensor
Input tensor of the first modality.
x2 : torch.Tensor
Input tensor of the second modality.
Returns
-------
list
List containing the output of the model.
torch.Tensor
Output tensor.
"""

# ~~ Checks ~~
check_model_validity.check_model_input(x)
check_model_validity.check_model_input(x1)
check_model_validity.check_model_input(x2)

x_tab1 = x[0]
x_tab2 = x[1]
x_tab1 = x1
x_tab2 = x2

num_channels = x_tab2.size(1)

# Channel attention
channel_attention = ChannelAttentionModule(num_features=num_channels,
reduction_ratio=self.attention_reduction_ratio)
channel_attention = ChannelAttentionModule(
num_features=num_channels, reduction_ratio=self.attention_reduction_ratio
)
x_tab2 = channel_attention(x_tab2)

for layer in self.mod1_layers.values():
Expand All @@ -158,9 +165,7 @@ def forward(self, x):

out = self.final_prediction(out_fuse)

return [
out,
]
return out


class ChannelAttentionModule(nn.Module):
Expand Down Expand Up @@ -193,7 +198,8 @@ def __init__(self, num_features, reduction_ratio=16):

if num_features // reduction_ratio < 1:
raise UserWarning(
"first tabular modality dimensions // attention_reduction_ratio < 1. This will cause an error in the model.")
"first tabular modality dimensions // attention_reduction_ratio < 1. This will cause an error in the model."
)

self.fc1 = nn.Linear(num_features, num_features // reduction_ratio, bias=False)
self.relu = nn.ReLU()
Expand Down
Loading

0 comments on commit b7df77f

Please sign in to comment.