Skip to content

Commit

Permalink
Feat (GPFA2Q): clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Dec 21, 2023
1 parent 376c759 commit 3aa98cb
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 11 deletions.
9 changes: 5 additions & 4 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,10 @@ def __init__(
layer,
name,
act_order,
len_parallel_layers=1,
create_weight_orig=True,
accumulator_bit_width=None,
p=1.0) -> None:
len_parallel_layers,
create_weight_orig,
accumulator_bit_width,
p) -> None:
GPFQ.__init__(
self,
layer=layer,
Expand All @@ -295,6 +295,7 @@ def __init__(
create_weight_orig=create_weight_orig,
p=p)
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None

def single_layer_update(self):
# raise error in case no quant-input is here
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import List, Optional, Set
import warnings

import torch

from brevitas.graph.calibrate import DisableEnableQuantization
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor
Expand Down Expand Up @@ -175,7 +177,8 @@ def process_input(self, inp):
if self.layer.weight_quant_requires_quant_input:
# Can minimize memory allocation by not storing actual values
self.quant_input = QuantTensor(
value=None,
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp.scale,
zero_point=inp.zero_point,
bit_width=inp.bit_width,
Expand All @@ -184,7 +187,8 @@ def process_input(self, inp):
inp = inp.value
elif self.layer.is_input_quant_enabled:
self.quant_input = QuantTensor(
value=None,
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=self.layer.quant_input_scale(),
zero_point=self.layer.quant_input_zero_point(),
bit_width=self.layer.quant_input_bit_width(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import random
from types import SimpleNamespace
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -38,7 +37,6 @@
from brevitas_examples.imagenet_classification.utils import validate

config.IGNORE_MISSING_KEYS = True
warnings.filterwarnings("ignore")


def parse_type(v, default_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=16):
def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=None):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@
add_bool_arg(
parser,
'weight-narrow-range',
default=True,
help='Narrow range for weight quantization (default: enabled)')
default=False,
help='Narrow range for weight quantization (default: disabled)')
parser.add_argument('--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 1.0)')
parser.add_argument(
'--quant-format',
Expand Down

0 comments on commit 3aa98cb

Please sign in to comment.