From e36faa481bce03306367ecb8abe1404ebbf02f38 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 4 Dec 2023 10:50:10 +0000 Subject: [PATCH] Fix (ptq/benchmark): better dataframe creation --- .../benchmark/ptq_benchmark_torchvision.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index b88db3e3a..75bbd1c1a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -270,18 +270,16 @@ def ptq_torchvision_models(args): acc_diff = np.around(top1 - fp_accuracy, decimals=3) acc_ratio = np.around(top1 / fp_accuracy, decimals=3) - options_names = [k.replace('_', ' ').capitalize() for k in config_namespace.__dict__.keys()] - torchvision_df = pd.DataFrame( - columns=options_names + [ - 'Top 1% floating point accuracy', - 'Top 1% quant accuracy', - 'Floating point accuracy - quant accuracy', - 'Quant accuracy / floating point accuracy', - 'Calibration size', - 'Calibration batch size', - 'Torch version', - 'Brevitas version']) - torchvision_df.at[0, :] = [v for _, v in config_namespace.__dict__.items()] + [ + column_names = [k.replace('_', ' ').capitalize() for k in config_namespace.__dict__.keys()] + [ + 'Top 1% floating point accuracy', + 'Top 1% quant accuracy', + 'Floating point accuracy - quant accuracy', + 'Quant accuracy / floating point accuracy', + 'Calibration size', + 'Calibration batch size', + 'Torch version', + 'Brevitas version'] + values = [v for _, v in config_namespace.__dict__.items()] + [ fp_accuracy, top1, acc_diff, @@ -290,6 +288,8 @@ def ptq_torchvision_models(args): args.batch_size_calibration, torch_version, brevitas_version] + torchvision_df = pd.DataFrame([values], columns=column_names) + folder = './multirun/' + str(args.idx) os.makedirs(folder, exist_ok=True) torchvision_df.to_csv(os.path.join(folder, 'RESULTS_TORCHVISION.csv'), index=False)