From ff64e344a1766a4b8fd7e6eecff31eab632c7b34 Mon Sep 17 00:00:00 2001 From: madsrehof <53825150+madsrehof@users.noreply.github.com> Date: Thu, 19 Sep 2024 09:06:04 +0200 Subject: [PATCH 1/2] Update mamba_vision.py Ensured consistent code formatting to match the style used across the repo. No functional changes were made. --- mambavision/models/mamba_vision.py | 66 +++++++++++++++--------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/mambavision/models/mamba_vision.py b/mambavision/models/mamba_vision.py index 1523ee9..43164a8 100644 --- a/mambavision/models/mamba_vision.py +++ b/mambavision/models/mamba_vision.py @@ -265,7 +265,8 @@ def forward(self, x): class ConvBlock(nn.Module): - def __init__(self, dim, + def __init__(self, + dim, drop_path=0., layer_scale=None, kernel_size=3): @@ -298,25 +299,24 @@ def forward(self, x): class MambaVisionMixer(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, - layer_idx=None, - device=None, - dtype=None, - ): + def __init__(self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, + layer_idx=None, + device=None, + dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model @@ -411,16 +411,15 @@ def forward(self, hidden_states): class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=False, - qk_norm=False, - attn_drop=0., - proj_drop=0., - norm_layer=nn.LayerNorm, - ): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm): + super().__init__() assert dim % num_heads == 0 self.num_heads = num_heads @@ -474,8 +473,8 @@ def __init__(self, act_layer=nn.GELU, norm_layer=nn.LayerNorm, Mlp_block=Mlp, - layer_scale=None, - ): + layer_scale=None): + super().__init__() self.norm1 = norm_layer(dim) if counter in transformer_blocks: @@ -529,8 +528,7 @@ def __init__(self, drop_path=0., layer_scale=None, layer_scale_conv=None, - transformer_blocks = [], - ): + transformer_blocks = []): """ Args: dim: feature size dimension. From 2c7801e1ff8b1ce9222edf823ccddcf2cf16f04f Mon Sep 17 00:00:00 2001 From: madsrehof <53825150+madsrehof@users.noreply.github.com> Date: Thu, 19 Sep 2024 09:35:45 +0200 Subject: [PATCH 2/2] Update mamba_vision.py Added docstrings to the _cfg function. --- mambavision/models/mamba_vision.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mambavision/models/mamba_vision.py b/mambavision/models/mamba_vision.py index 43164a8..f397969 100644 --- a/mambavision/models/mamba_vision.py +++ b/mambavision/models/mamba_vision.py @@ -30,6 +30,14 @@ def _cfg(url='', **kwargs): + """ + Generates a configuration dictionary for model initialization. + Args: + url (str): URL for the pre-trained model weights. + **kwargs: Additional keyword arguments to customize the configuration. + Returns: + dict: Configuration parameters for the model. + """ return {'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224),