diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index deee638b7..523dbec51 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -5179,3 +5179,14 @@ def test_argmax( backend=backend, converter_input_type=converter_input_type, ) + + +class TestCopy(TorchBaseTest): + @pytest.mark.parametrize( + "backend, rank", itertools.product(backends, list(range(1, 6))), + ) + def test_copy(self, backend, rank): + input_shape = tuple(np.random.randint(low=2, high=6, size=rank)) + + model = ModuleWrapper(function=lambda x: x.copy_()) + self.run_compare_torch(input_shape, model, backend=backend)