Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix weighted computations for non-real arrays #658

Closed
wants to merge 2 commits into from
Closed

fix weighted computations for non-real arrays #658

wants to merge 2 commits into from

Conversation

aplavin
Copy link
Contributor

@aplavin aplavin commented Feb 10, 2021

Update to #649 following suggestion by @nalimilan

Copy link
Member

@nalimilan nalimilan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. have you tried some benchmarking? dot doesn't seem to be faster anymore on a Julia 1.6, so I wonder whether we could always use broadcast when VERSION >= v"1.6.0" to be able to simplify code later.

julia> using LinearAlgebra, BenchmarkTools

julia> x = rand(10_000);

julia> xm = Vector{Union{Float64, Missing}}(x);

julia> w = rand(10_000);

julia> f(A, w) = sum(Base.Broadcast.instantiate(Base.Broadcast.Broadcasted(*, (A, w))))
f (generic function with 1 method)

julia> @btime dot(x, w);
  2.209 μs (1 allocation: 16 bytes)

julia> @btime f(x, w);
  2.472 μs (1 allocation: 16 bytes)

julia> @btime dot(xm, w);
  20.162 μs (1 allocation: 16 bytes)

julia> @btime f(xm, w);
  23.823 μs (1 allocation: 16 bytes)

src/weights.jl Show resolved Hide resolved
@aplavin
Copy link
Contributor Author

aplavin commented Feb 16, 2021

I may try to make it always go through the broadcasting implementation, as you say. But I'm not familiar with how to use the version check so that there is no runtime penalty, could you give an example?

@nalimilan
Copy link
Member

Sorry for the late reply. You can just do if VERSION > v"X.Y.Z" (if @static if if not in the global scope).

@aplavin
Copy link
Contributor Author

aplavin commented Nov 21, 2021

Bump without any changes...

The broadcast-based method doesn't work for empty collections:

Expression: wsum(Float64[], Float64[]) === 0.0
ArgumentError: reducing over an empty collection is not allowed

It also doesn't handle multidim arrays, see my comment above. So I didn't add the VERSION >= 1.6 condition as you suggested.

@nalimilan
Copy link
Member

Another option that I used at https://github.com/JuliaLang/Statistics.jl/pull/94 is to simply do vec(v)' * w.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants