Skip to content

Commit

Permalink
Return to built-in relayout.
Browse files Browse the repository at this point in the history
  • Loading branch information
orenbenkiki committed Mar 29, 2024
1 parent b4a6e9d commit 5cdf380
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 158 deletions.
2 changes: 1 addition & 1 deletion docs/v0.1.0/.documenter-siteinfo.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"documenter":{"julia_version":"1.10.2","generation_timestamp":"2024-03-29T14:38:37","documenter_version":"1.3.0"}}
{"documenter":{"julia_version":"1.10.2","generation_timestamp":"2024-03-29T16:47:26","documenter_version":"1.3.0"}}
158 changes: 1 addition & 157 deletions src/matrix_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,7 @@ function relayout!(destination::SparseMatrixCSC, source::AbstractMatrix)::Sparse
end
base_from = base_sparse_matrix(source)
transpose_base_from = transpose(base_from)
if transpose_base_from isa SparseMatrixCSC && nnz(transpose_base_from) == nnz(destination)
return sparse_transpose_in_memory!(destination, transpose_base_from)
else
return transpose!(destination, transpose_base_from) # untested
end
return transpose!(destination, transpose_base_from)
end

function relayout!(destination::DenseMatrix, source::AbstractMatrix)::DenseMatrix
Expand Down Expand Up @@ -367,156 +363,4 @@ function base_sparse_matrix(matrix::AbstractMatrix)::AbstractMatrix # untested
return error("unsupported relayout sparse matrix: $(typeof(matrix))")
end

mutable struct FromPosition
from_row::Int
from_column::Int
from_nnz::Int
end

function get_from_position_key(from_position::FromPosition)::Tuple{Int, Int}
return (from_position.from_row, from_position.from_column)
end

function sparse_transpose_in_memory!(destination::SparseMatrixCSC, source::SparseMatrixCSC)::SparseMatrixCSC
both_nnz = nnz(source)
from_nrows, from_ncols = size(source)
into_nrows, into_ncols = size(destination)
@assert nnz(destination) == nnz(source)
@assert into_nrows == from_ncols
@assert into_ncols == from_nrows
find_position_step = Int(ceil(from_nrows * (from_ncols / both_nnz)))
@assert find_position_step > 0

from_positions = Vector{FromPosition}(undef, from_ncols)
for from_column in 1:from_ncols
from_nnz = source.colptr[from_column]
if from_nnz == source.colptr[from_column + 1]
from_row = from_nrows + 1
from_nnz = both_nnz + 1
else
from_row = source.rowval[from_nnz]
end
from_positions[from_column] = FromPosition(from_row, from_column, from_nnz)
end

sort!(from_positions; by = get_from_position_key)

into_column = 0
for nnz_index in 1:both_nnz
from_position = from_positions[1]
into_column = insert_into_entry!(destination, into_column, nnz_index, source, from_position)
update_from_position!(source, from_position, from_nrows, both_nnz)

if from_ncols > 1
resort_from_positions!(from_position, from_positions, find_position_step)
end
end

@assert from_positions[1].from_row == from_nrows + 1

return destination
end

function insert_into_entry!(
destination::SparseMatrixCSC,
into_column::Int,
into_nnz::Int,
source::SparseMatrixCSC,
from_position::FromPosition,
)::Int
from_column = from_position.from_column
from_row = from_position.from_row
from_nnz = from_position.from_nnz

@assert into_column <= from_row
if into_column < from_row
destination.colptr[(into_column + 1):from_row] .= into_nnz # NOJET
into_column = from_row
end
destination.rowval[into_nnz] = from_column
destination.nzval[into_nnz] = source.nzval[from_nnz]

return into_column
end

function update_from_position!(
source::SparseMatrixCSC,
from_position::FromPosition,
from_nrows::Int,
both_nnz::Int,
)::Nothing
from_column = from_position.from_column
from_nnz = from_position.from_nnz

from_nnz += 1
if from_nnz == source.colptr[from_column + 1]
from_position.from_row = from_nrows + 1
from_position.from_nnz = both_nnz + 1
else
from_position.from_row = source.rowval[from_nnz]
from_position.from_nnz = from_nnz
end

return nothing
end

function resort_from_positions!(
from_position::FromPosition,
from_positions::Vector{FromPosition},
find_position_step::Int,
)::Nothing
from_position_key = get_from_position_key(from_position)
low_position_index = 2
low_position_key = get_from_position_key(from_positions[low_position_index])
if from_position_key < low_position_key
return nothing # untested
end

high_position_index =
find_high_position_index(from_positions, low_position_index, find_position_step, from_position_key)
if high_position_index == nothing
from_positions_length = length(from_positions)
copyto!(from_positions, 1, from_positions, 2, from_positions_length - 1)
from_positions[from_positions_length] = from_position
else
position_range = high_position_index - low_position_index
while position_range > 1
mid_position_index = low_position_index + div(position_range, 2)
mid_position_key = get_from_position_key(from_positions[mid_position_index])
if mid_position_key < from_position_key
low_position_index = mid_position_index
else
high_position_index = mid_position_index
end
position_range = high_position_index - low_position_index
end
copyto!(from_positions, 1, from_positions, 2, low_position_index - 1)
from_positions[low_position_index] = from_position
end

return nothing
end

function find_high_position_index(
from_positions::Vector{FromPosition},
low_position_index::Int,
find_position_step::Int,
from_position_key::Tuple{Int, Int},
)::Maybe{Int}
from_positions_length = length(from_positions)
high_position_index = low_position_index
while true
high_position_index = min(high_position_index + find_position_step, from_positions_length)
high_position_key = get_from_position_key(from_positions[high_position_index])

if high_position_key > from_position_key
return high_position_index
end

if high_position_index == from_positions_length
return nothing
end
end
end

end # module

0 comments on commit 5cdf380

Please sign in to comment.