Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start the tree stuff #434

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 107 additions & 15 deletions src/tensors/levels/sparselevels.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,110 @@
using DataStructures

struct TreeTable{Ti, Tp, Tbl}
tbl::Tbl
end

Base.:(==)(a::TreeTable, b::TreeTable) =
a.tbl == b.tbl

TreeTable{Ti, Tp}() where {Ti, Tp} =
TreeTable{Ti, Tp}(SortedDict{Tuple{Tp, Ti}, Tp}())
TreeTable{Ti, Tp}(tbl::Tbl) where {Ti, Tp, Tbl} =
TreeTable{Ti, Tp, Tbl}(tbl)

function table_coords(tbl::TreeTable{Ti, Tp}, pos) where {Ti, Tp}
searchequalrange(sc,inclusive(sc,st1,st2))
@view tbl.idx[tbl.ptr[pos]:tbl.ptr[pos + 1] - 1]
end

function declare_table!(tbl::TreeTable{Ti, Tp}, pos) where {Ti, Tp}
resize!(tbl.ptr, pos + Tp(1))
fill_range!(tbl.ptr, 0, pos + Tp(1), pos + Tp(1))
empty!(tbl.tbl)
return Tp(0)
end

function assemble_table!(tbl::TreeTable, pos_start, pos_stop)
resize_if_smaller!(tbl.ptr, pos_stop + 1)
fill_range!(tbl.ptr, 0, pos_start + 1, pos_stop + 1)
end

function freeze_table!(tbl::TreeTable, pos_stop)
srt = sort(collect(pairs(tbl.tbl)))
resize!(tbl.idx, length(srt))
resize!(tbl.val, length(srt))
for (q, ((p, i), v)) in enumerate(srt)
tbl.val[q] = v
tbl.idx[q] = i
end
resize!(tbl.ptr, pos_stop + 1)
tbl.ptr[1] = 1
for p = 2:pos_stop + 1
tbl.ptr[p] += tbl.ptr[p - 1]
end
tbl.ptr[pos_stop + 1] - 1
end

function thaw_table!(tbl::TreeTable, pos_stop)
qos_stop = tbl.ptr[pos_stop + 1] - 1
for p = pos_stop:-1:1
tbl.ptr[p + 1] -= tbl.ptr[p]
end
qos_stop
end

function table_length(tbl::TreeTable)
return length(tbl.ptr) - 1
end

function moveto(tbl::TreeTable, arch)
error(
"The table type $(typeof(tbl)) does not support moveto. ",
"Please use a table type that supports moveto."
)
end

table_isdefined(tbl::TreeTable{Ti, Tp}, p) where {Ti, Tp} = p + 1 <= length(tbl.ptr)

table_query(tbl::TreeTable{Ti, Tp}, p) where {Ti, Tp} = (p, tbl.ptr[p], tbl.ptr[p + 1])

subtable_init(tbl::TreeTable{Ti}, (p, start, stop)) where {Ti} = start < stop ? (tbl.idx[start], tbl.idx[stop - 1], start) : (Ti(1), Ti(0), start)

subtable_next(tbl::TreeTable, (p, start, stop), q) = q + 1

subtable_get(tbl::TreeTable, (p, start, stop), q) = (tbl.idx[q], tbl.val[q])

function subtable_seek(tbl, subtbl, state, i, j)
while i < j
state = subtable_next(tbl, subtbl, state)
(i, q) = subtable_get(tbl, subtbl, state)
end
return (i, state)
end

function subtable_seek(tbl::TreeTable, (p, start, stop), q, i, j)
q = Finch.scansearch(tbl.idx, j, q, stop)
return (tbl.idx[q], q)
end

function table_register(tbl::TreeTable, pos)
pos
end

function table_commit(tbl::TreeTable, pos)
end

function subtable_register(tbl::TreeTable, pos, idx)
return get(tbl.tbl, (pos, idx), length(tbl.tbl) + 1)
end

function subtable_commit(tbl::TreeTable, pos, qos, idx)
if qos > length(tbl.tbl)
tbl.tbl[(pos, idx)] = qos
tbl.ptr[pos + 1] += 1
end
end

struct DictTable{Ti, Tp, Ptr, Idx, Val, Tbl}
ptr::Ptr
idx::Idx
Expand All @@ -16,10 +123,6 @@ DictTable{Ti, Tp}() where {Ti, Tp} =
DictTable{Ti, Tp}(ptr::Ptr, idx::Idx, val::Val, tbl::Tbl) where {Ti, Tp, Ptr, Idx, Val, Tbl} =
DictTable{Ti, Tp, Ptr, Idx, Val, Tbl}(ptr, idx, val, tbl)

function table_coords(tbl::DictTable{Ti, Tp}, pos) where {Ti, Tp}
@view tbl.idx[tbl.ptr[pos]:tbl.ptr[pos + 1] - 1]
end

function declare_table!(tbl::DictTable{Ti, Tp}, pos) where {Ti, Tp}
resize!(tbl.ptr, pos + Tp(1))
fill_range!(tbl.ptr, 0, pos + Tp(1), pos + Tp(1))
Expand Down Expand Up @@ -230,17 +333,6 @@ end
@inline level_default(::Type{<:SparseLevel{Ti, Tbl, Lvl}}) where {Ti, Tbl, Lvl} = level_default(Lvl)
data_rep_level(::Type{<:SparseLevel{Ti, Tbl, Lvl}}) where {Ti, Tbl, Lvl} = SparseData(data_rep_level(Lvl))

(fbr::AbstractFiber{<:SparseLevel})() = fbr
function (fbr::SubFiber{<:SparseLevel{Ti}})(idxs...) where {Ti}
isempty(idxs) && return fbr
lvl = fbr.lvl
p = fbr.pos
crds = table_coords(lvl.tbl, p)
r = searchsorted(crds, idxs[end])
q = lvl.tbl.ptr[p] + first(r) - 1
length(r) == 0 ? default(fbr) : SubFiber(lvl.lvl, lvl.tbl.val[q])(idxs[1:end-1]...)
end

mutable struct VirtualSparseLevel <: AbstractVirtualLevel
lvl
ex
Expand Down
Loading