Skip to content

Commit

Permalink
Docstrings for simulai.models.Transformer
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
João Lucas de Sousa Almeida authored and João Lucas de Sousa Almeida committed Oct 26, 2023
1 parent a33fe17 commit 5c1d0f6
Showing 1 changed file with 130 additions and 10 deletions.
140 changes: 130 additions & 10 deletions simulai/models/_pytorch_models/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,30 @@
class BaseTemplate(NetworkTemplate):

def __init__(self):
"""
Template used for sharing fundamental methods with the
children transformer-like encoders and decoders.
"""

super(BaseTemplate, self).__init__()

def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.nn.Module:
"""
It configures the activation functions for the transformer layers.
Parameters
----------
activation : Union[str, torch.nn.Module]
Activation function to be used in all the network layers
Returns
A Module object for this activation function.
-------
Raises
------
Exception :
When the activation function is not supported.
"""

if isinstance(activation, torch.nn.Module):
return encoder_activation
Expand All @@ -23,11 +43,25 @@ def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.n

class BasicEncoder(BaseTemplate):

def __init__(self, num_heads=1,
def __init__(self, num_heads:int=1,
activation:Union[str, torch.nn.Module]='relu',
mlp_layer:torch.nn.Module=None,
embed_dim:Union[int, Tuple]=None,
):
) -> None:
"""
Generic transformer encoder.
Parameters
----------
num_heads : int
Number of attention heads for the self-attention layers.
activation : Union[str, torch.nn.Module]=
Activation function to be used in all the network layers
mlp_layer : torch.nn.Module
A Module object representing the MLP (Dense) operation.
embed_dim : Union[int, Tuple]
Dimension used for the transfoirmer embedding.
"""

super(BasicEncoder, self).__init__()

Expand All @@ -53,6 +87,18 @@ def __init__(self, num_heads=1,

def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None
) -> torch.Tensor:
"""
Parameters
----------
input_data : Union[torch.Tensor, np.ndarray]
The input dataset.
Returns
-------
torch.Tensor
The output generated by the encoder.
"""

h = input_data
h1 = self.activation_1(h)
Expand All @@ -68,6 +114,20 @@ def __init__(self, num_heads:int=1,
activation:Union[str, torch.nn.Module]='relu',
mlp_layer:torch.nn.Module=None,
embed_dim:Union[int, Tuple]=None):
"""
Generic transformer decoder.
Parameters
----------
num_heads : int
Number of attention heads for the self-attention layers.
activation : Union[str, torch.nn.Module]=
Activation function to be used in all the network layers
mlp_layer : torch.nn.Module
A Module object representing the MLP (Dense) operation.
embed_dim : Union[int, Tuple]
Dimension used for the transfoirmer embedding.
"""

super(BasicDecoder, self).__init__()

Expand All @@ -94,6 +154,20 @@ def __init__(self, num_heads:int=1,
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None,
encoder_output:torch.Tensor=None,
) -> torch.Tensor:
"""
Parameters
----------
input_data : Union[torch.Tensor, np.ndarray]
The input dataset (in principle, the same input used for the encoder).
encoder_output : torch.Tensor
The output provided by the encoder stage.
Returns
-------
torch.Tensor
The decoder output.
"""

h = input_data
h1 = self.activation_1(h)
Expand All @@ -115,6 +189,37 @@ def __init__(self, num_heads_encoder:int=1,
decoder_mlp_layer_config:dict=None,
number_of_encoders:int=1,
number_of_decoders:int=1) -> None:
"""
A classical encoder-decoder transformer:
U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e
(u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V
Parameters
----------
num_heads_encoder : int
The number of heads for the self-attention layer of the encoder.
num_heads_decoder :int
The number of heads for the self-attention layer of the decoder.
embed_dim_encoder : int
The dimension of the embedding for the encoder.
embed_dim_decoder : int
The dimension of the embedding for the decoder.
encoder_activation : Union[str, torch.nn.Module]
The activation to be used in all the encoder layers.
decoder_activation : Union[str, torch.nn.Module]
The activation to be used in all the decoder layers.
encoder_mlp_layer_config : dict
A configuration dictionary to instantiate the encoder MLP layer.weights
decoder_mlp_layer_config : dict
A configuration dictionary to instantiate the encoder MLP layer.weights
number_of_encoders : int
The number of encoders to be used.
number_of_decoders : int
The number of decoders to be used.
"""

super(Transformer, self).__init__()

Expand Down Expand Up @@ -165,7 +270,6 @@ def __init__(self, num_heads_encoder:int=1,
]



self.weights = list()

for e, encoder_e in enumerate(self.EncoderStage):
Expand All @@ -179,15 +283,31 @@ def __init__(self, num_heads_encoder:int=1,
@as_tensor
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None) -> torch.Tensor:

encoder_output = self.EncoderStage(input_data)
"""
Parameters
----------
input_data : Union[torch.Tensor, np.ndarray]
The input dataset.
Returns
-------
torch.Tensor
The transformer output.
"""

encoder_output = self.EncoderStage(input_data)

current_input = input_data
for decoder in self.DecoderStage:
output = decoder(input_data=current_input, encoder_output=encoder_output)
current_input = output
current_input = input_data
for decoder in self.DecoderStage:
output = decoder(input_data=current_input, encoder_output=encoder_output)
current_input = output

return output
return output

def summary(self):
"""
It prints a general view of the architecture.
"""

print(self)
print(self)

0 comments on commit 5c1d0f6

Please sign in to comment.