From ed31b9d15b026e3713c179451841973d946ef7c7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 23 Oct 2023 09:53:27 +0100 Subject: [PATCH 1/7] Add opchecks --- test/test_ops.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 6d80f037b88..39a5fd07c5c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -117,8 +117,9 @@ class RoIOpTester(ABC): torch.float32, torch.float64, ), - ids=str, + # ids=str, ) + @pytest.mark.opcheck_only_one() def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs): if device == "mps" and x_dtype is torch.float64: pytest.skip("MPS does not support float64") @@ -186,6 +187,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.opcheck_only_one() def test_backward(self, seed, device, contiguous, deterministic=False): atol = self.mps_backward_atol if device == "mps" else 1e-05 dtype = self.mps_dtype if device == "mps" else self.dtype @@ -228,6 +230,7 @@ def func(z): @needs_cuda @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) + @pytest.mark.opcheck_only_one() def test_autocast(self, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) @@ -369,6 +372,15 @@ def test_boxes_shape(self): self._helper_boxes_shape(ops.ps_roi_pool) +optests.generate_opcheck_tests( + testcase=TestPSRoIPool, + namespaces=["torchvision"], + failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"), + additional_decorators=[], + test_utils=OPTESTS, +) + + def bilinear_interpolate(data, y, x, snap_border=False): height, width = data.shape From 0a4112152ca816472daa7041ccc972be369aa05e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 23 Oct 2023 10:01:36 +0100 Subject: [PATCH 2/7] Add Meta implem --- torchvision/_meta_registrations.py | 34 ++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 7baece2ae2c..8f653661159 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -51,6 +51,40 @@ def meta_roi_align_backward( return grad.new_empty((batch_size, channels, height, width)) +@register_meta("ps_roi_pool") +def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width): + torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") + torch._check( + input.dtype == rois.dtype, + lambda: ( + "Expected tensor for input to have the same type as tensor for rois; " + f"but type {input.dtype} does not equal {rois.dtype}" + ), + ) + channels = input.size(1) + torch._check( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width", + ) + num_rois = rois.size(0) + out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width) + return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32) + + +@register_meta("_ps_roi_pool_backward") +def meta_ps_roi_pool_backward( + grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width +): + torch._check( + grad.dtype == rois.dtype, + lambda: ( + "Expected tensor for grad to have the same type as tensor for rois; " + f"but type {grad.dtype} does not equal {rois.dtype}" + ), + ) + return grad.new_empty((batch_size, channels, height, width)) + + @torch._custom_ops.impl_abstract("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") From be83cf2969279d625759cc3585b50378f84f53b0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 23 Oct 2023 10:11:16 +0100 Subject: [PATCH 3/7] Add SymInt support --- .../csrc/ops/autograd/ps_roi_pool_kernel.cpp | 54 +++++++++---------- torchvision/csrc/ops/ps_roi_pool.cpp | 47 +++++++++++++++- torchvision/csrc/ops/ps_roi_pool.h | 19 +++++++ 3 files changed, 91 insertions(+), 29 deletions(-) diff --git a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp index ddc37262382..d6225419325 100644 --- a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp @@ -15,15 +15,15 @@ class PSROIPoolFunction : public torch::autograd::Function { const torch::autograd::Variable& input, const torch::autograd::Variable& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sizes(); + ctx->saved_data["input_shape"] = input.sym_sizes(); at::AutoDispatchBelowADInplaceOrView g; auto result = - ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width); + ps_roi_pool_symint(input, rois, spatial_scale, pooled_height, pooled_width); auto output = std::get<0>(result); auto channel_mapping = std::get<1>(result); @@ -40,18 +40,18 @@ class PSROIPoolFunction : public torch::autograd::Function { auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = detail::_ps_roi_pool_backward( + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_pool_backward_symint( grad_output[0], rois, channel_mapping, ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toInt(), - ctx->saved_data["pooled_width"].toInt(), - input_shape[0], - input_shape[1], - input_shape[2], - input_shape[3]); + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); return { grad_in, @@ -72,14 +72,14 @@ class PSROIPoolBackwardFunction const torch::autograd::Variable& rois, const torch::autograd::Variable& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_pool_backward( + auto grad_in = detail::_ps_roi_pool_backward_symint( grad, rois, channel_mapping, @@ -105,8 +105,8 @@ std::tuple ps_roi_pool_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width) { auto result = PSROIPoolFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width); @@ -118,12 +118,12 @@ at::Tensor ps_roi_pool_backward_autograd( const at::Tensor& rois, const at::Tensor& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { return PSROIPoolBackwardFunction::apply( grad, rois, diff --git a/torchvision/csrc/ops/ps_roi_pool.cpp b/torchvision/csrc/ops/ps_roi_pool.cpp index c9f64661033..ff33d434de0 100644 --- a/torchvision/csrc/ops/ps_roi_pool.cpp +++ b/torchvision/csrc/ops/ps_roi_pool.cpp @@ -20,6 +20,20 @@ std::tuple ps_roi_pool( return op.call(input, rois, spatial_scale, pooled_height, pooled_width); } +std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + + namespace detail { at::Tensor _ps_roi_pool_backward( @@ -50,13 +64,42 @@ at::Tensor _ps_roi_pool_backward( width); } +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + + } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)")); + "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor")); + "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); } } // namespace ops diff --git a/torchvision/csrc/ops/ps_roi_pool.h b/torchvision/csrc/ops/ps_roi_pool.h index 20c2511e7aa..4a3cc54e0e5 100644 --- a/torchvision/csrc/ops/ps_roi_pool.h +++ b/torchvision/csrc/ops/ps_roi_pool.h @@ -13,6 +13,13 @@ VISION_API std::tuple ps_roi_pool( int64_t pooled_height, int64_t pooled_width); +VISION_API std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width); + namespace detail { at::Tensor _ps_roi_pool_backward( @@ -27,6 +34,18 @@ at::Tensor _ps_roi_pool_backward( int64_t height, int64_t width); +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + } // namespace detail } // namespace ops From 843f1a662f9e41cb5f031ae916171df40d13398e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 25 Oct 2023 14:22:00 +0100 Subject: [PATCH 4/7] Fix lint --- torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp index d6225419325..39b83819f94 100644 --- a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp @@ -22,8 +22,8 @@ class PSROIPoolFunction : public torch::autograd::Function { ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["input_shape"] = input.sym_sizes(); at::AutoDispatchBelowADInplaceOrView g; - auto result = - ps_roi_pool_symint(input, rois, spatial_scale, pooled_height, pooled_width); + auto result = ps_roi_pool_symint( + input, rois, spatial_scale, pooled_height, pooled_width); auto output = std::get<0>(result); auto channel_mapping = std::get<1>(result); From 830c6387ef0914a9d3ee025871f815dfd6e3ec66 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 25 Oct 2023 17:37:21 +0100 Subject: [PATCH 5/7] ADded some stuffffff --- test/optests_failures_dict.json | 26 ++++++++++++++++++++++++++ torchvision/csrc/ops/ps_roi_pool.cpp | 2 -- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/test/optests_failures_dict.json b/test/optests_failures_dict.json index 2d01571374f..30f748cab91 100644 --- a/test/optests_failures_dict.json +++ b/test/optests_failures_dict.json @@ -2,6 +2,32 @@ "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", "_version": 1, "data": { + "torchvision::ps_roi_pool": { + "TestPSRoIPool.test_aot_dispatch_dynamic__test_mps_error_inputs": { + "comment": "RuntimeError: MPS does not support ps_roi_align backward with float16 inputs", + "status": "xfail" + }, + "TestPSRoIPool.test_autograd_registration__test_backward[True-mps-0]": { + "comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}", + "status": "xfail" + }, + "TestPSRoIPool.test_autograd_registration__test_mps_error_inputs": { + "comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}", + "status": "xfail" + }, + "TestPSRoIPool.test_faketensor__test_backward[True-mps-0]": { + "comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!", + "status": "xfail" + }, + "TestPSRoIPool.test_faketensor__test_forward[x_dtype0-True-mps]": { + "comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!", + "status": "xfail" + }, + "TestPSRoIPool.test_faketensor__test_mps_error_inputs": { + "comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!", + "status": "xfail" + } + }, "torchvision::roi_align": { "TestRoIAlign.test_aot_dispatch_dynamic__test_mps_error_inputs": { "comment": "RuntimeError: MPS does not support roi_align backward with float16 inputs", diff --git a/torchvision/csrc/ops/ps_roi_pool.cpp b/torchvision/csrc/ops/ps_roi_pool.cpp index ff33d434de0..92469d5e380 100644 --- a/torchvision/csrc/ops/ps_roi_pool.cpp +++ b/torchvision/csrc/ops/ps_roi_pool.cpp @@ -33,7 +33,6 @@ std::tuple ps_roi_pool_symint( return op.call(input, rois, spatial_scale, pooled_height, pooled_width); } - namespace detail { at::Tensor _ps_roi_pool_backward( @@ -92,7 +91,6 @@ at::Tensor _ps_roi_pool_backward_symint( width); } - } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { From aadda3845822ef38fda4c740715b93dffbe27bee Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 27 Oct 2023 15:11:06 +0100 Subject: [PATCH 6/7] Remove optest as we can't xfail parametrized tests --- test/optests_failures_dict.json | 26 -------------------------- test/test_ops.py | 13 +------------ 2 files changed, 1 insertion(+), 38 deletions(-) diff --git a/test/optests_failures_dict.json b/test/optests_failures_dict.json index 30f748cab91..2d01571374f 100644 --- a/test/optests_failures_dict.json +++ b/test/optests_failures_dict.json @@ -2,32 +2,6 @@ "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", "_version": 1, "data": { - "torchvision::ps_roi_pool": { - "TestPSRoIPool.test_aot_dispatch_dynamic__test_mps_error_inputs": { - "comment": "RuntimeError: MPS does not support ps_roi_align backward with float16 inputs", - "status": "xfail" - }, - "TestPSRoIPool.test_autograd_registration__test_backward[True-mps-0]": { - "comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}", - "status": "xfail" - }, - "TestPSRoIPool.test_autograd_registration__test_mps_error_inputs": { - "comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}", - "status": "xfail" - }, - "TestPSRoIPool.test_faketensor__test_backward[True-mps-0]": { - "comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!", - "status": "xfail" - }, - "TestPSRoIPool.test_faketensor__test_forward[x_dtype0-True-mps]": { - "comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!", - "status": "xfail" - }, - "TestPSRoIPool.test_faketensor__test_mps_error_inputs": { - "comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!", - "status": "xfail" - } - }, "torchvision::roi_align": { "TestRoIAlign.test_aot_dispatch_dynamic__test_mps_error_inputs": { "comment": "RuntimeError: MPS does not support roi_align backward with float16 inputs", diff --git a/test/test_ops.py b/test/test_ops.py index 9430123c8d0..7773d4547c1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -117,9 +117,8 @@ class RoIOpTester(ABC): torch.float32, torch.float64, ), - # ids=str, + ids=str, ) - @pytest.mark.opcheck_only_one() def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs): if device == "mps" and x_dtype is torch.float64: pytest.skip("MPS does not support float64") @@ -187,7 +186,6 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) - @pytest.mark.opcheck_only_one() def test_backward(self, seed, device, contiguous, deterministic=False): atol = self.mps_backward_atol if device == "mps" else 1e-05 dtype = self.mps_dtype if device == "mps" else self.dtype @@ -372,15 +370,6 @@ def test_boxes_shape(self): self._helper_boxes_shape(ops.ps_roi_pool) -optests.generate_opcheck_tests( - testcase=TestPSRoIPool, - namespaces=["torchvision"], - failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"), - additional_decorators=[], - test_utils=OPTESTS, -) - - def bilinear_interpolate(data, y, x, snap_border=False): height, width = data.shape From 65cd14a1e0ec29b62cd2b713d0ee9bb17243715a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 30 Oct 2023 08:01:06 +0000 Subject: [PATCH 7/7] lint --- test/test_ops.py | 1 - torchvision/_meta_registrations.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index e36a0ef8c3e..f4d7c2840ba 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -230,7 +230,6 @@ def func(z): @needs_cuda @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - @pytest.mark.opcheck_only_one() def test_autocast(self, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 9cd4c30cc97..15513e538f5 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -160,7 +160,6 @@ def meta_ps_roi_pool_backward( return grad.new_empty((batch_size, channels, height, width)) - @torch._custom_ops.impl_abstract("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") @@ -172,4 +171,4 @@ def meta_nms(dets, scores, iou_threshold): ) ctx = torch._custom_ops.get_ctx() num_to_keep = ctx.create_unbacked_symint() - return dets.new_empty(num_to_keep, dtype=torch.long) \ No newline at end of file + return dets.new_empty(num_to_keep, dtype=torch.long)