Skip to content

Commit

Permalink
Merge pull request #143 from JuliaML/subtypes-fix
Browse files Browse the repository at this point in the history
Fix for Julia version < v1.3
  • Loading branch information
juliohm authored Jun 10, 2021
2 parents 1dc8c29 + 5dc8388 commit 7f4d604
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
version = "0.7.1"

[deps]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand Down
1 change: 1 addition & 0 deletions src/LossFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Base.*
using Base.Cartesian
using Markdown
using RecipesBase
using InteractiveUtils: subtypes

import LearnBase.AggMode
import LearnBase.ObsDim
Expand Down
9 changes: 8 additions & 1 deletion src/supervised.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,11 @@ for FUN in (:value, :deriv, :deriv2)
end

# convenient functor interface
(loss::SupervisedLoss)(target::AbstractArray, output::AbstractArray) = value(loss, target, output)
if VERSION v"1.3.0"
(loss::SupervisedLoss)(target::AbstractArray, output::AbstractArray) = value(loss, target, output)
else
# add method manually to all subtypes
for L in Iterators.flatten([subtypes(DistanceLoss), subtypes(MarginLoss)])
(loss::L)(target::AbstractArray, output::AbstractArray) = value(loss, target, output)
end
end

0 comments on commit 7f4d604

Please sign in to comment.