Skip to content

Commit

Permalink
Avoid broadcast-related type instabilities with huber_loss (#2306)
Browse files Browse the repository at this point in the history
* Avoid broadcast-related type instabilities with huber_loss

* Add test case

* Fix function name

* Fix test
  • Loading branch information
jeremiahpslewis authored Aug 7, 2023
1 parent bf9da7f commit 9e5851d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
end

function _huber_metric(abs_error, δ)
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
temp = Zygote.ignore_derivatives(abs_error .< δ)
x = ofeltype(abs_error, 0.5)
((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1 - temp)
end

"""
huber_loss(ŷ, y; delta = 1, agg = mean)
Expand All @@ -94,17 +101,14 @@ julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| >
0.003750000000000005
```
"""
function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing)
delta_tmp = _greek_ascii_depwarn=> delta, :huber_loss, "δ" => "delta")
δ = ofeltype(ŷ, delta_tmp)
_check_sizes(ŷ, y)
abs_error = abs.(ŷ .- y)
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
temp = Zygote.ignore_derivatives(abs_error .< δ)
x = ofeltype(ŷ, 0.5)
agg(((abs_error .^ 2) .* temp) .* x .+ δ * (abs_error .- x * δ) .* (1 .- temp))
end
function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing)
delta_tmp = _greek_ascii_depwarn=> delta, :huber_loss, "δ" => "delta")
δ = ofeltype(ŷ, delta_tmp)
_check_sizes(ŷ, y)
abs_error = abs.(ŷ .- y)

agg(_huber_metric.(abs_error, δ))
end
"""
label_smoothing(y::Union{Number, AbstractArray}, α; dims::Int=1)
Expand Down
12 changes: 12 additions & 0 deletions test/ext_metal/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,15 @@ include("test_utils.jl")
@testset "Basic" begin
include("basic.jl")
end

@testset "Huber Loss test" begin
X = Flux.gpu(Float32[0,1])
Y = Flux.gpu(Float32[1,0])

grad = Flux.gradient(X, Y) do a,b
Flux.Losses.huber_loss(a,b)
end

@test Flux.cpu(grad[1]) == [-0.5, 0.5]
@test Flux.cpu(grad[2]) == [0.5, -0.5]
end

0 comments on commit 9e5851d

Please sign in to comment.