Skip to content

Commit

Permalink
Added torch frontend function diagflat (#22384)
Browse files Browse the repository at this point in the history
Co-authored-by: paulaehab<[email protected]>
  • Loading branch information
Darshan-H-E authored Sep 4, 2023
1 parent 131196e commit 8854a01
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
9 changes: 9 additions & 0 deletions ivy/functional/frontends/torch/miscellaneous_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ def diag(input, diagonal=0, *, out=None):
return ivy.diag(input, k=diagonal)


@with_supported_dtypes(
{"2.0.1 and below": ("float32", "float64", "int32", "int64")}, "torch"
)
@to_ivy_arrays_and_back
def diagflat(x, offset=0, name=None):
arr = ivy.diagflat(x, offset=offset)
return arr


@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, "torch")
@to_ivy_arrays_and_back
def diagonal(input, offset=0, dim1=0, dim2=1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,42 @@ def test_torch_diag(
)


# diagflat
@handle_frontend_test(
fn_tree="torch.diagflat",
dtype_and_values=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=1,
max_num_dims=5,
min_dim_size=1,
max_dim_size=5,
),
offset=st.integers(min_value=-4, max_value=4),
test_with_out=st.just(False),
)
def test_torch_diagflat(
dtype_and_values,
offset,
test_flags,
backend_fw,
frontend,
fn_tree,
on_device,
):
input_dtype, x = dtype_and_values
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
x=x[0],
offset=offset,
)


@handle_frontend_test(
fn_tree="torch.diagonal",
dtype_and_values=helpers.dtype_and_values(
Expand Down

0 comments on commit 8854a01

Please sign in to comment.