Skip to content

Commit

Permalink
feat: implement a separate TracedRNumber (#161)
Browse files Browse the repository at this point in the history
* feat: TracedRScalar

* feat: partial progress on getting scalars to work

* refactor: Scalar --> Number

* fix: batching

* fix: promote_rule and introduce union over primitive types

* chore: apply formatting

* feat: type-restrict arrays

* refactor: move scalar ops to a separate file

* feat: support Base.float

* fix: import ordering

* feat: handle `broadcast_preserving_zero_d` in a generic fashion

* refactor: move code a bit

* test: more test fixes

* chore: apply formatting

* fix: setindex with scalars

* fix: scalar broadcasting case

* feat: support BFloat16 from Core (if available)

* test: more native lux functionality unblocked

* refactor: use a union type for traced types

* fix: check for reactant primitives

* fix: missing import

* fix: correct semantics for Colon mapreduce

* fix: trace_type

* fix: minor fixes

* feat: support logsoftmax

* fix: bool promote rule

* fix: broadcasting of closures

* refactor: use TracedTypes

* Fix type of `preserved_args`

* Rename `TracedTypes` to `TracedType`

* small testset rename

* fix: special handling for concatenation of numbers

* Reenable tests

* Rename `ReactantPrimitives` to `ReactantPrimitive`

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
3 people authored Oct 6, 2024
1 parent f2c0e8a commit 6e89952
Show file tree
Hide file tree
Showing 13 changed files with 557 additions and 252 deletions.
29 changes: 18 additions & 11 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
module ReactantNNlibExt

using NNlib
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber

for (jlop, hloop) in (
(:(NNlib.tanh_fast), :tanh),
(:(NNlib.sigmoid_fast), :logistic),
(:(NNlib.sigmoid), :logistic),
)
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
return TracedRArray{T,0}(
@eval function $(jlop)(x::TracedRNumber{T}) where {T}
return TracedRNumber{T}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
),
(),
)
end
end

# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, ))
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
end
end

# TODO handle non finite cases
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
max_ = NNlib.fast_maximum(x; dims)
Expand All @@ -39,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where
return out ./= tmp
end

function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
max_ = NNlib.fast_maximum(x; dims)
# if all(isfinite, max_)
@fastmath out .= x .- max_
# else
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
# @. out = ifelse(
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
# )
# end
@fastmath log_ = log.(sum(exp, out; dims))
return out .-= log_
end

function NNlib.conv(
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
) where {T,N}
Expand Down
10 changes: 6 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import ..Reactant:
XLA,
ConcreteRArray,
TracedRArray,
TracedRNumber,
OrderedIdDict,
make_tracer,
TracedToConcrete,
append_path
append_path,
TracedType

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)

Expand Down Expand Up @@ -286,10 +288,10 @@ function compile_mlir!(mod, f, args; optimize=true)
)
end

preserved_args = Tuple{TracedRArray,Int}[]
preserved_args = Tuple{TracedType,Int}[]
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
nresults = MLIR.IR.Value[]
linear_results2 = TracedRArray[]
linear_results2 = TracedType[]
for (i, op) in enumerate(results)
if !MLIR.IR.is_block_arg(op)
push!(nresults, op)
Expand Down Expand Up @@ -573,7 +575,7 @@ end
function compile_xla(f, args; client=nothing)
# register MLIR dialects
ctx = MLIR.IR.Context()
Base.append!(Reactant.registry[]; context=ctx)
append!(Reactant.registry[]; context=ctx)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

return MLIR.IR.context!(ctx) do
Expand Down
6 changes: 1 addition & 5 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T}
return to_float(x)
end

function Base.promote_rule(::Type{<:RArray{T1,0}}, ::Type{T2}) where {T1,T2}
return Base.promote_rule(T1, T2)
end

for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^))
@eval begin
function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U}
Expand Down Expand Up @@ -158,7 +154,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
end

function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
Base.setindex!(a, v, args...)
setindex!(a, v, args...)
return nothing
end

Expand Down
45 changes: 44 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,45 @@ include("OrderedIdDict.jl")

using Enzyme

abstract type RArray{T,N} <: AbstractArray{T,N} end
@static if isdefined(Core, :BFloat16)
const ReactantPrimitive = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Core.BFloat16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
else
const ReactantPrimitive = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
end

abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end
abstract type RNumber{T<:ReactantPrimitive} <: Number end

function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
return reshape(A, Base._reshape_uncolon(A, dims))
Expand Down Expand Up @@ -45,8 +83,13 @@ include("mlir/MLIR.jl")
include("XLA.jl")
include("Interpreter.jl")
include("utils.jl")

include("ConcreteRArray.jl")
include("TracedRNumber.jl")
include("TracedRArray.jl")

const TracedType = Union{TracedRArray,TracedRNumber}

include("Tracing.jl")
include("Compiler.jl")

Expand Down
Loading

1 comment on commit 6e89952

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 6e89952 Previous: f2c0e8a Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1325630083 ns 1315729546 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 217828012 ns 212083499 ns 1.03
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 6953042665 ns 5286469750 ns 1.32
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 19834605771 ns 23583347555 ns 0.84
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1251065036 ns 1254858296 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8303675 ns 8478570 ns 0.98
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1644888663 ns 1636237670 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 3346211607.5 ns 2376437823 ns 1.41
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1268605475.5 ns 1266018905 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 91911558 ns 84820407 ns 1.08
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2175800720 ns 2170879105 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5163881653 ns 4675094299 ns 1.10
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1305929899 ns 1263496480 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 8108212 ns 7782824 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1478744083.5 ns 1467043032.5 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1501162607 ns 1685775445 ns 0.89
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1344634647.5 ns 1306815930 ns 1.03
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 11617967 ns 11611908 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1762294312 ns 1752808523 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 2714323365 ns 2463987825.5 ns 1.10
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1258217588 ns 1325877558.5 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 85973438 ns 90330187 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2223618180 ns 2213119086 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 4021036919 ns 4023816395 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1322735409 ns 1270812264 ns 1.04
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 115423826 ns 113097539 ns 1.02
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 2991911664 ns 3042643080 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 15616144600 ns 8210106924.5 ns 1.90
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1316851458 ns 1324054039 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 128824702.5 ns 127669686.5 ns 1.01
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3240917788 ns 3203794253 ns 1.01
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 6435316514 ns 11004907984 ns 0.58
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1330411921 ns 1299288245 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 83811390 ns 96277750 ns 0.87
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 2022207539.5 ns 2155333265.5 ns 0.94
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 2432459556.5 ns 2863535293.5 ns 0.85

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.