Skip to content

Commit

Permalink
Move cell list management to a new struct
Browse files Browse the repository at this point in the history
  • Loading branch information
efaulhaber committed May 12, 2024
1 parent 77d2951 commit 3b02cd1
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 53 deletions.
1 change: 1 addition & 0 deletions src/PointNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Polyester: @batch
include("util.jl")
include("neighborhood_search.jl")
include("trivial_nhs.jl")
include("cell_lists.jl")
include("grid_nhs.jl")

export for_particle_neighbor
Expand Down
54 changes: 54 additions & 0 deletions src/cell_lists.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
struct DictionaryCellList{NDIMS}
hashtable :: Dict{NTuple{NDIMS, Int}, Vector{Int}}
empty_vector :: Vector{Int} # Just an empty vector (used in `eachneighbor`)

function DictionaryCellList{NDIMS}() where {NDIMS}
hashtable = Dict{NTuple{NDIMS, Int}, Vector{Int}}()
empty_vector = Int[]

new{NDIMS}(hashtable, empty_vector)
end
end

function Base.empty!(cell_list::DictionaryCellList)
Base.empty!(cell_list.hashtable)

return cell_list
end

function push_cell!(cell_list::DictionaryCellList, cell, particle)
(; hashtable) = cell_list

if haskey(hashtable, cell)
append!(hashtable[cell], particle)
else
hashtable[cell] = [particle]
end

return cell_list
end

function deleteat_cell!(cell_list::DictionaryCellList, cell, i)
(; hashtable) = cell_list

# This works for `i::Integer`, `i::AbstractVector`, and even `i::Base.Generator`
if length(hashtable[cell]) <= count(_ -> true, i)
delete_cell!(cell_list, cell)
else
deleteat!(hashtable[cell], i)
end
end

function delete_cell!(cell_list, cell)
delete!(cell_list.hashtable, cell)
end

@inline eachcell(cell_list::DictionaryCellList) = keys(cell_list.hashtable)

@inline function Base.getindex(cell_list::DictionaryCellList, cell)
(; hashtable, empty_vector) = cell_list

# Return an empty vector when `cell_index` is not a key of `hashtable` and
# reuse the empty vector to avoid allocations.
return get(hashtable, cell, empty_vector)
end
84 changes: 31 additions & 53 deletions src/grid_nhs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,23 @@ since not sorting makes our implementation a lot faster (although less paralleli
In: Computer Graphics Forum 30.1 (2011), pages 99–112.
[doi: 10.1111/J.1467-8659.2010.01832.X](https://doi.org/10.1111/J.1467-8659.2010.01832.X)
"""
struct GridNeighborhoodSearch{NDIMS, ELTYPE, PB}
hashtable :: Dict{NTuple{NDIMS, Int}, Vector{Int}}
struct GridNeighborhoodSearch{NDIMS, ELTYPE, CL, PB}
cell_list :: CL
search_radius :: ELTYPE
empty_vector :: Vector{Int} # Just an empty vector (used in `eachneighbor`)
periodic_box :: PB
n_cells :: NTuple{NDIMS, Int} # Required to calculate periodic cell index
cell_size :: NTuple{NDIMS, ELTYPE} # Required to calculate cell index
cell_buffer :: Array{NTuple{NDIMS, Int}, 2} # Multithreaded buffer for `update!`
cell_buffer_indices :: Vector{Int} # Store which entries of `cell_buffer` are initialized
periodic_box :: PB
n_cells :: NTuple{NDIMS, Int}
cell_size :: NTuple{NDIMS, ELTYPE}
threaded_nhs_update :: Bool

function GridNeighborhoodSearch{NDIMS}(search_radius, n_particles;
periodic_box_min_corner = nothing,
periodic_box_max_corner = nothing,
threaded_nhs_update = true) where {NDIMS}
ELTYPE = typeof(search_radius)
cell_list = DictionaryCellList{NDIMS}()

hashtable = Dict{NTuple{NDIMS, Int}, Vector{Int}}()
empty_vector = Int[]
cell_buffer = Array{NTuple{NDIMS, Int}, 2}(undef, n_particles, Threads.nthreads())
cell_buffer_indices = zeros(Int, Threads.nthreads())

Expand Down Expand Up @@ -119,10 +117,10 @@ struct GridNeighborhoodSearch{NDIMS, ELTYPE, PB}
"must either be both `nothing` or both an array or tuple"))
end

new{NDIMS, ELTYPE,
typeof(periodic_box)}(hashtable, search_radius, empty_vector,
cell_buffer, cell_buffer_indices,
periodic_box, n_cells, cell_size, threaded_nhs_update)
new{NDIMS, ELTYPE, typeof(cell_list),
typeof(periodic_box)}(cell_list, search_radius, periodic_box, n_cells,
cell_size, cell_buffer, cell_buffer_indices,
threaded_nhs_update)
end
end

Expand All @@ -147,20 +145,16 @@ function initialize!(neighborhood_search::GridNeighborhoodSearch, coords_fun1, c
end

function initialize_grid!(neighborhood_search::GridNeighborhoodSearch, coords_fun)
(; hashtable) = neighborhood_search
(; cell_list) = neighborhood_search

empty!(hashtable)
empty!(cell_list)

for particle in 1:nparticles(neighborhood_search)
# Get cell index of the particle's cell
cell = cell_coords(coords_fun(particle), neighborhood_search)

# Add particle to corresponding cell or create cell if it does not exist
if haskey(hashtable, cell)
append!(hashtable[cell], particle)
else
hashtable[cell] = [particle]
end
# Add particle to corresponding cell
push_cell!(cell_list, cell, particle)
end

return neighborhood_search
Expand Down Expand Up @@ -193,22 +187,21 @@ end

# Modify the existing hash table by moving particles into their new cells
function update_grid!(neighborhood_search::GridNeighborhoodSearch, coords_fun)
(; hashtable, cell_buffer, cell_buffer_indices, threaded_nhs_update) = neighborhood_search
(; cell_list, cell_buffer, cell_buffer_indices, threaded_nhs_update) = neighborhood_search

# Reset `cell_buffer` by moving all pointers to the beginning.
# Reset `cell_buffer` by moving all pointers to the beginning
cell_buffer_indices .= 0

# Find all cells containing particles that now belong to another cell.
# `collect` the keyset to be able to loop over it with `@threaded`.
mark_changed_cell!(neighborhood_search, hashtable, coords_fun,
# Find all cells containing particles that now belong to another cell
mark_changed_cell!(neighborhood_search, cell_list, coords_fun,
Val(threaded_nhs_update))

# Iterate over all marked cells and move the particles into their new cells.
for thread in 1:Threads.nthreads()
# Only the entries `1:cell_buffer_indices[thread]` are initialized for `thread`.
for i in 1:cell_buffer_indices[thread]
cell = cell_buffer[i, thread]
particles = hashtable[cell]
particles = cell_list[cell]

# Find all particles whose coordinates do not match this cell
moved_particle_indices = (i for i in eachindex(particles)
Expand All @@ -221,35 +214,28 @@ function update_grid!(neighborhood_search::GridNeighborhoodSearch, coords_fun)
new_cell_coords = cell_coords(coords_fun(particle), neighborhood_search)

# Add particle to corresponding cell or create cell if it does not exist
if haskey(hashtable, new_cell_coords)
append!(hashtable[new_cell_coords], particle)
else
hashtable[new_cell_coords] = [particle]
end
push_cell!(cell_list, new_cell_coords, particle)
end

# Remove moved particles from this cell or delete the cell if it is now empty
if count(_ -> true, moved_particle_indices) == length(particles)
delete!(hashtable, cell)
else
deleteat!(particles, moved_particle_indices)
end
# Remove moved particles from this cell
deleteat_cell!(cell_list, cell, moved_particle_indices)
end
end

return neighborhood_search
end

@inline function mark_changed_cell!(neighborhood_search, hashtable, coords_fun,
@inline function mark_changed_cell!(neighborhood_search, cell_list, coords_fun,
threaded_nhs_update::Val{true})
@threaded for cell in collect(keys(hashtable))
# `collect` the keyset to be able to loop over it with `@threaded`
@threaded for cell in collect(eachcell(cell_list))
mark_changed_cell!(neighborhood_search, cell, coords_fun)
end
end

@inline function mark_changed_cell!(neighborhood_search, hashtable, coords_fun,
@inline function mark_changed_cell!(neighborhood_search, cell_list, coords_fun,
threaded_nhs_update::Val{false})
for cell in collect(keys(hashtable))
for cell in eachcell(cell_list)
mark_changed_cell!(neighborhood_search, cell, coords_fun)
end
end
Expand All @@ -259,9 +245,9 @@ end
# Otherwise, `@threaded` does not work here with Julia ARM on macOS.
# See https://github.com/JuliaSIMD/Polyester.jl/issues/88.
@inline function mark_changed_cell!(neighborhood_search, cell, coords_fun)
(; hashtable, cell_buffer, cell_buffer_indices) = neighborhood_search
(; cell_list, cell_buffer, cell_buffer_indices) = neighborhood_search

for particle in hashtable[cell]
for particle in cell_list[cell]
if cell_coords(coords_fun(particle), neighborhood_search) != cell
# Mark this cell and continue with the next one.
#
Expand Down Expand Up @@ -311,12 +297,9 @@ end
end

@inline function particles_in_cell(cell_index, neighborhood_search)
(; hashtable, empty_vector) = neighborhood_search
(; cell_list) = neighborhood_search

# Return an empty vector when `cell_index` is not a key of `hashtable` and
# reuse the empty vector to avoid allocations
return get(hashtable, periodic_cell_index(cell_index, neighborhood_search),
empty_vector)
return cell_list[periodic_cell_index(cell_index, neighborhood_search)]
end

@inline function periodic_cell_index(cell_index, neighborhood_search)
Expand Down Expand Up @@ -385,8 +368,3 @@ function copy_neighborhood_search(nhs::GridNeighborhoodSearch, search_radius, x,

return search
end

# Create a copy of a neighborhood search but with a different search radius
function copy_neighborhood_search(nhs::TrivialNeighborhoodSearch, search_radius, x, y)
return nhs
end
5 changes: 5 additions & 0 deletions src/trivial_nhs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,8 @@ end
end

@inline eachneighbor(coords, search::TrivialNeighborhoodSearch) = search.eachparticle

# Create a copy of a neighborhood search but with a different search radius
function copy_neighborhood_search(nhs::TrivialNeighborhoodSearch, search_radius, x, y)
return nhs
end

0 comments on commit 3b02cd1

Please sign in to comment.