From 77f80b14063b498ccc3d47ab906e825a736b05ac Mon Sep 17 00:00:00 2001 From: EmmaRenauld <emmanuelle.renauld@usherbrooke.ca> Date: Thu, 23 May 2024 11:12:01 -0400 Subject: [PATCH 1/2] Updates from new torch, new scilpy --- .../scil_score_ismrm_Renauld2023.sh | 2 +- dwi_ml/models/projects/transformer_models.py | 31 +- .../models/projects/transformer_sublayers.py | 304 ++++++++++++++++ .../models/utils/transformers_from_torch.py | 337 ++++++++---------- requirements.txt | 7 - 5 files changed, 469 insertions(+), 212 deletions(-) create mode 100644 dwi_ml/models/projects/transformer_sublayers.py diff --git a/bash_utilities/scil_score_ismrm_Renauld2023.sh b/bash_utilities/scil_score_ismrm_Renauld2023.sh index cd19cffc..15b6ba14 100644 --- a/bash_utilities/scil_score_ismrm_Renauld2023.sh +++ b/bash_utilities/scil_score_ismrm_Renauld2023.sh @@ -31,7 +31,7 @@ fi echo '------------- SEGMENTATION ------------' -scil_score_tractogram.py $tractogram $config_file_segmentation $out_dir --no_empty \ +scil_tractogram_segment_and_score.py $tractogram $config_file_segmentation $out_dir --no_empty \ --gt_dir $scoring_data --reference $ref --json_prefix tmp_ --no_bbox_check; echo '------------- Merging CC sub-bundles ------------' diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index 1ecc2664..e10609d4 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -574,8 +574,17 @@ def __init__(self, **kw): self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) + + # Receiving weird warning: enable_nested_tensor is True, + # but self.use_nested_tensor is False because encoder_layer.norm_first + # was True. + enable_nested = False if self.norm_first else True + + # Note about norm: this is a final normalization step. Not linked to + # the normalization decided with self.norm_first. self.modified_torch_transformer = ModifiedTransformerEncoder( - main_layer_encoder, self.n_layers_e, norm=None) + main_layer_encoder, self.n_layers_e, norm=None, + enable_nested_tensor=enable_nested) @property def d_model(self): @@ -613,7 +622,7 @@ def _run_main_layer_forward(self, inputs, masks, return_weights): # mask_future, mask_padding = masks outputs, sa_weights = self.modified_torch_transformer( src=inputs, mask=masks[0], src_key_padding_mask=masks[1], - return_weights=return_weights) + is_causal=True, return_weights=return_weights) return outputs, (sa_weights,) @@ -844,8 +853,17 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw): dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) - encoder = ModifiedTransformerEncoder(encoder_layer, self.n_layers_e, - norm=None) + + # Receiving weird warning: enable_nested_tensor is True, + # but self.use_nested_tensor is False because encoder_layer.norm_first + # was True. + enable_nested = False if self.norm_first else True + + # Note about norm: this is a final normalization step. Not linked to + # the normalization decided with self.norm_first. + encoder = ModifiedTransformerEncoder( + encoder_layer, self.n_layers_e, norm=None, + enable_nested_tensor=enable_nested) # Decoder decoder_layer = ModifiedTransformerDecoderLayer( @@ -908,7 +926,8 @@ def _run_main_layer_forward(self, data, masks, return_weights): src=data[0], tgt=data[1], src_mask=masks[0], tgt_mask=masks[0], memory_mask=masks[0], src_key_padding_mask=masks[1], tgt_key_padding_mask=masks[1], - memory_key_padding_mask=masks[1], + memory_key_padding_mask=masks[1], src_is_causal=True, + tgt_is_causal=True, memory_is_causal=True, return_weights=return_weights) return outputs, (sa_weights_encoder, sa_weights_decoder, mha_weights) @@ -989,7 +1008,7 @@ def _run_main_layer_forward(self, concat_s_t, masks, return_weights): # mask_future, mask_padding = masks outputs, sa_weights = self.modified_torch_transformer( src=concat_s_t, mask=masks[0], src_key_padding_mask=masks[1], - return_weights=return_weights) + is_causal=True, return_weights=return_weights) return outputs, (sa_weights,) diff --git a/dwi_ml/models/projects/transformer_sublayers.py b/dwi_ml/models/projects/transformer_sublayers.py new file mode 100644 index 00000000..e2c65731 --- /dev/null +++ b/dwi_ml/models/projects/transformer_sublayers.py @@ -0,0 +1,304 @@ +""" +Child classes of Torch Transformers. Changes are: + +- EncoderLayer: Idem +- DecoderLayer: Idem + +""" +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import (TransformerDecoderLayer, TransformerEncoderLayer, + MultiheadAttention, Parameter) + +logger = logging.getLogger('model_logger') + + +def do_not_share_linear_weights(attn: MultiheadAttention, d_model): + """ + I added a request for this parameter to be accessible. + https://github.com/pytorch/pytorch/issues/92990 + + Copied from MultiheadAttention's init method + """ + + factory_kwargs = {'device': None, 'dtype': None} + + # Overriding some parameters in the self attention. + # Ugly but.... Torch does not have a parameter to NOT share linear + # weights. In their code, their only NOT share weights when dimensions + # are not the same. This is not our case. This is saved in their + # parameter _qkv_same_embed_dim. By changing this, we change their + # forward call to the MultiHeadAttention in self.self_attn. + attn._qkv_same_embed_dim = False + attn.q_proj_weight = Parameter( + torch.empty((d_model, d_model), **factory_kwargs)) + attn.k_proj_weight = Parameter( + torch.empty((d_model, d_model), **factory_kwargs)) + attn.v_proj_weight = Parameter( + torch.empty((d_model, d_model), **factory_kwargs)) + attn.register_parameter('in_proj_weight', None) + attn._reset_parameters() + + +class ModifiedTransformerEncoderLayer(TransformerEncoderLayer): + def __init__(self, d_model, nhead, **kw): + super().__init__(d_model, nhead, **kw) + + do_not_share_linear_weights(self.self_attn, d_model) + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Now returns weights. + """ + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(src_mask), + other_name="src_mask", + target_type=src.dtype + ) + + src_mask = F._canonical_mask( + mask=src_mask, + mask_name="src_mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + why_not_sparsity_fast_path = '' + if not src.dim() == 3: + why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + elif self.training: + why_not_sparsity_fast_path = "training is enabled" + elif not self.self_attn.batch_first: + why_not_sparsity_fast_path = "self_attn.batch_first was not True" + elif not self.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" + elif not self.activation_relu_or_gelu: + why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" + elif not (self.norm1.eps == self.norm2.eps): + why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" + elif src.is_nested and ( + src_key_padding_mask is not None or src_mask is not None): + why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" + elif self.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = "num_head is odd" + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + ) + + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + _supported_device_type = ["cpu", "cuda", + torch.utils.backend_registration._privateuse1_backend_name] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not all((x.device.type in _supported_device_type) for x in + tensor_args): + why_not_sparsity_fast_path = ( + "some Tensor argument's device is neither one of " + f"{_supported_device_type}") + elif torch.is_grad_enabled() and any( + x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + + if not why_not_sparsity_fast_path: + merged_mask, mask_type = self.self_attn.merge_masks(src_mask, + src_key_padding_mask, + src) + # MODIFIED: + if return_weights: + raise NotImplementedError( + "Did not expect to reach here. Not ready to return " + "weights. Please contact dwi_ml developpers") + return torch._transformer_encoder_layer_fwd( + src, + self.self_attn.embed_dim, + self.self_attn.num_heads, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.activation_relu_or_gelu == 2, + self.norm_first, + self.norm1.eps, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + merged_mask, + mask_type, + ) + + x = src + if self.norm_first: + # Norm, SA, Add, Norm, FF, Add + sa, sa_weights = self._sa_block( + self.norm1(x), src_mask, src_key_padding_mask, + is_causal=is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = x + sa + x = x + self._ff_block(self.norm2(x)) + else: + # SA, Add, Norm, FF, Add, Norm + sa, sa_weights = self._sa_block( + x, src_mask, src_key_padding_mask, is_causal=is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = self.norm1(x + sa) + x = self.norm2(x + self._ff_block(x)) + + return x, sa_weights + + # self-attention block + def _sa_block(self, x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + x, weights = self.self_attn( + x, x, x, + attn_mask=attn_mask, key_padding_mask=key_padding_mask, + is_causal=is_causal, + # Modified args: + need_weights=return_weights, average_attn_weights=average_heads) + + return self.dropout1(x), weights + + +class ModifiedTransformerDecoderLayer(TransformerDecoderLayer): + """ + Decoder Layer, in the case where we do not have a start of sequence (SOS) + token, and our mask contains only -inf for the first position. Output of + self-attention becomes nan after the softmax step. Setting to 0. + + Also, now returning attention weights. + """ + def __init__(self, d_model, nhead, **kw): + super().__init__(d_model, nhead, **kw) + + do_not_share_linear_weights(self.self_attn, d_model) + do_not_share_linear_weights(self.multihead_attn, d_model) + + def forward(self, tgt: Tensor, memory: Tensor, + tgt_mask: Tensor = None, memory_mask: Tensor = None, + tgt_key_padding_mask: Tensor = None, + memory_key_padding_mask: Tensor = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Now returns weights + converts nan to 0. + Weights are None if return_weights is False. + """ + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + x = tgt + if self.norm_first: + # Norm, SA, Add, Norm, MHA, Add, Norm, FF, Add + sa, sa_weights = self._sa_block( + self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = x + sa + + mha, mha_weights = self._mha_block( + self.norm2(x), memory, memory_mask, memory_key_padding_mask, + memory_is_causal, + # Nre args: + return_weights=return_weights, average_heads=average_heads) + x = x + mha + x = x + self._ff_block(self.norm3(x)) + else: + # SA, Add, Norm, MHA, Add, Norm, FF, Add, Norm. + sa, sa_weights = self._sa_block( + x, tgt_mask, tgt_key_padding_mask, tgt_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = self.norm1(x + sa) + + mha, mha_weights = self._mha_block( + x, memory, memory_mask, memory_key_padding_mask, + memory_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) + x = self.norm2(x + mha) + x = self.norm3(x + self._ff_block(x)) + + return x, mha_weights, sa_weights + + # self-attention block + def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Now returns weights. + """ + x, weights = self.self_attn( + x, x, x, + attn_mask=attn_mask, key_padding_mask=key_padding_mask, + is_causal=is_causal, + # Modified args: + need_weights=return_weights, average_attn_weights=average_heads) + + return self.dropout1(x), weights + + # multihead attention block + def _mha_block(self, x: Tensor, mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + # New args: + return_weights=False, average_heads=False): + """ + Copy-pasted from torch. Can now use need_weight = True. + """ + x = self.multihead_attn( + x, mem, mem, + attn_mask=attn_mask, key_padding_mask=key_padding_mask, + is_causal=is_causal, + # Modified args: + need_weights=return_weights, average_attn_weights=average_heads) + + if return_weights: + x, weights = x + else: + weights = None + + return self.dropout2(x[0]), weights diff --git a/dwi_ml/models/utils/transformers_from_torch.py b/dwi_ml/models/utils/transformers_from_torch.py index a9972c6b..7710b871 100644 --- a/dwi_ml/models/utils/transformers_from_torch.py +++ b/dwi_ml/models/utils/transformers_from_torch.py @@ -6,82 +6,129 @@ to decide if we want to share the linear weights for Q, K, V. - Encoder: Idem - Decoder: Idem -- EncoderLayer: Idem -- DecoderLayer: Idem """ import logging from typing import Optional import torch +import torch.nn.functional as F from torch import Tensor -from torch.nn import (Transformer, - TransformerDecoderLayer, TransformerDecoder, - TransformerEncoderLayer, TransformerEncoder, - MultiheadAttention, Parameter) +from torch.nn import Transformer, TransformerDecoder, TransformerEncoder +from torch.nn.modules.transformer import _get_seq_len, _detect_is_causal_mask from dwi_ml.experiment_utils.memory import log_gpu_memory_usage +from dwi_ml.models.projects.transformer_sublayers import \ + ModifiedTransformerDecoderLayer, ModifiedTransformerEncoderLayer logger = logging.getLogger('model_logger') -class ModifiedTransformer(Transformer): - def __init__(self, *args, **kw): - super().__init__(*args, **kw) - - def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor = None, - tgt_mask: Tensor = None, memory_mask: Tensor = None, - src_key_padding_mask: Tensor = None, - tgt_key_padding_mask: Tensor = None, - memory_key_padding_mask: Tensor = None, - return_weights=False, average_heads=False): - """ - Copy-pasted from torch. Now returns weights. - """ - logger.debug("Entering main Transformer's forward.") - log_gpu_memory_usage(logger) - memory, sa_weights_encoder = self.encoder( - src, mask=src_mask, src_key_padding_mask=src_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - - output, sa_weights_decoder, mha_weights = self.decoder( - tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - - return output, sa_weights_encoder, sa_weights_decoder, mha_weights - - class ModifiedTransformerEncoder(TransformerEncoder): - def __init__(self, *args, **kw): - super().__init__(*args, **kw) + def __init__(self, encoder_layer, *args, **kw): + if not isinstance(encoder_layer, ModifiedTransformerEncoderLayer): + raise ValueError("Encoder layer should be of type {}. Got {}" + .format(ModifiedTransformerEncoderLayer.__name__, + type(encoder_layer))) + super().__init__(encoder_layer, *args, **kw) def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, + # New args: return_weights=False, average_heads=False): """ Copy-pasted from torch. Now returns weights. - Erased all the fast-path check: it is not used anyway if it is - training, and if we supply both src_key_padding_mask and mask, which is - our case. - Layers must be TransformerEncoderLayerGetWeights layers. """ - if src_key_padding_mask is not None: - _skpm_dtype = src_key_padding_mask.dtype - if _skpm_dtype != torch.bool and not \ - torch.is_floating_point(src_key_padding_mask): - raise AssertionError("only bool and floating types of " - "key_padding_mask are supported") + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(mask), + other_name="mask", + target_type=src.dtype + ) + + mask = F._canonical_mask( + mask=mask, + mask_name="mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + output = src + convert_to_nested = False + first_layer = self.layers[0] src_key_padding_mask_for_layers = src_key_padding_mask + why_not_sparsity_fast_path = '' + str_first_layer = "self.layers[0]" + batch_first = first_layer.self_attn.batch_first + if not hasattr(self, "use_nested_tensor"): + why_not_sparsity_fast_path = "use_nested_tensor attribute not present" + elif not self.use_nested_tensor: + why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True" + elif first_layer.training: + why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" + elif not src.dim() == 3: + why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + elif src_key_padding_mask is None: + why_not_sparsity_fast_path = "src_key_padding_mask was None" + elif (((not hasattr(self, "mask_check")) or self.mask_check) + and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())): + why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" + elif output.is_nested: + why_not_sparsity_fast_path = "NestedTensor input is not supported" + elif mask is not None: + why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied" + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + first_layer.self_attn.in_proj_weight, + first_layer.self_attn.in_proj_bias, + first_layer.self_attn.out_proj.weight, + first_layer.self_attn.out_proj.bias, + first_layer.norm1.weight, + first_layer.norm1.bias, + first_layer.norm2.weight, + first_layer.norm2.bias, + first_layer.linear1.weight, + first_layer.linear1.bias, + first_layer.linear2.weight, + first_layer.linear2.bias, + ) + _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif src.device.type not in _supported_device_type: + why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}" + elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + + if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): + convert_to_nested = True + output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False) + src_key_padding_mask_for_layers = None + + seq_len = _get_seq_len(src, batch_first) + is_causal = _detect_is_causal_mask(mask, is_causal, seq_len) + + # THIS IS THE MODIFIED PART sa_weights = [None] * len(self.layers) - for mod, i in zip(self.layers, range(len(self.layers))): output, sa_weights[i] = mod( - output, src_mask=mask, + output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers, + # New args: return_weights=return_weights, average_heads=average_heads) + # END OF MODIFIED PART + + if convert_to_nested: + output = output.to_padded_tensor(0., src.size()) if self.norm is not None: output = self.norm(output) @@ -90,30 +137,45 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, class ModifiedTransformerDecoder(TransformerDecoder): - def __init__(self, *args, **kw): - super().__init__(*args, **kw) + + def __init__(self, decoder_layer, *args, **kw): + if not isinstance(decoder_layer, ModifiedTransformerDecoderLayer): + raise ValueError("Encoder layer should be of type {}. Got {}" + .format(ModifiedTransformerEncoderLayer.__name__, + type(decoder_layer))) + super().__init__(decoder_layer, *args, **kw) def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + # New args: return_weights=False, average_heads=False): """ Copy-pasted from torch. Now returns weights. - Layers must be TransformerDecoderLayerGetWeightsNoSOS layers. """ output = tgt + + seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) + tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) + + # THIS IS THE MODIFIED PART mha_weights = [None] * len(self.layers) sa_weights = [None] * len(self.layers) - for mod, i in zip(self.layers, range(len(self.layers))): output, mha_weights[i], sa_weights[i] = \ mod(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + # New args: return_weights=return_weights, average_heads=average_heads) + # END OF MODIFIED PART if self.norm is not None: output = self.norm(output) @@ -121,160 +183,39 @@ def forward(self, tgt: Tensor, memory: Tensor, return output, sa_weights, mha_weights -def do_not_share_linear_weights(attn: MultiheadAttention, d_model): - """ - I added a request for this parameter to be accessible. - https://github.com/pytorch/pytorch/issues/92990 - """ - factory_kwargs = {'device': None, 'dtype': None} - - # Overriding some parameters in the self attention. - # Ugly but.... Torch does not have a parameter to NOT share linear - # weights. In their code, their only NOT share weights when dimensions - # are not the same. This is not our case. This is saved in their - # parameter _qkv_same_embed_dim. By changing this, we change their - # forward call to the MultiHeadAttention in self.self_attn. - attn._qkv_same_embed_dim = False - attn.q_proj_weight = Parameter( - torch.empty((d_model, d_model), **factory_kwargs)) - attn.k_proj_weight = Parameter( - torch.empty((d_model, d_model), **factory_kwargs)) - attn.v_proj_weight = Parameter( - torch.empty((d_model, d_model), **factory_kwargs)) - attn.register_parameter('in_proj_weight', None) - attn._reset_parameters() - - -class ModifiedTransformerEncoderLayer(TransformerEncoderLayer): - def __init__(self, d_model, nhead, **kw): - super().__init__(d_model, nhead, **kw) - - do_not_share_linear_weights(self.self_attn, d_model) - - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - return_weights=False, average_heads=False): - """ - Copy-pasted from torch. Now returns weights. - Erased all the fast-track checks. - """ - x = src - if self.norm_first: - # Norm, SA, Add, Norm, FF, Add - sa, sa_weights = self._sa_block( - self.norm1(x), src_mask, src_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = x + sa - x = x + self._ff_block(self.norm2(x)) - else: - # SA, Add, Norm, FF, Add, Norm - sa, sa_weights = self._sa_block( - x, src_mask, src_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = self.norm1(x + sa) - x = self.norm2(x + self._ff_block(x)) - - return x, sa_weights - - # self-attention block - def _sa_block(self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - return_weights=False, average_heads=False): - output = self.self_attn( - x, x, x, - attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=return_weights, average_attn_weights=average_heads) - x, weights = output # if return_weights is False, weights is None - - return self.dropout1(x), weights - - -class ModifiedTransformerDecoderLayer(TransformerDecoderLayer): - """ - Decoder Layer, in the case where we do not have a start of sequence (SOS) - token, and our mask contains only -inf for the first position. Output of - self-attention becomes nan after the softmax step. Setting to 0. - - Also, now returning attention weights. - """ - def __init__(self, d_model, nhead, **kw): - super().__init__(d_model, nhead, **kw) +class ModifiedTransformer(Transformer): + encoder: ModifiedTransformerEncoder + decoder: ModifiedTransformerDecoder - do_not_share_linear_weights(self.self_attn, d_model) - do_not_share_linear_weights(self.multihead_attn, d_model) + def __init__(self, *args, **kw): + super().__init__(*args, **kw) - def forward(self, tgt: Tensor, memory: Tensor, + def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor = None, tgt_mask: Tensor = None, memory_mask: Tensor = None, + src_key_padding_mask: Tensor = None, tgt_key_padding_mask: Tensor = None, memory_key_padding_mask: Tensor = None, + src_is_causal: bool = None, tgt_is_causal: bool = None, + memory_is_causal: bool = False, + # New args: return_weights=False, average_heads=False): """ - Copy-pasted from torch. Now returns weights + converts nan to 0. - Weights are None if return_weights is False. - """ - # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf - x = tgt - if self.norm_first: - # Norm, SA, Add, Norm, MHA, Add, Norm, FF, Add - sa, sa_weights = self._sa_block( - self.norm1(x), tgt_mask, tgt_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = x + sa - - mha, mha_weights = self._mha_block( - self.norm2(x), memory, memory_mask, memory_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = x + mha - x = x + self._ff_block(self.norm3(x)) - else: - # SA, Add, Norm, MHA, Add, Norm, FF, Add, Norm. - sa, sa_weights = self._sa_block( - x, tgt_mask, tgt_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = self.norm1(x + sa) - - mha, mha_weights = self._mha_block( - x, memory, memory_mask, memory_key_padding_mask, - return_weights=return_weights, average_heads=average_heads) - x = self.norm2(x + mha) - x = self.norm3(x + self._ff_block(x)) - - return x, mha_weights, sa_weights - - # self-attention block - def _sa_block(self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - return_weights=False, average_heads=False): - """ Copy-pasted from torch. Now returns weights. """ - output = self.self_attn( - x, x, x, - attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=return_weights, average_attn_weights=average_heads) - - x, weights = output # If not return_weights, weights is None. - - return self.dropout1(x), weights - - # multihead attention block - def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - return_weights=False, average_heads=False): - """ - Copy-pasted from torch. Can now use need_weight = True. - """ - output = self.multihead_attn( - x, mem, mem, - attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=return_weights, average_attn_weights=average_heads) + logger.debug("Entering main Transformer's forward.") + log_gpu_memory_usage(logger) + memory, sa_weights_encoder = self.encoder( + src, mask=src_mask, src_key_padding_mask=src_key_padding_mask, + is_causal=src_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) - if return_weights: - x, weights = output - else: - x, weights = output, None + output, sa_weights_decoder, mha_weights = self.decoder( + tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal, + # New args: + return_weights=return_weights, average_heads=average_heads) - return self.dropout2(x[0]), weights + return output, sa_weights_encoder, sa_weights_decoder, mha_weights diff --git a/requirements.txt b/requirements.txt index d492885c..7739110b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,10 +34,3 @@ nibabel==5.2.* numpy==1.23.* scipy==1.9.* scikit-image==0.22.* - - -# --------------- Notes to developers -# If we upgrade torch, verify if code copied in -# models.projects.transformers_from_torch has changed. -# (current code copied from torch 1.13.1) -# ---------- From c809fbbe835e7d9ef31bf682f1c19f8aa4f7b21c Mon Sep 17 00:00:00 2001 From: EmmaRenauld <emmanuelle.renauld@usherbrooke.ca> Date: Thu, 23 May 2024 11:32:43 -0400 Subject: [PATCH 2/2] Use scilpy's add_verbose_arg --- dwi_ml/io_utils.py | 9 --------- dwi_ml/testing/projects/tt_visu_argparser.py | 6 +++--- dwi_ml/testing/visu_loss_utils.py | 4 ++-- scripts_python/dwiml_create_hdf5_dataset.py | 5 ++--- scripts_python/dwiml_visualize_logs.py | 4 ++-- scripts_python/dwiml_visualize_logs_correlation.py | 4 +--- .../l2t_resume_training_from_checkpoint.py | 3 ++- scripts_python/l2t_track_from_model.py | 4 +--- scripts_python/l2t_train_model.py | 5 +++-- scripts_python/l2t_update_deprecated_exp.py | 3 +-- .../tt_resume_training_from_checkpoint.py | 4 +++- scripts_python/tt_track_from_model.py | 4 ++-- scripts_python/tt_train_model.py | 14 ++++++++------ 13 files changed, 30 insertions(+), 39 deletions(-) diff --git a/dwi_ml/io_utils.py b/dwi_ml/io_utils.py index 4f969b15..b16e5baf 100644 --- a/dwi_ml/io_utils.py +++ b/dwi_ml/io_utils.py @@ -5,15 +5,6 @@ from scilpy.io.utils import add_processes_arg -def add_verbose_arg(p): - # Can eventually become scilpy.io.utils.add_verbose_arg - p.add_argument('-v', default="WARNING", const='INFO', nargs='?', - choices=['DEBUG', 'INFO', 'WARNING'], dest='verbose', - help='Produces verbose output depending on ' - 'the provided level. \nDefault level is warning, ' - 'default when using -v is info.') - - def add_resample_or_compress_arg(p: ArgumentParser): g = p.add_mutually_exclusive_group() g.add_argument( diff --git a/dwi_ml/testing/projects/tt_visu_argparser.py b/dwi_ml/testing/projects/tt_visu_argparser.py index cde26c57..9aadbac4 100644 --- a/dwi_ml/testing/projects/tt_visu_argparser.py +++ b/dwi_ml/testing/projects/tt_visu_argparser.py @@ -51,10 +51,10 @@ """ import argparse -from scilpy.io.utils import (add_overwrite_arg, add_reference_arg) +from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, + add_verbose_arg) -from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_verbose_arg, add_memory_args) +from dwi_ml.io_utils import add_arg_existing_experiment_path, add_memory_args from dwi_ml.testing.utils import add_args_testing_subj_hdf5 diff --git a/dwi_ml/testing/visu_loss_utils.py b/dwi_ml/testing/visu_loss_utils.py index ef81a594..9db78d52 100644 --- a/dwi_ml/testing/visu_loss_utils.py +++ b/dwi_ml/testing/visu_loss_utils.py @@ -4,11 +4,11 @@ import os.path from argparse import ArgumentParser -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, add_reference_arg, ranged_type) -from dwi_ml.io_utils import add_memory_args, add_verbose_arg +from dwi_ml.io_utils import add_memory_args def prepare_args_visu_loss(p: ArgumentParser): diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index c36a14e7..7266cd15 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -25,8 +25,8 @@ import shutil from pathlib import Path -from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, - assert_outputs_exist) +from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, + assert_inputs_exist, assert_outputs_exist) from dipy.io.stateful_tractogram import set_sft_logger_level @@ -34,7 +34,6 @@ from dwi_ml.data.hdf5.utils import ( add_hdf5_creation_args, add_streamline_processing_args) from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_verbose_arg def _initialize_intermediate_subdir(hdf5_file, save_intermediate): diff --git a/scripts_python/dwiml_visualize_logs.py b/scripts_python/dwiml_visualize_logs.py index f218e4a0..e8807748 100644 --- a/scripts_python/dwiml_visualize_logs.py +++ b/scripts_python/dwiml_visualize_logs.py @@ -36,9 +36,9 @@ import matplotlib.pyplot as plt import numpy as np -from scilpy.io.utils import assert_outputs_exist, add_overwrite_arg +from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, + assert_outputs_exist) -from dwi_ml.io_utils import add_verbose_arg from dwi_ml.viz.logs_plots import visualize_logs diff --git a/scripts_python/dwiml_visualize_logs_correlation.py b/scripts_python/dwiml_visualize_logs_correlation.py index 78d92af6..b1ecfee2 100644 --- a/scripts_python/dwiml_visualize_logs_correlation.py +++ b/scripts_python/dwiml_visualize_logs_correlation.py @@ -13,9 +13,7 @@ import matplotlib.pyplot as plt import numpy as np -from scilpy.io.utils import add_overwrite_arg - -from dwi_ml.io_utils import add_verbose_arg +from scilpy.io.utils import add_overwrite_arg, add_verbose_arg def _build_arg_parser(): diff --git a/scripts_python/l2t_resume_training_from_checkpoint.py b/scripts_python/l2t_resume_training_from_checkpoint.py index 21573424..15739413 100644 --- a/scripts_python/l2t_resume_training_from_checkpoint.py +++ b/scripts_python/l2t_resume_training_from_checkpoint.py @@ -9,9 +9,10 @@ # Importing now to solve issues later. import comet_ml +from scilpy.io.utils import add_verbose_arg + from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_verbose_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler diff --git a/scripts_python/l2t_track_from_model.py b/scripts_python/l2t_track_from_model.py index eb0891de..92f65831 100644 --- a/scripts_python/l2t_track_from_model.py +++ b/scripts_python/l2t_track_from_model.py @@ -12,9 +12,8 @@ from dipy.io.utils import is_header_compatible import h5py import nibabel as nib -import torch -from scilpy.io.utils import (add_sphere_arg, +from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, verify_compression_th) from scilpy.tracking.utils import (add_seeding_options, @@ -23,7 +22,6 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_verbose_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.testing.utils import prepare_dataset_one_subj, \ find_hdf5_associated_to_experiment diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py index eb1bdf36..68fca3df 100755 --- a/scripts_python/l2t_train_model.py +++ b/scripts_python/l2t_train_model.py @@ -14,12 +14,13 @@ import comet_ml import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist +from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist, + assert_outputs_exist) from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_verbose_arg, add_memory_args +from dwi_ml.io_utils import add_memory_args from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.models.projects.learn2track_utils import add_model_args from dwi_ml.models.utils.direction_getters import check_args_direction_getter diff --git a/scripts_python/l2t_update_deprecated_exp.py b/scripts_python/l2t_update_deprecated_exp.py index 1f3304dd..a62edce7 100644 --- a/scripts_python/l2t_update_deprecated_exp.py +++ b/scripts_python/l2t_update_deprecated_exp.py @@ -12,11 +12,10 @@ import numpy as np import torch -from scilpy.io.utils import add_overwrite_arg +from scilpy.io.utils import add_overwrite_arg, add_verbose_arg from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.io_utils import add_verbose_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler diff --git a/scripts_python/tt_resume_training_from_checkpoint.py b/scripts_python/tt_resume_training_from_checkpoint.py index 191d0877..b7af9ce6 100644 --- a/scripts_python/tt_resume_training_from_checkpoint.py +++ b/scripts_python/tt_resume_training_from_checkpoint.py @@ -9,9 +9,11 @@ # Importing now to solve issues later. import comet_ml +from scilpy.io.utils import add_verbose_arg + from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path +from dwi_ml.io_utils import verify_which_model_in_path from dwi_ml.models.projects.transformer_models import find_transformer_class from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler diff --git a/scripts_python/tt_track_from_model.py b/scripts_python/tt_track_from_model.py index a579cd8f..8ee0cb10 100644 --- a/scripts_python/tt_track_from_model.py +++ b/scripts_python/tt_track_from_model.py @@ -14,7 +14,7 @@ import h5py import nibabel as nib -from scilpy.io.utils import (add_sphere_arg, +from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, verify_compression_th) from scilpy.tracking.utils import (add_seeding_options, @@ -23,7 +23,7 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path +from dwi_ml.io_utils import verify_which_model_in_path from dwi_ml.models.projects.transformer_models import find_transformer_class from dwi_ml.testing.utils import prepare_dataset_one_subj, \ find_hdf5_associated_to_experiment diff --git a/scripts_python/tt_train_model.py b/scripts_python/tt_train_model.py index 00ce15ee..c8ba06da 100755 --- a/scripts_python/tt_train_model.py +++ b/scripts_python/tt_train_model.py @@ -14,14 +14,16 @@ import comet_ml import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist +from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist, + assert_outputs_exist) from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_memory_args, add_verbose_arg -from dwi_ml.models.projects.transformer_models import \ - OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel +from dwi_ml.io_utils import add_memory_args +from dwi_ml.models.projects.transformer_models import ( + OriginalTransformerModel, TransformerSrcAndTgtModel, + TransformerSrcOnlyModel) from dwi_ml.models.projects.transformers_utils import ( add_transformers_model_args) from dwi_ml.models.utils.direction_getters import check_args_direction_getter @@ -32,8 +34,8 @@ prepare_batch_loader) from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) -from dwi_ml.training.utils.trainer import add_training_args, run_experiment, \ - format_lr +from dwi_ml.training.utils.trainer import (add_training_args, run_experiment, + format_lr) def prepare_arg_parser():