Skip to content

Commit

Permalink
Merge pull request #695 from finch-tensor/wma/mode_node
Browse files Browse the repository at this point in the history
Wma/mode node
  • Loading branch information
willow-ahrens authored Jan 30, 2025
2 parents c59267a + 3f8b3af commit f9da137
Show file tree
Hide file tree
Showing 55 changed files with 509 additions and 413 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"

[targets]
test = ["ReTestItems", "Test", "ArgParse", "LinearAlgebra", "Random", "SparseArrays", "Graphs", "SimpleWeightedGraphs", "HDF5", "NPZ", "Pkg", "TensorMarket", "Documenter"]
26 changes: 12 additions & 14 deletions docs/src/docs/internals/virtualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ result to clean it up):
```jldoctest example1; filter=r"Finch\.FinchNotation\."
julia> (@macroexpand @finch (C .= 0; for i=_; C[i] = A[i] * B[i] end)) |> Finch.striplines |> Finch.regensym
quote
_res_1 = (Finch.execute)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.declare_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(0)), begin
_res_1 = (Finch.execute)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.declare_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(0), (Finch.FinchNotation.literal_instance)(Finch.auto)), begin
let i = index_instance(i)
(Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Dimensionless(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(Finch.FinchNotation.Updater()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(Finch.FinchNotation.initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))))))
(Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Auto(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), (Finch.FinchNotation.updater_instance)((Finch.FinchNotation.literal_instance)(initwrite)), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), (Finch.FinchNotation.reader_instance)(), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), (Finch.FinchNotation.reader_instance)(), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))))))
end
end), (Finch.FinchNotation.yieldbind_instance)(variable_instance(:C))); )
begin
Expand All @@ -76,7 +76,7 @@ julia> using Finch: @finch_program_instance
julia> prgm = Finch.@finch_program_instance (C .= 0; for i=_; C[i] = A[i] * B[i] end; return C)
Finch program instance: begin
tag(C, Tensor(SparseList(Element(0)))) .= 0
for i = Dimensionless()
for i = Auto()
tag(C, Tensor(SparseList(Element(0))))[tag(i, i)] <<initwrite>>= tag(*, *)(tag(A, Tensor(SparseList(Element(0))))[tag(i, i)], tag(B, Tensor(Dense(Element(0))))[tag(i, i)])
end
return (tag(C, Tensor(SparseList(Element(0)))))
Expand All @@ -91,7 +91,7 @@ different inputs, but the same program type. We can run our program using

```jldoctest example1; filter=r"Finch\.FinchNotation\."
julia> typeof(prgm)
Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}}
Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Auto()}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Auto, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.UpdaterInstance{Finch.FinchNotation.LiteralInstance{initwrite}}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}}
julia> C = Finch.execute(prgm).C
5 Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}:
Expand Down Expand Up @@ -164,15 +164,15 @@ julia> inst = Finch.@finch_program_instance begin
s[] += A[i]
end
end
Finch program instance: for i = Dimensionless()
Finch program instance: for i = Auto()
tag(s, Scalar{0, Int64})[] <<tag(+, +)>>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)]
end
julia> typeof(inst)
Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}
Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Auto, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.UpdaterInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}}, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}
julia> Finch.virtualize(Finch.JuliaContext(), :inst, typeof(inst))
Finch program: for i = virtual(Finch.FinchNotation.Dimensionless)
Finch program: for i = virtual(Finch.FinchNotation.Auto)
tag(s, virtual(Finch.VirtualScalar))[] <<tag(+, +)>>= tag(A, virtual(Finch.VirtualFiber{Finch.VirtualSparseListLevel}))[tag(i, i)]
end
Expand Down Expand Up @@ -289,23 +289,22 @@ julia> prgm_inst = Finch.@finch_program_instance for i = _
end;
julia> println(prgm_inst)
loop_instance(index_instance(i), Finch.FinchNotation.Dimensionless(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), literal_instance(Finch.FinchNotation.Updater())), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal_instance(Finch.FinchNotation.Reader()), tag_instance(variable_instance(:i), index_instance(i)))))
loop_instance(index_instance(i), Finch.FinchNotation.Auto(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), updater_instance(tag_instance(variable_instance(:+), literal_instance(+)))), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader_instance(), tag_instance(variable_instance(:i), index_instance(i)))))
julia> prgm_inst
Finch program instance: for i = Dimensionless()
Finch program instance: for i = Auto()
tag(s, Scalar{0, Int64})[] <<tag(+, +)>>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)]
end
julia> prgm = Finch.@finch_program for i = _
s[] += A[i]
end;
julia> println(prgm)
loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Finch.FinchNotation.Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Finch.FinchNotation.Reader()), index(i))))
loop(index(i), virtual(Finch.FinchNotation.Auto()), assign(access(literal(Scalar{0, Int64}(0)), updater(literal(+))), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader(), index(i))))
julia> prgm
Finch program: for i = virtual(Finch.FinchNotation.Dimensionless)
Finch program: for i = virtual(Finch.FinchNotation.Auto)
Scalar{0, Int64}(0)[] <<+>>= Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))[i]
end
Expand All @@ -319,9 +318,8 @@ representations, so you can use the standard `operation`, `arguments`, `istree`,
```jldoctest example2; setup = :(using Finch, AbstractTrees, SyntaxInterface, RewriteTools)
julia> using Finch.FinchNotation;
julia> PostOrderDFS(prgm)
PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Reader()), index(i)))))
PostOrderDFS{FinchNode}(loop(index(i), virtual(Auto()), assign(access(literal(Scalar{0, Int64}(0)), updater(literal(+))), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader(), index(i)))))
julia> (@capture prgm loop(~idx, ~ext, ~val))
true
Expand Down
2 changes: 1 addition & 1 deletion docs/src/docs/language/dimensionalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ The rules of declaration dimensionalization are as follows:
- The new dimensions of the declared tensor are used when the tensor is on the right hand side (reading) access.

```@docs
Finch.FinchNotation.Dimensionless
Finch.FinchNotation.Auto
```
2 changes: 2 additions & 0 deletions docs/src/docs/language/finch_language.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ Tensors must enter and exit scope in read mode. Finch inserts

Tensor lifecycle statements consist of:
```@docs
reader
updater
declare
freeze
thaw
Expand Down
8 changes: 4 additions & 4 deletions ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ end
#Finch.is_atomic(ctx, tns::VirtualSparseMatrixCSCColumn) = is_atomic(ctx, tns.mtx)[1]
#Finch.is_concurrent(ctx, tns::VirtualSparseMatrixCSCColumn) = is_concurrent(ctx, tns.mtx)[1]

function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)})
function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode, ::Union{typeof(defaultread), typeof(walk)})
tag = arr.ex
Unfurled(
arr = arr,
Expand All @@ -230,7 +230,7 @@ function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, m
)
end

function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)})
function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)})
tag = arr.ex
Unfurled(
arr = arr,
Expand Down Expand Up @@ -383,7 +383,7 @@ function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseVector)
return arr
end

function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)})
function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode, ::Union{typeof(defaultread), typeof(walk)})
tag = arr.ex
Ti = arr.Ti
my_i = freshen(ctx, tag, :_i)
Expand Down Expand Up @@ -439,7 +439,7 @@ function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode
)
end

function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)})
function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)})
tag = arr.ex
Tp = arr.Ti
qos = freshen(ctx, tag, :_qos)
Expand Down
2 changes: 1 addition & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export choose, minby, maxby, overwrite, initwrite, filterop, d
export fill_value, AsArray, expanddims, tensor_tree

export parallelAnalysis, ParallelAnalysisResults
export parallel, realextent, extent, dimless
export parallel, realextent, extent, auto
export CPU, CPULocalVector, CPULocalMemory

export Limit, Eps
Expand Down
4 changes: 2 additions & 2 deletions src/FinchNotation/FinchNotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module FinchNotation
export tag
export call
export cached
export reader, Reader, updater, Updater, access
export reader, updater, access
export define, declare, thaw, freeze
export block
export protocol
Expand All @@ -38,7 +38,7 @@ module FinchNotation

export getval, getname

export overwrite, initwrite, Dimensionless, dimless, extent, realextent
export overwrite, initwrite, Auto, auto, extent, realextent

export d

Expand Down
Loading

0 comments on commit f9da137

Please sign in to comment.