From 77f80b14063b498ccc3d47ab906e825a736b05ac Mon Sep 17 00:00:00 2001
From: EmmaRenauld <emmanuelle.renauld@usherbrooke.ca>
Date: Thu, 23 May 2024 11:12:01 -0400
Subject: [PATCH 1/2] Updates from new torch, new scilpy

---
 .../scil_score_ismrm_Renauld2023.sh           |   2 +-
 dwi_ml/models/projects/transformer_models.py  |  31 +-
 .../models/projects/transformer_sublayers.py  | 304 ++++++++++++++++
 .../models/utils/transformers_from_torch.py   | 337 ++++++++----------
 requirements.txt                              |   7 -
 5 files changed, 469 insertions(+), 212 deletions(-)
 create mode 100644 dwi_ml/models/projects/transformer_sublayers.py

diff --git a/bash_utilities/scil_score_ismrm_Renauld2023.sh b/bash_utilities/scil_score_ismrm_Renauld2023.sh
index cd19cffc..15b6ba14 100644
--- a/bash_utilities/scil_score_ismrm_Renauld2023.sh
+++ b/bash_utilities/scil_score_ismrm_Renauld2023.sh
@@ -31,7 +31,7 @@ fi
 
 
 echo '------------- SEGMENTATION ------------'
-scil_score_tractogram.py $tractogram $config_file_segmentation $out_dir --no_empty \
+scil_tractogram_segment_and_score.py $tractogram $config_file_segmentation $out_dir --no_empty \
     --gt_dir $scoring_data --reference $ref --json_prefix tmp_ --no_bbox_check;
 
 echo '------------- Merging CC sub-bundles ------------'
diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py
index 1ecc2664..e10609d4 100644
--- a/dwi_ml/models/projects/transformer_models.py
+++ b/dwi_ml/models/projects/transformer_models.py
@@ -574,8 +574,17 @@ def __init__(self, **kw):
             self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size,
             dropout=self.dropout_rate, activation=self.activation,
             batch_first=True, norm_first=self.norm_first)
+
+        # Receiving weird warning: enable_nested_tensor is True,
+        # but self.use_nested_tensor is False because encoder_layer.norm_first
+        # was True.
+        enable_nested = False if self.norm_first else True
+
+        # Note about norm: this is a final normalization step. Not linked to
+        # the normalization decided with self.norm_first.
         self.modified_torch_transformer = ModifiedTransformerEncoder(
-            main_layer_encoder, self.n_layers_e, norm=None)
+            main_layer_encoder, self.n_layers_e, norm=None,
+            enable_nested_tensor=enable_nested)
 
     @property
     def d_model(self):
@@ -613,7 +622,7 @@ def _run_main_layer_forward(self, inputs, masks, return_weights):
         # mask_future, mask_padding = masks
         outputs, sa_weights = self.modified_torch_transformer(
             src=inputs, mask=masks[0], src_key_padding_mask=masks[1],
-            return_weights=return_weights)
+            is_causal=True, return_weights=return_weights)
 
         return outputs, (sa_weights,)
 
@@ -844,8 +853,17 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw):
             dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate,
             activation=self.activation, batch_first=True,
             norm_first=self.norm_first)
-        encoder = ModifiedTransformerEncoder(encoder_layer, self.n_layers_e,
-                                             norm=None)
+
+        # Receiving weird warning: enable_nested_tensor is True,
+        # but self.use_nested_tensor is False because encoder_layer.norm_first
+        # was True.
+        enable_nested = False if self.norm_first else True
+
+        # Note about norm: this is a final normalization step. Not linked to
+        # the normalization decided with self.norm_first.
+        encoder = ModifiedTransformerEncoder(
+            encoder_layer, self.n_layers_e, norm=None,
+            enable_nested_tensor=enable_nested)
 
         # Decoder
         decoder_layer = ModifiedTransformerDecoderLayer(
@@ -908,7 +926,8 @@ def _run_main_layer_forward(self, data, masks, return_weights):
                 src=data[0], tgt=data[1],
                 src_mask=masks[0], tgt_mask=masks[0], memory_mask=masks[0],
                 src_key_padding_mask=masks[1], tgt_key_padding_mask=masks[1],
-                memory_key_padding_mask=masks[1],
+                memory_key_padding_mask=masks[1], src_is_causal=True,
+                tgt_is_causal=True, memory_is_causal=True,
                 return_weights=return_weights)
         return outputs, (sa_weights_encoder, sa_weights_decoder, mha_weights)
 
@@ -989,7 +1008,7 @@ def _run_main_layer_forward(self, concat_s_t, masks, return_weights):
         # mask_future, mask_padding = masks
         outputs, sa_weights = self.modified_torch_transformer(
             src=concat_s_t, mask=masks[0], src_key_padding_mask=masks[1],
-            return_weights=return_weights)
+            is_causal=True, return_weights=return_weights)
 
         return outputs, (sa_weights,)
 
diff --git a/dwi_ml/models/projects/transformer_sublayers.py b/dwi_ml/models/projects/transformer_sublayers.py
new file mode 100644
index 00000000..e2c65731
--- /dev/null
+++ b/dwi_ml/models/projects/transformer_sublayers.py
@@ -0,0 +1,304 @@
+"""
+Child classes of Torch Transformers. Changes are:
+
+- EncoderLayer: Idem
+- DecoderLayer: Idem
+
+"""
+import logging
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import (TransformerDecoderLayer, TransformerEncoderLayer,
+                      MultiheadAttention, Parameter)
+
+logger = logging.getLogger('model_logger')
+
+
+def do_not_share_linear_weights(attn: MultiheadAttention, d_model):
+    """
+    I added a request for this parameter to be accessible.
+    https://github.com/pytorch/pytorch/issues/92990
+
+    Copied from MultiheadAttention's init method
+    """
+
+    factory_kwargs = {'device': None, 'dtype': None}
+
+    # Overriding some parameters in the self attention.
+    # Ugly but.... Torch does not have a parameter to NOT share linear
+    # weights. In their code, their only NOT share weights when dimensions
+    # are not the same. This is not our case. This is saved in their
+    # parameter _qkv_same_embed_dim. By changing this, we change their
+    # forward call to the MultiHeadAttention in self.self_attn.
+    attn._qkv_same_embed_dim = False
+    attn.q_proj_weight = Parameter(
+        torch.empty((d_model, d_model), **factory_kwargs))
+    attn.k_proj_weight = Parameter(
+        torch.empty((d_model, d_model), **factory_kwargs))
+    attn.v_proj_weight = Parameter(
+        torch.empty((d_model, d_model), **factory_kwargs))
+    attn.register_parameter('in_proj_weight', None)
+    attn._reset_parameters()
+
+
+class ModifiedTransformerEncoderLayer(TransformerEncoderLayer):
+    def __init__(self, d_model, nhead, **kw):
+        super().__init__(d_model, nhead, **kw)
+
+        do_not_share_linear_weights(self.self_attn, d_model)
+
+    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
+                src_key_padding_mask: Optional[Tensor] = None,
+                is_causal: bool = False,
+                # New args:
+                return_weights=False, average_heads=False):
+        """
+        Copy-pasted from torch. Now returns weights.
+        """
+        src_key_padding_mask = F._canonical_mask(
+            mask=src_key_padding_mask,
+            mask_name="src_key_padding_mask",
+            other_type=F._none_or_dtype(src_mask),
+            other_name="src_mask",
+            target_type=src.dtype
+        )
+
+        src_mask = F._canonical_mask(
+            mask=src_mask,
+            mask_name="src_mask",
+            other_type=None,
+            other_name="",
+            target_type=src.dtype,
+            check_other=False,
+        )
+
+        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
+        why_not_sparsity_fast_path = ''
+        if not src.dim() == 3:
+            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
+        elif self.training:
+            why_not_sparsity_fast_path = "training is enabled"
+        elif not self.self_attn.batch_first:
+            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
+        elif not self.self_attn._qkv_same_embed_dim:
+            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
+        elif not self.activation_relu_or_gelu:
+            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
+        elif not (self.norm1.eps == self.norm2.eps):
+            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
+        elif src.is_nested and (
+                src_key_padding_mask is not None or src_mask is not None):
+            why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
+        elif self.self_attn.num_heads % 2 == 1:
+            why_not_sparsity_fast_path = "num_head is odd"
+        elif torch.is_autocast_enabled():
+            why_not_sparsity_fast_path = "autocast is enabled"
+        if not why_not_sparsity_fast_path:
+            tensor_args = (
+                src,
+                self.self_attn.in_proj_weight,
+                self.self_attn.in_proj_bias,
+                self.self_attn.out_proj.weight,
+                self.self_attn.out_proj.bias,
+                self.norm1.weight,
+                self.norm1.bias,
+                self.norm2.weight,
+                self.norm2.bias,
+                self.linear1.weight,
+                self.linear1.bias,
+                self.linear2.weight,
+                self.linear2.bias,
+            )
+
+            # We have to use list comprehensions below because TorchScript does not support
+            # generator expressions.
+            _supported_device_type = ["cpu", "cuda",
+                                      torch.utils.backend_registration._privateuse1_backend_name]
+            if torch.overrides.has_torch_function(tensor_args):
+                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
+            elif not all((x.device.type in _supported_device_type) for x in
+                         tensor_args):
+                why_not_sparsity_fast_path = (
+                    "some Tensor argument's device is neither one of "
+                    f"{_supported_device_type}")
+            elif torch.is_grad_enabled() and any(
+                    x.requires_grad for x in tensor_args):
+                why_not_sparsity_fast_path = (
+                    "grad is enabled and at least one of query or the "
+                    "input/output projection weights or biases requires_grad")
+
+            if not why_not_sparsity_fast_path:
+                merged_mask, mask_type = self.self_attn.merge_masks(src_mask,
+                                                                    src_key_padding_mask,
+                                                                    src)
+                # MODIFIED:
+                if return_weights:
+                    raise NotImplementedError(
+                        "Did not expect to reach here. Not ready to return "
+                        "weights. Please contact dwi_ml developpers")
+                return torch._transformer_encoder_layer_fwd(
+                    src,
+                    self.self_attn.embed_dim,
+                    self.self_attn.num_heads,
+                    self.self_attn.in_proj_weight,
+                    self.self_attn.in_proj_bias,
+                    self.self_attn.out_proj.weight,
+                    self.self_attn.out_proj.bias,
+                    self.activation_relu_or_gelu == 2,
+                    self.norm_first,
+                    self.norm1.eps,
+                    self.norm1.weight,
+                    self.norm1.bias,
+                    self.norm2.weight,
+                    self.norm2.bias,
+                    self.linear1.weight,
+                    self.linear1.bias,
+                    self.linear2.weight,
+                    self.linear2.bias,
+                    merged_mask,
+                    mask_type,
+                )
+
+        x = src
+        if self.norm_first:
+            # Norm, SA, Add, Norm, FF, Add
+            sa, sa_weights = self._sa_block(
+                self.norm1(x), src_mask, src_key_padding_mask,
+                is_causal=is_causal,
+                # New args:
+                return_weights=return_weights, average_heads=average_heads)
+            x = x + sa
+            x = x + self._ff_block(self.norm2(x))
+        else:
+            # SA, Add, Norm, FF, Add, Norm
+            sa, sa_weights = self._sa_block(
+                x, src_mask, src_key_padding_mask, is_causal=is_causal,
+                # New args:
+                return_weights=return_weights, average_heads=average_heads)
+            x = self.norm1(x + sa)
+            x = self.norm2(x + self._ff_block(x))
+
+        return x, sa_weights
+
+    # self-attention block
+    def _sa_block(self, x: Tensor,
+                  attn_mask: Optional[Tensor],
+                  key_padding_mask: Optional[Tensor],
+                  is_causal: bool = False,
+                  # New args:
+                  return_weights=False, average_heads=False):
+        x, weights = self.self_attn(
+            x, x, x,
+            attn_mask=attn_mask, key_padding_mask=key_padding_mask,
+            is_causal=is_causal,
+            # Modified args:
+            need_weights=return_weights, average_attn_weights=average_heads)
+
+        return self.dropout1(x), weights
+
+
+class ModifiedTransformerDecoderLayer(TransformerDecoderLayer):
+    """
+    Decoder Layer, in the case where we do not have a start of sequence (SOS)
+    token, and our mask contains only -inf for the first position. Output of
+    self-attention becomes nan after the softmax step. Setting to 0.
+
+    Also, now returning attention weights.
+    """
+    def __init__(self, d_model, nhead, **kw):
+        super().__init__(d_model, nhead, **kw)
+
+        do_not_share_linear_weights(self.self_attn, d_model)
+        do_not_share_linear_weights(self.multihead_attn, d_model)
+
+    def forward(self, tgt: Tensor, memory: Tensor,
+                tgt_mask: Tensor = None, memory_mask: Tensor = None,
+                tgt_key_padding_mask: Tensor = None,
+                memory_key_padding_mask: Tensor = None,
+                tgt_is_causal: bool = False,
+                memory_is_causal: bool = False,
+                # New args:
+                return_weights=False, average_heads=False):
+        """
+        Copy-pasted from torch. Now returns weights + converts nan to 0.
+        Weights are None if return_weights is False.
+        """
+        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
+        x = tgt
+        if self.norm_first:
+            # Norm, SA, Add, Norm, MHA, Add, Norm, FF, Add
+            sa, sa_weights = self._sa_block(
+                self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal,
+                # New args:
+                return_weights=return_weights, average_heads=average_heads)
+            x = x + sa
+
+            mha, mha_weights = self._mha_block(
+                self.norm2(x), memory, memory_mask, memory_key_padding_mask,
+                memory_is_causal,
+                # Nre args:
+                return_weights=return_weights, average_heads=average_heads)
+            x = x + mha
+            x = x + self._ff_block(self.norm3(x))
+        else:
+            # SA, Add, Norm, MHA, Add, Norm, FF, Add, Norm.
+            sa, sa_weights = self._sa_block(
+                x, tgt_mask, tgt_key_padding_mask, tgt_is_causal,
+                # New args:
+                return_weights=return_weights, average_heads=average_heads)
+            x = self.norm1(x + sa)
+
+            mha, mha_weights = self._mha_block(
+                x, memory, memory_mask, memory_key_padding_mask,
+                memory_is_causal,
+                # New args:
+                return_weights=return_weights, average_heads=average_heads)
+            x = self.norm2(x + mha)
+            x = self.norm3(x + self._ff_block(x))
+
+        return x, mha_weights, sa_weights
+
+    # self-attention block
+    def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor],
+                  key_padding_mask: Optional[Tensor],
+                  is_causal: bool = False,
+                  # New args:
+                  return_weights=False, average_heads=False):
+        """
+        Copy-pasted from torch. Now returns weights.
+        """
+        x, weights = self.self_attn(
+            x, x, x,
+            attn_mask=attn_mask, key_padding_mask=key_padding_mask,
+            is_causal=is_causal,
+            # Modified args:
+            need_weights=return_weights, average_attn_weights=average_heads)
+
+        return self.dropout1(x), weights
+
+    # multihead attention block
+    def _mha_block(self, x: Tensor, mem: Tensor,
+                   attn_mask: Optional[Tensor],
+                   key_padding_mask: Optional[Tensor],
+                   is_causal: bool = False,
+                   # New args:
+                   return_weights=False, average_heads=False):
+        """
+        Copy-pasted from torch. Can now use need_weight = True.
+        """
+        x = self.multihead_attn(
+            x, mem, mem,
+            attn_mask=attn_mask, key_padding_mask=key_padding_mask,
+            is_causal=is_causal,
+            # Modified args:
+            need_weights=return_weights, average_attn_weights=average_heads)
+
+        if return_weights:
+            x, weights = x
+        else:
+            weights = None
+
+        return self.dropout2(x[0]), weights
diff --git a/dwi_ml/models/utils/transformers_from_torch.py b/dwi_ml/models/utils/transformers_from_torch.py
index a9972c6b..7710b871 100644
--- a/dwi_ml/models/utils/transformers_from_torch.py
+++ b/dwi_ml/models/utils/transformers_from_torch.py
@@ -6,82 +6,129 @@
   to decide if we want to share the linear weights for Q, K, V.
 - Encoder: Idem
 - Decoder: Idem
-- EncoderLayer: Idem
-- DecoderLayer: Idem
 
 """
 import logging
 from typing import Optional
 
 import torch
+import torch.nn.functional as F
 from torch import Tensor
-from torch.nn import (Transformer,
-                      TransformerDecoderLayer, TransformerDecoder,
-                      TransformerEncoderLayer, TransformerEncoder,
-                      MultiheadAttention, Parameter)
+from torch.nn import Transformer, TransformerDecoder, TransformerEncoder
+from torch.nn.modules.transformer import _get_seq_len, _detect_is_causal_mask
 
 from dwi_ml.experiment_utils.memory import log_gpu_memory_usage
+from dwi_ml.models.projects.transformer_sublayers import \
+    ModifiedTransformerDecoderLayer, ModifiedTransformerEncoderLayer
 
 logger = logging.getLogger('model_logger')
 
 
-class ModifiedTransformer(Transformer):
-    def __init__(self, *args, **kw):
-        super().__init__(*args, **kw)
-
-    def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor = None,
-                tgt_mask: Tensor = None, memory_mask: Tensor = None,
-                src_key_padding_mask: Tensor = None,
-                tgt_key_padding_mask: Tensor = None,
-                memory_key_padding_mask: Tensor = None,
-                return_weights=False, average_heads=False):
-        """
-        Copy-pasted from torch. Now returns weights.
-        """
-        logger.debug("Entering main Transformer's forward.")
-        log_gpu_memory_usage(logger)
-        memory, sa_weights_encoder = self.encoder(
-            src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
-            return_weights=return_weights, average_heads=average_heads)
-
-        output, sa_weights_decoder, mha_weights = self.decoder(
-            tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
-            tgt_key_padding_mask=tgt_key_padding_mask,
-            memory_key_padding_mask=memory_key_padding_mask,
-            return_weights=return_weights, average_heads=average_heads)
-
-        return output, sa_weights_encoder, sa_weights_decoder, mha_weights
-
-
 class ModifiedTransformerEncoder(TransformerEncoder):
-    def __init__(self, *args, **kw):
-        super().__init__(*args, **kw)
+    def __init__(self, encoder_layer, *args, **kw):
+        if not isinstance(encoder_layer, ModifiedTransformerEncoderLayer):
+            raise ValueError("Encoder layer should be of type {}. Got {}"
+                             .format(ModifiedTransformerEncoderLayer.__name__,
+                                     type(encoder_layer)))
+        super().__init__(encoder_layer, *args, **kw)
 
     def forward(self, src: Tensor, mask: Optional[Tensor] = None,
                 src_key_padding_mask: Optional[Tensor] = None,
+                is_causal: Optional[bool] = None,
+                # New args:
                 return_weights=False, average_heads=False):
         """
         Copy-pasted from torch. Now returns weights.
-        Erased all the fast-path check: it is not used anyway if it is
-        training, and if we supply both src_key_padding_mask and mask, which is
-        our case.
-        Layers must be TransformerEncoderLayerGetWeights layers.
         """
-        if src_key_padding_mask is not None:
-            _skpm_dtype = src_key_padding_mask.dtype
-            if _skpm_dtype != torch.bool and not \
-                    torch.is_floating_point(src_key_padding_mask):
-                raise AssertionError("only bool and floating types of "
-                                     "key_padding_mask are supported")
+        src_key_padding_mask = F._canonical_mask(
+            mask=src_key_padding_mask,
+            mask_name="src_key_padding_mask",
+            other_type=F._none_or_dtype(mask),
+            other_name="mask",
+            target_type=src.dtype
+        )
+
+        mask = F._canonical_mask(
+            mask=mask,
+            mask_name="mask",
+            other_type=None,
+            other_name="",
+            target_type=src.dtype,
+            check_other=False,
+        )
+
         output = src
+        convert_to_nested = False
+        first_layer = self.layers[0]
         src_key_padding_mask_for_layers = src_key_padding_mask
+        why_not_sparsity_fast_path = ''
+        str_first_layer = "self.layers[0]"
+        batch_first = first_layer.self_attn.batch_first
+        if not hasattr(self, "use_nested_tensor"):
+            why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
+        elif not self.use_nested_tensor:
+            why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
+        elif first_layer.training:
+            why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
+        elif not src.dim() == 3:
+            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
+        elif src_key_padding_mask is None:
+            why_not_sparsity_fast_path = "src_key_padding_mask was None"
+        elif (((not hasattr(self, "mask_check")) or self.mask_check)
+                and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
+            why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
+        elif output.is_nested:
+            why_not_sparsity_fast_path = "NestedTensor input is not supported"
+        elif mask is not None:
+            why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
+        elif torch.is_autocast_enabled():
+            why_not_sparsity_fast_path = "autocast is enabled"
+
+        if not why_not_sparsity_fast_path:
+            tensor_args = (
+                src,
+                first_layer.self_attn.in_proj_weight,
+                first_layer.self_attn.in_proj_bias,
+                first_layer.self_attn.out_proj.weight,
+                first_layer.self_attn.out_proj.bias,
+                first_layer.norm1.weight,
+                first_layer.norm1.bias,
+                first_layer.norm2.weight,
+                first_layer.norm2.bias,
+                first_layer.linear1.weight,
+                first_layer.linear1.bias,
+                first_layer.linear2.weight,
+                first_layer.linear2.bias,
+            )
+            _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
+            if torch.overrides.has_torch_function(tensor_args):
+                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
+            elif src.device.type not in _supported_device_type:
+                why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
+            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
+                why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
+                                              "input/output projection weights or biases requires_grad")
+
+            if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
+                convert_to_nested = True
+                output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
+                src_key_padding_mask_for_layers = None
+
+        seq_len = _get_seq_len(src, batch_first)
+        is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
+
+        # THIS IS THE MODIFIED PART
         sa_weights = [None] * len(self.layers)
-
         for mod, i in zip(self.layers, range(len(self.layers))):
             output, sa_weights[i] = mod(
-                output, src_mask=mask,
+                output, src_mask=mask, is_causal=is_causal,
                 src_key_padding_mask=src_key_padding_mask_for_layers,
+                # New args:
                 return_weights=return_weights, average_heads=average_heads)
+        # END OF MODIFIED PART
+
+        if convert_to_nested:
+            output = output.to_padded_tensor(0., src.size())
 
         if self.norm is not None:
             output = self.norm(output)
@@ -90,30 +137,45 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None,
 
 
 class ModifiedTransformerDecoder(TransformerDecoder):
-    def __init__(self, *args, **kw):
-        super().__init__(*args, **kw)
+
+    def __init__(self, decoder_layer, *args, **kw):
+        if not isinstance(decoder_layer, ModifiedTransformerDecoderLayer):
+            raise ValueError("Encoder layer should be of type {}. Got {}"
+                             .format(ModifiedTransformerEncoderLayer.__name__,
+                                     type(decoder_layer)))
+        super().__init__(decoder_layer, *args, **kw)
 
     def forward(self, tgt: Tensor, memory: Tensor,
                 tgt_mask: Optional[Tensor] = None,
                 memory_mask: Optional[Tensor] = None,
                 tgt_key_padding_mask: Optional[Tensor] = None,
                 memory_key_padding_mask: Optional[Tensor] = None,
+                tgt_is_causal: Optional[bool] = None,
+                memory_is_causal: bool = False,
+                # New args:
                 return_weights=False, average_heads=False):
         """
         Copy-pasted from torch. Now returns weights.
-        Layers must be TransformerDecoderLayerGetWeightsNoSOS layers.
         """
         output = tgt
+
+        seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
+        tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
+
+        # THIS IS THE MODIFIED PART
         mha_weights = [None] * len(self.layers)
         sa_weights = [None] * len(self.layers)
-
         for mod, i in zip(self.layers, range(len(self.layers))):
             output, mha_weights[i], sa_weights[i] = \
                 mod(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                     tgt_key_padding_mask=tgt_key_padding_mask,
                     memory_key_padding_mask=memory_key_padding_mask,
+                    tgt_is_causal=tgt_is_causal,
+                    memory_is_causal=memory_is_causal,
+                    # New args:
                     return_weights=return_weights,
                     average_heads=average_heads)
+        # END OF MODIFIED PART
 
         if self.norm is not None:
             output = self.norm(output)
@@ -121,160 +183,39 @@ def forward(self, tgt: Tensor, memory: Tensor,
         return output, sa_weights, mha_weights
 
 
-def do_not_share_linear_weights(attn: MultiheadAttention, d_model):
-    """
-    I added a request for this parameter to be accessible.
-    https://github.com/pytorch/pytorch/issues/92990
-    """
-    factory_kwargs = {'device': None, 'dtype': None}
-
-    # Overriding some parameters in the self attention.
-    # Ugly but.... Torch does not have a parameter to NOT share linear
-    # weights. In their code, their only NOT share weights when dimensions
-    # are not the same. This is not our case. This is saved in their
-    # parameter _qkv_same_embed_dim. By changing this, we change their
-    # forward call to the MultiHeadAttention in self.self_attn.
-    attn._qkv_same_embed_dim = False
-    attn.q_proj_weight = Parameter(
-        torch.empty((d_model, d_model), **factory_kwargs))
-    attn.k_proj_weight = Parameter(
-        torch.empty((d_model, d_model), **factory_kwargs))
-    attn.v_proj_weight = Parameter(
-        torch.empty((d_model, d_model), **factory_kwargs))
-    attn.register_parameter('in_proj_weight', None)
-    attn._reset_parameters()
-
-
-class ModifiedTransformerEncoderLayer(TransformerEncoderLayer):
-    def __init__(self, d_model, nhead, **kw):
-        super().__init__(d_model, nhead, **kw)
-
-        do_not_share_linear_weights(self.self_attn, d_model)
-
-    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
-                src_key_padding_mask: Optional[Tensor] = None,
-                return_weights=False, average_heads=False):
-        """
-        Copy-pasted from torch. Now returns weights.
-        Erased all the fast-track checks.
-        """
-        x = src
-        if self.norm_first:
-            # Norm, SA, Add, Norm, FF, Add
-            sa, sa_weights = self._sa_block(
-                self.norm1(x), src_mask, src_key_padding_mask,
-                return_weights=return_weights, average_heads=average_heads)
-            x = x + sa
-            x = x + self._ff_block(self.norm2(x))
-        else:
-            # SA, Add, Norm, FF, Add, Norm
-            sa, sa_weights = self._sa_block(
-                x, src_mask, src_key_padding_mask,
-                return_weights=return_weights, average_heads=average_heads)
-            x = self.norm1(x + sa)
-            x = self.norm2(x + self._ff_block(x))
-
-        return x, sa_weights
-
-    # self-attention block
-    def _sa_block(self, x: Tensor,
-                  attn_mask: Optional[Tensor],
-                  key_padding_mask: Optional[Tensor],
-                  return_weights=False, average_heads=False):
-        output = self.self_attn(
-            x, x, x,
-            attn_mask=attn_mask, key_padding_mask=key_padding_mask,
-            need_weights=return_weights, average_attn_weights=average_heads)
-        x, weights = output  # if return_weights is False, weights is None
-
-        return self.dropout1(x), weights
-
-
-class ModifiedTransformerDecoderLayer(TransformerDecoderLayer):
-    """
-    Decoder Layer, in the case where we do not have a start of sequence (SOS)
-    token, and our mask contains only -inf for the first position. Output of
-    self-attention becomes nan after the softmax step. Setting to 0.
-
-    Also, now returning attention weights.
-    """
-    def __init__(self, d_model, nhead, **kw):
-        super().__init__(d_model, nhead, **kw)
+class ModifiedTransformer(Transformer):
+    encoder: ModifiedTransformerEncoder
+    decoder: ModifiedTransformerDecoder
 
-        do_not_share_linear_weights(self.self_attn, d_model)
-        do_not_share_linear_weights(self.multihead_attn, d_model)
+    def __init__(self, *args, **kw):
+        super().__init__(*args, **kw)
 
-    def forward(self, tgt: Tensor, memory: Tensor,
+    def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor = None,
                 tgt_mask: Tensor = None, memory_mask: Tensor = None,
+                src_key_padding_mask: Tensor = None,
                 tgt_key_padding_mask: Tensor = None,
                 memory_key_padding_mask: Tensor = None,
+                src_is_causal: bool = None, tgt_is_causal: bool = None,
+                memory_is_causal: bool = False,
+                # New args:
                 return_weights=False, average_heads=False):
         """
-        Copy-pasted from torch. Now returns weights + converts nan to 0.
-        Weights are None if return_weights is False.
-        """
-        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
-        x = tgt
-        if self.norm_first:
-            # Norm, SA, Add, Norm, MHA, Add, Norm, FF, Add
-            sa, sa_weights = self._sa_block(
-                self.norm1(x), tgt_mask, tgt_key_padding_mask,
-                return_weights=return_weights, average_heads=average_heads)
-            x = x + sa
-
-            mha, mha_weights = self._mha_block(
-                self.norm2(x), memory, memory_mask, memory_key_padding_mask,
-                return_weights=return_weights, average_heads=average_heads)
-            x = x + mha
-            x = x + self._ff_block(self.norm3(x))
-        else:
-            # SA, Add, Norm, MHA, Add, Norm, FF, Add, Norm.
-            sa, sa_weights = self._sa_block(
-                x, tgt_mask, tgt_key_padding_mask,
-                return_weights=return_weights, average_heads=average_heads)
-            x = self.norm1(x + sa)
-
-            mha, mha_weights = self._mha_block(
-                x, memory, memory_mask, memory_key_padding_mask,
-                return_weights=return_weights, average_heads=average_heads)
-            x = self.norm2(x + mha)
-            x = self.norm3(x + self._ff_block(x))
-
-        return x, mha_weights, sa_weights
-
-    # self-attention block
-    def _sa_block(self, x: Tensor,
-                  attn_mask: Optional[Tensor],
-                  key_padding_mask: Optional[Tensor],
-                  return_weights=False, average_heads=False):
-        """
         Copy-pasted from torch. Now returns weights.
         """
-        output = self.self_attn(
-            x, x, x,
-            attn_mask=attn_mask, key_padding_mask=key_padding_mask,
-            need_weights=return_weights, average_attn_weights=average_heads)
-
-        x, weights = output  # If not return_weights, weights is None.
-
-        return self.dropout1(x), weights
-
-    # multihead attention block
-    def _mha_block(self, x: Tensor, mem: Tensor,
-                   attn_mask: Optional[Tensor],
-                   key_padding_mask: Optional[Tensor],
-                   return_weights=False, average_heads=False):
-        """
-        Copy-pasted from torch. Can now use need_weight = True.
-        """
-        output = self.multihead_attn(
-            x, mem, mem,
-            attn_mask=attn_mask, key_padding_mask=key_padding_mask,
-            need_weights=return_weights, average_attn_weights=average_heads)
+        logger.debug("Entering main Transformer's forward.")
+        log_gpu_memory_usage(logger)
+        memory, sa_weights_encoder = self.encoder(
+            src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
+            is_causal=src_is_causal,
+            # New args:
+            return_weights=return_weights, average_heads=average_heads)
 
-        if return_weights:
-            x, weights = output
-        else:
-            x, weights = output, None
+        output, sa_weights_decoder, mha_weights = self.decoder(
+            tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+            tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal,
+            # New args:
+            return_weights=return_weights, average_heads=average_heads)
 
-        return self.dropout2(x[0]), weights
+        return output, sa_weights_encoder, sa_weights_decoder, mha_weights
diff --git a/requirements.txt b/requirements.txt
index d492885c..7739110b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -34,10 +34,3 @@ nibabel==5.2.*
 numpy==1.23.*
 scipy==1.9.*
 scikit-image==0.22.*
-
-
-# --------------- Notes to developers
-# If we upgrade torch, verify if code copied in
-#     models.projects.transformers_from_torch has changed.
-# (current code copied from torch 1.13.1)
-# ----------

From c809fbbe835e7d9ef31bf682f1c19f8aa4f7b21c Mon Sep 17 00:00:00 2001
From: EmmaRenauld <emmanuelle.renauld@usherbrooke.ca>
Date: Thu, 23 May 2024 11:32:43 -0400
Subject: [PATCH 2/2] Use scilpy's add_verbose_arg

---
 dwi_ml/io_utils.py                                 |  9 ---------
 dwi_ml/testing/projects/tt_visu_argparser.py       |  6 +++---
 dwi_ml/testing/visu_loss_utils.py                  |  4 ++--
 scripts_python/dwiml_create_hdf5_dataset.py        |  5 ++---
 scripts_python/dwiml_visualize_logs.py             |  4 ++--
 scripts_python/dwiml_visualize_logs_correlation.py |  4 +---
 .../l2t_resume_training_from_checkpoint.py         |  3 ++-
 scripts_python/l2t_track_from_model.py             |  4 +---
 scripts_python/l2t_train_model.py                  |  5 +++--
 scripts_python/l2t_update_deprecated_exp.py        |  3 +--
 .../tt_resume_training_from_checkpoint.py          |  4 +++-
 scripts_python/tt_track_from_model.py              |  4 ++--
 scripts_python/tt_train_model.py                   | 14 ++++++++------
 13 files changed, 30 insertions(+), 39 deletions(-)

diff --git a/dwi_ml/io_utils.py b/dwi_ml/io_utils.py
index 4f969b15..b16e5baf 100644
--- a/dwi_ml/io_utils.py
+++ b/dwi_ml/io_utils.py
@@ -5,15 +5,6 @@
 from scilpy.io.utils import add_processes_arg
 
 
-def add_verbose_arg(p):
-    # Can eventually become scilpy.io.utils.add_verbose_arg
-    p.add_argument('-v', default="WARNING", const='INFO', nargs='?',
-                   choices=['DEBUG', 'INFO', 'WARNING'], dest='verbose',
-                   help='Produces verbose output depending on '
-                        'the provided level. \nDefault level is warning, '
-                        'default when using -v is info.')
-
-
 def add_resample_or_compress_arg(p: ArgumentParser):
     g = p.add_mutually_exclusive_group()
     g.add_argument(
diff --git a/dwi_ml/testing/projects/tt_visu_argparser.py b/dwi_ml/testing/projects/tt_visu_argparser.py
index cde26c57..9aadbac4 100644
--- a/dwi_ml/testing/projects/tt_visu_argparser.py
+++ b/dwi_ml/testing/projects/tt_visu_argparser.py
@@ -51,10 +51,10 @@
 """
 import argparse
 
-from scilpy.io.utils import (add_overwrite_arg, add_reference_arg)
+from scilpy.io.utils import (add_overwrite_arg, add_reference_arg,
+                             add_verbose_arg)
 
-from dwi_ml.io_utils import (add_arg_existing_experiment_path,
-                             add_verbose_arg, add_memory_args)
+from dwi_ml.io_utils import add_arg_existing_experiment_path, add_memory_args
 from dwi_ml.testing.utils import add_args_testing_subj_hdf5
 
 
diff --git a/dwi_ml/testing/visu_loss_utils.py b/dwi_ml/testing/visu_loss_utils.py
index ef81a594..9db78d52 100644
--- a/dwi_ml/testing/visu_loss_utils.py
+++ b/dwi_ml/testing/visu_loss_utils.py
@@ -4,11 +4,11 @@
 import os.path
 from argparse import ArgumentParser
 
-from scilpy.io.utils import (add_overwrite_arg,
+from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg,
                              assert_inputs_exist, assert_outputs_exist,
                              add_reference_arg, ranged_type)
 
-from dwi_ml.io_utils import add_memory_args, add_verbose_arg
+from dwi_ml.io_utils import add_memory_args
 
 
 def prepare_args_visu_loss(p: ArgumentParser):
diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py
index c36a14e7..7266cd15 100644
--- a/scripts_python/dwiml_create_hdf5_dataset.py
+++ b/scripts_python/dwiml_create_hdf5_dataset.py
@@ -25,8 +25,8 @@
 import shutil
 from pathlib import Path
 
-from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
-                             assert_outputs_exist)
+from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg,
+                             assert_inputs_exist, assert_outputs_exist)
 
 from dipy.io.stateful_tractogram import set_sft_logger_level
 
@@ -34,7 +34,6 @@
 from dwi_ml.data.hdf5.utils import (
     add_hdf5_creation_args, add_streamline_processing_args)
 from dwi_ml.experiment_utils.timer import Timer
-from dwi_ml.io_utils import add_verbose_arg
 
 
 def _initialize_intermediate_subdir(hdf5_file, save_intermediate):
diff --git a/scripts_python/dwiml_visualize_logs.py b/scripts_python/dwiml_visualize_logs.py
index f218e4a0..e8807748 100644
--- a/scripts_python/dwiml_visualize_logs.py
+++ b/scripts_python/dwiml_visualize_logs.py
@@ -36,9 +36,9 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
-from scilpy.io.utils import assert_outputs_exist, add_overwrite_arg
+from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg,
+                             assert_outputs_exist)
 
-from dwi_ml.io_utils import add_verbose_arg
 from dwi_ml.viz.logs_plots import visualize_logs
 
 
diff --git a/scripts_python/dwiml_visualize_logs_correlation.py b/scripts_python/dwiml_visualize_logs_correlation.py
index 78d92af6..b1ecfee2 100644
--- a/scripts_python/dwiml_visualize_logs_correlation.py
+++ b/scripts_python/dwiml_visualize_logs_correlation.py
@@ -13,9 +13,7 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
-from scilpy.io.utils import add_overwrite_arg
-
-from dwi_ml.io_utils import add_verbose_arg
+from scilpy.io.utils import add_overwrite_arg, add_verbose_arg
 
 
 def _build_arg_parser():
diff --git a/scripts_python/l2t_resume_training_from_checkpoint.py b/scripts_python/l2t_resume_training_from_checkpoint.py
index 21573424..15739413 100644
--- a/scripts_python/l2t_resume_training_from_checkpoint.py
+++ b/scripts_python/l2t_resume_training_from_checkpoint.py
@@ -9,9 +9,10 @@
 # Importing now to solve issues later.
 import comet_ml
 
+from scilpy.io.utils import add_verbose_arg
+
 from dwi_ml.data.dataset.utils import prepare_multisubjectdataset
 from dwi_ml.experiment_utils.timer import Timer
-from dwi_ml.io_utils import add_verbose_arg
 from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
 from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput
 from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
diff --git a/scripts_python/l2t_track_from_model.py b/scripts_python/l2t_track_from_model.py
index eb0891de..92f65831 100644
--- a/scripts_python/l2t_track_from_model.py
+++ b/scripts_python/l2t_track_from_model.py
@@ -12,9 +12,8 @@
 from dipy.io.utils import is_header_compatible
 import h5py
 import nibabel as nib
-import torch
 
-from scilpy.io.utils import (add_sphere_arg,
+from scilpy.io.utils import (add_sphere_arg, add_verbose_arg,
                              assert_inputs_exist, assert_outputs_exist,
                              verify_compression_th)
 from scilpy.tracking.utils import (add_seeding_options,
@@ -23,7 +22,6 @@
 
 from dwi_ml.experiment_utils.prints import format_dict_to_str
 from dwi_ml.experiment_utils.timer import Timer
-from dwi_ml.io_utils import add_verbose_arg
 from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
 from dwi_ml.testing.utils import prepare_dataset_one_subj, \
     find_hdf5_associated_to_experiment
diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py
index eb1bdf36..68fca3df 100755
--- a/scripts_python/l2t_train_model.py
+++ b/scripts_python/l2t_train_model.py
@@ -14,12 +14,13 @@
 import comet_ml
 import torch
 
-from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist
+from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist,
+                             assert_outputs_exist)
 
 from dwi_ml.data.dataset.utils import prepare_multisubjectdataset
 from dwi_ml.experiment_utils.prints import format_dict_to_str
 from dwi_ml.experiment_utils.timer import Timer
-from dwi_ml.io_utils import add_verbose_arg, add_memory_args
+from dwi_ml.io_utils import add_memory_args
 from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
 from dwi_ml.models.projects.learn2track_utils import add_model_args
 from dwi_ml.models.utils.direction_getters import check_args_direction_getter
diff --git a/scripts_python/l2t_update_deprecated_exp.py b/scripts_python/l2t_update_deprecated_exp.py
index 1f3304dd..a62edce7 100644
--- a/scripts_python/l2t_update_deprecated_exp.py
+++ b/scripts_python/l2t_update_deprecated_exp.py
@@ -12,11 +12,10 @@
 
 import numpy as np
 import torch
-from scilpy.io.utils import add_overwrite_arg
+from scilpy.io.utils import add_overwrite_arg, add_verbose_arg
 
 from dwi_ml.data.dataset.utils import prepare_multisubjectdataset
 from dwi_ml.experiment_utils.prints import format_dict_to_str
-from dwi_ml.io_utils import add_verbose_arg
 from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
 from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput
 from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
diff --git a/scripts_python/tt_resume_training_from_checkpoint.py b/scripts_python/tt_resume_training_from_checkpoint.py
index 191d0877..b7af9ce6 100644
--- a/scripts_python/tt_resume_training_from_checkpoint.py
+++ b/scripts_python/tt_resume_training_from_checkpoint.py
@@ -9,9 +9,11 @@
 # Importing now to solve issues later.
 import comet_ml
 
+from scilpy.io.utils import add_verbose_arg
+
 from dwi_ml.data.dataset.utils import prepare_multisubjectdataset
 from dwi_ml.experiment_utils.timer import Timer
-from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path
+from dwi_ml.io_utils import verify_which_model_in_path
 from dwi_ml.models.projects.transformer_models import find_transformer_class
 from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput
 from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
diff --git a/scripts_python/tt_track_from_model.py b/scripts_python/tt_track_from_model.py
index a579cd8f..8ee0cb10 100644
--- a/scripts_python/tt_track_from_model.py
+++ b/scripts_python/tt_track_from_model.py
@@ -14,7 +14,7 @@
 import h5py
 import nibabel as nib
 
-from scilpy.io.utils import (add_sphere_arg,
+from scilpy.io.utils import (add_sphere_arg, add_verbose_arg,
                              assert_inputs_exist, assert_outputs_exist,
                              verify_compression_th)
 from scilpy.tracking.utils import (add_seeding_options,
@@ -23,7 +23,7 @@
 
 from dwi_ml.experiment_utils.prints import format_dict_to_str
 from dwi_ml.experiment_utils.timer import Timer
-from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path
+from dwi_ml.io_utils import verify_which_model_in_path
 from dwi_ml.models.projects.transformer_models import find_transformer_class
 from dwi_ml.testing.utils import prepare_dataset_one_subj, \
     find_hdf5_associated_to_experiment
diff --git a/scripts_python/tt_train_model.py b/scripts_python/tt_train_model.py
index 00ce15ee..c8ba06da 100755
--- a/scripts_python/tt_train_model.py
+++ b/scripts_python/tt_train_model.py
@@ -14,14 +14,16 @@
 import comet_ml
 import torch
 
-from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist
+from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist,
+                             assert_outputs_exist)
 
 from dwi_ml.data.dataset.utils import prepare_multisubjectdataset
 from dwi_ml.experiment_utils.prints import format_dict_to_str
 from dwi_ml.experiment_utils.timer import Timer
-from dwi_ml.io_utils import add_memory_args, add_verbose_arg
-from dwi_ml.models.projects.transformer_models import \
-    OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel
+from dwi_ml.io_utils import add_memory_args
+from dwi_ml.models.projects.transformer_models import (
+    OriginalTransformerModel, TransformerSrcAndTgtModel,
+    TransformerSrcOnlyModel)
 from dwi_ml.models.projects.transformers_utils import (
     add_transformers_model_args)
 from dwi_ml.models.utils.direction_getters import check_args_direction_getter
@@ -32,8 +34,8 @@
                                                  prepare_batch_loader)
 from dwi_ml.training.utils.experiment import (
     add_mandatory_args_experiment_and_hdf5_path)
-from dwi_ml.training.utils.trainer import add_training_args, run_experiment, \
-    format_lr
+from dwi_ml.training.utils.trainer import (add_training_args, run_experiment,
+                                           format_lr)
 
 
 def prepare_arg_parser():