Skip to content

Commit

Permalink
Clean up synthetic weights/inputs generation integration in pulp-nnx
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescoConti committed Aug 21, 2024
1 parent 973cd49 commit 13dd71f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 28 deletions.
47 changes: 34 additions & 13 deletions test/NnxTestClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class NnxTestConf(BaseModel):
has_norm_quant: bool
has_bias: bool
has_relu: bool
synthetic_weights: bool
synthetic_inputs: bool

@model_validator(mode="after") # type: ignore
def check_valid_depthwise_channels(self) -> NnxTestConf:
Expand Down Expand Up @@ -116,6 +118,8 @@ def __init__(
scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
global_shift: Optional[torch.Tensor] = torch.Tensor([0]),
synthetic_weights: Optional[bool] = False,
synthetic_inputs: Optional[bool] = False,
) -> None:
self.conf = conf
self.input = input
Expand All @@ -124,6 +128,8 @@ def __init__(
self.scale = scale
self.bias = bias
self.global_shift = global_shift
self.synthetic_weights = synthetic_weights
self.synthetic_inputs = synthetic_inputs

def is_valid(self) -> bool:
return all(
Expand Down Expand Up @@ -243,20 +249,30 @@ def from_conf(
bias_shape = (1, conf.out_channel, 1, 1)

if input is None:
input = NnxTestGenerator._random_data(
_type=conf.in_type,
shape=input_shape,
)
if conf.synthetic_inputs:
inputs = torch.zeros((1, conf.in_channel, conf.in_height, conf.in_width), dtype=torch.int64)
for i in range(conf.in_channel):
inputs[:, i,0,0] = i
else:
input = NnxTestGenerator._random_data(
_type=conf.in_type,
shape=input_shape,
)

if weight is None:
weight_mean = NnxTestGenerator._DEFAULT_WEIGHT_MEAN
weight_std = NnxTestGenerator._DEFAULT_WEIGHT_STDEV * (1<<(conf.weight_type._bits-1)-1)
weight = NnxTestGenerator._random_data_normal(
mean = weight_mean,
std = weight_std,
_type=conf.weight_type,
shape=weight_shape,
)
if conf.synthetic_weights:
weight = torch.zeros((conf.out_channel, 1 if conf.depthwise else conf.in_channel, conf.kernel_shape.height, conf.kernel_shape.width), dtype=torch.int64)
for i in range(0, min(weight.shape[0], weight.shape[1])):
weight[i,i,0,0] = 1
else:
weight_mean = NnxTestGenerator._DEFAULT_WEIGHT_MEAN
weight_std = NnxTestGenerator._DEFAULT_WEIGHT_STDEV * (1<<(conf.weight_type._bits-1)-1)
weight = NnxTestGenerator._random_data_normal(
mean = weight_mean,
std = weight_std,
_type=conf.weight_type,
shape=weight_shape,
)

if conf.has_norm_quant:
if scale is None:
Expand Down Expand Up @@ -306,6 +322,8 @@ def from_conf(
scale=scale,
bias=bias,
global_shift=global_shift,
synthetic_inputs=conf.synthetic_inputs,
synthetic_weights=conf.synthetic_weights,
)

@staticmethod
Expand Down Expand Up @@ -361,7 +379,10 @@ def generate(self, test_name: str, test: NnxTest):
weight_type = test.conf.weight_type
weight_bits = weight_type._bits
assert weight_bits > 1 and weight_bits <= 8
weight_offset = -(2 ** (weight_bits - 1))
if test.synthetic_weights:
weight_offset = 0
else:
weight_offset = -(2 ** (weight_bits - 1))
weight_out_ch, weight_in_ch, weight_ks_h, weight_ks_w = test.weight.shape
weight_data: np.ndarray = test.weight.numpy() - weight_offset
weight_init = self.weightEncode(
Expand Down
16 changes: 1 addition & 15 deletions test/testgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,7 @@ def test_gen(
exit(-1)

test_conf = nnxTestConfCls.model_validate(test_conf_dict)
if test_conf_dict['synthetic_weights']:
import torch
weight = torch.zeros((test_conf.out_channel, 1 if test_conf.depthwise else test_conf.in_channel, test_conf.kernel_shape.height, test_conf.kernel_shape.width), dtype=torch.int64)
for i in range(0, min(weight.shape[0], weight.shape[1])):
weight[i,i,0,0] = 1
else:
weight = None
if test_conf_dict['synthetic_inputs']:
import torch
inputs = torch.zeros((1, test_conf.in_channel, test_conf.in_height, test_conf.in_width), dtype=torch.int64)
for i in range(test_conf.in_channel):
inputs[:, i,0,0] = i
else:
inputs = None
test = NnxTestGenerator.from_conf(test_conf, verbose=args.print_tensors, weight=weight, input=inputs)
test = NnxTestGenerator.from_conf(test_conf, verbose=args.print_tensors)
if not args.skip_save:
test.save(args.test_dir)
if args.headers:
Expand Down

0 comments on commit 13dd71f

Please sign in to comment.