diff --git a/mambavision/models/mamba_vision.py b/mambavision/models/mamba_vision.py index 1523ee9..f397969 100644 --- a/mambavision/models/mamba_vision.py +++ b/mambavision/models/mamba_vision.py @@ -30,6 +30,14 @@ def _cfg(url='', **kwargs): + """ + Generates a configuration dictionary for model initialization. + Args: + url (str): URL for the pre-trained model weights. + **kwargs: Additional keyword arguments to customize the configuration. + Returns: + dict: Configuration parameters for the model. + """ return {'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), @@ -265,7 +273,8 @@ def forward(self, x): class ConvBlock(nn.Module): - def __init__(self, dim, + def __init__(self, + dim, drop_path=0., layer_scale=None, kernel_size=3): @@ -298,25 +307,24 @@ def forward(self, x): class MambaVisionMixer(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, - layer_idx=None, - device=None, - dtype=None, - ): + def __init__(self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, + layer_idx=None, + device=None, + dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model @@ -411,16 +419,15 @@ def forward(self, hidden_states): class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=False, - qk_norm=False, - attn_drop=0., - proj_drop=0., - norm_layer=nn.LayerNorm, - ): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm): + super().__init__() assert dim % num_heads == 0 self.num_heads = num_heads @@ -474,8 +481,8 @@ def __init__(self, act_layer=nn.GELU, norm_layer=nn.LayerNorm, Mlp_block=Mlp, - layer_scale=None, - ): + layer_scale=None): + super().__init__() self.norm1 = norm_layer(dim) if counter in transformer_blocks: @@ -529,8 +536,7 @@ def __init__(self, drop_path=0., layer_scale=None, layer_scale_conv=None, - transformer_blocks = [], - ): + transformer_blocks = []): """ Args: dim: feature size dimension.