From c956029469406f2c2b04542ef3054f2f26e41771 Mon Sep 17 00:00:00 2001 From: Felix Hirwa Nshuti Date: Sat, 23 Sep 2023 22:16:45 +0530 Subject: [PATCH] fix(torch-frontend): Added unsupported dtype for `torch.nn.functional.linear` (#25966) --- ivy/functional/backends/paddle/linear_algebra.py | 1 + .../frontends/torch/nn/functional/linear_functions.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/ivy/functional/backends/paddle/linear_algebra.py b/ivy/functional/backends/paddle/linear_algebra.py index 324779bc14592..3a6145297d844 100644 --- a/ivy/functional/backends/paddle/linear_algebra.py +++ b/ivy/functional/backends/paddle/linear_algebra.py @@ -227,6 +227,7 @@ def matmul( paddle.uint8, paddle.float16, paddle.bool, + paddle.bfloat16, ]: x1, x2 = x1.cast("float32"), x2.cast("float32") diff --git a/ivy/functional/frontends/torch/nn/functional/linear_functions.py b/ivy/functional/frontends/torch/nn/functional/linear_functions.py index 28cc206370d51..57322d401d0f2 100644 --- a/ivy/functional/frontends/torch/nn/functional/linear_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/linear_functions.py @@ -1,8 +1,10 @@ # local import ivy +from ivy.func_wrapper import with_unsupported_dtypes from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch") @to_ivy_arrays_and_back def linear(input, weight, bias=None): return ivy.linear(input, weight, bias=bias)