Skip to content

Commit

Permalink
Return to built-in relayout.
Browse files Browse the repository at this point in the history
  • Loading branch information
orenbenkiki committed Mar 29, 2024
1 parent b4a6e9d commit 3d8401f
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 159 deletions.
2 changes: 1 addition & 1 deletion docs/v0.1.0/.documenter-siteinfo.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"documenter":{"julia_version":"1.10.2","generation_timestamp":"2024-03-29T14:38:37","documenter_version":"1.3.0"}}
{"documenter":{"julia_version":"1.10.2","generation_timestamp":"2024-03-29T16:47:26","documenter_version":"1.3.0"}}
158 changes: 1 addition & 157 deletions src/matrix_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,7 @@ function relayout!(destination::SparseMatrixCSC, source::AbstractMatrix)::Sparse
end
base_from = base_sparse_matrix(source)
transpose_base_from = transpose(base_from)
if transpose_base_from isa SparseMatrixCSC && nnz(transpose_base_from) == nnz(destination)
return sparse_transpose_in_memory!(destination, transpose_base_from)
else
return transpose!(destination, transpose_base_from) # untested
end
return transpose!(destination, transpose_base_from)
end

function relayout!(destination::DenseMatrix, source::AbstractMatrix)::DenseMatrix
Expand Down Expand Up @@ -367,156 +363,4 @@ function base_sparse_matrix(matrix::AbstractMatrix)::AbstractMatrix # untested
return error("unsupported relayout sparse matrix: $(typeof(matrix))")
end

mutable struct FromPosition
from_row::Int
from_column::Int
from_nnz::Int
end

function get_from_position_key(from_position::FromPosition)::Tuple{Int, Int}
return (from_position.from_row, from_position.from_column)
end

function sparse_transpose_in_memory!(destination::SparseMatrixCSC, source::SparseMatrixCSC)::SparseMatrixCSC
both_nnz = nnz(source)
from_nrows, from_ncols = size(source)
into_nrows, into_ncols = size(destination)
@assert nnz(destination) == nnz(source)
@assert into_nrows == from_ncols
@assert into_ncols == from_nrows
find_position_step = Int(ceil(from_nrows * (from_ncols / both_nnz)))
@assert find_position_step > 0

from_positions = Vector{FromPosition}(undef, from_ncols)
for from_column in 1:from_ncols
from_nnz = source.colptr[from_column]
if from_nnz == source.colptr[from_column + 1]
from_row = from_nrows + 1
from_nnz = both_nnz + 1
else
from_row = source.rowval[from_nnz]
end
from_positions[from_column] = FromPosition(from_row, from_column, from_nnz)
end

sort!(from_positions; by = get_from_position_key)

into_column = 0
for nnz_index in 1:both_nnz
from_position = from_positions[1]
into_column = insert_into_entry!(destination, into_column, nnz_index, source, from_position)
update_from_position!(source, from_position, from_nrows, both_nnz)

if from_ncols > 1
resort_from_positions!(from_position, from_positions, find_position_step)
end
end

@assert from_positions[1].from_row == from_nrows + 1

return destination
end

function insert_into_entry!(
destination::SparseMatrixCSC,
into_column::Int,
into_nnz::Int,
source::SparseMatrixCSC,
from_position::FromPosition,
)::Int
from_column = from_position.from_column
from_row = from_position.from_row
from_nnz = from_position.from_nnz

@assert into_column <= from_row
if into_column < from_row
destination.colptr[(into_column + 1):from_row] .= into_nnz # NOJET
into_column = from_row
end
destination.rowval[into_nnz] = from_column
destination.nzval[into_nnz] = source.nzval[from_nnz]

return into_column
end

function update_from_position!(
source::SparseMatrixCSC,
from_position::FromPosition,
from_nrows::Int,
both_nnz::Int,
)::Nothing
from_column = from_position.from_column
from_nnz = from_position.from_nnz

from_nnz += 1
if from_nnz == source.colptr[from_column + 1]
from_position.from_row = from_nrows + 1
from_position.from_nnz = both_nnz + 1
else
from_position.from_row = source.rowval[from_nnz]
from_position.from_nnz = from_nnz
end

return nothing
end

function resort_from_positions!(
from_position::FromPosition,
from_positions::Vector{FromPosition},
find_position_step::Int,
)::Nothing
from_position_key = get_from_position_key(from_position)
low_position_index = 2
low_position_key = get_from_position_key(from_positions[low_position_index])
if from_position_key < low_position_key
return nothing # untested
end

high_position_index =
find_high_position_index(from_positions, low_position_index, find_position_step, from_position_key)
if high_position_index == nothing
from_positions_length = length(from_positions)
copyto!(from_positions, 1, from_positions, 2, from_positions_length - 1)
from_positions[from_positions_length] = from_position
else
position_range = high_position_index - low_position_index
while position_range > 1
mid_position_index = low_position_index + div(position_range, 2)
mid_position_key = get_from_position_key(from_positions[mid_position_index])
if mid_position_key < from_position_key
low_position_index = mid_position_index
else
high_position_index = mid_position_index
end
position_range = high_position_index - low_position_index
end
copyto!(from_positions, 1, from_positions, 2, low_position_index - 1)
from_positions[low_position_index] = from_position
end

return nothing
end

function find_high_position_index(
from_positions::Vector{FromPosition},
low_position_index::Int,
find_position_step::Int,
from_position_key::Tuple{Int, Int},
)::Maybe{Int}
from_positions_length = length(from_positions)
high_position_index = low_position_index
while true
high_position_index = min(high_position_index + find_position_step, from_positions_length)
high_position_key = get_from_position_key(from_positions[high_position_index])

if high_position_key > from_position_key
return high_position_index
end

if high_position_index == from_positions_length
return nothing
end
end
end

end # module
2 changes: 1 addition & 1 deletion test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "041c52acd12e4a2733576beea2998d165b562f88"
project_hash = "82b221f04d8c49dc49a939f33b7ce1de34c81a0c"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
ExceptionUnwrapping = "460bff9d-24e4-43bc-9d9f-a8973cb893f4"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Muon = "446846d7-b4ce-489d-bf74-72da18fe3629"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
NestedTests = "c7752788-e2db-4f03-8189-845608b41f97"
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Daf.Generic
using ExceptionUnwrapping
using HDF5
using LinearAlgebra
using Logging
using Muon
using NamedArrays
using NestedTests
Expand Down

0 comments on commit 3d8401f

Please sign in to comment.