Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve names and docs #23

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ for i in eachindex(x1,x2,x3,x4,y1,y2)
end
```

With this package, we can apply `@fused` to reduce the number of reads and preserve the memory layout:
With this package, we can apply `@fused_direct` to reduce the number of reads and preserve the memory layout:

```julia
import MultiBroadcastFusion as MBF
Expand All @@ -62,7 +62,7 @@ y1 = rand(3,3)
y2 = rand(3,3)

# 4 reads, 2 writes
MBF.@fused begin
MBF.@fused_direct begin
@. y1 = x1 * x2 + x3 * x4
@. y2 = x1 * x3 + x2 * x4
end
Expand All @@ -76,10 +76,11 @@ Users can write custom implementations, using the `@make_type` and `@make_fused`

```julia
import MultiBroadcastFusion as MBF
import MultiBroadcastFusion: fused_direct

MBF.@make_type MyFusedMultiBroadcast
MBF.@make_fused MBF.fused_pairs MyFusedMultiBroadcast my_fused
# Now, `@fused` will call `Base.copyto!(::MyFusedMultiBroadcast)`. Let's define it:
MBF.@make_fused fused_direct MyFusedMultiBroadcast my_fused
# Now, `@fused_direct` will call `Base.copyto!(::MyFusedMultiBroadcast)`. Let's define it:
function Base.copyto!(fmb::MyFusedMultiBroadcast)
pairs = fmb.pairs
destinations = map(x->x.first, pairs)
Expand Down Expand Up @@ -117,7 +118,7 @@ end
macro get_fused_multi_broadcast(expr)
_pairs = gensym()
quote
$_pairs = $(esc(MBF.fused_pairs(expr)))
$_pairs = $(esc(MBF.fused_direct(expr)))
FusedMultiBroadcast($_pairs)
end
end
Expand Down
2 changes: 1 addition & 1 deletion perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Y = get_arrays(:y, arr_size, AType)
function perf_kernel_fused!(X, Y)
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
@fused begin
@fused_direct begin
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
@. y3 = x3 + x4 + x5 + x6
Expand Down
4 changes: 2 additions & 2 deletions src/MultiBroadcastFusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module MultiBroadcastFusion
include(joinpath("collection", "utils.jl"))
include(joinpath("collection", "macros.jl"))
include(joinpath("collection", "code_lowered_single_expression.jl"))
include(joinpath("collection", "fused_pairs.jl"))
include(joinpath("collection", "fused_pairs_flexible.jl"))
include(joinpath("collection", "fused_direct.jl"))
include(joinpath("collection", "fused_assemble.jl"))

include(joinpath("execution", "fused_kernels.jl"))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,55 +1,91 @@
#####
##### Complex/flexible version
##### Fused assemble
#####

# General case: do nothing (identity)
transform_flex(x, sym) = x
transform_flex(s::Symbol, sym) = s
# Expression: recursively transform_flex for Expr
function transform_flex(e::Expr, sym)
transform_assemble(x, sym) = x
transform_assemble(s::Symbol, sym) = s
# Expression: recursively transform_assemble for Expr
function transform_assemble(e::Expr, sym)
if e.head == :macrocall && e.args[1] == Symbol("@__dot__")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2]))))
subexpr
else
Expr(transform_flex(e.head, sym), transform_flex.(e.args, sym)...)
Expr(
transform_assemble(e.head, sym),
transform_assemble.(e.args, sym)...,
)
end
end

"""
fused_pairs_flexible
fused_assemble(expr::Expr)

Function that fuses broadcast expressions
that stride flow control logic. For example:
Transforms the input expressions
into a runtime assembly of a tuple
of `Pair`s, containing (firsts)
the destination of broadcast expressions
and (seconds) the broadcasted objects.

For example:
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedMultiBroadcast
MBF.@make_fused fused_pairs_flexible MyFusedMultiBroadcast fused_flexible
using Test
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
end

expr_out = quote
tup = ()
tup = (tup..., Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)))
tup = (tup..., Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)))
tup
end

@test MBF.linefilter!(MBF.fused_assemble(expr_in, :tup)) ==
MBF.linefilter!(expr_out)
@test MBF.fused_assemble(expr_in, :tup) == expr_out
```

To use `MultiBroadcastFusion`'s `@fused_flexible` macro:
This can be used to make a custom kernel fusion macro:
```
import MultiBroadcastFusion as MBF
x = rand(1);y = rand(1);z = rand(1);
MBF.@fused_flexible begin
@. x += y
@. z += y
import MultiBroadcastFusion: fused_assemble
MBF.@make_type MyFusedBroadcast
MBF.@make_fused fused_assemble MyFusedBroadcast my_fused

Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")

x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)

# 4 reads, 2 writes
@my_fused begin
for i in 1:3
@. y1 = x1
@. y2 = x1
end
end
```

Also see [`fused_direct`](@ref)
"""
function fused_pairs_flexible(expr::Expr, sym::Symbol)
check_restrictions_flexible(expr)
e = transform_flex(expr, sym)
fused_assemble(expr::Expr) = fused_assemble(expr, gensym())
function fused_assemble(expr::Expr, sym::Symbol)
check_restrictions_assemble(expr)
e = transform_assemble(expr, sym)
@assert e.head == :block
ex = Expr(:block, :($sym = ()), e.args..., sym)
# Filter out LineNumberNode, as this will not be valid due to prepending `tup = ()`
linefilter!(ex)
ex
end

function check_restrictions_flexible(expr::Expr)
function check_restrictions_assemble(expr::Expr)
for arg in expr.args
arg isa LineNumberNode && continue
s_error = if arg isa QuoteNode
Expand Down
46 changes: 35 additions & 11 deletions src/collection/fused_pairs.jl → src/collection/fused_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,52 @@ function transform(e::Expr)
end

"""
fused_pairs
fused_direct(expr::Expr)

Function that fuses broadcast expressions that
are immediately one after another. For example:
Directly transforms the input expression
into a tuple of `Pair`s, containing (firsts)
the destination of broadcast expressions and
(seconds) the broadcasted objects.

For example:
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedMultiBroadcast
MBF.@make_fused fused_pairs MyFusedMultiBroadcast fused
using Test
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
end

expr_out = :(tuple(
Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)),
Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)),
))
@test MBF.fused_direct(expr_in) == expr_out
```

To use `MultiBroadcastFusion`'s `@fused` macro:
This can be used to make a custom kernel fusion macro:
```
import MultiBroadcastFusion as MBF
x = rand(1);y = rand(1);z = rand(1);
MBF.@fused begin
@. x += y
@. z += y
import MultiBroadcastFusion: fused_direct
MBF.@make_type MyFusedBroadcast
MBF.@make_fused fused_direct MyFusedBroadcast my_fused

Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")

x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)

# 4 reads, 2 writes
@my_fused begin
@. y1 = x1
@. y2 = x1
end
```

Also see [`fused_assemble`](@ref)
"""
function fused_pairs(expr::Expr)
function fused_direct(expr::Expr)
check_restrictions(expr)
e = transform(expr)
@assert e.head == :block
Expand Down
10 changes: 5 additions & 5 deletions src/collection/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ macro make_type(type_name)
end

"""
@make_fused fusion_type type_name fused_named
@make_fused fusion_style type_name fused_named

This macro
- Defines a type type_name
- Defines a macro, `@fused_name`, using the fusion type `fusion_type`
- Defines a macro, `@fused_name`, using the fusion type `fusion_style`

This allows users to flexibility
to customize their broadcast fusion.
Expand All @@ -27,7 +27,7 @@ to customize their broadcast fusion.
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedBroadcast
MBF.@make_fused MBF.fused_pairs MyFusedBroadcast my_fused
MBF.@make_fused MBF.fused_direct MyFusedBroadcast my_fused

Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")

Expand All @@ -42,12 +42,12 @@ y2 = rand(3,3)
end
```
"""
macro make_fused(fusion_type, type_name, fused_name)
macro make_fused(fusion_style, type_name, fused_name)
t = esc(type_name)
f = esc(fused_name)
return quote
macro $f(expr)
_pairs = esc($(fusion_type)(expr))
_pairs = esc($(fusion_style)(expr))
t = $t
quote
Base.copyto!($t($_pairs))
Expand Down
4 changes: 2 additions & 2 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@make_type FusedMultiBroadcast
@make_fused fused_pairs FusedMultiBroadcast fused
@make_fused fused_pairs_flexible FusedMultiBroadcast fused_flexible
@make_fused fused_direct FusedMultiBroadcast fused_direct
@make_fused fused_assemble FusedMultiBroadcast fused_assemble

struct CPU end
struct GPU end
Expand Down
2 changes: 1 addition & 1 deletion test/collection/expr_code_lowered_single_expression.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=
using Revise; include(joinpath("test", "expr_code_lowered_single_expression.jl"))
using Revise; include(joinpath("test", "collection", "expr_code_lowered_single_expression.jl"))
=#
using Test
import MultiBroadcastFusion as MBF
Expand Down
20 changes: 10 additions & 10 deletions test/collection/expr_errors_and_edge_cases.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=
using Revise; include(joinpath("test", "expr_errors_and_edge_cases.jl"))
using Revise; include(joinpath("test", "collection" "expr_errors_and_edge_cases.jl"))
=#
using Test
import MultiBroadcastFusion as MBF
Expand All @@ -21,7 +21,7 @@ import MultiBroadcastFusion as MBF
end
@. y1 = x1 + x2 + x3 + x4
end
@test_throws ErrorException("Loops are not allowed inside fused blocks") MBF.fused_pairs(
@test_throws ErrorException("Loops are not allowed inside fused blocks") MBF.fused_direct(
expr_in,
)
end
Expand All @@ -47,7 +47,7 @@ struct Foo end
end
@test_throws ErrorException(
"If-statements are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

bar() = nothing
Expand All @@ -61,7 +61,7 @@ bar() = nothing
end
@test_throws ErrorException(
"Function calls are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "Non-broadcast variable assignments" begin
Expand All @@ -74,7 +74,7 @@ end
end
@test_throws ErrorException(
"Non-broadcast assignments are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "No let-blocks" begin
Expand All @@ -87,7 +87,7 @@ end
end
@test_throws ErrorException(
"Let-blocks are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "Dangling symbols" begin
Expand All @@ -99,7 +99,7 @@ end
end
@test_throws ErrorException(
"Dangling symbols are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "quote" begin
Expand All @@ -110,7 +110,7 @@ end
quote end
@. y1 = x1 + x2 + x3 + x4
end
@test_throws ErrorException("Quotes are not allowed inside fused blocks") MBF.fused_pairs(
@test_throws ErrorException("Quotes are not allowed inside fused blocks") MBF.fused_direct(
expr_in,
)
end
Expand All @@ -127,10 +127,10 @@ end
Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)),
Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)),
))
@test MBF.fused_pairs(expr_in) == expr_out
@test MBF.fused_direct(expr_in) == expr_out
end

@testset "Empty" begin
expr_in = quote end
@test MBF.fused_pairs(expr_in) == :(tuple())
@test MBF.fused_direct(expr_in) == :(tuple())
end
Loading
Loading