diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 6f9980c006ac..009f38cb6414 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -453,21 +453,19 @@ def test_wrong_config(self): self.get_dummy_components(TorchAoConfig("int42")) -# This class is not to be run as a test by itself. See the tests that follow this class +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" - quant_method, quant_method_kwargs = None, None - device = "cuda" def tearDown(self): gc.collect() torch.cuda.empty_cache() - def get_dummy_model(self, device=None): - quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs) + def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): + quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) quantized_model = FluxTransformer2DModel.from_pretrained( self.model_name, subfolder="transformer", @@ -503,15 +501,15 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0): "timestep": timestep, } - def test_original_model_expected_slice(self): - quantized_model = self.get_dummy_model(torch_device) + def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice): + quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3)) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) - def check_serialization_expected_slice(self, expected_slice): - quantized_model = self.get_dummy_model(self.device) + def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): + quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device) with tempfile.TemporaryDirectory() as tmp_dir: quantized_model.save_pretrained(tmp_dir, safe_serialization=False) @@ -530,36 +528,33 @@ def check_serialization_expected_slice(self, expected_slice): ) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) - def test_serialization_expected_slice(self): - self.check_serialization_expected_slice(self.serialized_expected_slice) - - -class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - serialized_expected_slice = expected_slice - device = "cuda" - - -class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_weight_only", {} - expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) - serialized_expected_slice = expected_slice - device = "cuda" - - -class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - serialized_expected_slice = expected_slice - device = "cpu" - - -class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): - quant_method, quant_method_kwargs = "int8_weight_only", {} - expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) - serialized_expected_slice = expected_slice - device = "cpu" + def test_int_a8w8_cuda(self): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = "cuda" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + + def test_int_a16w8_cuda(self): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + device = "cuda" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + + def test_int_a8w8_cpu(self): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = "cpu" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + + def test_int_a16w8_cpu(self): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + device = "cpu" + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners