Skip to content

Commit

Permalink
Merge pull request #965 from JuliaAI/training-losses-patch
Browse files Browse the repository at this point in the history
Fix problem with training losses for pipelines
  • Loading branch information
ablaom authored Mar 17, 2024
2 parents 0725e90 + b417240 commit f01a03c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.2.0"
version = "1.2.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
14 changes: 3 additions & 11 deletions src/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -602,20 +602,12 @@ MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(

# ## Training losses

const ERR_TRAINING_LOSSES1 = ErrorException(
"Looking for training losses in a model's report but cannot find any. "
)

const ERR_TRAINING_LOSSES2 = ErrorException(
"Composite model does not appear to support training losses. "
)

# If supervised model does not support training losses, we won't find an entry in the
# report and so we need to return `nothing` (and not throw an error).
function MMI.training_losses(pipe::SupervisedPipeline, pipe_report)
supports_training_losses(pipe::SupervisedPipeline) ||
throw(ERR_PIPE_TRAINING_LOSSES2)
supervised = MLJBase.supervised_component(pipe)
supervised_name = MLJBase.supervised_component_name(pipe)
supervised_name in propertynames(pipe_report) || throw(ERR_PIPE_TRAINING_LOSSES1)
supervised_name in propertynames(pipe_report) || return nothing
report = getproperty(pipe_report, supervised_name)
return training_losses(supervised, report)
end
Expand Down

0 comments on commit f01a03c

Please sign in to comment.