Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binary_cross_entropy in functional.frontends.torch #2310

Merged
merged 21 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
07e8f0a
add frontend.torch.loss_functions and BCE
whitepurple Jul 26, 2022
d0a9034
add test_loss_functions and edit bce
whitepurple Jul 26, 2022
5baa5fd
Merge branch 'master' into frontend_torch_loss_functions
whitepurple Jul 28, 2022
3607cf5
revert nn.loss_functions to loss_functions
whitepurple Jul 31, 2022
3ce5cf7
Edit formating and edit test code
whitepurple Jul 31, 2022
3f081d8
Edit formatting
whitepurple Jul 31, 2022
3d3cc55
Merge branch 'master' into frontend_torch_loss_functions
whitepurple Aug 7, 2022
92395b8
Merge branch 'master' into frontend_torch_loss_functions
whitepurple Aug 15, 2022
9252e80
Update test_loss_functions.py
whitepurple Aug 15, 2022
bb02fa5
Delete statistical.py
whitepurple Aug 15, 2022
6118e72
Update test_loss_functions.py
whitepurple Aug 15, 2022
db7d861
Merge branch 'frontend_torch_loss_functions' of https://github.com/wh…
whitepurple Aug 15, 2022
6d7c67d
Revert "Delete statistical.py"
whitepurple Aug 15, 2022
14d46f2
Update loss_fuctions and test code
whitepurple Aug 15, 2022
992c699
Merge branch 'master' into frontend_torch_loss_functions
whitepurple Aug 15, 2022
31f8ee9
Update loss_functions formating
whitepurple Aug 15, 2022
e0a68e4
Update test exclude_min and max
whitepurple Aug 15, 2022
c172bf6
Merge branch 'master' into frontend_torch_loss_functions
whitepurple Aug 21, 2022
9f163ff
Update reviewed code and test code
whitepurple Aug 22, 2022
753fa0e
Merge remote-tracking branch 'origin/master' into frontend_torch_loss…
whitepurple Aug 22, 2022
bcefc0c
Merge branch 'unifyai:master' into frontend_torch_loss_functions
whitepurple Aug 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions ivy/functional/frontends/torch/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,43 @@
import ivy


def _get_reduction_func(reduction):
if reduction == 'none':
ret = lambda x : x
elif reduction == 'mean':
ret = ivy.mean
elif reduction == 'elementwise_mean':
whitepurple marked this conversation as resolved.
Show resolved Hide resolved
ret = ivy.mean
elif reduction == 'sum':
ret = ivy.sum
else:
raise ValueError("{} is not a valid value for reduction".format(reduction))
return ret


def _legacy_get_string(size_average, reduce):
if size_average is None:
size_average = True
if reduce is None:
reduce = True
if size_average and reduce:
ret = 'mean'
elif reduce:
ret = 'sum'
else:
ret = 'none'
return ret


def _get_reduction(reduction,
size_average=None,
reduce=None):
if size_average is not None or reduce is not None:
return _get_reduction_func(_legacy_get_string(size_average, reduce))
else:
return _get_reduction_func(reduction)


def cross_entropy(
input,
target,
Expand All @@ -16,3 +53,29 @@ def cross_entropy(


cross_entropy.unsupported_dtypes = ("float16",)


def binary_cross_entropy(
input,
target,
weight=None,
size_average=None,
reduce=None,
reduction='mean'
):
reduction = _get_reduction(reduction, size_average, reduce)
result = ivy.binary_cross_entropy(target, input, epsilon=0.0)

if weight is not None:
result = ivy.multiply(weight, result)
result = reduction(result)
return result


binary_cross_entropy.unsupported_dtypes = (
whitepurple marked this conversation as resolved.
Show resolved Hide resolved
'uint16',
'float16',
'uint64',
'float64',
'uint32'
)
148 changes: 118 additions & 30 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,34 @@
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=0,
max_value=1,
allow_inf=False,
min_num_dims=2,
max_num_dims=2,
min_dim_size=1,
),
min_value=0,
max_value=1,
allow_inf=False,
min_num_dims=2,
max_num_dims=2,
min_dim_size=1,
),
dtype_and_target=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=1.0013580322265625e-05,
max_value=1,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
as_variable=helpers.list_of_length(x=st.booleans(), length=2),
),
min_value=1.0013580322265625e-05,
max_value=1,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
as_variable=helpers.list_of_length(x=st.booleans(), length=2),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.torch.cross_entropy"
),
native_array=helpers.list_of_length(x=st.booleans(), length=2),
),
native_array=helpers.list_of_length(x=st.booleans(), length=2),
)
def test_torch_cross_entropy(
dtype_and_input,
Expand All @@ -54,14 +54,102 @@ def test_torch_cross_entropy(
inputs_dtype, input = dtype_and_input
target_dtype, target = dtype_and_target
helpers.test_frontend_function(
input_dtypes=[inputs_dtype, target_dtype],
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="torch",
fn_tree="nn.functional.cross_entropy",
input=np.asarray(input, dtype=inputs_dtype),
target=np.asarray(target, dtype=target_dtype),
input_dtypes=[inputs_dtype, target_dtype],
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="torch",
fn_tree="nn.functional.cross_entropy",
input=np.asarray(input, dtype=inputs_dtype),
target=np.asarray(target, dtype=target_dtype),
)


# binary_cross_entropy
@given(
dtype_and_true=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=0,
max_value=1,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
dtype_and_pred=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=1.0013580322265625e-05,
max_value=1.0,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
dtype_and_weight=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=1.0013580322265625e-05,
max_value=1.0,
allow_inf=False,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
size_average=st.booleans(),
reduce=st.booleans(),
reduction=st.sampled_from(["mean", "none", "sum", None]),
as_variable=helpers.list_of_length(x=st.booleans(), length=3),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.torch.binary_cross_entropy"
),
native_array=helpers.list_of_length(x=st.booleans(), length=3),
)
def test_binary_cross_entropy(
dtype_and_true,
dtype_and_pred,
dtype_and_weight,
size_average,
reduce,
reduction,
as_variable,
num_positional_args,
native_array,
fw,
):
pred_dtype, pred = dtype_and_pred
true_dtype, true = dtype_and_true
weight_dtype, weight = dtype_and_weight

helpers.test_frontend_function(
input_dtypes=[pred_dtype, true_dtype, weight_dtype],
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="torch",
fn_tree="nn.functional.binary_cross_entropy",
input=np.asarray(pred, dtype=pred_dtype),
target=np.asarray(true, dtype=true_dtype),
weight=np.asarray(weight, dtype=weight_dtype),
size_average=size_average,
reduce=reduce,
reduction=reduction,
)