Skip to content

Commit

Permalink
Numpy choice (ivy-llc#22440)
Browse files Browse the repository at this point in the history
Very nice work @VedantPol ! Merged ;D
  • Loading branch information
VedantPol authored and druvdub committed Oct 14, 2023
1 parent 007cbf4 commit 8441061
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
19 changes: 19 additions & 0 deletions ivy/functional/frontends/numpy/random/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ def chisquare(df, size=None):
return ivy.gamma(df / 2, 2, dtype="float64")


@to_ivy_arrays_and_back
@from_zero_dim_arrays_to_scalar
def choice(a, size=None, replace=True, p=None):
sc_size = 1
if isinstance(size, int):
sc_size = size
elif size is not None:
# If the given shape is, e.g., (m, n, k)
# then m * n * k samples are drawn. As per numpy docs
sc_size = 1
for s in size:
if s is not None:
sc_size *= s
if isinstance(a, int):
a = ivy.arange(a)
index = ivy.multinomial(len(a), sc_size, replace=replace, probs=p)
return a[index]


@to_ivy_arrays_and_back
@from_zero_dim_arrays_to_scalar
def dirichlet(alpha, size=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,37 @@ def test_numpy_chisquare(
)


@handle_frontend_test(
fn_tree="numpy.random.choice",
dtypes=helpers.get_dtypes("float", full=False),
a=helpers.ints(min_value=2, max_value=10),
size=helpers.get_shape(allow_none=True),
)
def test_numpy_choice(
dtypes,
size,
frontend,
test_flags,
backend_fw,
fn_tree,
on_device,
a,
):
helpers.test_frontend_function(
input_dtypes=dtypes,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
a=a,
size=size,
replace=True,
p=np.array([1 / a] * a, dtype=dtypes[0]),
)


# dirichlet
@handle_frontend_test(
fn_tree="numpy.random.dirichlet",
Expand Down

0 comments on commit 8441061

Please sign in to comment.