Skip to content

Commit

Permalink
add periodic boundary
Browse files Browse the repository at this point in the history
  • Loading branch information
SteffenPL committed Dec 6, 2023
1 parent ea4c10a commit 4bf833f
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 10 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
StaticArrays = "1"
Adapt = "3"
StaticArrays = "1"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[targets]
test = ["Test"]
test = ["Test","LinearAlgebra"]
22 changes: 22 additions & 0 deletions benchmarks/scripts/periodic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Revise
using SpatialHashTables
using StaticArrays
using LinearAlgebra
using Test

using SpatialHashTables: hashindex, gridindices

const SVec2 = SVector{2, Float64}

X = [SVec2(0.01,0.01), SVec2(0.99, 0.99)]
ht = BoundedHashTable(X, 0.1, [1.0, 1.0])
d = 0.0
for i in eachindex(X)
for (j, offset) in periodic_neighbours(ht, X[i], 0.1)
if i != j
#@show (i,j) offset
d += norm(X[j] - offset - X[i])
end
end
end
@test d 2 * sqrt(2*0.02^2)
2 changes: 1 addition & 1 deletion src/SpatialHashTables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ include("core.jl")
include("adapt.jl")

export SpatialHashTable, BoundedHashTable, AbstractSpatialHashTable
export updatetable!, resize!, neighbours, iterate_box, dimension, inside, hashposition
export updatetable!, resize!, neighbours, iterate_box, dimension, inside, hashposition, periodic_neighbours
end
6 changes: 4 additions & 2 deletions src/boundedhashtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ struct BoundedHashTable{Dim,VT<:AbstractVector,FT<:AbstractFloat,IT<:Integer} <:

domainstart::SVector{Dim,FT}
domainend::SVector{Dim,FT}
domainsize::SVector{Dim,FT}

inv_cellsize::SVector{Dim,FT}

Expand All @@ -31,10 +32,11 @@ function BoundedHashTable(N::Integer, grid::Tuple, domainstart::SVector, domaine
cellcount = Vector{typeof(N)}(undef, prod(grid) + 1)
particlemap = Vector{typeof(N)}(undef, N)

inv_cellsize = grid ./ (domainend - domainstart)
domainsize = domainend - domainstart
inv_cellsize = grid ./ domainsize
strides = (oneunit(eltype(grid)), cumprod(grid[1:end-1])...)

return BoundedHashTable(cellcount, particlemap, domainstart, domainend, inv_cellsize, strides, grid)
return BoundedHashTable(cellcount, particlemap, domainstart, domainend, domainsize, inv_cellsize, strides, grid)
end


Expand Down
29 changes: 25 additions & 4 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ function neighbouring_boxes(ht::AbstractSpatialHashTable, gridpos, r)
IT = inttype(ht)
Dim = IT(dimension(ht))
widths = @. ceil(IT, r * ht.inv_cellsize)
int_offsets = CartesianIndices(ntuple(i -> -widths[i]:widths[i], Dim))
offsets = (gridpos .+ Tuple(i) for i in int_offsets)
return (hashindex(ht, offset) for offset in offsets if insidegrid(ht, offset))
neighbour_indices = CartesianIndices(ntuple(i -> -widths[i]:widths[i], Dim))
neighbour_reps = (gridpos .+ Tuple(i) for i in neighbour_indices)
return (hashindex(ht, rep) for rep in neighbour_reps if insidegrid(ht, rep))
end

"""
Expand All @@ -76,9 +76,30 @@ This is the main method of this package and is used to find the neighbours of a
"""
function neighbours(ht::AbstractSpatialHashTable, pos, r)
gridpos = gridindices(ht, pos)
return (k for boxhash in neighbouring_boxes(ht, gridpos, r) for k in iterate_box(ht, boxhash))
return (k for boxhash in neighbouring_boxes(ht, gridpos, r)
for k in iterate_box(ht, boxhash))
end

@inline function wrap_index(ht::BoundedHashTable, gridpos)
rep = @. mod(gridpos - 1, ht.gridsize) + 1
offset = @. ceil(Int64, (rep - gridpos) / ht.gridsize) * ht.domainsize
return (; rep, offset)
end

function periodic_neighbouring_boxes(ht::BoundedHashTable, gridpos, r)
IT = inttype(ht)
Dim = IT(dimension(ht))
widths = @. ceil(IT, r * ht.inv_cellsize)
neighbour_indices = CartesianIndices(ntuple(i -> -widths[i]:widths[i], Dim))
neighbour_reps = (wrap_index(ht, gridpos .+ Tuple(i)) for i in neighbour_indices)
return ( (hashindex(ht, rep), offset) for (; rep, offset) in neighbour_reps)
end

function periodic_neighbours(ht::BoundedHashTable, pos, r)
gridpos = gridindices(ht, pos)
return ((k, offset) for (boxhash, offset) in periodic_neighbouring_boxes(ht, gridpos, r)
for k in iterate_box(ht, boxhash))
end

# The following code deals with hash index collisions which could result
# in the same index being returned multiple times.
Expand Down
16 changes: 15 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using SpatialHashTables
using StaticArrays
using Test
using Test, LinearAlgebra

using SpatialHashTables: hashindex, gridindices

Expand Down Expand Up @@ -69,4 +69,18 @@ end
ht = SpatialHashTable(X, 0.1, 2)

@test allunique(neighbours(ht, X[1], 0.1))
end

@testset "Periodic boundary" begin
X = [SVec2(0.01,0.01), SVec2(0.99, 0.99)]
ht = BoundedHashTable(X, 0.1, [1.0, 1.0])
d = 0.0
for i in eachindex(X)
for (j, offset) in periodic_neighbours(ht, X[i], 0.1)
if i != j
d += norm(X[j] - offset - X[i])
end
end
end
@test d 2 * sqrt(2*0.02^2)
end

0 comments on commit 4bf833f

Please sign in to comment.