Skip to content

Commit

Permalink
Merge pull request #507 from willow-ahrens/wma/options
Browse files Browse the repository at this point in the history
Wma/options
  • Loading branch information
willow-ahrens authored Apr 25, 2024
2 parents dfc2a2a + 19fd760 commit 0ffa443
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 83 deletions.
6 changes: 3 additions & 3 deletions docs/examples/spgemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ function spgemm_inner(A, B)
C = Tensor(Dense(SparseList(Element(z))))
w = Tensor(SparseDict(SparseDict(Element(z))))
AT = Tensor(Dense(SparseList(Element(z))))
@finch mode=fastfinch (w .= 0; for k=_, i=_; w[k, i] = A[i, k] end)
@finch mode=fastfinch (AT .= 0; for i=_, k=_; AT[k, i] = w[k, i] end)
@finch mode=:fast (w .= 0; for k=_, i=_; w[k, i] = A[i, k] end)
@finch mode=:fast (AT .= 0; for i=_, k=_; AT[k, i] = w[k, i] end)
@finch (C .= 0; for j=_, i=_, k=_; C[i, j] += AT[k, i] * B[k, j] end)
return C
end
Expand All @@ -14,7 +14,7 @@ function spgemm_outer(A, B)
C = Tensor(Dense(SparseList(Element(z))))
w = Tensor(SparseDict(SparseDict(Element(z))))
BT = Tensor(Dense(SparseList(Element(z))))
@finch mode=fastfinch (w .= 0; for j=_, k=_; w[j, k] = B[k, j] end)
@finch mode=:fast (w .= 0; for j=_, k=_; w[j, k] = B[k, j] end)
@finch (BT .= 0; for k=_, j=_; BT[j, k] = w[j, k] end)
@finch (w .= 0; for k=_, j=_, i=_; w[i, j] += A[i, k] * BT[j, k] end)
@finch (C .= 0; for j=_, i=_; C[i, j] = w[i, j] end)
Expand Down
2 changes: 0 additions & 2 deletions docs/src/guides/custom_operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ properties as follows:
```
struct MyAlgebra <: AbstractAlgebra end
Finch.virtualize(ctx, ex, ::Type{MyAlgebra}) = MyAlgebra()
Finch.isassociative(::MyAlgebra, ::typeof(gcd)) = true
Finch.iscommutative(::MyAlgebra, ::typeof(gcd)) = true
Finch.isannihilator(::MyAlgebra, ::typeof(gcd), x) = x == 1
Expand Down
6 changes: 3 additions & 3 deletions docs/src/interactive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.6.7"
"version": "1.10.2"
},
"kernelspec": {
"name": "julia-1.6",
"display_name": "Julia 1.6.7",
"name": "julia-1.10",
"display_name": "Julia 1.10.2",
"language": "julia"
}
},
Expand Down
2 changes: 1 addition & 1 deletion docs/src/reference/advanced_implementation/internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ quote
let i = index_instance(i)
(Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Dimensionless(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(Finch.FinchNotation.Updater()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(Finch.FinchNotation.initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))))))
end
end), (Finch.FinchNotation.yieldbind_instance)(variable_instance(:C))), (;))
end), (Finch.FinchNotation.yieldbind_instance)(variable_instance(:C))); )
begin
C = _res_1[:C]
end
Expand Down
2 changes: 0 additions & 2 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ using Distributions: Binomial, Normal, Poisson

export @finch, @finch_program, @finch_code, @finch_kernel, value

export fastfinch, safefinch, debugfinch

export Tensor
export SparseRLE, SparseRLELevel
export DenseRLE, DenseRLELevel
Expand Down
71 changes: 40 additions & 31 deletions src/execute.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
abstract type CompileMode end
struct DebugFinch <: CompileMode end
const debugfinch = DebugFinch()
virtualize(ctx, ex, ::Type{DebugFinch}) = DebugFinch()
struct SafeFinch <: CompileMode end
const safefinch = SafeFinch()
virtualize(ctx, ex, ::Type{SafeFinch}) = SafeFinch()
struct FastFinch <: CompileMode end
const fastfinch = FastFinch()
virtualize(ctx, ex, ::Type{FastFinch}) = FastFinch()

issafe(::DebugFinch) = true
issafe(::SafeFinch) = true
issafe(::FastFinch) = false
function issafe(mode)
if mode === :debug
return true
elseif mode === :safe
return true
elseif mode === :fast
return false
else
throw(ArgumentError("Unknown mode: $mode"))
end
end

"""
instantiate!(ctx, prgm)
Expand Down Expand Up @@ -56,26 +53,35 @@ function (ctx::InstantiateTensors)(node::FinchNode)
end
end

execute(ex) = execute(ex, NamedTuple())
execute(ex; algebra = DefaultAlgebra(), mode = :safe) =
execute_impl(ex, Val(algebra), Val(mode))

@staged function execute(ex, opts)
contain(JuliaContext()) do ctx
code = execute_code(:ex, ex; virtualize(ctx, :opts, opts)...)
quote
# try
@inbounds @fastmath begin
getvalue(::Type{Val{v}}) where {v} = v

@staged function execute_impl(ex, algebra, mode)
code = execute_code(:ex, ex; algebra=getvalue(algebra), mode=getvalue(mode))
if mode === :debug
return quote
try
begin
$(code |> unblock)
end
# catch
# println("Error executing code:")
# println($(QuoteNode(code |> unblock |> pretty |> unquote_literals)))
# rethrow()
#end
catch
println("Error executing code:")
println($(QuoteNode(code |> unblock |> pretty |> unquote_literals)))
rethrow()
end
end
else
return quote
@inbounds @fastmath begin
$(code |> unblock |> pretty |> unquote_literals)
end
end
end
end

function execute_code(ex, T; algebra = DefaultAlgebra(), mode = safefinch, ctx = LowerJulia(algebra = algebra, mode=mode))
function execute_code(ex, T; algebra = DefaultAlgebra(), mode = :safe, ctx = LowerJulia(algebra = algebra, mode=mode))
code = contain(ctx) do ctx_2
prgm = nothing
prgm = virtualize(ctx_2.code, ex, T)
Expand Down Expand Up @@ -154,8 +160,11 @@ sparsity information to reliably skip iterations when possible.
`options` are optional keyword arguments:
- `algebra`: the algebra to use for the program. The default is `DefaultAlgebra()`.
- `mode`: the optimization mode to use for the program. The default is `fastfinch`.
- `ctx`: the context to use for the program. The default is a `LowerJulia` context with the given options.
- `mode`: the optimization mode to use for the program. Possible modes are:
- `:debug`: run the program in debug mode, with bounds checking and better error handling.
- `:safe`: run the program in safe mode, with modest checks for performance and correctness.
- `:fast`: run the program in fast mode, with no checks or warnings, this mode is for power users.
The default is `:safe`.
See also: [`@finch_code`](@ref)
"""
Expand All @@ -173,7 +182,7 @@ macro finch(opts_ex...)
)
res = esc(:res)
thunk = quote
res = $execute($prgm, (;$(map(esc, opts)...),))
res = $execute($prgm, ;$(map(esc, opts)...),)
end
for tns in something(FinchNotation.finch_parse_yieldbind(ex), FinchNotation.finch_parse_default_yieldbind(ex))
push!(thunk.args, quote
Expand Down Expand Up @@ -219,7 +228,7 @@ type `prgm`. Here, `fname` is the name of the function and `args` is a
See also: [`@finch`](@ref)
"""
function finch_kernel(fname, args, prgm; algebra = DefaultAlgebra(), mode = safefinch, ctx = LowerJulia(algebra=algebra, mode=mode))
function finch_kernel(fname, args, prgm; algebra = DefaultAlgebra(), mode = :safe, ctx = LowerJulia(algebra=algebra, mode=mode))
maybe_typeof(x) = x isa Type ? x : typeof(x)
code = contain(ctx) do ctx_2
foreach(args) do (key, val)
Expand Down
2 changes: 1 addition & 1 deletion src/interface/copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
idxs = [Symbol(:i_, n) for n = 1:ndims(dst)]
exts = Expr(:block, (:($idx = _) for idx in reverse(idxs))...)
return quote
@finch mode=fastfinch begin
@finch mode=:fast begin
dst .= $(default(dst))
$(Expr(:for, exts, quote
dst[$(idxs...)] = src[$(idxs...)]
Expand Down
2 changes: 1 addition & 1 deletion src/lower.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
result = freshen(code, :result)
algebra = DefaultAlgebra()
bindings::Dict{FinchNode, FinchNode} = Dict{FinchNode, FinchNode}()
mode = fastfinch
mode = :fast
modes::Dict{Any, Any} = Dict()
scope = Set()
shash = StaticHash()
Expand Down
2 changes: 0 additions & 2 deletions src/symbolic/symbolic.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
abstract type AbstractAlgebra end
struct DefaultAlgebra<:AbstractAlgebra end

virtualize(ctx, ex, ::Type{DefaultAlgebra}) = DefaultAlgebra()

struct Chooser{D} end

(f::Chooser{D})(x) where {D} = x
Expand Down
4 changes: 2 additions & 2 deletions test/reference32/typical/typical_transpose_csc_to_coo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Dense [:,1:3]
julia> B = Tensor(SparseDict(SparseDict(Element(0.0))))
Sparse (0.0) [:,1:0]

julia> @finch_code mode = fastfinch begin
julia> @finch_code mode = :fast begin
B .= 0
for j = _
for i = _
Expand Down Expand Up @@ -101,7 +101,7 @@ quote
result = (B = Tensor((SparseLevel){Int32}((SparseLevel){Int32}(B_lvl_3, A_lvl.shape, B_lvl_tbl_2), A_lvl_2.shape, B_lvl_tbl)),)
result
end
julia> @finch mode = fastfinch begin
julia> @finch mode = :fast begin
B .= 0
for j = _
for i = _
Expand Down
4 changes: 2 additions & 2 deletions test/reference64/typical/typical_transpose_csc_to_coo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Dense [:,1:3]
julia> B = Tensor(SparseDict(SparseDict(Element(0.0))))
Sparse (0.0) [:,1:0]

julia> @finch_code mode = fastfinch begin
julia> @finch_code mode = :fast begin
B .= 0
for j = _
for i = _
Expand Down Expand Up @@ -101,7 +101,7 @@ quote
result = (B = Tensor((SparseLevel){Int64}((SparseLevel){Int64}(B_lvl_3, A_lvl.shape, B_lvl_tbl_2), A_lvl_2.shape, B_lvl_tbl)),)
result
end
julia> @finch mode = fastfinch begin
julia> @finch mode = :fast begin
B .= 0
for j = _
for i = _
Expand Down
4 changes: 0 additions & 4 deletions test/test_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ struct MyAlgebra2 <: Finch.AbstractAlgebra end

@test pattern!(w) == [1, 1, 1, 0, 1, 0, 1, 1, 1, 0]

Finch.virtualize(ctx, ex, ::Type{MyAlgebra}) = MyAlgebra()

Finch.virtualize(ctx, ex, ::Type{MyAlgebra2}) = MyAlgebra2()

Finch.isassociative(::MyAlgebra, ::typeof(gcd)) = true
Finch.iscommutative(::MyAlgebra, ::typeof(gcd)) = true
Finch.isannihilator(::MyAlgebra, ::typeof(gcd), x) = x == 1
Expand Down
14 changes: 7 additions & 7 deletions test/test_continuous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
let
s1 = Scalar(0)
x = Tensor(SparseRLE{Limit{Float32}}(Element(0)), 10)
@finch mode=fastfinch (x[3] = 1)
@finch mode=fastfinch (for i=realextent(5+Eps,7-Eps); x[~i] = 1 end)
@finch mode=fastfinch (for i=realextent(8,9+Eps); x[~i] = 1 end)
@finch mode=:fast (x[3] = 1)
@finch mode=:fast (for i=realextent(5+Eps,7-Eps); x[~i] = 1 end)
@finch mode=:fast (for i=realextent(8,9+Eps); x[~i] = 1 end)

@finch mode=fastfinch (for i=_; s1[] += x[i] * d(i) end)
@finch mode=:fast (for i=_; s1[] += x[i] * d(i) end)
@test s1.val == 3
end

let
s1 = Scalar(0)
x = Tensor(SparseRLE{Limit{Float32}}(SparseList(Element(0))), 10, 10)
a = [1, 4, 8]
@finch mode=fastfinch (for i=realextent(2,4-Eps); for j=extent(1,3); x[a[j], ~i] = 1 end end)
@finch mode=fastfinch (for i=realextent(6+Eps,10-Eps); x[2, ~i] = 1 end)
@finch mode=:fast (for i=realextent(2,4-Eps); for j=extent(1,3); x[a[j], ~i] = 1 end end)
@finch mode=:fast (for i=realextent(6+Eps,10-Eps); x[2, ~i] = 1 end)

@finch mode=fastfinch (for i=_; for j=_; s1[] += x[j,i] * d(i) end end)
@finch mode=:fast (for i=_; for j=_; s1[] += x[j,i] * d(i) end end)
@test s1.val == 10
end

Expand Down
6 changes: 3 additions & 3 deletions test/test_continuousexamples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@

output = Tensor(SparseByteMap{Int64}(Pattern(), shape_id))

def = @finch_kernel mode=fastfinch function rangequery(output, box, points)
def = @finch_kernel mode=:fast function rangequery(output, box, points)
output .= false
for x=_, y=_, id=_
output[id] |= box[y,x] && points[id,y,x]
end
end

radius=ox=oy=0.0 #placeholder
def2 = @finch_kernel mode=fastfinch function radiusquery(output, points, radius, ox, oy)
def2 = @finch_kernel mode=:fast function radiusquery(output, points, radius, ox, oy)
output .= false
for x=realextent(ox-radius,ox+radius), y=realextent(oy-radius,oy+radius)
if (x-ox)^2 + (y-oy)^2 <= radius^2
Expand Down Expand Up @@ -164,7 +164,7 @@
ox=oy=oz=0.1

#Main Kernel
@finch mode=fastfinch begin
@finch mode=:fast begin
output .= 0
for t=_
if timeray[t]
Expand Down
33 changes: 16 additions & 17 deletions test/test_issues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using SparseArrays
Tensor(Dense(Dense(Element(0)))),
]
E = deepcopy(D)
@finch mode=fastfinch begin
@finch mode=:fast begin
D .= 0
E .= 0
for j = _, i = _
Expand Down Expand Up @@ -127,7 +127,6 @@ using SparseArrays
return a+b+c
end
struct MyAlgebra115 <: Finch.AbstractAlgebra end
Finch.virtualize(::Finch.JuliaContext, ex, ::Type{MyAlgebra115}) = MyAlgebra115()
t = Tensor(SparseList(SparseList(Element(0.0))))
B = SparseMatrixCSC([0 0 0 0; -1 -1 -1 -1; -2 -2 -2 -2; -3 -3 -3 -3])
A = dropdefaults(copyto!(Tensor(SparseList(SparseList(Element(0.0)))), B))
Expand Down Expand Up @@ -372,7 +371,7 @@ using SparseArrays
let
C = Tensor(Dense(Dense(Element(0.0))), [1 0; 0 1])
w = Tensor(Dense(Dense(Element(0.0))), [0 0; 0 0])
@finch mode=fastfinch begin
@finch mode=:fast begin
for j = _, i = _
C[i, j] += 1
end
Expand All @@ -390,40 +389,40 @@ using SparseArrays
let
A = [1 2 3; 4 5 6; 7 8 9]
x = Scalar(0.0)
@finch mode=fastfinch for j=_, i=_; if i < j x[] += A[i, j] end end
@finch mode=:fast for j=_, i=_; if i < j x[] += A[i, j] end end
@test x[] == 11.0

@finch mode=fastfinch (x .= 0; for i=_, j=_; if i < j x[] += A[j, i] end end)
@finch mode=:fast (x .= 0; for i=_, j=_; if i < j x[] += A[j, i] end end)
@test x[] == 19.0

@finch mode=fastfinch (x .= 0; for j=_, i=_; if i <= j x[] += A[i, j] end end)
@finch mode=:fast (x .= 0; for j=_, i=_; if i <= j x[] += A[i, j] end end)
@test x[] == 26.0

@finch mode=fastfinch (x .= 0; for i=_, j=_; if i <= j x[] += A[j, i] end end)
@finch mode=:fast (x .= 0; for i=_, j=_; if i <= j x[] += A[j, i] end end)
@test x[] == 34.0

@finch mode=fastfinch (x .= 0; for j=_, i=_; if i > j x[] += A[i, j] end end)
@finch mode=:fast (x .= 0; for j=_, i=_; if i > j x[] += A[i, j] end end)
@test x[] == 19.0

@finch mode=fastfinch (x .= 0; for i=_, j=_; if i > j x[] += A[j, i] end end)
@finch mode=:fast (x .= 0; for i=_, j=_; if i > j x[] += A[j, i] end end)
@test x[] == 11.0

@finch mode=fastfinch (x .= 0; for j=_, i=_; if i >= j x[] += A[i, j] end end)
@finch mode=:fast (x .= 0; for j=_, i=_; if i >= j x[] += A[i, j] end end)
@test x[] == 34.0

@finch mode=fastfinch (x .= 0; for i=_, j=_; if i >= j x[] += A[j, i] end end)
@finch mode=:fast (x .= 0; for i=_, j=_; if i >= j x[] += A[j, i] end end)
@test x[] == 26.0

@finch mode=fastfinch (x .= 0; for j=_, i=_; if i == j x[] += A[i, j] end end)
@finch mode=:fast (x .= 0; for j=_, i=_; if i == j x[] += A[i, j] end end)
@test x[] == 15.0

@finch mode=fastfinch (x .= 0; for i=_, j=_; if i == j x[] += A[j, i] end end)
@finch mode=:fast (x .= 0; for i=_, j=_; if i == j x[] += A[j, i] end end)
@test x[] == 15.0

@finch mode=fastfinch (x .= 0; for j=_, i=_; if i != j x[] += A[i, j] end end)
@finch mode=:fast (x .= 0; for j=_, i=_; if i != j x[] += A[i, j] end end)
@test x[] == 30.0

@finch mode=fastfinch (x .= 0; for i=_, j=_; if i != j x[] += A[j, i] end end)
@finch mode=:fast (x .= 0; for i=_, j=_; if i != j x[] += A[j, i] end end)
@test x[] == 30.0
end

Expand All @@ -439,7 +438,7 @@ using SparseArrays
A = zeros(3, 3, 3)
C = zeros(3, 3, 3)
X = zeros(3, 3)
@test check_output("issues/issue288_concordize_let.jl", @finch_code mode=fastfinch begin
@test check_output("issues/issue288_concordize_let.jl", @finch_code mode=:fast begin
for k=_, j=_, i=_
let temp1 = X[i, j]
for l=_
Expand All @@ -452,7 +451,7 @@ using SparseArrays
end
end
end)
@test check_output("issues/issue288_concordize_double_let.jl", @finch_code mode=fastfinch begin
@test check_output("issues/issue288_concordize_double_let.jl", @finch_code mode=:fast begin
for k=_, j=_, i=_
let temp1 = X[i, j]
for l=_
Expand Down
Loading

0 comments on commit 0ffa443

Please sign in to comment.