diff --git a/fasterai/_modidx.py b/fasterai/_modidx.py index 7e7d592..94d613c 100644 --- a/fasterai/_modidx.py +++ b/fasterai/_modidx.py @@ -198,6 +198,8 @@ 'fasterai/sparse/sparsifier.py'), 'fasterai.sparse.sparsifier.Sparsifier._save_weights': ( 'sparse.sparsifier.html#sparsifier._save_weights', 'fasterai/sparse/sparsifier.py'), + 'fasterai.sparse.sparsifier.Sparsifier.apply_nm_sparsity': ( 'sparse.sparsifier.html#sparsifier.apply_nm_sparsity', + 'fasterai/sparse/sparsifier.py'), 'fasterai.sparse.sparsifier.Sparsifier.print_sparsity': ( 'sparse.sparsifier.html#sparsifier.print_sparsity', 'fasterai/sparse/sparsifier.py'), 'fasterai.sparse.sparsifier.Sparsifier.save_model': ( 'sparse.sparsifier.html#sparsifier.save_model', diff --git a/fasterai/sparse/sparsifier.py b/fasterai/sparse/sparsifier.py index 50a6673..2212b9c 100644 --- a/fasterai/sparse/sparsifier.py +++ b/fasterai/sparse/sparsifier.py @@ -15,7 +15,8 @@ # %% ../../nbs/01_sparse.sparsifier.ipynb 5 class Sparsifier(): "Class providing sparsifying capabilities" - def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d): + def __init__(self, model, granularity, context, criteria, nm=False, layer_type=nn.Conv2d): + if nm == True: print('Sparsity automatically set to 50%') store_attr() self._save_weights() # Save the original weights self._reset_threshold() @@ -116,6 +117,7 @@ def _compute_threshold(self, scores, sparsity, round_to): return self.threshold def _compute_mask(self, scores, threshold): + if self.nm == True: return self.apply_nm_sparsity(scores) if threshold > scores.max(): threshold = scores.max() # Make sure we don't remove every weight of a given layer return scores.ge(threshold).to(dtype=scores.dtype) @@ -123,3 +125,20 @@ def print_sparsity(self): for k,m in enumerate(self.model.modules()): if isinstance(m, self.layer_type): print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%") + + def apply_nm_sparsity(self, scores): + out_channels, in_channels, kernel_height, kernel_width = scores.shape + sparse_mask = torch.ones_like(scores) + if in_channels * kernel_height * kernel_width % 16 != 0: + print(f"Skipping 2:4 sparsity, Cin * Kh * Kw is not a multiple of 16") + return sparse_mask # Return weights unchanged if condition is not met + for out_ch in range(out_channels): + for h in range(kernel_height): + for w in range(kernel_width): + kernel_weights = scores[out_ch, :, h, w] + blocks = kernel_weights.view(-1, 4) # Flatten into blocks of 4 + _, indices = blocks.topk(2, dim=1, largest=True, sorted=False) # Retain top-2 absolute values in each block + mask = torch.zeros_like(blocks) + mask.scatter_(1, indices, 1) + sparse_mask[out_ch, :, h, w] = mask.view(-1) # Reshape and place the mask in the appropriate location + return sparse_mask diff --git a/fasterai/sparse/sparsify_callback.py b/fasterai/sparse/sparsify_callback.py index 8805c90..cca8a52 100644 --- a/fasterai/sparse/sparsify_callback.py +++ b/fasterai/sparse/sparsify_callback.py @@ -17,7 +17,7 @@ # %% ../../nbs/02_sparse.sparsify_callback.ipynb 4 class SparsifyCallback(Callback): "Sparsify model during training" - def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d): + def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, nm=False, layer_type=nn.Conv2d): store_attr() self.sparsity = listify(self.sparsity) @@ -25,7 +25,7 @@ def before_fit(self): print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%') assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process' model = self.model or self.learn.model - self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type) + self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.nm, self.layer_type) def before_epoch(self): if self.epoch == self.rewind_epoch: diff --git a/nbs/01_sparse.sparsifier.ipynb b/nbs/01_sparse.sparsifier.ipynb index 3afe38b..eff3161 100644 --- a/nbs/01_sparse.sparsifier.ipynb +++ b/nbs/01_sparse.sparsifier.ipynb @@ -30,8 +30,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/HubensN/miniconda3/envs/prune/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n", - " warn(f\"Failed to load image Python extension: {e}\")\n" + "/home/HubensN/miniconda3/envs/fasterai20/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -75,7 +75,8 @@ "#| export\n", "class Sparsifier():\n", " \"Class providing sparsifying capabilities\"\n", - " def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):\n", + " def __init__(self, model, granularity, context, criteria, nm=False, layer_type=nn.Conv2d):\n", + " if nm == True: print('Sparsity automatically set to 50%')\n", " store_attr()\n", " self._save_weights() # Save the original weights\n", " self._reset_threshold()\n", @@ -176,13 +177,31 @@ " return self.threshold\n", " \n", " def _compute_mask(self, scores, threshold):\n", + " if self.nm == True: return self.apply_nm_sparsity(scores)\n", " if threshold > scores.max(): threshold = scores.max() # Make sure we don't remove every weight of a given layer\n", " return scores.ge(threshold).to(dtype=scores.dtype)\n", " \n", " def print_sparsity(self):\n", " for k,m in enumerate(self.model.modules()):\n", " if isinstance(m, self.layer_type):\n", - " print(f\"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%\")" + " print(f\"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%\")\n", + "\n", + " def apply_nm_sparsity(self, scores):\n", + " out_channels, in_channels, kernel_height, kernel_width = scores.shape\n", + " sparse_mask = torch.ones_like(scores)\n", + " if in_channels * kernel_height * kernel_width % 16 != 0:\n", + " print(f\"Skipping 2:4 sparsity, Cin * Kh * Kw is not a multiple of 16\")\n", + " return sparse_mask # Return weights unchanged if condition is not met\n", + " for out_ch in range(out_channels):\n", + " for h in range(kernel_height):\n", + " for w in range(kernel_width):\n", + " kernel_weights = scores[out_ch, :, h, w]\n", + " blocks = kernel_weights.view(-1, 4) # Flatten into blocks of 4\n", + " _, indices = blocks.topk(2, dim=1, largest=True, sorted=False) # Retain top-2 absolute values in each block\n", + " mask = torch.zeros_like(blocks)\n", + " mask.scatter_(1, indices, 1)\n", + " sparse_mask[out_ch, :, h, w] = mask.view(-1) # Reshape and place the mask in the appropriate location\n", + " return sparse_mask" ] }, { @@ -195,22 +214,26 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/nathanhubens/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L16){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### Sparsifier\n", "\n", - "> Sparsifier (model, granularity, context, criteria, layer_type= 'torch.nn.modules.conv.Conv2d'>)\n", + "> Sparsifier (model, granularity, context, criteria, nm=False,\n", + "> layer_type=)\n", "\n", - "Class providing sparsifying capabilities" + "*Class providing sparsifying capabilities*" ], "text/plain": [ "---\n", "\n", + "[source](https://github.com/nathanhubens/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L16){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### Sparsifier\n", "\n", - "> Sparsifier (model, granularity, context, criteria, layer_type= 'torch.nn.modules.conv.Conv2d'>)\n", + "> Sparsifier (model, granularity, context, criteria, nm=False,\n", + "> layer_type=)\n", "\n", - "Class providing sparsifying capabilities" + "*Class providing sparsifying capabilities*" ] }, "execution_count": null, @@ -250,6 +273,8 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/nathanhubens/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L23){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### Sparsifier.sparsify_layer\n", "\n", "> Sparsifier.sparsify_layer (m, sparsity, round_to=None)" @@ -257,6 +282,8 @@ "text/plain": [ "---\n", "\n", + "[source](https://github.com/nathanhubens/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L23){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### Sparsifier.sparsify_layer\n", "\n", "> Sparsifier.sparsify_layer (m, sparsity, round_to=None)" @@ -288,6 +315,8 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/nathanhubens/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L31){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### Sparsifier.sparsify_model\n", "\n", "> Sparsifier.sparsify_model (sparsity, round_to=None)" @@ -295,6 +324,8 @@ "text/plain": [ "---\n", "\n", + "[source](https://github.com/nathanhubens/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L31){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### Sparsifier.sparsify_model\n", "\n", "> Sparsifier.sparsify_model (sparsity, round_to=None)" diff --git a/nbs/02_sparse.sparsify_callback.ipynb b/nbs/02_sparse.sparsify_callback.ipynb index aecfbde..5076c64 100644 --- a/nbs/02_sparse.sparsify_callback.ipynb +++ b/nbs/02_sparse.sparsify_callback.ipynb @@ -73,7 +73,7 @@ "#| export\n", "class SparsifyCallback(Callback):\n", " \"Sparsify model during training\"\n", - " def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d):\n", + " def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, nm=False, layer_type=nn.Conv2d):\n", " store_attr()\n", " self.sparsity = listify(self.sparsity)\n", "\n", @@ -81,7 +81,7 @@ " print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')\n", " assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'\n", " model = self.model or self.learn.model\n", - " self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type)\n", + " self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.nm, self.layer_type)\n", "\n", " def before_epoch(self):\n", " if self.epoch == self.rewind_epoch:\n",