diff --git a/Project.toml b/Project.toml index 582e71162..c7dbb04e6 100644 --- a/Project.toml +++ b/Project.toml @@ -67,11 +67,11 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" [targets] test = ["ReTestItems", "Test", "ArgParse", "LinearAlgebra", "Random", "SparseArrays", "Graphs", "SimpleWeightedGraphs", "HDF5", "NPZ", "Pkg", "TensorMarket", "Documenter"] diff --git a/docs/src/docs/internals/virtualization.md b/docs/src/docs/internals/virtualization.md index 0218d992c..f3a9181ad 100644 --- a/docs/src/docs/internals/virtualization.md +++ b/docs/src/docs/internals/virtualization.md @@ -49,9 +49,9 @@ result to clean it up): ```jldoctest example1; filter=r"Finch\.FinchNotation\." julia> (@macroexpand @finch (C .= 0; for i=_; C[i] = A[i] * B[i] end)) |> Finch.striplines |> Finch.regensym quote - _res_1 = (Finch.execute)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.declare_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(0)), begin + _res_1 = (Finch.execute)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.declare_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(0), (Finch.FinchNotation.literal_instance)(Finch.auto)), begin let i = index_instance(i) - (Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Dimensionless(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(Finch.FinchNotation.Updater()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(Finch.FinchNotation.initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i)))))) + (Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Auto(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), (Finch.FinchNotation.updater_instance)((Finch.FinchNotation.literal_instance)(initwrite)), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), (Finch.FinchNotation.reader_instance)(), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), (Finch.FinchNotation.reader_instance)(), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i)))))) end end), (Finch.FinchNotation.yieldbind_instance)(variable_instance(:C))); ) begin @@ -76,7 +76,7 @@ julia> using Finch: @finch_program_instance julia> prgm = Finch.@finch_program_instance (C .= 0; for i=_; C[i] = A[i] * B[i] end; return C) Finch program instance: begin tag(C, Tensor(SparseList(Element(0)))) .= 0 - for i = Dimensionless() + for i = Auto() tag(C, Tensor(SparseList(Element(0))))[tag(i, i)] <>= tag(*, *)(tag(A, Tensor(SparseList(Element(0))))[tag(i, i)], tag(B, Tensor(Dense(Element(0))))[tag(i, i)]) end return (tag(C, Tensor(SparseList(Element(0))))) @@ -91,7 +91,7 @@ different inputs, but the same program type. We can run our program using ```jldoctest example1; filter=r"Finch\.FinchNotation\." julia> typeof(prgm) -Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}} +Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Auto()}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Auto, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.UpdaterInstance{Finch.FinchNotation.LiteralInstance{initwrite}}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}} julia> C = Finch.execute(prgm).C 5 Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}: @@ -164,15 +164,15 @@ julia> inst = Finch.@finch_program_instance begin s[] += A[i] end end -Finch program instance: for i = Dimensionless() +Finch program instance: for i = Auto() tag(s, Scalar{0, Int64})[] <>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)] end julia> typeof(inst) -Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}} +Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Auto, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.UpdaterInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}}, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}} julia> Finch.virtualize(Finch.JuliaContext(), :inst, typeof(inst)) -Finch program: for i = virtual(Finch.FinchNotation.Dimensionless) +Finch program: for i = virtual(Finch.FinchNotation.Auto) tag(s, virtual(Finch.VirtualScalar))[] <>= tag(A, virtual(Finch.VirtualFiber{Finch.VirtualSparseListLevel}))[tag(i, i)] end @@ -289,10 +289,10 @@ julia> prgm_inst = Finch.@finch_program_instance for i = _ end; julia> println(prgm_inst) -loop_instance(index_instance(i), Finch.FinchNotation.Dimensionless(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), literal_instance(Finch.FinchNotation.Updater())), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal_instance(Finch.FinchNotation.Reader()), tag_instance(variable_instance(:i), index_instance(i))))) +loop_instance(index_instance(i), Finch.FinchNotation.Auto(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), updater_instance(tag_instance(variable_instance(:+), literal_instance(+)))), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader_instance(), tag_instance(variable_instance(:i), index_instance(i))))) julia> prgm_inst -Finch program instance: for i = Dimensionless() +Finch program instance: for i = Auto() tag(s, Scalar{0, Int64})[] <>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)] end @@ -300,12 +300,11 @@ julia> prgm = Finch.@finch_program for i = _ s[] += A[i] end; - julia> println(prgm) -loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Finch.FinchNotation.Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Finch.FinchNotation.Reader()), index(i)))) +loop(index(i), virtual(Finch.FinchNotation.Auto()), assign(access(literal(Scalar{0, Int64}(0)), updater(literal(+))), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader(), index(i)))) julia> prgm -Finch program: for i = virtual(Finch.FinchNotation.Dimensionless) +Finch program: for i = virtual(Finch.FinchNotation.Auto) Scalar{0, Int64}(0)[] <<+>>= Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))[i] end @@ -319,9 +318,8 @@ representations, so you can use the standard `operation`, `arguments`, `istree`, ```jldoctest example2; setup = :(using Finch, AbstractTrees, SyntaxInterface, RewriteTools) julia> using Finch.FinchNotation; - julia> PostOrderDFS(prgm) -PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Reader()), index(i))))) +PostOrderDFS{FinchNode}(loop(index(i), virtual(Auto()), assign(access(literal(Scalar{0, Int64}(0)), updater(literal(+))), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader(), index(i))))) julia> (@capture prgm loop(~idx, ~ext, ~val)) true diff --git a/docs/src/docs/language/dimensionalization.md b/docs/src/docs/language/dimensionalization.md index a3dacdbd4..fa15b99bc 100644 --- a/docs/src/docs/language/dimensionalization.md +++ b/docs/src/docs/language/dimensionalization.md @@ -53,5 +53,5 @@ The rules of declaration dimensionalization are as follows: - The new dimensions of the declared tensor are used when the tensor is on the right hand side (reading) access. ```@docs -Finch.FinchNotation.Dimensionless +Finch.FinchNotation.Auto ``` \ No newline at end of file diff --git a/docs/src/docs/language/finch_language.md b/docs/src/docs/language/finch_language.md index 015719e19..ef1d62e3f 100644 --- a/docs/src/docs/language/finch_language.md +++ b/docs/src/docs/language/finch_language.md @@ -80,6 +80,8 @@ Tensors must enter and exit scope in read mode. Finch inserts Tensor lifecycle statements consist of: ```@docs +reader +updater declare freeze thaw diff --git a/ext/SparseArraysExt.jl b/ext/SparseArraysExt.jl index 5059704c5..8ba444693 100644 --- a/ext/SparseArraysExt.jl +++ b/ext/SparseArraysExt.jl @@ -220,7 +220,7 @@ end #Finch.is_atomic(ctx, tns::VirtualSparseMatrixCSCColumn) = is_atomic(ctx, tns.mtx)[1] #Finch.is_concurrent(ctx, tns::VirtualSparseMatrixCSCColumn) = is_concurrent(ctx, tns.mtx)[1] -function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) tag = arr.ex Unfurled( arr = arr, @@ -230,7 +230,7 @@ function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, m ) end -function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) tag = arr.ex Unfurled( arr = arr, @@ -383,7 +383,7 @@ function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseVector) return arr end -function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) tag = arr.ex Ti = arr.Ti my_i = freshen(ctx, tag, :_i) @@ -439,7 +439,7 @@ function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode ) end -function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) tag = arr.ex Tp = arr.Ti qos = freshen(ctx, tag, :_qos) diff --git a/src/Finch.jl b/src/Finch.jl index 9cf486b49..1b6fee820 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -56,7 +56,7 @@ export choose, minby, maxby, overwrite, initwrite, filterop, d export fill_value, AsArray, expanddims, tensor_tree export parallelAnalysis, ParallelAnalysisResults -export parallel, realextent, extent, dimless +export parallel, realextent, extent, auto export CPU, CPULocalVector, CPULocalMemory export Limit, Eps diff --git a/src/FinchNotation/FinchNotation.jl b/src/FinchNotation/FinchNotation.jl index 21455857e..5f75532d4 100644 --- a/src/FinchNotation/FinchNotation.jl +++ b/src/FinchNotation/FinchNotation.jl @@ -14,7 +14,7 @@ module FinchNotation export tag export call export cached - export reader, Reader, updater, Updater, access + export reader, updater, access export define, declare, thaw, freeze export block export protocol @@ -38,7 +38,7 @@ module FinchNotation export getval, getname - export overwrite, initwrite, Dimensionless, dimless, extent, realextent + export overwrite, initwrite, Auto, auto, extent, realextent export d diff --git a/src/FinchNotation/instances.jl b/src/FinchNotation/instances.jl index fbc4b17de..7d522fbb1 100644 --- a/src/FinchNotation/instances.jl +++ b/src/FinchNotation/instances.jl @@ -4,15 +4,17 @@ struct LiteralInstance{val} <: FinchNodeInstance end struct IndexInstance{name} <: FinchNodeInstance end struct VariableInstance{name} <: FinchNodeInstance end struct DefineInstance{Lhs, Rhs, Body} <: FinchNodeInstance lhs::Lhs; rhs::Rhs; body::Body end -struct DeclareInstance{Tns, Init} <: FinchNodeInstance tns::Tns; init::Init end -struct FreezeInstance{Tns} <: FinchNodeInstance tns::Tns end -struct ThawInstance{Tns} <: FinchNodeInstance tns::Tns end +struct DeclareInstance{Tns, Init, Op} <: FinchNodeInstance tns::Tns; init::Init; op::Op end +struct FreezeInstance{Tns, Op} <: FinchNodeInstance tns::Tns; op::Op end +struct ThawInstance{Tns, Op} <: FinchNodeInstance tns::Tns; op::Op end struct BlockInstance{Bodies} <: FinchNodeInstance bodies::Bodies end struct LoopInstance{Idx, Ext, Body} <: FinchNodeInstance idx::Idx; ext::Ext; body::Body end struct SieveInstance{Cond, Body} <: FinchNodeInstance cond::Cond; body::Body end struct AssignInstance{Lhs, Op, Rhs} <: FinchNodeInstance lhs::Lhs; op::Op; rhs::Rhs end struct CallInstance{Op, Args<:Tuple} <: FinchNodeInstance op::Op; args::Args end struct AccessInstance{Tns, Mode, Idxs} <: FinchNodeInstance tns::Tns; mode::Mode; idxs::Idxs end +struct ReaderInstance{} <: FinchNodeInstance end +struct UpdaterInstance{Op} <: FinchNodeInstance op::Op end struct TagInstance{Var, Bind} <: FinchNodeInstance var::Var; bind::Bind end struct YieldBindInstance{Args} <: FinchNodeInstance args::Args end @@ -24,9 +26,9 @@ Base.getproperty(::VariableInstance{val}, name::Symbol) where {val} = name == :n @inline index_instance(name) = IndexInstance{name}() @inline variable_instance(name) = VariableInstance{name}() @inline define_instance(lhs, rhs, body) = DefineInstance(lhs, rhs, body) -@inline declare_instance(tns, init) = DeclareInstance(tns, init) -@inline freeze_instance(tns) = FreezeInstance(tns) -@inline thaw_instance(tns) = ThawInstance(tns) +@inline declare_instance(tns, init, op) = DeclareInstance(tns, init, op) +@inline freeze_instance(tns, op) = FreezeInstance(tns, op) +@inline thaw_instance(tns, op) = ThawInstance(tns, op) @inline block_instance(bodies...) = BlockInstance(bodies) @inline loop_instance(idx, ext, body) = LoopInstance(idx, ext, body) @inline loop_instance(body) = body @@ -36,14 +38,14 @@ Base.getproperty(::VariableInstance{val}, name::Symbol) where {val} = name == :n @inline assign_instance(lhs, op, rhs) = AssignInstance(lhs, op, rhs) @inline call_instance(op, args...) = CallInstance(op, args) @inline access_instance(tns, mode, idxs...) = AccessInstance(tns, mode, idxs) +@inline reader_instance() = ReaderInstance() +@inline updater_instance(op) = UpdaterInstance(op) @inline tag_instance(var, bind) = TagInstance(var, bind) @inline yieldbind_instance(args...) = YieldBindInstance(args) @inline finch_leaf_instance(arg::Type) = literal_instance(arg) @inline finch_leaf_instance(arg::Function) = literal_instance(arg) @inline finch_leaf_instance(arg::FinchNodeInstance) = arg -@inline finch_leaf_instance(arg::Reader) = literal_instance(arg) -@inline finch_leaf_instance(arg::Updater) = literal_instance(arg) @inline finch_leaf_instance(arg) = arg SyntaxInterface.istree(node::FinchNodeInstance) = Int(operation(node)) & IS_TREE != 0 @@ -63,6 +65,8 @@ instance_ctrs = Dict( assign => assign_instance, call => call_instance, access => access_instance, + reader => reader_instance, + updater => updater_instance, variable => variable_instance, tag => tag_instance, yieldbind => yieldbind_instance, @@ -84,12 +88,14 @@ SyntaxInterface.operation(::SieveInstance) = sieve SyntaxInterface.operation(::AssignInstance) = assign SyntaxInterface.operation(::CallInstance) = call SyntaxInterface.operation(::AccessInstance) = access +SyntaxInterface.operation(::ReaderInstance) = reader +SyntaxInterface.operation(::UpdaterInstance) = updater SyntaxInterface.operation(::VariableInstance) = variable SyntaxInterface.operation(::TagInstance) = tag SyntaxInterface.operation(::YieldBindInstance) = yieldbind SyntaxInterface.arguments(node::DefineInstance) = [node.lhs, node.rhs, node.body] -SyntaxInterface.arguments(node::DeclareInstance) = [node.tns, node.init] +SyntaxInterface.arguments(node::DeclareInstance) = [node.tns, node.init, node.op] SyntaxInterface.arguments(node::FreezeInstance) = [node.tns] SyntaxInterface.arguments(node::ThawInstance) = [node.tns] SyntaxInterface.arguments(node::BlockInstance) = node.bodies @@ -98,6 +104,8 @@ SyntaxInterface.arguments(node::SieveInstance) = [node.cond, node.body] SyntaxInterface.arguments(node::AssignInstance) = [node.lhs, node.op, node.rhs] SyntaxInterface.arguments(node::CallInstance) = [node.op, node.args...] SyntaxInterface.arguments(node::AccessInstance) = [node.tns, node.mode, node.idxs...] +SyntaxInterface.arguments(node::ReaderInstance) = [] +SyntaxInterface.arguments(node::UpdaterInstance) = [node.op] SyntaxInterface.arguments(node::TagInstance) = [node.var, node.bind] SyntaxInterface.arguments(node::YieldBindInstance) = node.args diff --git a/src/FinchNotation/nodes.jl b/src/FinchNotation/nodes.jl index 71bf4f1fb..11a2c377c 100644 --- a/src/FinchNotation/nodes.jl +++ b/src/FinchNotation/nodes.jl @@ -1,9 +1,3 @@ -struct Reader end -struct Updater end - -const reader = Reader() -const updater = Updater() - const IS_TREE = 1 const IS_STATEFUL = 2 const IS_CONST = 4 @@ -18,6 +12,8 @@ const ID = 8 tag = 5ID | IS_TREE call = 6ID | IS_TREE access = 7ID | IS_TREE + reader = 8ID | IS_TREE + updater = 9ID | IS_TREE cached = 10ID | IS_TREE assign = 11ID | IS_TREE | IS_STATEFUL loop = 12ID | IS_TREE | IS_STATEFUL @@ -97,6 +93,23 @@ access is in-place. """ access +""" + reader() + +Finch AST expression representing a read-only mode for a tensor access. Declare, +freeze, and thaw statements can change the mode of a tensor. +""" +reader + +""" + updater(op) + +Finch AST expression representing an update-only mode for a tensor access, using +the reduction operator `op`. Declare, freeze, and thaw statements can change +the mode of a tensor. +""" +updater + """ cached(val, ref) @@ -138,23 +151,25 @@ A new scope is introduced to evaluate `body`. define """ - declare(tns, init) + declare(tns, init, op) -Finch AST statement that declares `tns` with an initial value `init` in the current scope. +Finch AST statement that declares `tns` with an initial value `init` reduced with `op` in the current scope. """ declare """ - freeze(tns) + freeze(tns, op) -Finch AST statement that freezes `tns` in the current scope. +Finch AST statement that freezes `tns` in the current scope after modifications +with `op`, moving the tensor from update-only mode to read-only mode. """ freeze """ - thaw(tns) + thaw(tns, op) -Finch AST statement that thaws `tns` in the current scope. +Finch AST statement that thaws `tns` in the current scope, moving the tensor from +read-only mode to update-only mode with a reduction operator `op`. """ thaw @@ -259,6 +274,8 @@ function FinchNode(kind::FinchNodeKind, args::Vector) elseif (kind === value || kind === literal || kind === index || kind === variable || kind === virtual) && length(args) == 2 return FinchNode(kind, args[1], args[2], FinchNode[]) elseif (kind === cached && length(args) == 2) || + (kind === reader && length(args) == 0) || + (kind === updater && length(args) == 1) || (kind === access && length(args) >= 2) || (kind === tag && length(args) == 2) || (kind === call && length(args) >= 1) || @@ -266,9 +283,9 @@ function FinchNode(kind::FinchNodeKind, args::Vector) (kind === sieve && length(args) == 2) || (kind === assign && length(args) == 3) || (kind === define && length(args) == 3) || - (kind === declare && length(args) == 2) || - (kind === freeze && length(args) == 1) || - (kind === thaw && length(args) == 1) || + (kind === declare && length(args) == 3) || + (kind === freeze && length(args) == 2) || + (kind === thaw && length(args) == 2) || (kind === block) || (kind === yieldbind) return FinchNode(kind, nothing, nothing, args) @@ -288,6 +305,7 @@ function Base.getproperty(node::FinchNode, sym::Symbol) elseif node.kind === variable && sym === :name node.val::Symbol elseif node.kind === tag && sym === :var node.children[1] elseif node.kind === tag && sym === :bind node.children[2] + elseif node.kind === updater && sym === :op node.children[1] elseif node.kind === access && sym === :tns node.children[1] elseif node.kind === access && sym === :mode node.children[2] elseif node.kind === access && sym === :idxs @view node.children[3:end] @@ -308,8 +326,11 @@ function Base.getproperty(node::FinchNode, sym::Symbol) elseif node.kind === define && sym === :body node.children[3] elseif node.kind === declare && sym === :tns node.children[1] elseif node.kind === declare && sym === :init node.children[2] + elseif node.kind === declare && sym === :op node.children[3] elseif node.kind === freeze && sym === :tns node.children[1] + elseif node.kind === freeze && sym === :op node.children[2] elseif node.kind === thaw && sym === :tns node.children[1] + elseif node.kind === thaw && sym === :op node.children[2] elseif node.kind === block && sym === :bodies node.children elseif node.kind === yieldbind && sym === :args node.children else @@ -392,8 +413,6 @@ virtual. finch_leaf(arg) = literal(arg) finch_leaf(arg::Type) = literal(arg) finch_leaf(arg::Function) = literal(arg) -finch_leaf(arg::Reader) = literal(arg) -finch_leaf(arg::Updater) = literal(arg) finch_leaf(arg::FinchNode) = arg Base.convert(::Type{FinchNode}, x) = finch_leaf(x) diff --git a/src/FinchNotation/syntax.jl b/src/FinchNotation/syntax.jl index 9d10b766d..14c015d67 100644 --- a/src/FinchNotation/syntax.jl +++ b/src/FinchNotation/syntax.jl @@ -14,13 +14,13 @@ const program_nodes = ( call = call, access = access, yieldbind = yieldbind, - reader = literal(reader), - updater = literal(updater), + reader = reader, + updater = updater, variable = variable, tag = (ex) -> :(finch_leaf($(esc(ex)))), literal = literal, leaf = (ex) -> :(finch_leaf($(esc(ex)))), - dimless = :(finch_leaf(dimless)) + auto = :(finch_leaf(auto)) ) const instance_nodes = ( @@ -36,13 +36,13 @@ const instance_nodes = ( call = call_instance, access = access_instance, yieldbind = yieldbind_instance, - reader = literal_instance(reader), - updater = literal_instance(updater), + reader = reader_instance, + updater = updater_instance, variable = variable_instance, tag = (ex) -> :($tag_instance($(variable_instance(ex)), $finch_leaf_instance($(esc(ex))))), literal = literal_instance, leaf = (ex) -> :($finch_leaf_instance($(esc(ex)))), - dimless = :($finch_leaf_instance(dimless)) + auto = :($finch_leaf_instance(auto)) ) d() = 1 @@ -98,15 +98,15 @@ julia> x[] overwrite(l, r) = r """ - Dimensionless() + Auto() A singleton type representing the lack of a dimension. This is used in place of a dimension when we want to avoid dimensionality checks. In the `@finch` macro, -you can write `Dimensionless()` with an underscore as `for i = _`, allowing +you can write `Auto()` with an underscore as `for i = _`, allowing finch to pick up the loop bounds from the tensors automatically. """ -struct Dimensionless end -const dimless = Dimensionless() +struct Auto end +const auto = Auto() function extent end function realextent end @@ -116,7 +116,7 @@ end function (ctx::FinchParserVisitor)(ex::Symbol) if ex == :_ || ex == :(:) - return :($dimless) + return :($auto) elseif ex in evaluable_exprs return ctx.nodes.literal(@eval($ex)) else @@ -137,11 +137,13 @@ function (ctx::FinchParserVisitor)(ex::Expr) elseif @capture ex :elseif(~args...) throw(FinchSyntaxError("Finch does not support elseif.")) elseif @capture ex :(.=)(~tns, ~init) - return :($(ctx.nodes.declare)($(ctx(tns)), $(ctx(init)))) - elseif @capture ex :macrocall($(Symbol("@freeze")), ~ln::islinenum, ~tns) - return :($(ctx.nodes.freeze)($(ctx(tns)))) - elseif @capture ex :macrocall($(Symbol("@thaw")), ~ln::islinenum, ~tns) - return :($(ctx.nodes.thaw)($(ctx(tns)))) + return :($(ctx.nodes.declare)($(ctx(tns)), $(ctx(init)), $(ctx.nodes.literal)(auto))) + elseif @capture ex :macrocall($(Symbol("@declare")), ~ln::islinenum, ~tns, ~init, ~op) + return :($(ctx.nodes.declare)($(ctx(tns)), $(ctx(init)), $(ctx(op)))) + elseif @capture ex :macrocall($(Symbol("@freeze")), ~ln::islinenum, ~tns, ~op) + return :($(ctx.nodes.freeze)($(ctx(tns)), $(ctx(op)))) + elseif @capture ex :macrocall($(Symbol("@thaw")), ~ln::islinenum, ~tns, ~op) + return :($(ctx.nodes.thaw)($(ctx(tns)), $(ctx(op)))) elseif @capture ex :for(:block(), ~body) return ctx(body) elseif @capture ex :for(:block(:(=)(~idx, ~ext), ~tail...), ~body) @@ -199,17 +201,17 @@ function (ctx::FinchParserVisitor)(ex::Expr) return :($(ctx.nodes.yieldbind)($(ctx(arg)))) elseif @capture ex :ref(~tns, ~idxs...) mode = ctx.nodes.reader - return :($(ctx.nodes.access)($(ctx(tns)), $mode, $(map(ctx, idxs)...))) + return :($(ctx.nodes.access)($(ctx(tns)), $mode(), $(map(ctx, idxs)...))) elseif (@capture ex (~op)(~lhs, ~rhs)) && haskey(incs, op) return ctx(:($lhs << $(incs[op]) >>= $rhs)) elseif @capture ex :(=)(:ref(~tns, ~idxs...), ~rhs) mode = ctx.nodes.updater - lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode, $(map(ctx, idxs)...))) op = :($(ctx.nodes.literal)($initwrite)) + lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode($op), $(map(ctx, idxs)...))) return :($(ctx.nodes.assign)($lhs, $op, $(ctx(rhs)))) elseif @capture ex :>>=(:call(:<<, :ref(~tns, ~idxs...), ~op), ~rhs) mode = ctx.nodes.updater - lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode, $(map(ctx, idxs)...))) + lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode($(ctx(op))), $(map(ctx, idxs)...))) return :($(ctx.nodes.assign)($lhs, $(ctx(op)), $(ctx(rhs)))) elseif @capture ex :>>=(:call(:<<, ~lhs, ~op), ~rhs) error("Finch doesn't support incrementing definitions of variables") @@ -406,17 +408,31 @@ function display_statement(io, mime, node::Union{FinchNode, FinchNodeInstance}, print(io, ">>= ") display_expression(io, mime, node.rhs) elseif operation(node) === declare - print(io, " "^indent) - display_expression(io, mime, node.tns) - print(io, " .= ") - display_expression(io, mime, node.init) + if operation(node.op) === literal && node.op.val === auto + print(io, " "^indent) + display_expression(io, mime, node.tns) + print(io, " .= ") + display_expression(io, mime, node.init) + else + print(io, " "^indent * "@declare(") + display_expression(io, mime, node.tns) + print(io, ", ") + display_expression(io, mime, node.init) + print(io, ", ") + display_expression(io, mime, node.op) + print(io, ")") + end elseif operation(node) === freeze print(io, " "^indent * "@freeze(") display_expression(io, mime, node.tns) + print(io, ", ") + display_expression(io, mime, node.op) print(io, ")") elseif operation(node) === thaw print(io, " "^indent * "@thaw(") display_expression(io, mime, node.tns) + print(io, ", ") + display_expression(io, mime, node.op) print(io, ")") elseif operation(node) === yieldbind print(io, " "^indent * "return (") @@ -457,7 +473,7 @@ function finch_unparse_program(ctx, node::Union{FinchNode, FinchNodeInstance}) @assert operation(node.var) === variable node.var.name elseif operation(node) === virtual - if node.val == dimless + if node.val == auto :_ else ctx(node) @@ -496,10 +512,16 @@ function finch_unparse_program(ctx, node::Union{FinchNode, FinchNodeInstance}) elseif operation(node) === declare tns = finch_unparse_program(ctx, node.tns) init = finch_unparse_program(ctx, node.init) - :($tns .= $init) + if operation(node.op) === literal && node.op.val === auto + :($tns .= $init) + else + op = finch_unparse_program(ctx, node.op) + :(@declare($tns, $init, $op)) + end elseif operation(node) === freeze tns = finch_unparse_program(ctx, node.tns) - :(@freeze($tns)) + op = finch_unparse_program(ctx, node.op) + :(@freeze($tns, $op)) elseif operation(node) === thaw tns = finch_unparse_program(ctx, node.tns) :(@thaw($tns)) diff --git a/src/FinchNotation/virtualize.jl b/src/FinchNotation/virtualize.jl index 824c5bc2e..e546a5e89 100644 --- a/src/FinchNotation/virtualize.jl +++ b/src/FinchNotation/virtualize.jl @@ -7,9 +7,9 @@ function Finch.virtualize(ctx, ex, ::Type{FinchNotation.IndexInstance{name}}) wh index(name) end Finch.virtualize(ctx, ex, ::Type{FinchNotation.DefineInstance{Lhs, Rhs, Body}}) where {Lhs, Rhs, Body} = define(virtualize(ctx, :($ex.lhs), Lhs), virtualize(ctx, :($ex.rhs), Rhs), virtualize(ctx, :($ex.body), Body)) -Finch.virtualize(ctx, ex, ::Type{FinchNotation.DeclareInstance{Tns, Init}}) where {Tns, Init} = declare(virtualize(ctx, :($ex.tns), Tns), virtualize(ctx, :($ex.init), Init)) -Finch.virtualize(ctx, ex, ::Type{FinchNotation.FreezeInstance{Tns}}) where {Tns} = freeze(virtualize(ctx, :($ex.tns), Tns)) -Finch.virtualize(ctx, ex, ::Type{FinchNotation.ThawInstance{Tns}}) where {Tns} = thaw(virtualize(ctx, :($ex.tns), Tns)) +Finch.virtualize(ctx, ex, ::Type{FinchNotation.DeclareInstance{Tns, Init, Op}}) where {Tns, Init, Op} = declare(virtualize(ctx, :($ex.tns), Tns), virtualize(ctx, :($ex.init), Init), virtualize(ctx, :($ex.op), Op)) +Finch.virtualize(ctx, ex, ::Type{FinchNotation.FreezeInstance{Tns, Op}}) where {Tns, Op} = freeze(virtualize(ctx, :($ex.tns), Tns), virtualize(ctx, :($ex.op), Op)) +Finch.virtualize(ctx, ex, ::Type{FinchNotation.ThawInstance{Tns, Op}}) where {Tns, Op} = thaw(virtualize(ctx, :($ex.tns), Tns), virtualize(ctx, :($ex.op), Op)) function Finch.virtualize(ctx, ex, ::Type{FinchNotation.BlockInstance{Bodies}}) where {Bodies} bodies = map(enumerate(Bodies.parameters)) do (n, Body) virtualize(ctx, :($ex.bodies[$n]), Body) @@ -44,6 +44,11 @@ function Finch.virtualize(ctx, ex, ::Type{FinchNotation.AccessInstance{Tns, Mode end access(tns, virtualize(ctx, :($ex.mode), Mode), idxs...) end +Finch.virtualize(ctx, ex, ::Type{FinchNotation.ReaderInstance}) = reader() +function Finch.virtualize(ctx, ex, ::Type{FinchNotation.UpdaterInstance{Op}}) where {Op} + op = virtualize(ctx, :($ex.op), Op) + updater(op) +end Finch.virtualize(ctx, ex, ::Type{FinchNotation.VariableInstance{tag}}) where {tag} = variable(tag) function Finch.virtualize(ctx, ex, ::Type{FinchNotation.TagInstance{Var, Bind}}) where {Var, Bind} var = virtualize(ctx, :($ex.var), Var) diff --git a/src/Galley/ExecutionEngine/execution-engine.jl b/src/Galley/ExecutionEngine/execution-engine.jl index e8603ebd2..3189134f0 100644 --- a/src/Galley/ExecutionEngine/execution-engine.jl +++ b/src/Galley/ExecutionEngine/execution-engine.jl @@ -12,8 +12,7 @@ function initialize_access(tensor_id::Symbol, tensor, index_ids, protocols, inde return literal_instance(tensor) end - mode = read ? Reader() : Updater() - mode = literal_instance(mode) + mode = read ? reader_instance() : updater_instance(auto) index_expressions = [] for i in range(1, length(index_ids)) index = if cannonicalize @@ -127,12 +126,12 @@ function execute_query(alias_dict, q::PlanNode, verbose, cannonicalize, return_p read=false, cannonicalize=cannonicalize) dec_instance = declare_instance(variable_instance(output_name), - literal_instance(output_default)) + literal_instance(output_default), literal_instance(auto)) prgm_instance = assign_instance(output_access, literal_instance(agg_op), rhs_instance) loop_order = [cannonicalize ? index_instance(index_sym_dict[i]) : index_instance(i) for i in loop_order] for index in reverse(loop_order) - prgm_instance = loop_instance(index, Dimensionless(), prgm_instance) + prgm_instance = loop_instance(index, Auto(), prgm_instance) end prgm_instance = block_instance(dec_instance, prgm_instance) diff --git a/src/Galley/Galley.jl b/src/Galley/Galley.jl index d3f2ad06e..262a9255a 100644 --- a/src/Galley/Galley.jl +++ b/src/Galley/Galley.jl @@ -14,10 +14,10 @@ using Finch using Finch: Element, SparseListLevel, SparseDict, Dense, SparseCOO, fsparse_impl, compute_parse, isimmediate, set_options, flatten_plans using Finch.FinchNotation: index_instance, variable_instance, tag_instance, literal_instance, - access_instance, assign_instance, loop_instance, declare_instance, - block_instance, define_instance, call_instance, freeze_instance, - thaw_instance, finch_unparse_program, - Updater, Reader, Dimensionless + access_instance, reader_instance, updater_instance, assign_instance, + loop_instance, declare_instance, block_instance, define_instance, + call_instance, freeze_instance, thaw_instance, finch_unparse_program, + Auto using Finch.FinchLogic export galley diff --git a/src/Galley/utility-funcs.jl b/src/Galley/utility-funcs.jl index c7773f294..b46923e83 100644 --- a/src/Galley/utility-funcs.jl +++ b/src/Galley/utility-funcs.jl @@ -117,10 +117,10 @@ function get_sparsity_structure(tensor::Tensor) full_prgm = assign_instance(output_instance, literal_instance(initwrite(false)), tensor_instance) for index in indices - full_prgm = loop_instance(index_instance(index_sym_dict[index]), Dimensionless(), full_prgm) + full_prgm = loop_instance(index_instance(index_sym_dict[index]), Auto(), full_prgm) end - initializer = declare_instance(variable_instance(:output_tensor), literal_instance(false)) + initializer = declare_instance(variable_instance(:output_tensor), literal_instance(false), literal_instance(auto)) full_prgm = block_instance(initializer, full_prgm) Finch.execute(full_prgm) return output_tensor @@ -182,9 +182,9 @@ function one_off_reduce(op, full_prgm = assign_instance(output_access, op_instance, tensor_instance) for index in reverse(loop_index_instances) - full_prgm = loop_instance(index, Dimensionless(), full_prgm) + full_prgm = loop_instance(index, Auto(), full_prgm) end - initializer = declare_instance(output_variable, literal_instance(0.0)) + initializer = declare_instance(output_variable, literal_instance(0.0), literal_instance(auto)) full_prgm = block_instance(initializer, full_prgm) Finch.execute(full_prgm, mode=:fast) return output_tensor @@ -202,9 +202,9 @@ function count_non_default(A) prgm = assign_instance(count_access, literal_instance(+), prgm) loop_index_instances = [index_instance(index_sym_dict[idx]) for idx in reverse(indexes)] for idx in reverse(loop_index_instances) - prgm = loop_instance(idx, Dimensionless(), prgm) + prgm = loop_instance(idx, Auto(), prgm) end - prgm = block_instance(declare_instance(tag_instance(variable_instance(:count), count), literal_instance(0)), prgm) + prgm = block_instance(declare_instance(tag_instance(variable_instance(:count), count), literal_instance(0), literal_instance(auto)), prgm) Finch.execute(prgm, mode=:fast) return count[] end diff --git a/src/abstract_tensor.jl b/src/abstract_tensor.jl index 4e9448e41..483cfa9b6 100644 --- a/src/abstract_tensor.jl +++ b/src/abstract_tensor.jl @@ -58,7 +58,7 @@ function virtual_eltype end function virtual_resize!(ctx, tns, dims...) for (dim, ref) in zip(dims, virtual_size(ctx, tns)) - if dim !== dimless && ref !== dimless #TODO this should be a function like checkdim or something haha + if dim !== auto && ref !== auto #TODO this should be a function like checkdim or something haha push_preamble!(ctx, quote $(ctx(getstart(dim))) == $(ctx(getstart(ref))) || throw(DimensionMismatch("mismatched dimension start")) $(ctx(getstop(dim))) == $(ctx(getstop(ref))) || throw(DimensionMismatch("mismatched dimension stop")) diff --git a/src/dimensions.jl b/src/dimensions.jl index 00dbfcb9a..b3afad741 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -1,9 +1,9 @@ -FinchNotation.finch_leaf(x::Dimensionless) = virtual(x) -FinchNotation.finch_leaf_instance(x::Dimensionless) = value_instance(x) -virtualize(ctx, ex, ::Type{Dimensionless}) = dimless +FinchNotation.finch_leaf(x::Auto) = virtual(x) +FinchNotation.finch_leaf_instance(x::Auto) = value_instance(x) +virtualize(ctx, ex, ::Type{Auto}) = auto -getstart(::Dimensionless) = error("asked for start of dimensionless range") -getstop(::Dimensionless) = error("asked for stop of dimensionless range") +getstart(::Auto) = error("asked for start of dimensionless range") +getstop(::Auto) = error("asked for stop of dimensionless range") struct UnknownDimension end @@ -33,7 +33,7 @@ combinedim(ctx, ::B, ::A) """ combinedim(ctx, a, b) = UnknownDimension() -combinedim(ctx, a::Dimensionless, b) = b +combinedim(ctx, a::Auto, b) = b @kwdef struct Extent start @@ -85,7 +85,7 @@ combinedim(ctx, a::Extent, b::Extent) = stop = checklim(ctx, a.stop, b.stop) ) -combinedim(ctx, a::Dimensionless, b::Extent) = b +combinedim(ctx, a::Auto, b::Extent) = b struct SuggestedExtent ext @@ -97,7 +97,7 @@ Base.:(==)(a::SuggestedExtent, b::SuggestedExtent) = a.ext == b.ext suggest(ext) = SuggestedExtent(ext) suggest(ext::SuggestedExtent) = ext -suggest(ext::Dimensionless) = dimless +suggest(ext::Auto) = auto resolvedim(ext::Symbol) = error() resolvedim(ext::SuggestedExtent) = resolvedim(ext.ext) @@ -105,7 +105,7 @@ cache_dim!(ctx, tag, ext::SuggestedExtent) = SuggestedExtent(cache_dim!(ctx, tag combinedim(ctx, a::SuggestedExtent, b::Extent) = b -combinedim(ctx, a::SuggestedExtent, b::Dimensionless) = a +combinedim(ctx, a::SuggestedExtent, b::Auto) = a combinedim(ctx, a::SuggestedExtent, b::SuggestedExtent) = SuggestedExtent(combinedim(ctx, a.ext, b.ext)) @@ -184,7 +184,7 @@ function shiftdim(ext::ContinuousExtent, delta) end -shiftdim(ext::Dimensionless, delta) = dimless +shiftdim(ext::Auto, delta) = auto shiftdim(ext::ParallelDimension, delta) = ParallelDimension(ext, shiftdim(ext.ext, delta), ext.device) function shiftdim(ext::FinchNode, body) @@ -208,7 +208,7 @@ function scaledim(ext::ContinuousExtent, scale) ) end -scaledim(ext::Dimensionless, scale) = dimless +scaledim(ext::Auto, scale) = auto scaledim(ext::ParallelDimension, scale) = ParallelDimension(ext, scaledim(ext.ext, scale), ext.device) function scaledim(ext::FinchNode, body) @@ -227,9 +227,9 @@ function virtual_intersect(ctx, a, b) error() end -virtual_intersect(ctx, a::Dimensionless, b) = b -virtual_intersect(ctx, a, b::Dimensionless) = a -virtual_intersect(ctx, a::Dimensionless, b::Dimensionless) = b +virtual_intersect(ctx, a::Auto, b) = b +virtual_intersect(ctx, a, b::Auto) = a +virtual_intersect(ctx, a::Auto, b::Auto) = b function virtual_intersect(ctx, a::Extent, b::Extent) Extent( @@ -238,9 +238,9 @@ function virtual_intersect(ctx, a::Extent, b::Extent) ) end -virtual_union(ctx, a::Dimensionless, b) = b -virtual_union(ctx, a, b::Dimensionless) = a -virtual_union(ctx, a::Dimensionless, b::Dimensionless) = b +virtual_union(ctx, a::Auto, b) = b +virtual_union(ctx, a, b::Auto) = a +virtual_union(ctx, a::Auto, b::Auto) = b #virtual_union(ctx, a, b) = virtual_union(ctx, promote(a, b)...) function virtual_union(ctx, a::Extent, b::Extent) @@ -291,7 +291,7 @@ getstop(ext::FinchNode) = ext.kind === virtual ? getstop(ext.val) : ext measure(ext::ContinuousExtent) = call(-, ext.stop, ext.start) # TODO: Think carefully, Not quite sure! combinedim(ctx, a::ContinuousExtent, b::ContinuousExtent) = ContinuousExtent(checklim(ctx, a.start, b.start), checklim(ctx, a.stop, b.stop)) -combinedim(ctx, a::Dimensionless, b::ContinuousExtent) = b +combinedim(ctx, a::Auto, b::ContinuousExtent) = b combinedim(ctx, a::Extent, b::ContinuousExtent) = throw(ArgumentError("Extent and ContinuousExtent cannot interact ...yet")) combinedim(ctx, a::SuggestedExtent, b::ContinuousExtent) = b diff --git a/src/execute.jl b/src/execute.jl index c17bba248..c4c7f01bf 100644 --- a/src/execute.jl +++ b/src/execute.jl @@ -42,8 +42,8 @@ function (ctx::InstantiateTensors)(node::FinchNode) push!(ctx.escape, node.tns) node elseif (@capture node access(~tns, ~mode, ~idxs...)) && !(getroot(tns) in ctx.escape) - #@assert get(ctx.ctx.modes, tns, reader) === node.mode.val - tns_2 = instantiate(ctx.ctx, tns, mode.val) + #@assert get(ctx.ctx.modes, tns, reader()) === node.mode + tns_2 = instantiate(ctx.ctx, tns, mode) access(tns_2, mode, idxs...) elseif istree(node) return similarterm(node, operation(node), map(ctx, arguments(node))) diff --git a/src/interface/abstract_arrays.jl b/src/interface/abstract_arrays.jl index 5e8b40d95..d1be3783a 100644 --- a/src/interface/abstract_arrays.jl +++ b/src/interface/abstract_arrays.jl @@ -48,7 +48,7 @@ function unfurl(ctx, tns::VirtualAbstractArraySlice, ext, mode, proto) idx_2 = (i, idx...) if length(idx_2) == arr.ndims val = freshen(ctx, :val) - if mode === reader + if mode.kind === reader Thunk( preamble = quote $val = $(arr.ex)[$(map(ctx, idx_2)...)] @@ -76,7 +76,7 @@ end function instantiate(ctx::AbstractCompiler, arr::VirtualAbstractArray, mode) if arr.ndims == 0 val = freshen(ctx, :val) - if mode === reader + if mode.kind === reader Thunk( preamble = quote $val = $(arr.ex)[] diff --git a/src/interface/abstract_unit_ranges.jl b/src/interface/abstract_unit_ranges.jl index 22e402a4f..b3b344a23 100644 --- a/src/interface/abstract_unit_ranges.jl +++ b/src/interface/abstract_unit_ranges.jl @@ -22,21 +22,22 @@ end virtual_resize!(ctx::AbstractCompiler, arr::VirtualAbstractUnitRange, idx_dim) = arr -function unfurl(ctx, arr::VirtualAbstractUnitRange, ext, mode, proto) - Unfurled( - arr = arr, - body = Lookup( - body = (ctx, i) -> FillLeaf(value(:($(arr.ex)[$(ctx(i))]))) - ) - ) -end - function declare!(ctx::AbstractCompiler, arr::VirtualAbstractUnitRange, init) throw(FinchProtocolError("$(arr.arrtype) is not writeable")) end -unfurl(ctx::AbstractCompiler, arr::VirtualAbstractUnitRange, ext, mode::Updater, proto) = - throw(FinchProtocolError("$(arr.arrtype) is not writeable")) +function unfurl(ctx::AbstractCompiler, arr::VirtualAbstractUnitRange, ext, mode, proto) + if mode.kind === reader + Unfurled( + arr = arr, + body = Lookup( + body = (ctx, i) -> FillLeaf(value(:($(arr.ex)[$(ctx(i))]))) + ) + ) + else + throw(FinchProtocolError("$(arr.arrtype) is not writeable")) + end +end FinchNotation.finch_leaf(x::VirtualAbstractUnitRange) = virtual(x) diff --git a/src/interface/fsparse.jl b/src/interface/fsparse.jl index c40ac7fba..89168b18d 100644 --- a/src/interface/fsparse.jl +++ b/src/interface/fsparse.jl @@ -33,21 +33,21 @@ fsparse_parse(I, V::AbstractVector, m::Tuple; kwargs...) = fsparse_impl(I, V, m; fsparse_parse(I, V::AbstractVector, m::Tuple, combine; kwargs...) = fsparse_impl(I, V, m, combine; kwargs...) function fsparse_impl(I::Tuple, V::Vector, shape = map(maximum, I), combine = eltype(V) isa Bool ? (|) : (+); fill_value = zero(eltype(V))) C = map(tuple, reverse(I)...) - updater = false + dirty = false if !issorted(C) P = sortperm(C) C = C[P] V = V[P] - updater = true + dirty = true end if !allunique(C) P = unique(p -> C[p], 1:length(C)) C = C[P] push!(P, length(I[1]) + 1) V = map((start, stop) -> foldl(combine, @view V[start:stop - 1]), P[1:end - 1], P[2:end]) - updater = true + dirty = true end - if updater + if dirty I = map(i -> similar(i, length(C)), I) foreach(((p, c),) -> ntuple(n->I[n][p] = c[n], length(I)), enumerate(C)) I = reverse(I) diff --git a/src/looplets/jumpers.jl b/src/looplets/jumpers.jl index 1b81d20e5..42e37551f 100644 --- a/src/looplets/jumpers.jl +++ b/src/looplets/jumpers.jl @@ -33,13 +33,13 @@ combine_style(a::JumperStyle, b::PhaseStyle) = b jumper_seek(ctx, node::Jumper, ext) = node.seek(ctx, ext) jumper_seek(ctx, node, ext) = quote end -jumper_range(ctx, node, ext) = dimless +jumper_range(ctx, node, ext) = auto function jumper_range(ctx, node::FinchNode, ext) if @capture node access(~tns::isvirtual, ~i...) jumper_range(ctx, tns.val, ext) else - return dimless + return auto end end diff --git a/src/looplets/phases.jl b/src/looplets/phases.jl index 902994c6f..50e6d2819 100644 --- a/src/looplets/phases.jl +++ b/src/looplets/phases.jl @@ -17,11 +17,11 @@ function phase_range(ctx, node::FinchNode, ext) if @capture node access(~tns::isvirtual, ~i...) phase_range(ctx, tns.val, ext) else - return dimless + return auto end end -phase_range(ctx, node, ext) = dimless +phase_range(ctx, node, ext) = auto phase_range(ctx, node::Phase, ext) = node.range(ctx, ext) function phase_body(ctx, node::FinchNode, ext, ext_2) diff --git a/src/looplets/steppers.jl b/src/looplets/steppers.jl index 6a71279ca..cc6de81ed 100644 --- a/src/looplets/steppers.jl +++ b/src/looplets/steppers.jl @@ -35,13 +35,13 @@ combine_style(a::StepperStyle, b::PhaseStyle) = b stepper_seek(ctx, node::Stepper, ext) = node.seek(ctx, ext) stepper_seek(ctx, node, ext) = quote end -stepper_range(ctx, node, ext) = dimless +stepper_range(ctx, node, ext) = auto function stepper_range(ctx, node::FinchNode, ext) if @capture node access(~tns::isvirtual, ~i...) stepper_range(ctx, tns.val, ext) else - return dimless + return auto end end diff --git a/src/lower.jl b/src/lower.jl index da3982704..278ef8bfa 100644 --- a/src/lower.jl +++ b/src/lower.jl @@ -29,9 +29,9 @@ get_mode_flag(ctx::FinchCompiler) = ctx.mode get_binding(ctx::FinchCompiler, var) = get_binding(ctx.scope, var) has_binding(ctx::FinchCompiler, var) = has_binding(ctx.scope, var) set_binding!(ctx::FinchCompiler, var, val) = set_binding!(ctx.scope, var, val) -set_declared!(ctx::FinchCompiler, var, val) = set_declared!(ctx.scope, var, val) +set_declared!(ctx::FinchCompiler, var, val, op) = set_declared!(ctx.scope, var, val, op) set_frozen!(ctx::FinchCompiler, var, val) = set_frozen!(ctx.scope, var, val) -set_thawed!(ctx::FinchCompiler, var, val) = set_thawed!(ctx.scope, var, val) +set_thawed!(ctx::FinchCompiler, var, val, op) = set_thawed!(ctx.scope, var, val, op) get_tensor_mode(ctx::FinchCompiler, var) = get_tensor_mode(ctx.scope, var) function open_scope(f::F, ctx::FinchCompiler) where {F} open_scope(ctx.scope) do scope_2 @@ -143,13 +143,13 @@ function lower(ctx::AbstractCompiler, root::FinchNode, ::DefaultStyle) ctx(block(head.bodies..., body)) elseif head.kind === declare val_2 = declare!(ctx, get_binding(ctx, head.tns), head.init) - set_declared!(ctx, head.tns, val_2) + set_declared!(ctx, head.tns, val_2, head.op) elseif head.kind === freeze val_2 = freeze!(ctx, get_binding(ctx, head.tns)) set_frozen!(ctx, head.tns, val_2) elseif head.kind === thaw val_2 = thaw!(ctx, get_binding(ctx, head.tns)) - set_thawed!(ctx, head.tns, val_2) + set_thawed!(ctx, head.tns, val_2, head.op) else preamble = contain(ctx) do ctx_2 ctx_2(instantiate!(ctx_2, head)) @@ -181,8 +181,8 @@ function lower(ctx::AbstractCompiler, root::FinchNode, ::DefaultStyle) if length(root.idxs) > 0 throw(FinchCompileError("Finch failed to completely lower an access to $tns")) end - @assert root.mode.kind === literal - return lower_access(ctx, tns, root.mode.val) + @assert (root.mode.kind === reader) || (root.mode.kind === updater) + return lower_access(ctx, tns, root.mode) elseif root.kind === call root = simplify(ctx, root) if root.kind === call @@ -227,13 +227,13 @@ function lower(ctx::AbstractCompiler, root::FinchNode, ::DefaultStyle) ctx(root.val) elseif root.kind === assign @assert root.lhs.kind === access - @assert root.lhs.mode.val === updater + @assert root.lhs.mode.kind === updater if length(root.lhs.idxs) > 0 throw(FinchCompileError("Finch failed to completely lower an access to $tns")) end rhs = simplify(ctx, root.rhs) tns = resolve(ctx, root.lhs.tns) - return lower_assign(ctx, tns, root.lhs.mode.val, root.op, rhs) + return lower_assign(ctx, tns, root.lhs.mode, root.op, rhs) elseif root.kind === variable return ctx(get_binding(ctx, root)) elseif root.kind === yieldbind @@ -263,7 +263,7 @@ function lower_assign(ctx, tns, mode, op, rhs) end function lower_access(ctx, tns::Number, mode) - @assert node.mode.val === reader + @assert node.mode.kind === reader tns end @@ -281,7 +281,7 @@ function lower_loop(ctx, root, ext) contain(ctx) do ctx_2 root_2 = Rewrite(Postwalk(@rule access(~tns, ~mode, ~idxs...) => begin if !isempty(idxs) && root.idx == idxs[end] - tns_2 = unfurl(ctx_2, tns, root.ext.val, mode.val, (mode.val === reader ? defaultread : defaultupdate)) + tns_2 = unfurl(ctx_2, tns, root.ext.val, mode, (mode.kind === reader ? defaultread : defaultupdate)) access(Unfurled(resolve(ctx_2, tns), tns_2), mode, idxs...) end end))(root) @@ -298,7 +298,7 @@ function lower_parallel_loop(ctx, root, ext::ParallelDimension, device::VirtualC i = freshen(ctx, :i) decl_in_scope = unique(filter(!isnothing, map(node-> begin - if @capture(node, declare(~tns, ~init)) + if @capture(node, declare(~tns, ~init, ~op)) tns end end, PostOrderDFS(root.body)))) @@ -311,7 +311,7 @@ function lower_parallel_loop(ctx, root, ext::ParallelDimension, device::VirtualC root_2 = loop(tid, Extent(value(i, Int), value(i, Int)), loop(root.idx, ext.ext, - sieve(access(VirtualSplitMask(device.n), reader, root.idx, tid), + sieve(access(VirtualSplitMask(device.n), reader(), root.idx, tid), root.body ) ) diff --git a/src/scheduler/LogicInterpreter.jl b/src/scheduler/LogicInterpreter.jl index 11e861817..1319942f1 100644 --- a/src/scheduler/LogicInterpreter.jl +++ b/src/scheduler/LogicInterpreter.jl @@ -1,4 +1,4 @@ -using Finch.FinchNotation: block_instance, declare_instance, call_instance, loop_instance, index_instance, variable_instance, tag_instance, access_instance, assign_instance, literal_instance, yieldbind_instance +using Finch.FinchNotation: block_instance, declare_instance, call_instance, loop_instance, index_instance, variable_instance, tag_instance, access_instance, reader_instance, updater_instance, assign_instance, literal_instance, yieldbind_instance @kwdef struct PointwiseMachineLowerer ctx @@ -19,7 +19,7 @@ function (ctx::PointwiseMachineLowerer)(ex) idxs_3 = map(enumerate(idxs_1)) do (n, idx) idx in idxs_2 ? index_instance(idx.name) : first(axes(ctx.ctx.scope[arg])[n]) end - access_instance(tag_instance(variable_instance(arg.name), ctx.ctx.scope[arg]), literal_instance(reader), idxs_3...) + access_instance(tag_instance(variable_instance(arg.name), ctx.ctx.scope[arg]), reader_instance(), idxs_3...) elseif (@capture ex reorder(~arg::isimmediate, ~idxs...)) literal_instance(arg.val) elseif ex.kind === immediate @@ -47,17 +47,17 @@ function (ctx::LogicMachine)(ex) loop_idxs = withsubsequence(intersect(idxs_1, idxs_2), idxs_2) lhs_idxs = idxs_2 res = tag_instance(variable_instance(:res), tns.val) - lhs = access_instance(res, literal_instance(updater), map(idx -> index_instance(idx.name), lhs_idxs)...) + lhs = access_instance(res, updater_instance(auto), map(idx -> index_instance(idx.name), lhs_idxs)...) (rhs, rhs_idxs) = lower_pointwise_logic(ctx, reorder(relabel(arg, idxs_1...), idxs_2...)) body = assign_instance(lhs, literal_instance(initwrite(fill_value(tns.val))), rhs) for idx in loop_idxs if idx in rhs_idxs - body = loop_instance(index_instance(idx.name), dimless, body) + body = loop_instance(index_instance(idx.name), auto, body) elseif idx in lhs_idxs body = loop_instance(index_instance(idx.name), call_instance(literal_instance(extent), literal_instance(1), literal_instance(1)), body) end end - body = block_instance(declare_instance(res, literal_instance(fill_value(tns.val))), body, yieldbind_instance(res)) + body = block_instance(declare_instance(res, literal_instance(fill_value(tns.val)), literal_instance(auto)), body, yieldbind_instance(res)) if ctx.verbose print("Running: ") display(body) @@ -70,17 +70,17 @@ function (ctx::LogicMachine)(ex) loop_idxs = getfields(arg) lhs_idxs = setdiff(getfields(arg), idxs_1) res = tag_instance(variable_instance(:res), tns.val) - lhs = access_instance(res, literal_instance(updater), map(idx -> index_instance(idx.name), lhs_idxs)...) + lhs = access_instance(res, updater_instance(auto), map(idx -> index_instance(idx.name), lhs_idxs)...) (rhs, rhs_idxs) = lower_pointwise_logic(ctx, arg) body = assign_instance(lhs, literal_instance(op.val), rhs) for idx in loop_idxs if idx in rhs_idxs - body = loop_instance(index_instance(idx.name), dimless, body) + body = loop_instance(index_instance(idx.name), auto, body) elseif idx in lhs_idxs body = loop_instance(index_instance(idx.name), call_instance(literal_instance(extent), literal_instance(1), literal_instance(1)), body) end end - body = block_instance(declare_instance(res, literal_instance(fill_value(tns.val))), body, yieldbind_instance(res)) + body = block_instance(declare_instance(res, literal_instance(fill_value(tns.val)), literal_instance(auto)), body, yieldbind_instance(res)) if ctx.verbose print("Running: ") display(body) diff --git a/src/scopes.jl b/src/scopes.jl index c208ac235..ba30cb323 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -44,16 +44,16 @@ Get the binding of a variable in the context, or set it to a default value. get_binding!(ctx::AbstractCompiler, var, val) = has_binding(ctx, var) ? get_binding(ctx, var) : set_binding!(ctx, var, val) """ - set_declared!(ctx, var, val) + set_declared!(ctx, var, val, op) Mark a tensor variable as declared in the context. """ -function set_declared!(ctx::ScopeContext, var, val) +function set_declared!(ctx::ScopeContext, var, val, op) @assert var.kind === variable - @assert get(ctx.modes, var, reader) === reader + @assert get(ctx.modes, var, reader()).kind === reader push!(ctx.defs, var) set_binding!(ctx, var, val) - ctx.modes[var] = updater + ctx.modes[var] = updater(op) end """ @@ -63,28 +63,28 @@ Mark a tensor variable as frozen in the context. """ function set_frozen!(ctx::ScopeContext, var, val) @assert var.kind === variable - @assert ctx.modes[var] === updater + @assert ctx.modes[var].kind === updater set_binding!(ctx, var, val) - ctx.modes[var] = reader + ctx.modes[var] = reader() end """ - set_thawed!(ctx, var, val) + set_thawed!(ctx, var, val, op) Mark a tensor variable as thawed in the context. """ -function set_thawed!(ctx::ScopeContext, var, val) +function set_thawed!(ctx::ScopeContext, var, val, op) @assert var.kind === variable - @assert get(ctx.modes, var, reader) === reader + @assert get(ctx.modes, var, reader()).kind === reader set_binding!(ctx, var, val) - ctx.modes[var] = updater + ctx.modes[var] = updater(op) end """ get_tensor_mode(ctx, var) Get the mode of a tensor variable in the context. """ -get_tensor_mode(ctx::ScopeContext, var) = get(ctx.modes, var, reader) +get_tensor_mode(ctx::ScopeContext, var) = get(ctx.modes, var, reader()) """ open_scope(f, ctx) diff --git a/src/symbolic/simplify.jl b/src/symbolic/simplify.jl index 9975a1d92..cef5d5c25 100644 --- a/src/symbolic/simplify.jl +++ b/src/symbolic/simplify.jl @@ -77,7 +77,7 @@ function get_simplify_rules(alg, shash) (@rule call(==, ~a, ~a) => literal(true)), (@rule call(<=, ~a, ~a) => literal(true)), (@rule call(<, ~a, ~a) => literal(false)), - (@rule assign(access(~a, updater, ~i...), ~f, ~b) => if isidentity(alg, f, b) block() end), + (@rule assign(access(~a, updater(~g), ~i...), ~f, ~b) => if isidentity(alg, f, b) block() end), #updater(auto) (@rule assign(access(~a, ~m, ~i...), $(literal(missing))) => block()), (@rule assign(access(~a, ~m, ~i..., $(literal(missing)), ~j...), ~b) => block()), (@rule call(coalesce, ~a..., ~b, ~c...) => if isvalue(b) && !(Missing <: b.type) || isliteral(b) && !ismissing(b.val) @@ -143,15 +143,15 @@ function get_simplify_rules(alg, shash) (@rule loop(~idx, ~ext::isvirtual, ~body) => begin body_contain_idx = idx ∈ getunbound(body) if !body_contain_idx - decl_in_scope = filter(!isnothing, map(node-> if @capture(node, declare(~tns, ~init)) tns + decl_in_scope = filter(!isnothing, map(node-> if @capture(node, declare(~tns, ~init, ~op)) tns elseif @capture(node, define(~var, ~val, ~body_2)) var end, PostOrderDFS(body))) - Postwalk(@rule assign(access(~lhs, updater, ~j...), ~f, ~rhs) => begin - access_in_rhs = filter(!isnothing, map(node-> if @capture(node, access(~tns, reader, ~k...)) tns # TODO add getroot here? + Postwalk(@rule assign(access(~lhs, updater(~g), ~j...), ~f, ~rhs) => begin #updater(auto) + access_in_rhs = filter(!isnothing, map(node-> if @capture(node, access(~tns, reader(), ~k...)) tns # TODO add getroot here? elseif @capture(node, ~var::isvariable) var end, PostOrderDFS(rhs))) if !(lhs in decl_in_scope) && isempty(intersect(access_in_rhs, decl_in_scope)) - collapsed(alg, idx, ext.val, access(lhs, updater, j...), f, rhs) + collapsed(alg, idx, ext.val, access(lhs, updater(f), j...), f, rhs) end end)(body) end @@ -165,37 +165,37 @@ function get_simplify_rules(alg, shash) end), # Bottom-up reduction1 - (@rule loop(~idx, ~ext::isvirtual, assign(access(~lhs, updater, ~j...), ~f, ~rhs)) => begin + (@rule loop(~idx, ~ext::isvirtual, assign(access(~lhs, updater(~g), ~j...), ~f, ~rhs)) => begin #updater(auto) if idx ∉ j && idx ∉ getunbound(rhs) - collapsed(alg, idx, ext.val, access(lhs, updater, j...), f, rhs) + collapsed(alg, idx, ext.val, access(lhs, updater(f), j...), f, rhs) end end), ## Bottom-up reduction2 - (@rule loop(~idx, ~ext::isvirtual, block(~s1..., assign(access(~lhs, updater, ~j...), ~f, ~rhs), ~s2...)) => begin + (@rule loop(~idx, ~ext::isvirtual, block(~s1..., assign(access(~lhs, updater(~g), ~j...), ~f, ~rhs), ~s2...)) => begin #updater(auto) if ortho(getroot(lhs), s1) && ortho(getroot(lhs), s2) if idx ∉ j && idx ∉ getunbound(rhs) - body = block(s1..., assign(access(lhs, updater, j...), f, rhs), s2...) - decl_in_scope = filter(!isnothing, map(node-> if @capture(node, declare(~tns, ~init)) tns + body = block(s1..., assign(access(lhs, updater(f), j...), f, rhs), s2...) + decl_in_scope = filter(!isnothing, map(node-> if @capture(node, declare(~tns, ~init, ~op)) tns elseif @capture(node, define(~var, ~val, ~body_2)) var end, PostOrderDFS(body))) - access_in_rhs = filter(!isnothing, map(node-> if @capture(node, access(~tns, reader, ~k...)) tns + access_in_rhs = filter(!isnothing, map(node-> if @capture(node, access(~tns, reader(), ~k...)) tns elseif @capture(node, ~var::isvariable) var end, PostOrderDFS(rhs))) if !(lhs in decl_in_scope) && isempty(intersect(access_in_rhs, decl_in_scope)) - collapsed_body = collapsed(alg, idx, ext.val, access(lhs, updater, j...), f, rhs) + collapsed_body = collapsed(alg, idx, ext.val, access(lhs, updater(f), j...), f, rhs) block(collapsed_body, loop(idx, ext, block(s1..., s2...))) end end end end), - (@rule block(~s1..., thaw(~a::isvariable), ~s2..., freeze(~a), ~s3...) => if ortho(a, s2) + (@rule block(~s1..., thaw(~a::isvariable, ~f), ~s2..., freeze(~a, ~f), ~s3...) => if ortho(a, s2) block(s1..., s2..., s3...) end), - (@rule block(~s1..., freeze(~a::isvariable), ~s2..., thaw(~a), ~s3...) => if ortho(a, s2) + (@rule block(~s1..., freeze(~a::isvariable, ~f), ~s2..., thaw(~a, ~f), ~s3...) => if ortho(a, s2) block(s1..., s2..., s3...) end), ] diff --git a/src/tensors/combinators/offset.jl b/src/tensors/combinators/offset.jl index 97923fdb4..93e162aec 100644 --- a/src/tensors/combinators/offset.jl +++ b/src/tensors/combinators/offset.jl @@ -20,8 +20,10 @@ is_atomic(ctx, lvl::VirtualOffsetArray) = is_atomic(ctx, lvl.body) is_concurrent(ctx, lvl::VirtualOffsetArray) = is_concurrent(ctx, lvl.body) Base.show(io::IO, ex::VirtualOffsetArray) = Base.show(io, MIME"text/plain"(), ex) +Base.show(io::IO, mime::MIME"text/plain", ex::VirtualOffsetArray) = + print(io, "VirtualOffsetArray($(ex.body), $(ex.delta))") -Base.summary(io::IO, ex::VirtualOffsetArray) = print(io, "VOffset($(summary(ex.body)), $(ex.delta))") +Base.summary(io::IO, mime::MIME"text/plain", ex::VirtualOffsetArray) = print(io, "VOffset($(summary(ex.body)), $(ex.delta))") FinchNotation.finch_leaf(x::VirtualOffsetArray) = virtual(x) diff --git a/src/tensors/combinators/permissive.jl b/src/tensors/combinators/permissive.jl index 68fee5547..609e8b3b2 100644 --- a/src/tensors/combinators/permissive.jl +++ b/src/tensors/combinators/permissive.jl @@ -61,7 +61,7 @@ unwrap(ctx, arr::VirtualPermissiveArray, var) = call(permissive, unwrap(ctx, arr lower(ctx::AbstractCompiler, tns::VirtualPermissiveArray, ::DefaultStyle) = :(PermissiveArray($(ctx(tns.body)), $(tns.dims))) virtual_size(ctx::AbstractCompiler, arr::VirtualPermissiveArray) = - ifelse.(arr.dims, (dimless,), virtual_size(ctx, arr.body)) + ifelse.(arr.dims, (auto,), virtual_size(ctx, arr.body)) virtual_resize!(ctx::AbstractCompiler, arr::VirtualPermissiveArray, dims...) = virtual_resize!(ctx, arr.body, ifelse.(arr.dims, virtual_size(ctx, arr.body), dim)) @@ -137,8 +137,8 @@ getroot(tns::VirtualPermissiveArray) = getroot(tns.body) function unfurl(ctx, tns::VirtualPermissiveArray, ext, mode, proto) tns_2 = unfurl(ctx, tns.body, ext, mode, proto) dims = virtual_size(ctx, tns.body) - garb = (mode === reader) ? FillLeaf(literal(missing)) : FillLeaf(Null()) - if tns.dims[end] && dims[end] != dimless + garb = (mode.kind === reader) ? FillLeaf(literal(missing)) : FillLeaf(Null()) + if tns.dims[end] && dims[end] != auto VirtualPermissiveArray( Unfurled( tns, diff --git a/src/tensors/combinators/product.jl b/src/tensors/combinators/product.jl index 6429824f4..f802cf9b1 100644 --- a/src/tensors/combinators/product.jl +++ b/src/tensors/combinators/product.jl @@ -69,15 +69,15 @@ unwrap(ctx, arr::VirtualProductArray, var) = call(products, unwrap(ctx, arr.body lower(ctx::AbstractCompiler, tns::VirtualProductArray, ::DefaultStyle) = :(ProductArray($(ctx(tns.body)), $(tns.dim))) -#virtual_size(ctx::AbstractCompiler, arr::FillLeaf) = (dimless,) # this is needed for multidimensional convolution.. -#virtual_size(ctx::AbstractCompiler, arr::Simplify) = (dimless,) +#virtual_size(ctx::AbstractCompiler, arr::FillLeaf) = (auto,) # this is needed for multidimensional convolution.. +#virtual_size(ctx::AbstractCompiler, arr::Simplify) = (auto,) function virtual_size(ctx::AbstractCompiler, arr::VirtualProductArray) dims = virtual_size(ctx, arr.body) - return (dims[1:arr.dim - 1]..., dimless, dimless, dims[arr.dim + 1:end]...) + return (dims[1:arr.dim - 1]..., auto, auto, dims[arr.dim + 1:end]...) end function virtual_resize!(ctx::AbstractCompiler, arr::VirtualProductArray, dims...) - virtual_resize!(ctx, arr.body, dims[1:arr.dim - 1]..., dimless, dims[arr.dim + 2:end]...) + virtual_resize!(ctx, arr.body, dims[1:arr.dim - 1]..., auto, dims[arr.dim + 2:end]...) end instantiate(arr::VirtualProductArray, ctx, mode) = diff --git a/src/tensors/combinators/toeplitz.jl b/src/tensors/combinators/toeplitz.jl index 520f77b33..594ae7736 100644 --- a/src/tensors/combinators/toeplitz.jl +++ b/src/tensors/combinators/toeplitz.jl @@ -74,10 +74,10 @@ lower(ctx::AbstractCompiler, tns::VirtualToeplitzArray, ::DefaultStyle) = :(Toep function virtual_size(ctx::AbstractCompiler, arr::VirtualToeplitzArray) dims = virtual_size(ctx, arr.body) - return (dims[1:arr.dim - 1]..., dimless, dimless, dims[arr.dim + 1:end]...) + return (dims[1:arr.dim - 1]..., auto, auto, dims[arr.dim + 1:end]...) end virtual_resize!(ctx::AbstractCompiler, arr::VirtualToeplitzArray, dims...) = - virtual_resize!(ctx, arr.body, dims[1:arr.dim - 1]..., dimless, dims[arr.dim + 2:end]...) + virtual_resize!(ctx, arr.body, dims[1:arr.dim - 1]..., auto, dims[arr.dim + 2:end]...) instantiate(ctx, arr::VirtualToeplitzArray, mode) = VirtualToeplitzArray(instantiate(ctx, arr.body, mode), arr.dim) diff --git a/src/tensors/levels/atomic_element_levels.jl b/src/tensors/levels/atomic_element_levels.jl index 6393cbebe..d1cc423d7 100644 --- a/src/tensors/levels/atomic_element_levels.jl +++ b/src/tensors/levels/atomic_element_levels.jl @@ -155,26 +155,26 @@ function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualAtomicElementLe end) end -function instantiate(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode::Reader) +function instantiate(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode) (lvl, pos) = (fbr.lvl, fbr.pos) - val = freshen(ctx.code, lvl.ex, :_val) - return Thunk( - preamble = quote - $val = $(lvl.val)[$(ctx(pos))] - end, - body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val) - ) -end - -function instantiate(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode::Updater) - fbr + if mode.kind === reader + val = freshen(ctx.code, lvl.ex, :_val) + return Thunk( + preamble = quote + $val = $(lvl.val)[$(ctx(pos))] + end, + body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val) + ) + else + return fbr + end end -function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualAtomicElementLevel}, mode::Updater) +function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualAtomicElementLevel}, mode) fbr end -function lower_assign(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode::Updater, op, rhs) +function lower_assign(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) op = ctx(op) rhs = ctx(rhs) @@ -182,7 +182,7 @@ function lower_assign(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode :(Finch.atomic_modify!($device, $(lvl.val), $(ctx(pos)), $op, $rhs)) end -function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualAtomicElementLevel}, mode::Updater, op, rhs) +function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualAtomicElementLevel}, mode, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) push_preamble!(ctx, quote $(fbr.dirty) = true diff --git a/src/tensors/levels/dense_rle_levels.jl b/src/tensors/levels/dense_rle_levels.jl index 88b9d4b1f..ed1199e56 100644 --- a/src/tensors/levels/dense_rle_levels.jl +++ b/src/tensors/levels/dense_rle_levels.jl @@ -339,7 +339,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop check = VirtualScalar(:UNREACHABLE, Bool, false, :check, checkval) exts = virtual_level_size(ctx_2, lvl.buf) inds = [index(freshen(ctx_2, :i, n)) for n = 1:length(exts)] - prgm = assign(access(check, updater), and, call(isequal, access(left, reader, inds...), access(right, reader, inds...))) + prgm = assign(access(check, updater(and)), and, call(isequal, access(left, reader(), inds...), access(right, reader(), inds...))) for (ind, ext) in zip(inds, exts) prgm = loop(ind, ext, prgm) end @@ -360,7 +360,8 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop set_binding!(ctx_2, dst, virtual(VirtualSubFiber(lvl.lvl, value(q_2, Tp)))) exts = virtual_level_size(ctx_2, lvl.buf) inds = [index(freshen(ctx_2, :i, n)) for n = 1:length(exts)] - prgm = assign(access(dst, updater, inds...), initwrite(virtual_level_fill_value(lvl.lvl)), access(src, reader, inds...)) + op = initwrite(virtual_level_fill_value(lvl.lvl)) + prgm = assign(access(dst, updater(op), inds...), op, access(src, reader(), inds...)) for (ind, ext) in zip(inds, exts) prgm = loop(ind, ext, prgm) end @@ -412,7 +413,7 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop) =# end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualRunListLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualRunListLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -454,7 +455,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualRunListLevel}, ext, mode::Reade ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualRunListLevel}, ext, mode::Updater, proto) = +unfurl(ctx, fbr::VirtualSubFiber{VirtualRunListLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) #Invariants of the level (Write Mode): @@ -463,7 +464,7 @@ unfurl(ctx, fbr::VirtualSubFiber{VirtualRunListLevel}, ext, mode::Updater, proto # 3. for all p in 1:prevpos-1, ptr[p] is the number of runs in that position # 4. qos_fill is the position of the last index written -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualRunListLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualRunListLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/levels/element_levels.jl b/src/tensors/levels/element_levels.jl index 58f1f086b..9c8850af5 100644 --- a/src/tensors/levels/element_levels.jl +++ b/src/tensors/levels/element_levels.jl @@ -163,33 +163,33 @@ function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualElementLevel, a end) end -function instantiate(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Reader) +function instantiate(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode) (lvl, pos) = (fbr.lvl, fbr.pos) - val = freshen(ctx, lvl.ex, :_val) - return Thunk( - preamble = quote - $val = $(lvl.val)[$(ctx(pos))] - end, - body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val) - ) -end - -function instantiate(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Updater) - (lvl, pos) = (fbr.lvl, fbr.pos) - VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))])) + if mode.kind === reader + val = freshen(ctx, lvl.ex, :_val) + return Thunk( + preamble = quote + $val = $(lvl.val)[$(ctx(pos))] + end, + body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val) + ) + else + VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))])) + end end -function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode::Updater) +function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode) (lvl, pos) = (fbr.lvl, fbr.pos) + @assert mode.kind === updater VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty) end -function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode::Updater, op, rhs) +function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) lower_assign(ctx, VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty), mode, op, rhs) end -function lower_assign(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Updater, op, rhs) +function lower_assign(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) lower_assign(ctx, VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))])), mode, op, rhs) end \ No newline at end of file diff --git a/src/tensors/levels/mutex_levels.jl b/src/tensors/levels/mutex_levels.jl index b7737b77a..d91f79074 100644 --- a/src/tensors/levels/mutex_levels.jl +++ b/src/tensors/levels/mutex_levels.jl @@ -201,35 +201,39 @@ function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualMutexLevel, arc virtual_moveto_level(ctx, lvl.lvl, arch) end -function instantiate(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, mode::Reader) +function instantiate(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, mode) (lvl, pos) = (fbr.lvl, fbr.pos) - instantiate(ctx, VirtualSubFiber(lvl.lvl, pos), mode) -end - -function unfurl(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, ext, mode::Reader, proto) - (lvl, pos) = (fbr.lvl, fbr.pos) - unfurl(ctx, VirtualSubFiber(lvl.lvl, pos), ext, mode, proto) + if mode.kind === reader + instantiate(ctx, VirtualSubFiber(lvl.lvl, pos), mode) + else + fbr + end end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, ext, mode::Updater, proto) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, ext, mode, proto) (lvl, pos) = (fbr.lvl, fbr.pos) - sym = freshen(ctx, lvl.ex, :after_atomic_lvl) - atomicData = freshen(ctx, lvl.ex, :atomicArraysAcc) - lockVal = freshen(ctx, lvl.ex, :lockVal) - dev = lower(ctx, virtual_get_device(ctx.code.task), DefaultStyle()) - push_preamble!(ctx, quote - $atomicData = Finch.get_lock($dev, $(lvl.locks), $(ctx(pos)), eltype($(lvl.AVal))) - $lockVal = Finch.aquire_lock!($dev, $atomicData) - end) - res = unfurl(ctx, VirtualSubFiber(lvl.lvl, pos), ext, mode, proto) - push_epilogue!(ctx, quote - Finch.release_lock!($dev, $atomicData) - end) - return res + if mode.kind === reader + return unfurl(ctx, VirtualSubFiber(lvl.lvl, pos), ext, mode, proto) + else + sym = freshen(ctx, lvl.ex, :after_atomic_lvl) + atomicData = freshen(ctx, lvl.ex, :atomicArraysAcc) + lockVal = freshen(ctx, lvl.ex, :lockVal) + dev = lower(ctx, virtual_get_device(ctx.code.task), DefaultStyle()) + push_preamble!(ctx, quote + $atomicData = Finch.get_lock($dev, $(lvl.locks), $(ctx(pos)), eltype($(lvl.AVal))) + $lockVal = Finch.aquire_lock!($dev, $atomicData) + end) + res = unfurl(ctx, VirtualSubFiber(lvl.lvl, pos), ext, mode, proto) + push_epilogue!(ctx, quote + Finch.release_lock!($dev, $atomicData) + end) + return res + end end -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualMutexLevel}, ext, mode::Updater, proto) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualMutexLevel}, ext, mode, proto) (lvl, pos) = (fbr.lvl, fbr.pos) + @assert mode.kind === updater sym = freshen(ctx, lvl.ex, :after_atomic_lvl) atomicData = freshen(ctx, lvl.ex, :atomicArraysAcc) lockVal = freshen(ctx, lvl.ex, :lockVal) @@ -245,7 +249,7 @@ function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualMutexLevel}, ext, mode::U return res end -function lower_assign(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, mode::Updater, op, rhs) +function lower_assign(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, mode, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) sym = freshen(ctx, lvl.ex, :after_atomic_lvl) atomicData = freshen(ctx, lvl.ex, :atomicArraysAcc) @@ -262,7 +266,7 @@ function lower_assign(ctx, fbr::VirtualSubFiber{VirtualMutexLevel}, mode::Update return res end -function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualMutexLevel}, mode::Updater, op, rhs) +function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualMutexLevel}, mode, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) sym = freshen(ctx, lvl.ex, :after_atomic_lvl) atomicData = freshen(ctx, lvl.ex, :atomicArraysAcc) diff --git a/src/tensors/levels/pattern_levels.jl b/src/tensors/levels/pattern_levels.jl index b6fdab12a..fcdd24332 100644 --- a/src/tensors/levels/pattern_levels.jl +++ b/src/tensors/levels/pattern_levels.jl @@ -123,14 +123,17 @@ thaw_level!(ctx, lvl::VirtualPatternLevel, pos) = lvl assemble_level!(ctx, lvl::VirtualPatternLevel, pos_start, pos_stop) = quote end reassemble_level!(ctx, lvl::VirtualPatternLevel, pos_start, pos_stop) = quote end -instantiate(ctx, ::VirtualSubFiber{VirtualPatternLevel}, mode::Reader) = FillLeaf(true) - -function instantiate(ctx, fbr::VirtualSubFiber{VirtualPatternLevel}, mode::Updater) - val = freshen(ctx, :null) - push_preamble!(ctx, :($val = false)) - VirtualScalar(nothing, Bool, false, gensym(), val) +function instantiate(ctx, ::VirtualSubFiber{VirtualPatternLevel}, mode) + if mode.kind === reader + FillLeaf(true) + else + val = freshen(ctx, :null) + push_preamble!(ctx, :($val = false)) + VirtualScalar(nothing, Bool, false, gensym(), val) + end end -function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualPatternLevel}, mode::Updater) +function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualPatternLevel}, mode) + @assert mode.kind === updater VirtualScalar(nothing, Bool, false, gensym(), fbr.dirty) end \ No newline at end of file diff --git a/src/tensors/levels/separate_levels.jl b/src/tensors/levels/separate_levels.jl index 324283d0d..43dcc47ee 100644 --- a/src/tensors/levels/separate_levels.jl +++ b/src/tensors/levels/separate_levels.jl @@ -206,46 +206,44 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSeparateLevel, pos) return lvl end -function instantiate(ctx, fbr::VirtualSubFiber{VirtualSeparateLevel}, mode::Reader) +function instantiate(ctx, fbr::VirtualSubFiber{VirtualSeparateLevel}, mode) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex - isnulltest = freshen(ctx, tag, :_nulltest) - Vf = level_fill_value(lvl.Lvl) sym = freshen(ctx, :pointer_to_lvl) - val = freshen(ctx, lvl.ex, :_val) - return Thunk( - body = (ctx) -> begin - lvl_2 = virtualize(ctx.code, :($(lvl.val)[$(ctx(pos))]), lvl.Lvl, sym) - instantiate(ctx, VirtualSubFiber(lvl_2, literal(1)), mode) - end, - ) + if mode.kind === reader + isnulltest = freshen(ctx, tag, :_nulltest) + Vf = level_fill_value(lvl.Lvl) + val = freshen(ctx, lvl.ex, :_val) + return Thunk( + body = (ctx) -> begin + lvl_2 = virtualize(ctx.code, :($(lvl.val)[$(ctx(pos))]), lvl.Lvl, sym) + instantiate(ctx, VirtualSubFiber(lvl_2, literal(1)), mode) + end, + ) + else + return Thunk( + body = (ctx) -> begin + lvl_2 = virtualize(ctx.code, :($(lvl.val)[$(ctx(pos))]), lvl.Lvl, sym) + lvl_2 = thaw_level!(ctx, lvl_2, literal(1)) + push_preamble!(ctx, assemble_level!(ctx, lvl_2, literal(1), literal(1))) + res = instantiate(ctx, VirtualSubFiber(lvl_2, literal(1)), mode) + push_epilogue!(ctx, + contain(ctx) do ctx_2 + lvl_2 = freeze_level!(ctx_2, lvl_2, literal(1)) + :($(lvl.val)[$(ctx_2(pos))] = $(ctx_2(lvl_2))) + end + ) + res + end + ) + end end -function instantiate(ctx, fbr::VirtualSubFiber{VirtualSeparateLevel}, mode::Updater) - (lvl, pos) = (fbr.lvl, fbr.pos) - tag = lvl.ex - sym = freshen(ctx, :pointer_to_lvl) - - return Thunk( - body = (ctx) -> begin - lvl_2 = virtualize(ctx.code, :($(lvl.val)[$(ctx(pos))]), lvl.Lvl, sym) - lvl_2 = thaw_level!(ctx, lvl_2, literal(1)) - push_preamble!(ctx, assemble_level!(ctx, lvl_2, literal(1), literal(1))) - res = instantiate(ctx, VirtualSubFiber(lvl_2, literal(1)), mode) - push_epilogue!(ctx, - contain(ctx) do ctx_2 - lvl_2 = freeze_level!(ctx_2, lvl_2, literal(1)) - :($(lvl.val)[$(ctx_2(pos))] = $(ctx_2(lvl_2))) - end - ) - res - end - ) -end -function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualSeparateLevel}, mode::Updater) +function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualSeparateLevel}, mode) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex sym = freshen(ctx, :pointer_to_lvl) + @assert mode.kind === updater return Thunk( body = (ctx) -> begin diff --git a/src/tensors/levels/sparse_band_levels.jl b/src/tensors/levels/sparse_band_levels.jl index 39663dbe9..55015aa29 100644 --- a/src/tensors/levels/sparse_band_levels.jl +++ b/src/tensors/levels/sparse_band_levels.jl @@ -255,7 +255,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, pos_s return lvl end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBandLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBandLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -300,9 +300,9 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBandLevel}, ext, mode::Re ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBandLevel}, ext, mode::Updater, proto) = +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBandLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseBandLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseBandLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/levels/sparse_bytemap_levels.jl b/src/tensors/levels/sparse_bytemap_levels.jl index 6060752a3..8e45f48a9 100644 --- a/src/tensors/levels/sparse_bytemap_levels.jl +++ b/src/tensors/levels/sparse_bytemap_levels.jl @@ -316,7 +316,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseByteMapLevel, po return lvl end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Ti = lvl.Ti @@ -370,7 +370,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode: ) end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode::Reader, ::typeof(gallop)) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode, ::typeof(gallop)) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Ti = lvl.Ti @@ -426,7 +426,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode: end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode::Reader, ::typeof(follow)) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode, ::typeof(follow)) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex my_q = freshen(ctx, tag, :_q) @@ -449,9 +449,9 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode: ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode::Updater, proto) = +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseByteMapLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude), typeof(laminate)}) = unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseByteMapLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude), typeof(laminate)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseByteMapLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude), typeof(laminate)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/levels/sparse_coo_levels.jl b/src/tensors/levels/sparse_coo_levels.jl index 71a8ec893..e65c7c3ab 100644 --- a/src/tensors/levels/sparse_coo_levels.jl +++ b/src/tensors/levels/sparse_coo_levels.jl @@ -295,7 +295,7 @@ struct SparseCOOWalkTraversal stop end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseCOOLevel}, ext, mode::Reader, proto) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseCOOLevel}, ext, mode, proto) (lvl, pos) = (fbr.lvl, fbr.pos) Tp = postype(lvl) start = value(:($(lvl.ptr)[$(ctx(pos))]), Tp) @@ -307,7 +307,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseCOOLevel}, ext, mode::Rea ) end -function unfurl(ctx, trv::SparseCOOWalkTraversal, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, trv::SparseCOOWalkTraversal, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, R, start, stop) = (trv.lvl, trv.R, trv.start, trv.stop) tag = lvl.ex TI = lvl.TI @@ -387,10 +387,10 @@ struct SparseCOOExtrudeTraversal prev_coord end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseCOOLevel}, ext, mode::Updater, proto) = - unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseCOOLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = + unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseCOOLevel}, ext, mode::Updater, proto) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseCOOLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex TI = lvl.TI @@ -425,7 +425,7 @@ function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseCOOLevel}, ext, mod ) end -function unfurl(ctx, trv::SparseCOOExtrudeTraversal, ext, mode::Updater, proto)#::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, trv::SparseCOOExtrudeTraversal, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, qos, fbr_dirty, coords) = (trv.lvl, trv.qos, trv.fbr_dirty, trv.coords) TI = lvl.TI Tp = postype(lvl) diff --git a/src/tensors/levels/sparse_dict_levels.jl b/src/tensors/levels/sparse_dict_levels.jl index da24437af..e587a6918 100644 --- a/src/tensors/levels/sparse_dict_levels.jl +++ b/src/tensors/levels/sparse_dict_levels.jl @@ -320,7 +320,7 @@ function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel virtual_moveto_level(ctx, lvl.lvl, arch) end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -372,7 +372,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode::Re ) end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode::Reader, ::typeof(follow)) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode, ::typeof(follow)) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -391,10 +391,10 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode::Re ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode::Updater, proto) = begin +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseDictLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = begin unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) end -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseDictLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseDictLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/levels/sparse_interval_levels.jl b/src/tensors/levels/sparse_interval_levels.jl index 156a7757b..c007fdc19 100644 --- a/src/tensors/levels/sparse_interval_levels.jl +++ b/src/tensors/levels/sparse_interval_levels.jl @@ -235,7 +235,7 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSparseIntervalLevel, pos return lvl end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseIntervalLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseIntervalLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -266,10 +266,10 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseIntervalLevel}, ext, mode ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseIntervalLevel}, ext, mode::Updater, proto) = +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseIntervalLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseIntervalLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseIntervalLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/levels/sparse_list_levels.jl b/src/tensors/levels/sparse_list_levels.jl index edef249b1..7dc217a82 100644 --- a/src/tensors/levels/sparse_list_levels.jl +++ b/src/tensors/levels/sparse_list_levels.jl @@ -279,7 +279,7 @@ function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualSparseListLevel virtual_moveto_level(ctx, lvl.lvl, arch) end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -327,7 +327,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode::Re end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode::Reader, ::typeof(follow)) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode, ::typeof(follow)) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -355,7 +355,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode::Re ) end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode::Reader, ::typeof(gallop)) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode, ::typeof(gallop)) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -405,10 +405,10 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode::Re ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode::Updater, proto) = begin +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseListLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = begin unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) end -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseListLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseListLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/levels/sparse_point_levels.jl b/src/tensors/levels/sparse_point_levels.jl index 2b638e30e..29c94773f 100644 --- a/src/tensors/levels/sparse_point_levels.jl +++ b/src/tensors/levels/sparse_point_levels.jl @@ -229,7 +229,7 @@ function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualSparsePointLeve virtual_moveto_level(ctx, lvl.lvl, arch) end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparsePointLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparsePointLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -263,9 +263,9 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparsePointLevel}, ext, mode::R ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparsePointLevel}, ext, mode::Updater, proto) = +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparsePointLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparsePointLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparsePointLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex dirty = freshen(ctx, tag, :dirty) diff --git a/src/tensors/levels/sparse_rle_levels.jl b/src/tensors/levels/sparse_rle_levels.jl index 3089a982b..2c467aada 100644 --- a/src/tensors/levels/sparse_rle_levels.jl +++ b/src/tensors/levels/sparse_rle_levels.jl @@ -325,7 +325,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po check = VirtualScalar(:UNREACHABLE, Bool, false, :check, checkval) exts = virtual_level_size(ctx_2, lvl.buf) inds = [index(freshen(ctx_2, :i, n)) for n = 1:length(exts)] - prgm = assign(access(check, updater), and, call(isequal, access(left, reader, inds...), access(right, reader, inds...))) + prgm = assign(access(check, updater(and)), and, call(isequal, access(left, reader(), inds...), access(right, reader(), inds...))) for (ind, ext) in zip(inds, exts) prgm = loop(ind, ext, prgm) end @@ -347,7 +347,8 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po set_binding!(ctx_2, dst, virtual(VirtualSubFiber(lvl.lvl, value(q_2, Tp)))) exts = virtual_level_size(ctx_2, lvl.buf) inds = [index(freshen(ctx_2, :i, n)) for n = 1:length(exts)] - prgm = assign(access(dst, updater, inds...), initwrite(virtual_level_fill_value(lvl.lvl)), access(src, reader, inds...)) + op = initwrite(virtual_level_fill_value(lvl.lvl)) + prgm = assign(access(dst, updater(op), inds...), op, access(src, reader(), inds...)) for (ind, ext) in zip(inds, exts) prgm = loop(ind, ext, prgm) end @@ -400,7 +401,7 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, pos_ return lvl end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseRunListLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseRunListLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -462,10 +463,10 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseRunListLevel}, ext, mode: end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseRunListLevel}, ext, mode::Updater, proto) = +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseRunListLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseRunListLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseRunListLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/levels/sparse_vbl_levels.jl b/src/tensors/levels/sparse_vbl_levels.jl index 8c764efb0..fe3a4443a 100644 --- a/src/tensors/levels/sparse_vbl_levels.jl +++ b/src/tensors/levels/sparse_vbl_levels.jl @@ -282,7 +282,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBlockListLevel, return lvl end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)}) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mode, ::Union{typeof(defaultread), typeof(walk)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -352,7 +352,7 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mod ) end -function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mode::Reader, ::typeof(gallop)) +function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mode, ::typeof(gallop)) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) @@ -419,9 +419,9 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mod ) end -unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mode::Updater, proto) = +unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBlockListLevel}, ext, mode, proto::Union{typeof(defaultupdate), typeof(extrude)}) = unfurl(ctx, VirtualHollowSubFiber(fbr.lvl, fbr.pos, freshen(ctx, :null)), ext, mode, proto) -function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseBlockListLevel}, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)}) +function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseBlockListLevel}, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)}) (lvl, pos) = (fbr.lvl, fbr.pos) tag = lvl.ex Tp = postype(lvl) diff --git a/src/tensors/masks.jl b/src/tensors/masks.jl index 0ed6cdafb..05050b8c2 100644 --- a/src/tensors/masks.jl +++ b/src/tensors/masks.jl @@ -17,7 +17,7 @@ struct VirtualDiagMask <: AbstractVirtualTensor end virtualize(ctx, ex, ::Type{DiagMask}) = VirtualDiagMask() FinchNotation.finch_leaf(x::VirtualDiagMask) = virtual(x) -Finch.virtual_size(ctx, ::VirtualDiagMask) = (dimless, dimless) +Finch.virtual_size(ctx, ::VirtualDiagMask) = (auto, auto) struct VirtualDiagMaskColumn j @@ -25,7 +25,7 @@ end FinchNotation.finch_leaf(x::VirtualDiagMaskColumn) = virtual(x) -function unfurl(ctx, arr::VirtualDiagMask, ext, mode::Reader, proto) +function unfurl(ctx, arr::VirtualDiagMask, ext, mode, proto::typeof(defaultread)) Unfurled( arr = arr, body = Lookup( @@ -34,7 +34,7 @@ function unfurl(ctx, arr::VirtualDiagMask, ext, mode::Reader, proto) ) end -function unfurl(ctx, arr::VirtualDiagMaskColumn, ext, mode::Reader, proto) +function unfurl(ctx, arr::VirtualDiagMaskColumn, ext, mode, proto::typeof(defaultread)) j = arr.j Sequence([ Phase( @@ -68,7 +68,7 @@ struct VirtualUpTriMask <: AbstractVirtualTensor end virtualize(ctx, ex, ::Type{UpTriMask}) = VirtualUpTriMask() FinchNotation.finch_leaf(x::VirtualUpTriMask) = virtual(x) -Finch.virtual_size(ctx, ::VirtualUpTriMask) = (dimless, dimless) +Finch.virtual_size(ctx, ::VirtualUpTriMask) = (auto, auto) struct VirtualUpTriMaskColumn j @@ -76,7 +76,7 @@ end FinchNotation.finch_leaf(x::VirtualUpTriMaskColumn) = virtual(x) -function unfurl(ctx, arr::VirtualUpTriMask, ext, mode::Reader, proto) +function unfurl(ctx, arr::VirtualUpTriMask, ext, mode, proto::typeof(defaultread)) Unfurled( arr = arr, body = Lookup( @@ -85,7 +85,7 @@ function unfurl(ctx, arr::VirtualUpTriMask, ext, mode::Reader, proto) ) end -function unfurl(ctx, arr::VirtualUpTriMaskColumn, ext, mode::Reader, proto) +function unfurl(ctx, arr::VirtualUpTriMaskColumn, ext, mode, proto::typeof(defaultread)) j = arr.j Sequence([ Phase( @@ -117,7 +117,7 @@ struct VirtualLoTriMask <: AbstractVirtualTensor end virtualize(ctx, ex, ::Type{LoTriMask}) = VirtualLoTriMask() FinchNotation.finch_leaf(x::VirtualLoTriMask) = virtual(x) -Finch.virtual_size(ctx, ::VirtualLoTriMask) = (dimless, dimless) +Finch.virtual_size(ctx, ::VirtualLoTriMask) = (auto, auto) struct VirtualLoTriMaskColumn j @@ -125,7 +125,7 @@ end FinchNotation.finch_leaf(x::VirtualLoTriMaskColumn) = virtual(x) -function unfurl(ctx, arr::VirtualLoTriMask, ext, mode::Reader, proto) +function unfurl(ctx, arr::VirtualLoTriMask, ext, mode, proto::typeof(defaultread)) Unfurled( arr = arr, body = Lookup( @@ -134,7 +134,7 @@ function unfurl(ctx, arr::VirtualLoTriMask, ext, mode::Reader, proto) ) end -function unfurl(ctx, arr::VirtualLoTriMaskColumn, ext, mode::Reader, proto) +function unfurl(ctx, arr::VirtualLoTriMaskColumn, ext, mode, proto::typeof(defaultread)) j = arr.j Sequence([ Phase( @@ -166,7 +166,7 @@ struct VirtualBandMask <: AbstractVirtualTensor end virtualize(ctx, ex, ::Type{BandMask}) = VirtualBandMask() FinchNotation.finch_leaf(x::VirtualBandMask) = virtual(x) -Finch.virtual_size(ctx, ::VirtualBandMask) = (dimless, dimless, dimless) +Finch.virtual_size(ctx, ::VirtualBandMask) = (auto, auto, auto) struct VirtualBandMaskSlice j_lo @@ -181,7 +181,7 @@ end FinchNotation.finch_leaf(x::VirtualBandMaskColumn) = virtual(x) -function unfurl(ctx, arr::VirtualBandMask, ext, mode, proto) +function unfurl(ctx, arr::VirtualBandMask, ext, mode, proto::typeof(defaultread)) Unfurled( arr = arr, body = Lookup( @@ -190,13 +190,13 @@ function unfurl(ctx, arr::VirtualBandMask, ext, mode, proto) ) end -function unfurl(ctx, arr::VirtualBandMaskSlice, ext, mode, proto) +function unfurl(ctx, arr::VirtualBandMaskSlice, ext, mode, proto::typeof(defaultread)) Lookup( body = (ctx, j_hi) -> VirtualBandMaskColumn(arr.j_lo, j_hi) ) end -function unfurl(ctx, arr::VirtualBandMaskColumn, ext, mode, proto) +function unfurl(ctx, arr::VirtualBandMaskColumn, ext, mode, proto::typeof(defaultread)) Sequence([ Phase( stop = (ctx, ext) -> value(:($(ctx(j)) - 1)), @@ -230,7 +230,7 @@ function virtualize(ctx, ex, ::Type{SplitMask}) end FinchNotation.finch_leaf(x::VirtualSplitMask) = virtual(x) -Finch.virtual_size(ctx, arr::VirtualSplitMask) = (dimless, Extent(literal(1), arr.P)) +Finch.virtual_size(ctx, arr::VirtualSplitMask) = (auto, Extent(literal(1), arr.P)) struct VirtualSplitMaskColumn P @@ -239,7 +239,7 @@ end FinchNotation.finch_leaf(x::VirtualSplitMaskColumn) = virtual(x) -function unfurl(ctx, arr::VirtualSplitMask, ext, mode, proto) +function unfurl(ctx, arr::VirtualSplitMask, ext, mode, proto::typeof(defaultread)) Unfurled( arr = arr, body = Lookup( @@ -248,7 +248,7 @@ function unfurl(ctx, arr::VirtualSplitMask, ext, mode, proto) ) end -function unfurl(ctx, arr::VirtualSplitMaskColumn, ext_2, mode, proto) +function unfurl(ctx, arr::VirtualSplitMaskColumn, ext_2, mode, proto::typeof(defaultread)) j = arr.j P = arr.P Sequence([ @@ -315,7 +315,7 @@ end FinchNotation.finch_leaf(x::VirtualChunkMaskColumn) = virtual(x) FinchNotation.finch_leaf(x::VirtualChunkMaskCleanupColumn) = virtual(x) -function unfurl(ctx, arr::VirtualChunkMask, ext, mode, proto) +function unfurl(ctx, arr::VirtualChunkMask, ext, mode, proto::typeof(defaultread)) Unfurled( arr = arr, body = Sequence([ @@ -334,7 +334,7 @@ function unfurl(ctx, arr::VirtualChunkMask, ext, mode, proto) ) end -function unfurl(ctx, arr::VirtualChunkMaskColumn, ext, mode, proto) +function unfurl(ctx, arr::VirtualChunkMaskColumn, ext, mode, proto::typeof(defaultread)) j = arr.j Sequence([ Phase( @@ -349,7 +349,7 @@ function unfurl(ctx, arr::VirtualChunkMaskColumn, ext, mode, proto) ]) end -function unfurl(ctx, arr::VirtualChunkMaskCleanupColumn, ext, mode, proto) +function unfurl(ctx, arr::VirtualChunkMaskCleanupColumn, ext, mode, proto::typeof(defaultread)) Sequence([ Phase( stop = (ctx, ext) -> call(*, call(fld, measure(arr.arr.dim), arr.arr.b), arr.arr.b), diff --git a/src/tensors/scalars.jl b/src/tensors/scalars.jl index 30c6923ce..cc2836316 100644 --- a/src/tensors/scalars.jl +++ b/src/tensors/scalars.jl @@ -154,11 +154,15 @@ function freeze!(ctx, tns::VirtualSparseScalar) return tns end -function instantiate(ctx, tns::VirtualSparseScalar, mode::Reader) - Switch( - tns.dirty => tns, - true => Simplify(FillLeaf(tns.Vf)), - ) +function instantiate(ctx, tns::VirtualSparseScalar, mode) + if mode.kind === reader + Switch( + tns.dirty => tns, + true => Simplify(FillLeaf(tns.Vf)), + ) + else + tns + end end FinchNotation.finch_leaf(x::VirtualSparseScalar) = virtual(x) @@ -329,11 +333,15 @@ function freeze!(ctx, tns::VirtualSparseShortCircuitScalar) return tns end -function instantiate(ctx, tns::VirtualSparseShortCircuitScalar, mode::Reader) - Switch([ - value(tns.dirty, Bool) => tns, - true => Simplify(FillLeaf(tns.Vf)), - ]) +function instantiate(ctx, tns::VirtualSparseShortCircuitScalar, mode) + if mode.kind === reader + Switch([ + value(tns.dirty, Bool) => tns, + true => Simplify(FillLeaf(tns.Vf)), + ]) + else + tns + end end FinchNotation.finch_leaf(x::VirtualSparseShortCircuitScalar) = virtual(x) diff --git a/src/transforms/concurrent.jl b/src/transforms/concurrent.jl index 133fa8eff..5a7074fc3 100644 --- a/src/transforms/concurrent.jl +++ b/src/transforms/concurrent.jl @@ -42,7 +42,7 @@ function ensure_concurrent(root, ctx) #get local definitions locals = Set(filter(!isnothing, map(PostOrderDFS(body)) do node - if @capture(node, declare(~tns, ~init)) tns end + if @capture(node, declare(~tns, ~init, ~op)) tns end end)) #get nonlocal assignments and group by root diff --git a/src/transforms/dimensionalize.jl b/src/transforms/dimensionalize.jl index 4d6b5bdd4..e268597b5 100644 --- a/src/transforms/dimensionalize.jl +++ b/src/transforms/dimensionalize.jl @@ -35,7 +35,7 @@ struct FinchCompileError msg end function (ctx::DeclareDimensions)(node::FinchNode) if node.kind === access @assert @capture node access(~tns, ~mode, ~idxs...) - if node.mode.val !== reader && haskey(ctx.hints, getroot(tns)) + if node.mode.kind !== reader && haskey(ctx.hints, getroot(tns)) shape = map(suggest, virtual_size(ctx.ctx, tns)) push!(ctx.hints[getroot(tns)], node) else @@ -45,7 +45,7 @@ function (ctx::DeclareDimensions)(node::FinchNode) length(idxs) < length(shape) && throw(DimensionMismatch("less indices than dimensions in $(sprint(show, MIME("text/plain"), node))")) idxs = map(zip(shape, idxs)) do (dim, idx) if isindex(idx) - ctx.dims[idx] = resultdim(ctx.ctx, dim, get(ctx.dims, idx, dimless)) + ctx.dims[idx] = resultdim(ctx.ctx, dim, get(ctx.dims, idx, auto)) idx else ctx(idx) #Probably not strictly necessary to preserve the result of this, since this expr can't contain a statement and so won't be modified @@ -58,7 +58,7 @@ function (ctx::DeclareDimensions)(node::FinchNode) end ctx.dims[node.idx] = node.ext.val body = ctx(node.body) - ctx.dims[node.idx] != dimless || throw(FinchCompileError("could not resolve dimension of index $(node.idx)")) + ctx.dims[node.idx] != auto || throw(FinchCompileError("could not resolve dimension of index $(node.idx)")) return loop(node.idx, cache_dim!(ctx.ctx, getname(node.idx), resolvedim(ctx.dims[node.idx])), body) elseif node.kind === block block(map(ctx, node.bodies)...) @@ -70,12 +70,12 @@ function (ctx::DeclareDimensions)(node::FinchNode) shape = virtual_size(ctx.ctx, node.tns) shape = map(suggest, shape) for hint in ctx.hints[node.tns] - @assert @capture hint access(~tns, updater, ~idxs...) + @assert @capture hint access(~tns, updater(~f), ~idxs...) shape = map(zip(shape, idxs)) do (dim, idx) if isindex(idx) resultdim(ctx.ctx, dim, ctx.dims[idx]) else - resultdim(ctx.ctx, dim, dimless) #TODO I can't think of a case where this doesn't equal `dim` + resultdim(ctx.ctx, dim, auto) #TODO I can't think of a case where this doesn't equal `dim` end end end diff --git a/src/transforms/enforce_lifecycles.jl b/src/transforms/enforce_lifecycles.jl index 71b0cdb21..a2c7ce9a4 100644 --- a/src/transforms/enforce_lifecycles.jl +++ b/src/transforms/enforce_lifecycles.jl @@ -27,8 +27,8 @@ end function close_scope(prgm, ctx::EnforceLifecyclesVisitor) prgm = ctx(prgm) for tns in getmodified(prgm) - if ctx.modes[tns] !== reader - prgm = block(prgm, freeze(tns)) + if ctx.modes[tns].kind !== reader + prgm = block(prgm, freeze(tns, ctx.modes[tns].op)) end end prgm @@ -41,17 +41,18 @@ A transformation which adds `freeze` and `thaw` statements automatically to tensor roots, depending on whether they appear on the left or right hand side. """ function enforce_lifecycles(prgm) + prgm = infer_declare_ops(prgm, Dict()) close_scope(prgm, EnforceLifecyclesVisitor()) end #assumes arguments to prgm have been visited already and their uses collected function open_stmt(prgm, ctx::EnforceLifecyclesVisitor) for (tns, mode) in ctx.uses - cur_mode = get(ctx.modes, tns, reader) - if mode === reader && cur_mode === updater - prgm = block(freeze(tns), prgm) - elseif mode === updater && cur_mode === reader - prgm = block(thaw(tns), prgm) + cur_mode = get(ctx.modes, tns, reader()) + if mode.kind === reader && cur_mode.kind === updater + prgm = block(freeze(tns, cur_mode.op), prgm) + elseif mode.kind === updater && cur_mode.kind === reader + prgm = block(thaw(tns, mode.op), prgm) end ctx.modes[tns] = mode end @@ -68,41 +69,63 @@ function (ctx::EnforceLifecyclesVisitor)(node::FinchNode) open_stmt(define(node.lhs, ctx(node.rhs), open_scope(ctx, node.body)), ctx) elseif node.kind === declare ctx.scoped_uses[node.tns] = ctx.uses - if get(ctx.modes, node.tns, reader) === updater - node = block(freeze(node.tns), node) + mode = get(ctx.modes, node.tns, reader()) + if mode.kind === updater + node = block(freeze(node.tns, mode.op), node) end - ctx.modes[node.tns] = updater + ctx.modes[node.tns] = updater(node.op) node elseif node.kind === freeze haskey(ctx.modes, node.tns) || throw(EnforceLifecyclesError("cannot freeze undefined $(node.tns)")) - ctx.modes[node.tns] === reader && return block() - ctx.modes[node.tns] = reader + ctx.modes[node.tns].kind === reader && return block() + ctx.modes[node.tns] = reader() node elseif node.kind === thaw - get(ctx.modes, node.tns, reader) === updater && return block() - ctx.modes[node.tns] = updater - node + mode = get(ctx.modes, node.tns, reader()) + ctx.modes[node.tns] = updater(node.op) + mode == updater(node.op) && return block() + mode == reader() && return node + #mode.kind === updater + return block(freeze(node.tns, mode.op), node) elseif node.kind === assign return open_stmt(assign(ctx(node.lhs), ctx(node.op), ctx(node.rhs)), ctx) elseif node.kind === access idxs = map(ctx, node.idxs) uses = get(ctx.scoped_uses, getroot(node.tns), ctx.global_uses) - get(uses, getroot(node.tns), node.mode.val) !== node.mode.val && + mode = get(uses, getroot(node.tns), node.mode) + mode.kind != node.mode.kind && throw(EnforceLifecyclesError("cannot mix reads and writes to $(node.tns) outside of defining scope (hint: perhaps add a declaration like `var .= 0` or use an updating operator like `var += 1`)")) - uses[getroot(node.tns)] = node.mode.val + mode.kind === updater && mode.op != node.mode.op && + throw(EnforceLifecyclesError("cannot mix reduction operators to $(node.tns) outside of defining scope (hint: perhaps add a declaration like `var .= 0` or use an updating operator like `var += 1`)")) + uses[getroot(node.tns)] = node.mode access(node.tns, node.mode, idxs...) elseif node.kind === yieldbind args_2 = map(node.args) do arg uses = get(ctx.scoped_uses, getroot(arg), ctx.global_uses) - get(uses, getroot(arg), reader) !== reader && + get(uses, getroot(arg), reader()).kind !== reader && throw(EnforceLifecyclesError("cannot return $(arg) outside of defining scope")) - uses[getroot(arg)] = reader + uses[getroot(arg)] = reader() ctx(arg) end open_stmt(yieldbind(args_2...), ctx) elseif istree(node) - return similarterm(node, operation(node), map(ctx, arguments(node))) + return similarterm(node, operation(node), simple_map(FinchNode, ctx, arguments(node))) else return node end end + +function infer_declare_ops(node, ops=Dict()) + if node.kind === declare + declare(node.tns, node.init, get(ops, getroot(node.tns), overwrite)) + else + if node.kind === access && node.mode === updater + ops[getroot(node.tns)] = node.mode.op + end + if istree(node) + similarterm(node, operation(node), reverse(simple_map(FinchNode, n->infer_declare_ops(n, ops), reverse(arguments(node))))) + else + node + end + end +end \ No newline at end of file diff --git a/src/transforms/enforce_scopes.jl b/src/transforms/enforce_scopes.jl index af99c989f..f7605951c 100644 --- a/src/transforms/enforce_scopes.jl +++ b/src/transforms/enforce_scopes.jl @@ -30,15 +30,15 @@ function (ctx::EnforceScopesVisitor)(node::FinchNode) loop(ctx(idx), ctx(ext), open_scope(ctx, body)) elseif @capture node sieve(~cond, ~body) sieve(ctx(cond), open_scope(ctx, body)) - elseif @capture node declare(~tns, ~init) + elseif @capture node declare(~tns, ~init, ~op) push!(ctx.scope, tns) - declare(ctx(tns), init) - elseif @capture node freeze(~tns) + declare(ctx(tns), init, op) + elseif @capture node freeze(~tns, ~op) node.tns in ctx.scope || ctx.scope === ctx.global_scope || throw(ScopeError("cannot freeze $tns not defined in this scope")) - freeze(ctx(tns)) - elseif @capture node thaw(~tns) + freeze(ctx(tns), op) + elseif @capture node thaw(~tns, ~op) node.tns in ctx.scope || ctx.scope === ctx.global_scope || throw(ScopeError("cannot thaw $tns not defined in this scope")) - thaw(ctx(tns)) + thaw(ctx(tns), op) elseif node.kind === variable if !(node in ctx.scope) push!(ctx.global_scope, node) diff --git a/src/transforms/wrapperize.jl b/src/transforms/wrapperize.jl index 1524d1a1b..1f328d39e 100644 --- a/src/transforms/wrapperize.jl +++ b/src/transforms/wrapperize.jl @@ -82,62 +82,62 @@ function get_wrapper_rules(ctx, depth, alg) end), (@rule call(<, ~i, ~j::isindex) => begin if depth(i) < depth(j) - access(VirtualLoTriMask(), reader, j, call(+, i, 1)) + access(VirtualLoTriMask(), reader(), j, call(+, i, 1)) end end), (@rule call(<, ~i::isindex, ~j) => begin if depth(i) > depth(j) - access(VirtualUpTriMask(), reader, i, call(-, j, 1)) + access(VirtualUpTriMask(), reader(), i, call(-, j, 1)) end end), (@rule call(<=, ~i, ~j::isindex) => begin if depth(i) < depth(j) - access(VirtualLoTriMask(), reader, j, i) + access(VirtualLoTriMask(), reader(), j, i) end end), (@rule call(<=, ~i::isindex, ~j) => begin if depth(i) > depth(j) - access(VirtualUpTriMask(), reader, i, j) + access(VirtualUpTriMask(), reader(), i, j) end end), (@rule call(>, ~i, ~j::isindex) => begin if depth(i) < depth(j) - access(VirtualUpTriMask(), reader, j, call(-, i, 1)) + access(VirtualUpTriMask(), reader(), j, call(-, i, 1)) end end), (@rule call(>, ~i::isindex, ~j) => begin if depth(i) > depth(j) - access(VirtualLoTriMask(), reader, i, call(+, j, 1)) + access(VirtualLoTriMask(), reader(), i, call(+, j, 1)) end end), (@rule call(>=, ~i, ~j::isindex) => begin if depth(i) < depth(j) - access(VirtualUpTriMask(), reader, j, i) + access(VirtualUpTriMask(), reader(), j, i) end end), (@rule call(>=, ~i::isindex, ~j) => begin if depth(i) > depth(j) - access(VirtualLoTriMask(), reader, i, j) + access(VirtualLoTriMask(), reader(), i, j) end end), (@rule call(==, ~i, ~j::isindex) => begin if depth(i) < depth(j) - access(VirtualDiagMask(), reader, j, i) + access(VirtualDiagMask(), reader(), j, i) end end), (@rule call(==, ~i::isindex, ~j) => begin if depth(i) > depth(j) - access(VirtualDiagMask(), reader, i, j) + access(VirtualDiagMask(), reader(), i, j) end end), (@rule call(!=, ~i, ~j::isindex) => begin if depth(i) < depth(j) - call(!, access(VirtualDiagMask(), reader, j, i)) + call(!, access(VirtualDiagMask(), reader(), j, i)) end end), (@rule call(!=, ~i::isindex, ~j) => begin if depth(i) > depth(j) - call(!, access(VirtualDiagMask(), reader, i, j)) + call(!, access(VirtualDiagMask(), reader(), i, j)) end end), (@rule call(toeplitz, call(swizzle, ~A, ~sigma...), ~dim...) => begin @@ -172,8 +172,8 @@ function get_wrapper_rules(ctx, depth, alg) A_3 = call(offset, A_2, [0 for _ in i1]..., call(-, getstart(I), 1), [0 for _ in i2]...) access(A_3, m, i1..., k, i2...) end), - (@rule assign(access(~a, updater, ~i...), initwrite, ~rhs) => begin - assign(access(a, updater, i...), call(initwrite, call(fill_value, a)), rhs) + (@rule assign(access(~a, updater(initwrite), ~i...), initwrite, ~rhs) => begin + assign(access(a, updater(call(initwrite, call(fill_value, a))), i...), call(initwrite, call(fill_value, a)), rhs) #updater(auto) end), (@rule call(swizzle, call(swizzle, ~A, ~sigma_1...), ~sigma_2...) => call(swizzle, A, sigma_1[getval.(sigma_2)]...)), @@ -212,17 +212,17 @@ function wrapperize(ctx::AbstractCompiler, root) (@rule loop(~idx, ~ext, ~body) => begin counts = OrderedDict() for node in PostOrderDFS(body) - if @capture(node, access(~tn, reader, ~idxs...)) + if @capture(node, access(~tn, reader(), ~idxs...)) counts[node] = get(counts, node, 0) + 1 end end applied = false for (node, count) in counts if depth(idx) == depth(node) - if @capture(node, access(~tn, reader, ~idxs...)) && count > 1 + if @capture(node, access(~tn, reader(), ~idxs...)) && count > 1 var = variable(Symbol(freshen(ctx, tn.val), "_", join([idx.val for idx in idxs]))) body = Postwalk(@rule node => var)(body) - body = define(var, access(tn, reader, idxs...), body) + body = define(var, access(tn, reader(), idxs...), body) applied = true end end @@ -256,11 +256,11 @@ function unwrap_roots(ctx, root) @info "Hi" (A) end getroot(A) - elseif @capture(node, declare(~A, ~i)) + elseif @capture(node, declare(~A, ~i, ~op)) A - elseif @capture(node, freeze(~A)) + elseif @capture(node, freeze(~A, ~op)) A - elseif @capture(node, thaw(~A)) + elseif @capture(node, thaw(~A, ~op)) A end end)) @@ -274,9 +274,9 @@ function unwrap_roots(ctx, root) #@info "Unwrapping" tns val val_2 root = Rewrite(Postwalk(@rule tns => val_2))(root) root = Rewrite(Postwalk(Chain([ - (@rule declare(val_2, ~i) => declare(tns, i)), - (@rule freeze(val_2) => freeze(tns)), - (@rule thaw(val_2) => thaw(tns)), + (@rule declare(val_2, ~i, ~op) => declare(tns, i, op)), + (@rule freeze(val_2, ~op) => freeze(tns, op)), + (@rule thaw(val_2, ~op) => thaw(tns, op)), ])))(root) end end diff --git a/test/suites/issue_tests.jl b/test/suites/issue_tests.jl index 3e2fcab49..dff28baf5 100644 --- a/test/suites/issue_tests.jl +++ b/test/suites/issue_tests.jl @@ -470,6 +470,8 @@ for i=_ C[i, j] *= beta end + end + for j=_ for k=_ let foo = alpha * B[k, j] for i=_ diff --git a/test/suites/parallel_tests.jl b/test/suites/parallel_tests.jl index f5d804edb..f54e28af0 100644 --- a/test/suites/parallel_tests.jl +++ b/test/suites/parallel_tests.jl @@ -503,7 +503,7 @@ x = Tensor(Dense(Element(0.0))) y = Tensor(Dense(Element(0.0))) - @test_throws Finch.FinchConcurrencyError begin + @test_throws Finch.EnforceLifecyclesError begin @finch_code begin y .= 0 for j = parallel(_)