From 47fa031f75175d641c89e4b482264fc2e27cefd5 Mon Sep 17 00:00:00 2001 From: rzou Date: Thu, 18 Apr 2024 11:06:13 -0700 Subject: [PATCH] Update usages of torch.library APIs We deprecated impl_abstract. This PR replaces it with the new API (register_fake). register_fake also (sometimes) requires a `set_python_module` in C++, so I add that as well. Test Plan: - existing tests --- torchvision/_meta_registrations.py | 2 +- torchvision/csrc/ops/nms.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 4ff55280e89..f75bfb77a7f 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -160,7 +160,7 @@ def meta_ps_roi_pool_backward( return grad.new_empty((batch_size, channels, height, width)) -@torch._custom_ops.impl_abstract("torchvision::nms") +@torch.library.register_fake("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp index 07a934bce5a..5ecf8812f1b 100644 --- a/torchvision/csrc/ops/nms.cpp +++ b/torchvision/csrc/ops/nms.cpp @@ -19,6 +19,7 @@ at::Tensor nms( } TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.set_python_module("torchvision._meta_registrations"); m.def(TORCH_SELECTIVE_SCHEMA( "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); }