From 2f4730781ac35a5afc0a78bae774eb1a402f3682 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 30 Oct 2024 15:45:34 +0100 Subject: [PATCH] [Conformance][TorchFX] GPU quantization support --- .../smooth_quant/torch_fx_backend.py | 3 +- tests/post_training/conftest.py | 5 ++ .../data/ptq_reference_data.yaml | 8 ++++ tests/post_training/model_scope.py | 15 ++++-- tests/post_training/pipelines/base.py | 13 +++++- .../pipelines/image_classification_base.py | 46 +++++++++++++++---- .../image_classification_torchvision.py | 20 ++++++-- .../test_quantize_conformance.py | 7 +++ 8 files changed, 95 insertions(+), 22 deletions(-) diff --git a/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py index 6dee64469ed..2584d0af9cb 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py @@ -42,7 +42,8 @@ class FXSQMultiply(torch.nn.Module): def __init__(self, scale: torch.Tensor): super().__init__() - self._scale_value = scale + self.register_buffer("_scale_value", scale) + self._scale_value: torch.Tensor def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mul(x, self._scale_value) diff --git a/tests/post_training/conftest.py b/tests/post_training/conftest.py index a05d20d0d5c..6bdae739ff6 100644 --- a/tests/post_training/conftest.py +++ b/tests/post_training/conftest.py @@ -19,6 +19,11 @@ def pytest_addoption(parser): parser.addoption("--fp32", action="store_true", help="Test original model") parser.addoption("--cuda", action="store_true", help="Enable CUDA_TORCH backend") parser.addoption("--benchmark", action="store_true", help="Run benchmark_app") + parser.addoption( + "--validate-in-backend", + action="store_true", + help="Validate quantized model in native backend, not in openvino.", + ) parser.addoption( "--extra-columns", action="store_true", diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index 94f70f3a931..d4d63db76e1 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -38,6 +38,8 @@ torchvision/resnet18_backend_CUDA_TORCH: metric_value: 0.69152 torchvision/resnet18_backend_FX_TORCH: metric_value: 0.6946 +torchvision/resnet18_backend_CUDA_FX_TORCH: + metric_value: 0.6946 torchvision/mobilenet_v3_small_BC_backend_FP32: metric_value: 0.6766 torchvision/mobilenet_v3_small_BC_backend_OV: @@ -46,18 +48,24 @@ torchvision/mobilenet_v3_small_BC_backend_ONNX: metric_value: 0.6679 torchvision/mobilenet_v3_small_BC_backend_FX_TORCH: metric_value: 0.6679 +torchvision/mobilenet_v3_small_BC_backend_CUDA_FX_TORCH: + metric_value: 0.6679 torchvision/vit_b_16_backend_FP32: metric_value: 0.8107 torchvision/vit_b_16_backend_OV: metric_value: 0.80948 torchvision/vit_b_16_backend_FX_TORCH: metric_value: 0.80922 +torchvision/vit_b_16_backend_CUDA_FX_TORCH: + metric_value: 0.80922 torchvision/swin_v2_s_backend_FP32: metric_value: 0.83712 torchvision/swin_v2_s_backend_OV: metric_value: 0.83638 torchvision/swin_v2_s_backend_FX_TORCH: metric_value: 0.8360 +torchvision/swin_v2_s_backend_CUDA_FX_TORCH: + metric_value: 0.8360 timm/crossvit_9_240_backend_CUDA_TORCH: metric_value: 0.7275 timm/crossvit_9_240_backend_FP32: diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index 54b49d63a21..9cdb4bda608 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -87,7 +87,14 @@ "model_id": "resnet18", "pipeline_cls": ImageClassificationTorchvision, "compression_params": {}, - "backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX], + "backends": [ + BackendType.FX_TORCH, + BackendType.CUDA_FX_TORCH, + BackendType.TORCH, + BackendType.CUDA_TORCH, + BackendType.OV, + BackendType.ONNX, + ], "batch_size": 128, }, { @@ -98,7 +105,7 @@ "fast_bias_correction": False, "preset": QuantizationPreset.MIXED, }, - "backends": [BackendType.FX_TORCH, BackendType.OV, BackendType.ONNX], + "backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV, BackendType.ONNX], "batch_size": 128, }, { @@ -109,7 +116,7 @@ "model_type": ModelType.TRANSFORMER, "advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.15), }, - "backends": [BackendType.FX_TORCH, BackendType.OV], + "backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV], "batch_size": 1, }, { @@ -120,7 +127,7 @@ "model_type": ModelType.TRANSFORMER, "advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.5), }, - "backends": [BackendType.FX_TORCH, BackendType.OV], + "backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV], "batch_size": 1, }, # Timm models diff --git a/tests/post_training/pipelines/base.py b/tests/post_training/pipelines/base.py index 41e9bcf17e4..8912a5697c2 100644 --- a/tests/post_training/pipelines/base.py +++ b/tests/post_training/pipelines/base.py @@ -44,6 +44,7 @@ class BackendType(Enum): TORCH = "TORCH" CUDA_TORCH = "CUDA_TORCH" FX_TORCH = "FX_TORCH" + CUDA_FX_TORCH = "CUDA_FX_TORCH" ONNX = "ONNX" OV = "OV" OPTIMUM = "OPTIMUM" @@ -52,6 +53,7 @@ class BackendType(Enum): NNCF_PTQ_BACKENDS = [BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.ONNX, BackendType.OV] ALL_PTQ_BACKENDS = NNCF_PTQ_BACKENDS PT_BACKENDS = [BackendType.TORCH, BackendType.CUDA_TORCH] +FX_BACKENDS = [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH] OV_BACKENDS = [BackendType.OV, BackendType.OPTIMUM] LIMIT_LENGTH_OF_STATUS = 120 @@ -211,6 +213,7 @@ def __init__( reference_data: dict, no_eval: bool, run_benchmark_app: bool, + validate_in_backend: bool = False, params: dict = None, batch_size: int = 1, memory_monitor: bool = False, @@ -227,6 +230,7 @@ def __init__( self.memory_monitor = memory_monitor self.no_eval = no_eval self.run_benchmark_app = run_benchmark_app + self.validate_in_backend = validate_in_backend self.output_model_dir: Path = self.output_dir / self.reported_name / self.backend.value self.output_model_dir.mkdir(parents=True, exist_ok=True) self.model_name = f"{self.reported_name}_{self.backend.value}" @@ -405,11 +409,16 @@ def save_compressed_model(self) -> None: ) self.path_compressed_ir = self.output_model_dir / "model.xml" ov.serialize(ov_model, self.path_compressed_ir) - elif self.backend == BackendType.FX_TORCH: - exported_model = torch.export.export(self.compressed_model, (self.dummy_tensor,)) + elif self.backend in FX_BACKENDS: + exported_model = torch.export.export(self.compressed_model.cpu(), (self.dummy_tensor.cpu(),)) ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor.cpu(), input=self.input_size) self.path_compressed_ir = self.output_model_dir / "model.xml" ov.serialize(ov_model, self.path_compressed_ir) + + if BackendType.CUDA_FX_TORCH: + self.model = self.model.cuda() + self.dummy_tensor = self.dummy_tensor.cuda() + elif self.backend == BackendType.ONNX: onnx_path = self.output_model_dir / "model.onnx" onnx.save(self.compressed_model, str(onnx_path)) diff --git a/tests/post_training/pipelines/image_classification_base.py b/tests/post_training/pipelines/image_classification_base.py index 22e60a5ae3b..d509a0aac1b 100644 --- a/tests/post_training/pipelines/image_classification_base.py +++ b/tests/post_training/pipelines/image_classification_base.py @@ -21,6 +21,7 @@ import nncf from nncf.common.logging.track_progress import track from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS +from tests.post_training.pipelines.base import FX_BACKENDS from tests.post_training.pipelines.base import PTQTestPipeline @@ -33,18 +34,15 @@ def prepare_calibration_dataset(self): self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn()) - def _validate(self): - val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform) - val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) - - dataset_size = len(val_loader) - - # Initialize result tensors for async inference support. - predictions = np.zeros((dataset_size)) - references = -1 * np.ones((dataset_size)) + def _validate_ov( + self, + val_loader: torch.utils.data.DataLoader, + predictions: np.ndarray, + references: np.ndarray, + dataset_size: int, + ): core = ov.Core() - if os.environ.get("INFERENCE_NUM_THREADS"): # Set CPU_THREADS_NUM for OpenVINO inference inference_num_threads = os.environ.get("INFERENCE_NUM_THREADS") @@ -73,6 +71,34 @@ def process_result(request, userdata): references[i] = target infer_queue.wait_all() + return predictions, references + + def _validate_torch_compile( + self, val_loader: torch.utils.data.DataLoader, predictions: np.ndarray, references: np.ndarray + ): + compiled_model = torch.compile(self.compressed_model.cpu(), backend="openvino") + for i, (images, target) in enumerate(val_loader): + # W/A for memory leaks when using torch DataLoader and OpenVINO + pred = compiled_model(images) + pred = torch.argmax(pred, dim=1) + predictions[i] = pred.numpy() + references[i] = target.numpy() + return predictions, references + + def _validate(self): + val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) + + dataset_size = len(val_loader) + + # Initialize result tensors for async inference support. + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + if self.validate_in_backend and self.backend in FX_BACKENDS: + predictions, references = self._validate_torch_compile(val_loader, predictions, references) + else: + predictions, references = self._validate_ov(val_loader, predictions, references, dataset_size) acc_top1 = accuracy_score(predictions, references) diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py index b675bd6e814..de73aa09512 100644 --- a/tests/post_training/pipelines/image_classification_torchvision.py +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -20,6 +20,7 @@ from torchvision import models from nncf.torch import disable_patching +from tests.post_training.pipelines.base import FX_BACKENDS from tests.post_training.pipelines.base import PT_BACKENDS from tests.post_training.pipelines.base import BackendType from tests.post_training.pipelines.image_classification_base import ImageClassificationBase @@ -75,9 +76,12 @@ def prepare_model(self) -> None: if self.batch_size > 1: # Dynamic batch_size shape export self.input_size[0] = -1 - if self.backend == BackendType.FX_TORCH: + if self.backend in FX_BACKENDS: with torch.no_grad(): with disable_patching(): + if self.backend is BackendType.CUDA_FX_TORCH: + model = model.cuda() + self.dummy_tensor = self.dummy_tensor.cuda() self.model = self.model_params.export_fn(model, (self.dummy_tensor,)) elif self.backend in PT_BACKENDS: @@ -121,11 +125,15 @@ def _dump_model_fp32(self) -> None: ) ov.serialize(ov_model, self.fp32_model_dir / "model_fp32.xml") - if self.backend == BackendType.FX_TORCH: - exported_model = torch.export.export(self.model, (self.dummy_tensor,)) + if self.backend in FX_BACKENDS: + exported_model = torch.export.export(self.model.cpu(), (self.dummy_tensor.cpu(),)) ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size) ov.serialize(ov_model, self.fp32_model_dir / "fx_model_fp32.xml") + if self.backend is BackendType.CUDA_FX_TORCH: + self.model = self.model.cuda() + self.dummy_tensor = self.dummy_tensor.cuda() + if self.backend in [BackendType.FP32, BackendType.OV]: ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml") @@ -133,8 +141,10 @@ def prepare_preprocessor(self) -> None: self.transform = self.model_params.weights.transforms() def get_transform_calibration_fn(self): - if self.backend in [BackendType.FX_TORCH] + PT_BACKENDS: - device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu") + if self.backend in FX_BACKENDS + PT_BACKENDS: + device = torch.device( + "cuda" if self.backend in [BackendType.CUDA_TORCH, BackendType.CUDA_FX_TORCH] else "cpu" + ) def transform_fn(data_item): images, _ = data_item diff --git a/tests/post_training/test_quantize_conformance.py b/tests/post_training/test_quantize_conformance.py index 20504a8b086..075b838e6d7 100644 --- a/tests/post_training/test_quantize_conformance.py +++ b/tests/post_training/test_quantize_conformance.py @@ -75,6 +75,11 @@ def fixture_run_benchmark_app(pytestconfig): return pytestconfig.getoption("benchmark") +@pytest.fixture(scope="session", name="validate_in_backend") +def fixture_validate_in_backend(pytestconfig): + return pytestconfig.getoption("validate_in_backend") + + @pytest.fixture(scope="session", name="extra_columns") def fixture_extra_columns(pytestconfig): return pytestconfig.getoption("extra_columns") @@ -266,6 +271,7 @@ def test_ptq_quantization( run_torch_cuda_backend: bool, subset_size: Optional[int], run_benchmark_app: bool, + validate_in_backend: bool, capsys: pytest.CaptureFixture, extra_columns: bool, memory_monitor: bool, @@ -293,6 +299,7 @@ def test_ptq_quantization( "data_dir": data_dir, "no_eval": no_eval, "run_benchmark_app": run_benchmark_app, + "validate_in_backend": validate_in_backend, "batch_size": batch_size, "memory_monitor": memory_monitor, }