diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py
index 33ee5fbb4..92e3da2bf 100644
--- a/src/brevitas/graph/gpfq.py
+++ b/src/brevitas/graph/gpfq.py
@@ -31,110 +31,6 @@
 from brevitas.quant_tensor import _unpack_quant_tensor
 
 
-class gpfq_mode(gpxq_mode):
-    """
-    Apply GPFQ algorithm.
-
-    Args:
-        model (Module): The model to quantize with GPFQ
-        group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
-            of layer names that can be optimized in parallel. Default: None
-        inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
-        create_weight_orig (bool): If True, store the original floating point weights before applying
-            gpfq. These weights will be used anytime quantization is disabled. Default: True
-        use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
-            GPFQ. Default: False
-        p (float): The percentage of processed inputs to use. Default: 1.0
-        return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
-            forward call inside the context manager returns None. Default: False
-        act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
-
-    Example:
-        >>> with torch.no_grad():
-        >>>     with gpfq_mode(model) as gpfq:
-        >>>         gpfq_model = gpfq.model
-        >>>         for i in tqdm(range(gpfq.num_layers)):
-        >>>             for img, t in calib_loader:
-        >>>                 img = img.cuda()
-        >>>                 gpfq_model(img)
-        >>>             gpfq.update()
-    """
-
-    def __init__(
-            self,
-            model: nn.Module,
-            group_of_parallel_layers: Optional[List[str]] = None,
-            inplace: bool = True,
-            create_weight_orig: bool = True,
-            use_quant_activations: bool = True,
-            p: float = 1.0,
-            return_forward_output: bool = False,
-            act_order: bool = False,
-            gpfq_class: Optional[nn.Module] = None) -> None:
-        if not inplace:
-            model = deepcopy(model)
-        super().__init__(
-            model,
-            group_of_parallel_layers,
-            inplace,
-            create_weight_orig,
-            use_quant_activations,
-            act_order,
-            return_forward_output)
-
-        self.p = p
-        if gpfq_class is None:
-            gpfq_class = GPFQ
-        self.gpfq_class = gpfq_class
-        assert issubclass(gpfq_class, GPxQ), \
-            "Error: expected `gpfq_class` to be derived from `brevitas.graph.gpxq.GPxQ`."
-
-    def catch_stopfwd(self, *args, **kwargs):
-        # Collect quant input
-        try:
-            self.orig_forward(*args, **kwargs)
-        except StopFwdException:
-            pass
-
-        # Disable quantization
-        self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
-        self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
-        self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
-        # Collect float input
-        try:
-            self.orig_forward(*args, **kwargs)
-        except StopFwdException:
-            pass
-
-        # Re-enable quantization. If activation quantization is disabled,
-        # we also disable bias quantization
-        self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
-        if self.use_quant_activations:
-            self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
-        else:
-            self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
-        restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
-
-        if self.return_forward_output:
-            # If we want to return the output of the network, we need to disable all hooks
-            for name, gpxq_class in self.gpxq_layers.items():
-                gpxq_class.disable_pre_forward_hook = True
-            out = self.orig_forward(*args, **kwargs)
-            for name, gpxq_class in self.gpxq_layers.items():
-                gpxq_class.disable_pre_forward_hook = False
-            return out
-
-    def initialize_module_optimizer(
-            self, layer, name, act_order, len_parallel_layers, create_weight_orig):
-        return self.gpfq_class(
-            layer=layer,
-            name=name,
-            act_order=act_order,
-            len_parallel_layers=len_parallel_layers,
-            create_weight_orig=create_weight_orig,
-            p=self.p)
-
-
 class GPFQ(GPxQ):
     """
     Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
@@ -243,7 +139,11 @@ def single_layer_update(self):
         self.float_input = self.float_input.to(dev)
         self.quant_input = self.quant_input.to(dev)
         U = torch.zeros(
-            weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
+            weight.shape[0],
+            weight.shape[1],
+            self.float_input.shape[1],
+            device=dev,
+            dtype=torch.float32)
         # We don't need full Hessian, we just need the diagonal
         # Summing over batch dimension
         H_diag = self.quant_input.transpose(2, 1).square().sum(2)
@@ -261,7 +161,8 @@ def single_layer_update(self):
         for t in range(weight.shape[-1]):
             for group_index in range(self.groups):
                 U[group_index] += torch.matmul(
-                    weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
+                    weight[group_index, :,
+                           permutation_list[group_index][t]].unsqueeze(1).to(torch.float32),
                     self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
                         0))  #[OC/Groups, 1] * [1, INSHAPE[1]]
                 norm = torch.linalg.norm(
@@ -272,11 +173,11 @@ def single_layer_update(self):
                 else:
                     q_arg = torch.zeros_like(U[group_index, :, 0])
 
-                weight[group_index, :, permutation_list[group_index][t]] = q_arg
+                weight[group_index, :, permutation_list[group_index][t]] = q_arg.to(dtype)
             q = self.get_quant_weights(t, 0, permutation_list)
             for group_index in range(self.groups):
                 U[group_index] -= torch.matmul(
-                    q[group_index].unsqueeze(1),
+                    q[group_index].unsqueeze(1).to(torch.float32),
                     self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0))
 
         del self.float_input
@@ -360,7 +261,7 @@ def update_batch(self, module, input, current_layer):
         # if quant is not enabled, then it is the float input; if it is a float input
         # then a quant input has already happened and we can update G
         if not is_quant_enabled:
-            # Computing the normalized H matrix using CPU buffer
+            # Computing the normalized G matrix using CPU buffer
             self.B.copy_(self.quant_input.bmm(inp_processed.transpose(2, 1)))
             self.G += self.B
             self.quant_input = None  # NOTE: set back to None now that we've used it
@@ -401,6 +302,8 @@ def _get_permutation_list(self, weight: Tensor):
     def single_layer_update(self, percdamp: float = 0.01):
         assert not self.layer.weight_quant.requires_quant_input, \
             "Error: GPFQ does not support weight quantizers that require quantized inputs."
+        if hasattr(self.layer, "allocate_params"):
+            self.layer.allocate_params(self.layer)
         weight = self.layer.weight.data
         dev = weight.device
         dtype = weight.dtype
@@ -446,14 +349,17 @@ def single_layer_update(self, percdamp: float = 0.01):
         permutation_list = self._get_permutation_list(weight)
 
         U = torch.zeros(
-            weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev,
-            dtype=dtype)  # [Groups, OC/groups, Samples]
+            weight.shape[0],
+            weight.shape[1],
+            self.float_input.shape[1],
+            device=dev,
+            dtype=torch.float32)  # [Groups, OC/groups, Samples]
 
         for t in range(weight.shape[-1]):
             for group_index in range(self.groups):
                 i = permutation_list[group_index][t]
                 U[group_index] += torch.matmul(
-                    weight[group_index, :, i].unsqueeze(1),
+                    weight[group_index, :, i].unsqueeze(1).to(torch.float32),
                     self.float_input[group_index, :, i].unsqueeze(0),
                 )  # [OC/Groups, 1] * [1, INSHAPE[1]]
                 norm = norms[group_index, i]
@@ -461,13 +367,116 @@ def single_layer_update(self, percdamp: float = 0.01):
                     q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm
                 else:
                     q_arg = torch.zeros_like(U[group_index, :, 0])
-                weight[group_index, :, i] = q_arg
+                weight[group_index, :, i] = q_arg.to(dtype)
             q_groups = self.get_quant_weights(t, 0, permutation_list)
             for group_index in range(self.groups):
                 U[group_index] -= torch.matmul(
-                    q_groups[group_index].unsqueeze(1),
+                    q_groups[group_index].unsqueeze(1).to(torch.float32),
                     self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0),
                 )
-
+        if hasattr(self.layer, 'offload_params'):
+            self.layer.offload_params(self.layer)
         del self.float_input
         del self.quant_input
+
+
+class gpfq_mode(gpxq_mode):
+    """
+    Apply GPFQ algorithm.
+
+    Args:
+        model (Module): The model to quantize with GPFQ
+        group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
+            of layer names that can be optimized in parallel. Default: None
+        inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
+        create_weight_orig (bool): If True, store the original floating point weights before applying
+            gpfq. These weights will be used anytime quantization is disabled. Default: True
+        use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
+            GPFQ. Default: False
+        p (float): The percentage of processed inputs to use. Default: 1.0
+        return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
+            forward call inside the context manager returns None. Default: False
+        act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
+        gpfq_class (GPFQ): The uninitialized class to perform GPFQ.  Default: `brevitas.graph.gpfq.GPFQv2`,
+            which is the memory-efficient formulation
+
+    Example:
+        >>> with torch.no_grad():
+        >>>     with gpfq_mode(model) as gpfq:
+        >>>         gpfq_model = gpfq.model
+        >>>         for i in tqdm(range(gpfq.num_layers)):
+        >>>             for img, t in calib_loader:
+        >>>                 img = img.cuda()
+        >>>                 gpfq_model(img)
+        >>>             gpfq.update()
+    """
+
+    def __init__(
+            self,
+            model: nn.Module,
+            group_of_parallel_layers: Optional[List[str]] = None,
+            inplace: bool = True,
+            create_weight_orig: bool = True,
+            use_quant_activations: bool = True,
+            p: float = 1.0,
+            return_forward_output: bool = False,
+            act_order: bool = False,
+            gpfq_class: GPFQ = GPFQv2) -> None:
+        if not inplace:
+            model = deepcopy(model)
+        super().__init__(
+            model,
+            group_of_parallel_layers,
+            inplace,
+            create_weight_orig,
+            use_quant_activations,
+            act_order,
+            return_forward_output)
+
+        self.p = p
+        self.gpfq_class = gpfq_class
+
+    def catch_stopfwd(self, *args, **kwargs):
+        # Collect quant input
+        try:
+            self.orig_forward(*args, **kwargs)
+        except StopFwdException:
+            pass
+
+        # Disable quantization
+        self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
+        self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
+        self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
+        # Collect float input
+        try:
+            self.orig_forward(*args, **kwargs)
+        except StopFwdException:
+            pass
+
+        # Re-enable quantization. If activation quantization is disabled,
+        # we also disable bias quantization
+        self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
+        if self.use_quant_activations:
+            self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
+        else:
+            self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
+        restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
+
+        if self.return_forward_output:
+            # If we want to return the output of the network, we need to disable all hooks
+            for name, gpxq_class in self.gpxq_layers.items():
+                gpxq_class.disable_pre_forward_hook = True
+            out = self.orig_forward(*args, **kwargs)
+            for name, gpxq_class in self.gpxq_layers.items():
+                gpxq_class.disable_pre_forward_hook = False
+            return out
+
+    def initialize_module_optimizer(
+            self, layer, name, act_order, len_parallel_layers, create_weight_orig):
+        return self.gpfq_class(
+            layer=layer,
+            name=name,
+            act_order=act_order,
+            len_parallel_layers=len_parallel_layers,
+            create_weight_orig=create_weight_orig,
+            p=self.p)
diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py
index a1380da4e..667e47d40 100644
--- a/src/brevitas/graph/gptq.py
+++ b/src/brevitas/graph/gptq.py
@@ -23,85 +23,6 @@
 import brevitas.nn as qnn
 
 
-class gptq_mode(gpxq_mode):
-    """
-    Apply GPTQ algorithm https://arxiv.org/abs/2210.17323.
-
-    Args:
-        model (Module): The model to quantize with GPTQ
-        group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
-            of layer names that can be optimized in parallel. Default: None
-        inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True
-        create_weight_orig (bool): If True, store the original floating point weights before applying
-            gptq. These weights will be used anytime quantization is disabled. Default: True
-        use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
-            GPTQ. Default: False
-        num_blocks (int): The number of sub-blocks to use to speed-up GPTQ computation. Default: 100
-        act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
-        return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
-            forward call inside the context manager returns None. Default: False
-
-    Example:
-        >>> with torch.no_grad():
-        >>>     with gptq_mode(model) as gptq:
-        >>>         gptq_model = gptq.model
-        >>>         for i in tqdm(range(gptq.num_layers)):
-        >>>             for img, t in calib_loader:
-        >>>                 img = img.cuda()
-        >>>                 gptq_model(img)
-        >>>             gptq.update()
-    """
-
-    def __init__(
-            self,
-            model,
-            group_of_parallel_layers: Optional[List[str]] = None,
-            inplace: bool = True,
-            create_weight_orig: bool = True,
-            use_quant_activations: bool = True,
-            num_blocks: int = 100,
-            return_forward_output: bool = False,
-            act_order: bool = False) -> None:
-        if not inplace:
-            model = deepcopy(model)
-        super().__init__(
-            model,
-            group_of_parallel_layers,
-            inplace,
-            create_weight_orig,
-            use_quant_activations,
-            act_order,
-            return_forward_output)
-
-        # How many subblock to use during GPTQ for each layer
-        self.num_blocks = num_blocks
-
-    def catch_stopfwd(self, *args, **kwargs):
-        try:
-            self.orig_forward(*args, **kwargs)
-        except StopFwdException:
-            pass
-        finally:
-            if self.return_forward_output:
-                # If we want to return the output of the network, we need to disable all hooks
-                for name, gpxq_class in self.gpxq_layers.items():
-                    gpxq_class.disable_pre_forward_hook = True
-                out = self.orig_forward(*args, **kwargs)
-                for name, gpxq_class in self.gpxq_layers.items():
-                    gpxq_class.disable_pre_forward_hook = False
-                return out
-
-    def initialize_module_optimizer(
-            self, layer, name, act_order, len_parallel_layers, create_weight_orig):
-        return GPTQ(
-            layer=layer,
-            name=name,
-            act_order=act_order,
-            len_parallel_layers=len_parallel_layers,
-            create_weight_orig=create_weight_orig,
-            num_blocks=self.num_blocks)
-
-
 class GPTQ(GPxQ):
     """
     Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE:
@@ -275,7 +196,7 @@ def single_layer_update(self, percdamp=.01):
                 q_groups = self.get_quant_weights(i, i1, permutation_list)  # [groups, OC/groups]
                 for group_index in range(self.groups):
                     perm = permutation_list[group_index]
-                    q = q_groups[group_index]  # [OC/groups]
+                    q = q_groups[group_index].to(torch.float32)  # [OC/groups]
                     w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32)  # [OC/groups]
                     d = h_inv_block[group_index, i, i]  # [1]
                     error = (w - q) / d  # [OC/groups]
@@ -292,3 +213,85 @@ def single_layer_update(self, percdamp=.01):
                                                           i2:].to(dev))).to(dtype)
         if hasattr(self.layer, 'offload_params'):
             self.layer.offload_params(self.layer)
+
+
+class gptq_mode(gpxq_mode):
+    """
+    Apply GPTQ algorithm https://arxiv.org/abs/2210.17323.
+
+    Args:
+        model (Module): The model to quantize with GPTQ
+        group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
+            of layer names that can be optimized in parallel. Default: None
+        inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True
+        create_weight_orig (bool): If True, store the original floating point weights before applying
+            gptq. These weights will be used anytime quantization is disabled. Default: True
+        use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
+            GPTQ. Default: False
+        num_blocks (int): The number of sub-blocks to use to speed-up GPTQ computation. Default: 100
+        act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
+        return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
+            forward call inside the context manager returns None. Default: False
+        gptq_class (GPTQ): The uninitialized class to perform GPTQ. Default: `brevitas.graph.gptq.GPTQ`
+
+    Example:
+        >>> with torch.no_grad():
+        >>>     with gptq_mode(model) as gptq:
+        >>>         gptq_model = gptq.model
+        >>>         for i in tqdm(range(gptq.num_layers)):
+        >>>             for img, t in calib_loader:
+        >>>                 img = img.cuda()
+        >>>                 gptq_model(img)
+        >>>             gptq.update()
+    """
+
+    def __init__(
+            self,
+            model,
+            group_of_parallel_layers: Optional[List[str]] = None,
+            inplace: bool = True,
+            create_weight_orig: bool = True,
+            use_quant_activations: bool = True,
+            num_blocks: int = 100,
+            return_forward_output: bool = False,
+            act_order: bool = False,
+            gptq_class: GPTQ = GPTQ) -> None:
+        if not inplace:
+            model = deepcopy(model)
+        super().__init__(
+            model,
+            group_of_parallel_layers,
+            inplace,
+            create_weight_orig,
+            use_quant_activations,
+            act_order,
+            return_forward_output)
+
+        # How many subblock to use during GPTQ for each layer
+        self.num_blocks = num_blocks
+        self.gptq_class = gptq_class
+
+    def catch_stopfwd(self, *args, **kwargs):
+        try:
+            self.orig_forward(*args, **kwargs)
+        except StopFwdException:
+            pass
+        finally:
+            if self.return_forward_output:
+                # If we want to return the output of the network, we need to disable all hooks
+                for name, gpxq_class in self.gpxq_layers.items():
+                    gpxq_class.disable_pre_forward_hook = True
+                out = self.orig_forward(*args, **kwargs)
+                for name, gpxq_class in self.gpxq_layers.items():
+                    gpxq_class.disable_pre_forward_hook = False
+                return out
+
+    def initialize_module_optimizer(
+            self, layer, name, act_order, len_parallel_layers, create_weight_orig):
+        return self.gptq_class(
+            layer=layer,
+            name=name,
+            act_order=act_order,
+            len_parallel_layers=len_parallel_layers,
+            create_weight_orig=create_weight_orig,
+            num_blocks=self.num_blocks)
diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py
new file mode 100644
index 000000000..39e22535d
--- /dev/null
+++ b/src/brevitas_examples/common/axe.py
@@ -0,0 +1,436 @@
+# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+
+import math
+import warnings
+
+import numpy as np
+import torch
+from torch import Tensor
+
+try:
+    from torch.linalg import LinAlgError
+except:
+    LinAlgError = RuntimeError
+
+from brevitas.graph.gpfq import GPFQv2
+from brevitas.graph.gptq import GPTQ
+from brevitas.graph.gpxq import SUPPORTED_CONV_OP
+from brevitas.graph.gpxq import SUPPORTED_TCONV_OP
+
+
+def _get_average_of_nonzero_magnitudes(vec: np.ndarray, radius: float = 1.0):
+    assert radius > 0, "Error: radius needs to be strictly positive."
+    assert vec.ndim == 1, "Error: projection assumes a vector, not a matrix."
+    assert vec.min() >= 0, "Error: assuming a vector of non-negative numbers."
+    n_elems = vec.shape[0]
+    # if we are already within the simplex, then the best projection is itself
+    if vec.sum() <= radius:
+        return 0.0
+    # using algorithm detailed in "Efficient Projections onto the L1-Ball for Learning in High Dimensions"
+    v = vec
+    u = np.sort(v)[::-1]
+    cumsum_u = np.cumsum(u)
+    rho = np.nonzero(u * np.arange(1, n_elems + 1) > (cumsum_u - radius))[0][-1]
+    theta = float(cumsum_u[rho] - radius) / (rho + 1)
+    return theta
+
+
+def calc_average_nonzero_mag(weight: Tensor, lim: Tensor) -> Tensor:
+    thetas = torch.zeros(weight.shape[0], device=weight.device)
+    for i in range(weight.shape[0]):
+        l = lim[i].item() if lim.ndim > 0 else lim.item()
+        w = weight[i].cpu().detach().numpy()
+        t = _get_average_of_nonzero_magnitudes(np.abs(w), l)
+        thetas[i] = t
+    return thetas
+
+
+def pad_tensor_with_zeros(tensor: Tensor, tile_size: int) -> Tensor:
+    pad_size = tile_size - (tensor.shape[1] % tile_size)
+    if pad_size == tile_size:
+        return tensor
+    padding = torch.zeros((tensor.shape[0], pad_size), device=tensor.device)
+    pad_tensor = torch.concat([tensor, padding], axis=1)
+    return pad_tensor
+
+
+class A2GPTQ(GPTQ):
+    """
+    Accumulator-aware GPTQ as proposed in https://arxiv.org/pdf/2409.17092
+    """
+
+    def __init__(
+            self,
+            layer,
+            name,
+            act_order,
+            len_parallel_layers,
+            create_weight_orig,
+            num_blocks,
+            max_accumulator_bit_width,
+            max_accumulator_tile_size) -> None:
+        super().__init__(
+            layer, name, act_order, len_parallel_layers, create_weight_orig, num_blocks)
+        self.max_accumulator_bit_width = max_accumulator_bit_width
+        self.max_accumulator_tile_size = max_accumulator_tile_size
+        if self.max_accumulator_tile_size is None:
+            self.max_accumulator_tile_size = self.columns
+        assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2."
+        assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2."
+
+    def single_layer_update(self, percdamp=0.01):
+        assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs."
+        if self.quant_metadata is None:
+            raise ValueError(
+                "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. "
+                "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. "
+                "Also, check if `use_quant_activations=True` in `gptq_mode` when `max_accumulator_bit_width` is specified. "
+            )
+        if hasattr(self.layer, "allocate_params"):
+            self.layer.allocate_params(self.layer)
+        weight = self.layer.weight.data
+        dev = weight.device
+
+        # Store the original dtype of the weights
+        # During computation, everything is converted to float32.
+        # When the weights are updated, we cast everything back to the original dtype
+        dtype = weight.dtype
+
+        if isinstance(self.layer, SUPPORTED_CONV_OP):
+            if isinstance(self.layer, SUPPORTED_TCONV_OP):
+                weight = weight.transpose(1, 0)  # This performs a view
+            weight = weight.flatten(1)
+
+        # TODO: add support for signed input activations
+        if self.quant_metadata.signed:
+            raise NotImplementedError("Signed inputs not yet supported.")
+
+        # TODO: currently assuming round-to-zero; need to handle other rounding functions
+        rounding_mode = self.layer.weight_quant.rounding_mode
+        if rounding_mode.lower() != "round":
+            raise NotImplementedError(f"{rounding_mode} not yet supported.")
+
+        n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size)
+        scales: Tensor = self.layer.weight_quant.scale()
+        if isinstance(self.layer, SUPPORTED_CONV_OP):
+            if isinstance(self.layer, SUPPORTED_TCONV_OP):
+                scales = scales.transpose(1, 0)  # This performs a view
+            scales = scales.flatten(1)
+        P = torch.tensor(self.max_accumulator_bit_width)
+        N = self.quant_metadata.bit_width
+        # NOTE: using sign-magnitude here, which is sufficient to support both
+        # sign-magnitude and 2s complement accumulators
+        self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1)  # A
+        self.lower_lim = -self.upper_lim  # B
+        Z = (pow(2, P) - 2) / float(pow(2, N) - 1)  # l1-norm lim for zero-centered weight vector
+        # translating into the quantized range; need to pad to get these thresholds
+        wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view(
+            -1, self.max_accumulator_tile_size)  # [OC * Tiles, IC / Tiles]
+        thresholds = calc_average_nonzero_mag(
+            wT - wT.mean(axis=1, keepdim=True), Z)  # [Groups * OC * Tiles]
+        thresholds = thresholds.view(self.groups, -1,
+                                     n_tiles).transpose(1, 2)  # [Groups, Tiles, OC/Groups]
+        del wT
+        # supporting groupwise quantization where each tile has its own scaling factor
+        if self.layer.weight_quant.is_groupwise:
+            scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(
+                -1, self.max_accumulator_tile_size)  # [Groups, OC * Tiles, IC / Tiles]
+            scales = scales[:, 0]  # [Groups * OC * Tiles, 1]
+            scales = scales.view(self.groups, -1,
+                                 n_tiles).transpose(1, 2)  # [Groups, Tiles, OC/Groups]
+        # else each tile has the same scaling factor (per-tensor or per-channel)
+        else:
+            scales = scales.view(self.groups, 1, -1)  # [Groups, 1, OC/Groups]
+            scales = scales.repeat(1, n_tiles, 1)  # [Groups, Tiles, OC/Groups]
+        thresholds *= scales  # translating centers back to the float range
+        weight = weight.view(self.groups, -1, weight.shape[-1])  # [Groups, OC/Groups, IC]
+
+        # List with permutation tensors for the Hessian and weight matrix.
+        # If act_order is False, the tensors will be ordered indexes.
+        # For groupwise convolution, we have one tensor per group,
+        # thus len(permutation_list) is always equal to self.groups.
+        # We do not explicity permute the weight matrix, only the Hessian.
+        permutation_list = []
+        weight = weight.view(self.groups, -1, weight.shape[-1])
+        # For groupwise convolution, these operations are groupwise so we iterate
+        for i in range(self.groups):
+            # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding
+            # column in the weight matrix.
+            # The diagonal element is set to 1 to avoid division-by-zero
+            dead = torch.diag(self.H[i, :, :]) == 0
+            self.H[i, dead, dead] = 1
+            # If the diagonal of activations is zero, we set the weight to zero
+            weight[i, :, dead] = 0
+            if self.act_order:
+                # Re-order Hessian so that weights associated to
+                # higher magnitude activations are quantized first
+                perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True)
+                self.H[i, :, :] = self.H[i, perm, :][:, perm]
+            else:
+                # No permutation, permutation tensor is a ordered index
+                perm = torch.tensor(range(self.H.shape[-1]), device=dev)
+            permutation_list.append(perm)
+
+        # Try/Except in case the inverse Hessian cannot be computed
+        try:
+            for i in range(self.groups):
+                damp = percdamp * torch.mean(torch.diag(self.H[i, :, :]))
+                diag = torch.arange(self.columns, device='cpu')
+                self.H[i, diag, diag] += damp
+                self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :])
+                self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :])
+                self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True)
+            h_inv = self.H
+        except LinAlgError:
+            warnings.warn(
+                f'Failed to compute the inverse of the Hessian for layer {self.name} '
+                f'GPTQ will not be applied. '
+                f'Increasing the number of samples might fix this issue')
+            return
+        finally:
+            del self.H, self.B
+
+        # initialize cumulative l1-norm
+        a = torch.zeros_like(thresholds, device=dev)  # positive limits
+        b = torch.zeros_like(thresholds, device=dev)  # negative limits
+
+        for i1 in range(0, self.columns, self.blocksize):
+            i2 = min(i1 + self.blocksize, self.columns)
+            count = i2 - i1
+            error_block = torch.zeros_like(
+                weight[:, :, permutation_list[-1][i1:i2]],
+                dtype=torch.float32)  # [groups, OC/groups, i2-i1]
+
+            h_inv_block = h_inv[:, i1:i2, i1:i2]
+            for i in range(count):
+                # need to apply soft thresholding and clamping before quantization
+                for group_index in range(self.groups):
+                    perm = permutation_list[group_index]
+                    bx = perm[i1:i2][i] // self.max_accumulator_tile_size  # block index
+                    # calculate the q_max and q_min for the right group and right block
+                    q_max = scales[group_index, bx, :] * torch.clamp_min(
+                        self.upper_lim - a[group_index, bx, :] - 0.5, 0.0)  # [OC/groups]
+                    q_min = scales[group_index, bx, :] * torch.clamp_max(
+                        self.lower_lim - b[group_index, bx, :] + 0.5, 0.0)  # [OC/groups]
+                    q_arg = weight[group_index, :, perm[i1:i2][i]]  # [OC/groups]
+                    # soft thresholding then clamping
+                    q_arg = q_arg.sign() * torch.relu(
+                        q_arg.abs() - thresholds[group_index, bx])  # [OC/groups]
+                    q_arg.clamp_(q_min, q_max)  # clamping to bounds
+                    weight[group_index, :, perm[i1:i2][i]] = q_arg
+                q_groups = self.get_quant_weights(i, i1, permutation_list)  # [Groups, OC/groups]
+                for group_index in range(self.groups):
+                    perm = permutation_list[group_index]
+                    q = q_groups[group_index].to(torch.float32)  # [OC/groups]
+                    w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32)  # [OC/groups]
+                    d = h_inv_block[group_index, i, i]  # [1]
+                    error = (w - q) / d  # [OC/groups]
+                    error_block[group_index, :, i] = error
+                    # We need to update the original weights
+                    weight[group_index, :, perm[i1:i2][i:]] -= (
+                        error.unsqueeze(1).matmul(
+                            h_inv_block[group_index, i, i:].unsqueeze(0).to(dev))).to(dtype)
+                # update the tracking mechanisms
+                for group_index in range(self.groups):
+                    perm = permutation_list[group_index]
+                    bx = perm[i1:i2][i] // self.max_accumulator_tile_size  # block index
+                    q = q_groups[group_index] / scales[group_index, bx]  # [OC/groups]
+                    # increment cumulative l1-norm
+                    a[group_index, bx, q >= 0] += q[q >= 0]
+                    b[group_index, bx, q <= 0] += q[q <= 0]
+                    assert (a <= self.upper_lim).all() and (a >= 0).all()
+                    assert (b >= self.lower_lim).all() and (b <= 0).all()
+
+            for group_index in range(self.groups):
+                perm = permutation_list[group_index]
+                weight[group_index, :, perm[i2:]] -= (
+                    error_block[group_index].matmul(h_inv[group_index, i1:i2,
+                                                          i2:].to(dev))).to(dtype)
+        if hasattr(self.layer, "offload_params"):
+            self.layer.offload_params(self.layer)
+
+        del thresholds, scales  # memory management
+
+
+class A2GPFQ(GPFQv2):
+    """
+    Memory-efficient, accumulator-aware GPFQ as proposed in https://arxiv.org/pdf/2409.17092
+    """
+
+    def __init__(
+            self,
+            layer,
+            name,
+            act_order,
+            len_parallel_layers,
+            create_weight_orig,
+            p,
+            max_accumulator_bit_width,
+            max_accumulator_tile_size) -> None:
+        super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig, p)
+        self.max_accumulator_bit_width = max_accumulator_bit_width
+        self.max_accumulator_tile_size = max_accumulator_tile_size
+        if self.max_accumulator_tile_size is None:
+            self.max_accumulator_tile_size = self.columns
+        assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2."
+        assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2."
+
+    def single_layer_update(self, percdamp=0.01):
+        assert not self.layer.weight_quant.requires_quant_input, \
+            "Error: GPFQ does not support weight quantizers that require quantized inputs."
+        if self.quant_metadata is None:
+            raise ValueError(
+                "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. "
+                "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. "
+                "Also, check if `use_quant_activations=True` in `gpfq_mode` when `max_accumulator_bit_width` is specified. "
+            )
+        if hasattr(self.layer, "allocate_params"):
+            self.layer.allocate_params(self.layer)
+        weight: Tensor = self.layer.weight.data
+        dev = weight.device
+
+        # Store the original dtype of the weights
+        # During computation, everything is converted to float32.
+        # When the weights are updated, we cast everything back to the original dtype
+        dtype = weight.dtype
+
+        if isinstance(self.layer, SUPPORTED_CONV_OP):
+            if isinstance(self.layer, SUPPORTED_TCONV_OP):
+                weight = weight.transpose(1, 0)  # This performs a view
+            weight = weight.flatten(1)
+
+        # TODO: add support for signed input activations
+        if self.quant_metadata.signed:
+            raise NotImplementedError("Signed inputs not yet supported.")
+
+        # TODO: currently assuming round-to-zero; need to handle other rounding functions
+        rounding_mode = self.layer.weight_quant.rounding_mode
+        if rounding_mode.lower() != "round":
+            raise NotImplementedError(f"{rounding_mode} not yet supported.")
+
+        n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size)
+        scales: Tensor = self.layer.weight_quant.scale()
+        if isinstance(self.layer, SUPPORTED_CONV_OP):
+            if isinstance(self.layer, SUPPORTED_TCONV_OP):
+                scales = scales.transpose(1, 0)  # This performs a view
+            scales = scales.flatten(1)
+        P = torch.tensor(self.max_accumulator_bit_width)
+        N = self.quant_metadata.bit_width
+        # NOTE: using sign-magnitude here, which is sufficient to support both
+        # sign-magnitude and 2s complement accumulators
+        self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1)  # A
+        self.lower_lim = -self.upper_lim  # B
+        Z = (pow(2, P) - 2) / float(pow(2, N) - 1)  # l1-norm lim for zero-centered weight vector
+        # translating into the quantized range; need to pad to get these thresholds
+        wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view(
+            -1, self.max_accumulator_tile_size)  # [OC * Tiles, IC / Tiles]
+        thresholds = calc_average_nonzero_mag(
+            wT - wT.mean(axis=1, keepdim=True), Z)  # [Groups * OC * Tiles]
+        thresholds = thresholds.view(self.groups, -1,
+                                     n_tiles).transpose(1, 2)  # [Groups, Tiles, OC/Groups]
+        del wT
+        # supporting groupwise quantization where each tile has its own scaling factor
+        if self.layer.weight_quant.is_groupwise:
+            scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(
+                -1, self.max_accumulator_tile_size)  # [Groups, OC * Tiles, IC / Tiles]
+            scales = scales[:, 0]  # [Groups * OC * Tiles, 1]
+            scales = scales.view(self.groups, -1,
+                                 n_tiles).transpose(1, 2)  # [Groups, Tiles, OC/Groups]
+        # else each tile has the same scaling factor (per-tensor or per-channel)
+        else:
+            scales = scales.view(self.groups, 1, -1)  # [Groups, 1, OC/Groups]
+            scales = scales.repeat(1, n_tiles, 1)  # [Groups, Tiles, OC/Groups]
+        thresholds *= scales  # translating centers back to the float range
+
+        weight = weight.view(self.groups, -1, weight.shape[-1])  # [Groups, OC/Groups, IC]
+
+        # initialize cumulative l1-norm
+        a = torch.zeros_like(thresholds, device=dev)  # positive limit
+        b = torch.zeros_like(thresholds, device=dev)  # negative limit
+
+        # Try/Except in case the square root of H cannot be computed
+        try:
+            norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32)
+            self.H = self.H.to(dev)
+            diag = torch.arange(self.columns, device='cpu')
+            for i in range(self.groups):
+                # stablize H with a dampening factor and then square root the matrix
+                damp = percdamp * self.H[i].diag().mean()
+                self.H[i, diag, diag] += damp
+                norms[i] = self.H[i].diag()  # set the norms post-dampening
+                eigvals, eigvecs = torch.linalg.eigh(self.H[i])
+                eigvals.clamp_min_(0.0).sqrt_()  # should be positive-definite
+                self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t()
+            del eigvecs, eigvals, diag
+            self.quant_input = self.H  # NOTE: do this here for the `get_permutation_list` function
+        except LinAlgError:
+            warnings.warn(
+                f'Failed to compute the matrix square root of H for layer {self.name} '
+                f'GPFQ will not be applied. '
+                f'Increasing the number of samples might fix this issue')
+            return
+
+        # Try/Except in case the inverse of H cannot be computed
+        try:
+            self.float_input = self.H.clone()  # going to calculate H^{-1} here
+            for i in range(self.groups):
+                # from our matrix sqrt, we know G is symmetric and positive-definite, so we
+                # can use Cholesky decomposition as an efficient, numerically stable inverse
+                L = torch.linalg.cholesky(self.float_input[i])
+                self.float_input[i] = torch.cholesky_inverse(L)
+            self.float_input = torch.bmm(self.float_input.to(dev), self.G.to(dev))
+            del L  # memory management
+        except LinAlgError:
+            warnings.warn(
+                f'Failed to compute the inverse of H for layer {self.name} '
+                f'GPFQ will not be applied. '
+                f'Increasing the number of samples might fix this issue')
+            return
+        finally:
+            del self.H, self.G, self.B  # memory management
+
+        permutation_list = self._get_permutation_list(weight)
+
+        U = torch.zeros(
+            weight.shape[0],
+            weight.shape[1],
+            self.float_input.shape[1],
+            device=dev,
+            dtype=torch.float32)  # [Groups, OC/groups, Samples]
+
+        for t in range(weight.shape[-1]):
+            for group_index in range(self.groups):
+                i = permutation_list[group_index][t]
+                U[group_index] += torch.matmul(
+                    weight[group_index, :, i].unsqueeze(1).to(torch.float32),
+                    self.float_input[group_index, :, i].unsqueeze(0))
+                norm = norms[group_index, i]
+                if norm > 0:
+                    q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm
+                else:
+                    q_arg = torch.zeros_like(U[group_index, :, 0])
+                bx = i // self.max_accumulator_tile_size  # block index
+                q_arg = q_arg.sign() * torch.relu(
+                    q_arg.abs() - thresholds[group_index, bx, :])  # soft thresholding
+                q_max = scales[group_index, bx] * torch.clamp_min(
+                    self.upper_lim - a[group_index, bx, :] - 0.5, 0.0)
+                q_min = scales[group_index, bx] * torch.clamp_max(
+                    self.lower_lim - b[group_index, bx, :] + 0.5, 0.0)
+                q_arg.clamp_(q_min, q_max)
+                weight[group_index, :, i] = q_arg.to(dtype)
+            q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list)
+            for group_index in range(self.groups):
+                i = permutation_list[group_index][t]
+                U[group_index] -= torch.matmul(
+                    q_groups[group_index].unsqueeze(1).to(torch.float32),
+                    self.quant_input[group_index, :, i].unsqueeze(0))
+                bx = i // self.max_accumulator_tile_size  # block index
+                q = q_groups[group_index] / scales[group_index, bx]  # [OC/groups]
+                # increment cumulative l1-norm
+                a[group_index, bx, q >= 0] += q[q >= 0]
+                b[group_index, bx, q <= 0] += q[q <= 0]
+                assert (a <= self.upper_lim).all() and (a >= 0).all()
+                assert (b >= self.lower_lim).all() and (b <= 0).all()
+
+        del self.quant_input, self.float_input
diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
index 0151c9232..38ed85678 100644
--- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
+++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
@@ -1,11 +1,10 @@
 # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
 # SPDX-License-Identifier: BSD-3-Clause
 
-from copy import deepcopy
+from functools import partial
 import math
 
 import torch
-import torch.backends.cudnn as cudnn
 from tqdm import tqdm
 
 from brevitas.core.function_wrapper.shape import OverBatchOverTensorView
@@ -16,6 +15,8 @@
 from brevitas.graph.calibrate import norm_correction_mode
 from brevitas.graph.equalize import activation_equalization_mode
 from brevitas.graph.gpfq import gpfq_mode
+from brevitas.graph.gpfq import GPFQv2
+from brevitas.graph.gptq import GPTQ
 from brevitas.graph.gptq import gptq_mode
 from brevitas.graph.quantize import layerwise_quantize
 from brevitas.graph.quantize import quantize
@@ -60,7 +61,6 @@
 from brevitas.quant.scaled_int import Int32Bias
 from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFixedPoint
 from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
-from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatHQO
 from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE
 from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
 from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO
@@ -68,6 +68,8 @@
 from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
 from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO
 from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
+from brevitas_examples.common.axe import A2GPFQ
+from brevitas_examples.common.axe import A2GPTQ
 from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
 from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat
 from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator
@@ -574,12 +576,32 @@ def apply_act_equalization(model, calib_loader, layerwise):
                 model(images)
 
 
-def apply_gptq(calib_loader, model, act_order=False):
+def apply_gptq(
+        calib_loader,
+        model,
+        act_order=False,
+        use_quant_activations=False,
+        create_weight_orig=False,
+        max_accumulator_bit_width=None,
+        max_accumulator_tile_size=128):
+    if max_accumulator_bit_width is not None:
+        # Use accumulator-aware extension (AXE) framework
+        print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...")
+        gptq_class = partial(
+            A2GPTQ,
+            max_accumulator_bit_width=max_accumulator_bit_width,
+            max_accumulator_tile_size=max_accumulator_tile_size)
+    else:
+        gptq_class = GPTQ
     model.eval()
     dtype = next(model.parameters()).dtype
     device = next(model.parameters()).device
     with torch.no_grad():
-        with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq:
+        with gptq_mode(model,
+                       act_order=act_order,
+                       use_quant_activations=use_quant_activations,
+                       create_weight_orig=create_weight_orig,
+                       gptq_class=gptq_class) as gptq:
             gptq_model = gptq.model
             for i in tqdm(range(gptq.num_layers)):
                 for i, (images, target) in enumerate(calib_loader):
@@ -593,21 +615,27 @@ def apply_gpfq(
         calib_loader,
         model,
         act_order,
-        p=1.0,
-        use_gpfa2q=False,
-        accumulator_bit_width=None,
-        compression_rate=0.0):
+        create_weight_orig=False,
+        max_accumulator_bit_width=None,
+        max_accumulator_tile_size=128):
     model.eval()
     dtype = next(model.parameters()).dtype
     device = next(model.parameters()).device
+    if max_accumulator_bit_width is not None:
+        # Use accumulator-aware extension (AXE) framework
+        print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...")
+        gpfq_class = partial(
+            A2GPFQ,
+            max_accumulator_bit_width=max_accumulator_bit_width,
+            max_accumulator_tile_size=max_accumulator_tile_size)
+    else:
+        gpfq_class = GPFQv2
     with torch.no_grad():
         with gpfq_mode(model,
-                       p=p,
+                       create_weight_orig=create_weight_orig,
                        use_quant_activations=True,
                        act_order=act_order,
-                       use_gpfa2q=use_gpfa2q,
-                       accumulator_bit_width=accumulator_bit_width,
-                       compression_rate=compression_rate) as gpfq:
+                       gpfq_class=gpfq_class) as gpfq:
             gpfq_model = gpfq.model
             for i in tqdm(range(gpfq.num_layers)):
                 for i, (images, target) in enumerate(calib_loader):
diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
index 8a70e29ba..34bdfbc96 100644
--- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
+++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
@@ -233,10 +233,15 @@ def validate_args(args):
     type=int,
     help='Exponent bit width used with float quantization for activations (default: 3)')
 parser.add_argument(
-    '--accumulator-bit-width',
+    '--gpxq-accumulator-bit-width',
     default=None,
     type=int,
-    help='Accumulator Bit Width for GPFA2Q (default: None)')
+    help='Accumulator Bit Width for GPxQ (default: None)')
+parser.add_argument(
+    '--gpxq-accumulator-tile-size',
+    default=None,
+    type=int,
+    help='Accumulator tile size for GPxQ (default: None)')
 parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version')
 parser.add_argument(
     '--channel-splitting-ratio',
@@ -245,17 +250,20 @@ def validate_args(args):
     help=
     'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)'
 )
-parser.add_argument(
-    '--compression-rate',
-    default=0.0,
-    type=float,
-    help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.'
-)
 add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
 add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
-add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
 add_bool_arg(
     parser, 'gpxq-act-order', default=False, help='GPxQ Act order heuristic (default: disabled)')
+add_bool_arg(
+    parser,
+    'gptq-use-quant-activations',
+    default=False,
+    help='Use quant activations for GPTQ (default: disabled)')
+add_bool_arg(
+    parser,
+    'gpxq-create-weight-orig',
+    default=False,
+    help='Maintain original weights for non-quant forward pass (default: disabled)')
 add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)')
 add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)')
 add_bool_arg(
@@ -270,7 +278,7 @@ def validate_args(args):
     help='Merge BN layers before quantizing the model (default: enabled)')
 add_bool_arg(
     parser,
-    'uint_sym_act_for_unsigned_values',
+    'uint-sym-act-for-unsigned-values',
     default=True,
     help='Use unsigned act quant when possible (default: enabled)')
 add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)')
@@ -312,7 +320,6 @@ def main():
         f"w{args.weight_bit_width}_"
         f"{'gptq_' if args.gptq else ''}"
         f"{'gpfq_' if args.gpfq else ''}"
-        f"{'gpfa2q_' if args.gpfa2q else ''}"
         f"{'gpxq_act_order_' if args.gpxq_act_order else ''}"
         f"{'learned_round_' if args.learned_round else ''}"
         f"{'weight_narrow_range_' if args.weight_narrow_range else ''}"
@@ -335,10 +342,8 @@ def main():
         f"Weight bit width: {args.weight_bit_width} - "
         f"GPTQ: {args.gptq} - "
         f"GPFQ: {args.gpfq} - "
-        f"GPFA2Q: {args.gpfa2q} - "
-        f"GPFQ P: {args.gpfq_p} - "
         f"GPxQ Act Order: {args.gpxq_act_order} - "
-        f"GPFA2Q Accumulator Bit Width: {args.accumulator_bit_width} - "
+        f"GPxQ Accumulator Bit Width: {args.gpxq_accumulator_bit_width} - "
         f"Learned Round: {args.learned_round} - "
         f"Weight narrow range: {args.weight_narrow_range} - "
         f"Bias bit width: {args.bias_bit_width} - "
@@ -412,7 +417,9 @@ def main():
     if args.act_equalization is not None:
         print("Applying activation equalization:")
         apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise')
+
     device = next(iter(model.parameters())).device
+
     # Define the quantized model
     quant_model = quantize_model(
         model,
@@ -452,24 +459,21 @@ def main():
         apply_gpfq(
             calib_loader,
             quant_model,
-            p=args.gpfq_p,
             act_order=args.gpxq_act_order,
-            compression_rate=args.compression_rate)
+            create_weight_orig=args.gpxq_create_weight_orig,
+            max_accumulator_bit_width=args.gpxq_accumulator_bit_width,
+            max_accumulator_tile_size=args.gpxq_accumulator_tile_size)
 
-    if args.gpfa2q:
-        print("Performing GPFA2Q:")
-        apply_gpfq(
+    if args.gptq:
+        print("Performing GPTQ:")
+        apply_gptq(
             calib_loader,
             quant_model,
-            p=args.gpfq_p,
             act_order=args.gpxq_act_order,
-            use_gpfa2q=args.gpfa2q,
-            accumulator_bit_width=args.accumulator_bit_width,
-            compression_rate=args.compression_rate)
-
-    if args.gptq:
-        print("Performing GPTQ:")
-        apply_gptq(calib_loader, quant_model, act_order=args.gpxq_act_order)
+            use_quant_activations=args.gptq_use_quant_activations,
+            create_weight_orig=args.gpxq_create_weight_orig,
+            max_accumulator_bit_width=args.gpxq_accumulator_bit_width,
+            max_accumulator_tile_size=args.gpxq_accumulator_tile_size)
 
     if args.learned_round:
         print("Applying Learned Round:")
diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md
index 457c74804..5cd067e64 100644
--- a/src/brevitas_examples/llm/README.md
+++ b/src/brevitas_examples/llm/README.md
@@ -34,6 +34,9 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
                [--input-quant-granularity {per_tensor,per_row,per_group}]
                [--input-group-size INPUT_GROUP_SIZE]
                [--quantize-input-zero-point] [--quantize-last-layer] [--gptq]
+               [--gpfq] [--gpxq-act-order] [--gpxq-use-quant-activations] [--gpxq-create-weight-orig]
+               [--gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH]
+               [--gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE]
                [--act-calibration] [--bias-corr] [--ln-affine-merge]
                [--no-quantize] [--no-float16] [--replace-mha]
                [--weight-equalization]
@@ -105,6 +108,16 @@ options:
   --quantize-last-layer
                         Quantize last nn.Linear layer.
   --gptq                Apply GPTQ.
+  --gpfq                Apply GPFQ.
+  --gpxq-act-order      Apply GPxQ activation ordering.
+  --gpxq-use-quant-activations
+                        Use quantized activations in GPxQ.
+  --gpxq-create-weight-orig
+                        Create weight_orig in GPxQ.
+  --gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH
+                        Maximum accumulator bit width for GPxQ using AXE.
+  --gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE
+                        Maximum accumulator tile size for GPxQ using AXE.
   --act-calibration     Apply activation calibration.
   --bias-corr           Apply bias correction.
   --ln-affine-merge     Merge LN affine params.
diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py
index 44b99772f..5e61306d4 100644
--- a/src/brevitas_examples/llm/llm_quant/gpxq.py
+++ b/src/brevitas_examples/llm/llm_quant/gpxq.py
@@ -1,9 +1,8 @@
-"""
-Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
+# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
 # SPDX-License-Identifier: BSD-3-Clause
-"""
 
 from copy import deepcopy
+from functools import partial
 
 from accelerate.utils.operations import send_to_device
 import torch
@@ -13,9 +12,13 @@
 from brevitas.graph.calibrate import DisableEnableQuantization
 from brevitas.graph.calibrate import restore_return_quant_tensor
 from brevitas.graph.gpfq import gpfq_mode
+from brevitas.graph.gpfq import GPFQv2
+from brevitas.graph.gptq import GPTQ
 from brevitas.graph.gptq import gptq_mode
 from brevitas.graph.gpxq import StopFwdException
 from brevitas.utils.python_utils import recurse_getattr
+from brevitas_examples.common.axe import A2GPFQ
+from brevitas_examples.common.axe import A2GPTQ
 
 
 @torch.no_grad()
@@ -109,20 +112,33 @@ def apply_gptq(
         use_quant_activations=False,
         create_weight_orig=False,
         group_of_parallel_layers=None,
-        block_name=None):
+        block_name=None,
+        max_accumulator_bit_width=None,
+        max_accumulator_tile_size=128):
+    if max_accumulator_bit_width is not None:
+        # Use accumulator-aware extension (AXE) framework
+        print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...")
+        gptq_class = partial(
+            A2GPTQ,
+            max_accumulator_bit_width=max_accumulator_bit_width,
+            max_accumulator_tile_size=max_accumulator_tile_size)
+    else:
+        gptq_class = GPTQ
     if block_name is not None:
         context_manager_kwargs = {
             'act_order': act_order,
             'group_of_parallel_layers': group_of_parallel_layers,
             'create_weight_orig': create_weight_orig,
-            'use_quant_activations': use_quant_activations}
+            'use_quant_activations': use_quant_activations,
+            'gptq_class': gptq_class}
         block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs)
     else:
         with gptq_mode(model,
                        use_quant_activations=use_quant_activations,
                        group_of_parallel_layers=group_of_parallel_layers,
                        act_order=act_order,
-                       create_weight_orig=create_weight_orig) as gptq:
+                       create_weight_orig=create_weight_orig,
+                       gptq_class=gptq_class) as gptq:
             gptq_model = gptq.model
             for _ in tqdm(range(gptq.num_layers)):
                 for inps in dataloader:
@@ -131,14 +147,36 @@ def apply_gptq(
 
 
 @torch.no_grad()
-def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
+def apply_gpfq(
+        model,
+        dataloader,
+        act_order=True,
+        group_of_parallel_layers=None,
+        block_name=None,
+        max_accumulator_bit_width=None,
+        max_accumulator_tile_size=128):
+    if max_accumulator_bit_width is not None:
+        # Use accumulator-aware extension (AXE) framework
+        print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...")
+        gpfq_class = partial(
+            A2GPFQ,
+            max_accumulator_bit_width=max_accumulator_bit_width,
+            max_accumulator_tile_size=max_accumulator_tile_size)
+    else:
+        gpfq_class = GPFQv2
     if block_name is not None:
-        raise RuntimeError("Block optimization not support for GPFQ at the moment")
+        context_manager_kwargs = {
+            'act_order': act_order,
+            'group_of_parallel_layers': group_of_parallel_layers,
+            'create_weight_orig': True,
+            'gpfq_class': gpfq_class}
+        block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs)
     else:
         with gpfq_mode(model,
                        act_order=act_order,
                        group_of_parallel_layers=group_of_parallel_layers,
-                       create_weight_orig=True) as gpfq:
+                       create_weight_orig=True,
+                       gpfq_class=gpfq_class) as gpfq:
             gpfq_model = gpfq.model
             for _ in tqdm(range(gpfq.num_layers)):
                 for inps in dataloader:
diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py
index bf995a426..4a87f5a1a 100644
--- a/src/brevitas_examples/llm/main.py
+++ b/src/brevitas_examples/llm/main.py
@@ -1,7 +1,5 @@
-"""
-Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
+# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
 # SPDX-License-Identifier: BSD-3-Clause
-"""
 
 import argparse
 import sys
@@ -74,6 +72,20 @@ def validate(args):
     if not args.no_quantize:
         if args.gptq and args.gpfq:
             warn("Both GPTQ and GPFQ are enabled.")
+        if args.gpxq_max_accumulator_bit_width is not None:
+            assert args.weight_quant_format == 'int', "AXE only supports integer formats."
+            assert args.input_quant_format == 'int', "AXE only supports integer formats."
+            assert args.input_bit_width is not None, \
+                "Specify input bit width; activation quantization is required to guarantee accumulator bounds."
+            if not (args.gptq or args.gpfq):
+                warn("Max accumulator bit width is specified, but no GPxQ is enabled.")
+            if args.gpxq_max_accumulator_tile_size is not None:
+                if args.weight_quant_granularity == 'per_group':
+                    assert args.gpxq_max_accumulator_tile_size == args.weight_group_size, \
+                        "Group size must be equal to tile size with per_group quantization."
+                if args.input_quant_granularity == 'per_group':
+                    assert args.gpxq_max_accumulator_tile_size == args.input_group_size, \
+                        "Group size must be equal to tile size with per_group quantization."
         if args.export_target is not None:
             assert args.input_quant_format == 'int', "Only integer quantization supported for export currently."
         if args.export_target is not None and args.input_bit_width is not None:
@@ -158,8 +170,7 @@ def main(args):
         seed=args.seed,
         require_fx=require_fx,
         device=None,
-        fuse_sequences=args.fuse_sequences,
-    )
+        fuse_sequences=args.fuse_sequences)
 
     validation_loader = get_dataset_for_model(
         args.model,
@@ -171,8 +182,7 @@ def main(args):
         seed=args.seed,
         require_fx=require_fx,
         device=None,
-        fuse_sequences=args.fuse_sequences,
-    )
+        fuse_sequences=args.fuse_sequences)
 
     device = next(iter(model.parameters())).device
     print("Data loaded.")
@@ -287,7 +297,9 @@ def main(args):
             act_order=args.gpxq_act_order,
             use_quant_activations=args.gpxq_use_quant_activations,
             create_weight_orig=args.gpxq_create_weight_orig,
-            block_name=args.gpxq_block_name)
+            block_name=args.gpxq_block_name,
+            max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
+            max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
         print("GPTQ applied.")
 
     if args.gpfq:
@@ -296,7 +308,9 @@ def main(args):
             model,
             calibration_loader,
             act_order=args.gpxq_act_order,
-            block_name=args.gpxq_block_name)
+            block_name=args.gpxq_block_name,
+            max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
+            max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
         print("GPFQ applied.")
 
     if args.bias_corr:
@@ -304,7 +318,7 @@ def main(args):
         apply_bias_correction(model, calibration_loader)
         print("Bias correction applied.")
 
-    if args.eval:
+    if args.eval and not args.no_quantize:
         print("Model eval...")
         quant_ppl = compute_perplexity(
             model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
@@ -455,13 +469,23 @@ def parse_args(args):
     parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.')
     parser.add_argument('--gpfq', action='store_true', help='Apply GPFQ.')
     parser.add_argument(
-        '--gpxq-act-order', action='store_true', help='Apply GPXQ activation ordering.')
+        '--gpxq-act-order', action='store_true', help='Apply GPxQ activation ordering.')
     parser.add_argument(
         '--gpxq-use-quant-activations',
         action='store_true',
-        help='Use quantized activations in GPXQ.')
+        help='Use quantized activations in GPxQ.')
     parser.add_argument(
-        '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPXQ.')
+        '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPxQ.')
+    parser.add_argument(
+        '--gpxq-max-accumulator-bit-width',
+        type=int,
+        default=None,
+        help='Maximum accumulator bit width for GPxQ using AXE.')
+    parser.add_argument(
+        '--gpxq-max-accumulator-tile-size',
+        type=int,
+        default=None,
+        help='Maximum accumulator tile size for GPxQ using AXE.')
     parser.add_argument(
         '--act-calibration', action='store_true', help='Apply activation calibration.')
     parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.')