Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address some predict/transform type instabilities #969

Merged
merged 2 commits into from
Apr 24, 2024

Conversation

ablaom
Copy link
Member

@ablaom ablaom commented Apr 8, 2024

(edited) This PR addresses a type instability for operations (predict, transform, etc) acting on machines, as identified in #959 (although this PR does not resolve the particular issue there).

I admit it is not clear to me that the performance gains here are likely to significantly benefit many use cases. But having done the work to identify these instabilities, I don't see harm in addressing them.

The type instability is not difficult to address in the case of machines attached to ordinary models, by annotating a currently abstract type in the Machine struct. However, in the special case of a machine attached to a symbolic model (which appear exclusively in learning networks), the type instability remains (and looks difficult to remove).

Benchmarks

In 69 regression models, we compared the "high level" predict(::Machine, ...) method with the "low level" predict(::Model, ...) method (edit plus reformat(::Model, ...)) implemented by third party model providers. The benchmark code is hidden below:

using MLJTestInterface, MLJModels, MLJBase
using Tables
using Random
using BenchmarkTools
using Statistics
import DataFrames
import MLJModelInterface as MMI

const MODELS  = models() do m
    !(m.package_name in ["MLJText"]) &&
        AbstractVector{Continuous} <: m.target_scitype &&
        m.is_supervised
end

# This is a way to load all needed model code:
MLJTestInterface.test(MODELS, mod=@__MODULE__, level=1, throw=true)

rng = Random.MersenneTwister(0)
Xmat = randn(rng, 30, 3)
X = Tables.table(Xmat)
y = @. cos(Xmat[:, 1] * 2.1 - 0.9) * Xmat[:, 2] - Xmat[:, 3]

function predict_low(model, fitresult, X)
    Xraw = MMI.reformat(model, X)
    MMI.predict(model, fitresult, Xraw...)
end

stats = []
for m in MODELS
    print("\rBenchmarking $(m.name) $(m.package_name).")
    model = eval(:(@load $(m.name) pkg=$(m.package_name) verbosity=0))()
    mach = machine(model, X, y)
    fit!(mach, verbosity=0)
    fitresult = mach.fitresult
    b_high = @benchmark predict($mach, $X)
    b_low = @benchmark predict_low($model, $fitresult, $X)
    slow_down = median(b_high.times)/median(b_low.times)
    bloat = b_high.allocs/b_low.allocs
    push!(stats, (; model=m.name, pkg=m.package_name, slow_down, bloat))
    print(" Done.                           ")
end

@show length(MODELS)
#length(MODELS) = 69

stats = DataFrames.DataFrame([stats...])
filter(stats) do row
    row.slow_down > 1.75 || row.bloat > 2.0
end

In the tables below:

  • "slow_down" is the ratio of elapsed time of "high" to "low"
  • "bloat" is the ratio of number of allocations of "high" to "low"

Only models with slow_down > 1.75 or bloat > 2 are reported.

Before this PR

#  Row │ model                           pkg                           slow_down  bloat
#      │ String                          String                        Float64    Float64
# ─────┼──────────────────────────────────────────────────────────────────────────────────
#    1 │ ConstantRegressor               MLJModels                       9.67007    4.0
#    2 │ DeterministicConstantRegressor  MLJModels                      14.0756     4.0
#    3 │ ElasticNetRegressor             MLJLinearModels                 4.41319    2.0
#    4 │ HuberRegressor                  MLJLinearModels                 3.93931    2.0
#    5 │ LADRegressor                    MLJLinearModels                 3.97205    2.0
#    6 │ LassoRegressor                  MLJLinearModels                 4.20191    2.0
#    7 │ LinearRegressor                 MLJLinearModels                 3.95137    2.0
#    8 │ LinearRegressor                 MultivariateStats               4.92059    2.5
#    9 │ PLSRegressor                    PartialLeastSquaresRegressor    2.11492    1.375
#   10 │ QuantileRegressor               MLJLinearModels                 4.15291    2.0
#   11 │ RidgeRegressor                  MLJLinearModels                 3.99158    2.0
#   12 │ RidgeRegressor                  MultivariateStats               5.03699    2.5
#   13 │ RobustRegressor                 MLJLinearModels                 3.95992    2.0

After this PR:

#  Row │ model              pkg        slow_down  bloat
#      │ String             String     Float64    Float64
# ─────┼──────────────────────────────────────────────────
#    1 │ ConstantRegressor  MLJModels    1.61401      3.0

Note that machines serialised using #master cannot be deserialised after this PR. But I don't consider this triggers a breaking release.

To do:

  • Run MLJ tests with integration tests switched on

@ablaom ablaom marked this pull request as draft April 8, 2024 21:07
@ablaom ablaom requested a review from OkonSamuel April 8, 2024 21:08
@ablaom ablaom marked this pull request as ready for review April 10, 2024 00:29
Copy link
Member

@OkonSamuel OkonSamuel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!!!.
@ablaom I think we should benchmark the compile time for this package with vs without this PR.
But in general, I think run time performance gains we would get from this should outweigh any added compile time increase, because these operations (e.g predict, etc.) are expected in general to be quite expensive.

@ablaom
Copy link
Member Author

ablaom commented Apr 24, 2024

Doesn't look like there's a significant difference.

Before this PR:

julia> @time_imports import MLJBase
    413.2 ms  MLJBase 23.21% compilation time

After this PR:

@time_imports import MLJBase
    437.3 ms  MLJBase 22.21% compilation time

@ablaom ablaom merged commit 6e77d6a into dev Apr 24, 2024
3 checks passed
@ablaom ablaom deleted the predict-type-instability branch April 24, 2024 00:04
This was referenced May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants