Skip to content

Commit

Permalink
Merge pull request #366 from adrn/reorder-nbody
Browse files Browse the repository at this point in the history
Fix inefficient nbody orbit reorder
  • Loading branch information
adrn authored Mar 20, 2024
2 parents a8be7d3 + 1a21825 commit 2067009
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Bug fixes
class to enable using the new (correct) parameter values, but the default will
continue to use the Gala modified values (for backwards compatibility).

- Improved internal efficiency of ``DirectNBody``.


API changes
-----------
Expand Down
10 changes: 2 additions & 8 deletions gala/dynamics/nbody/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,7 @@ def integrate_orbit(self, Integrator=None, Integrator_kwargs=dict(), **time_spec
frame=self.frame,
)

# Reorder orbits:
remap_idx = np.zeros((orbits.shape[-1], orbits.shape[-1]), dtype=int)
remap_idx[idx, np.arange(orbits.shape[-1])] = 1
_, undo_idx = np.where(remap_idx == 1)

return orbits[..., undo_idx]
remap_idx[idx, np.arange(orbits.shape[-1])] = 1
_, undo_idx = np.where(remap_idx == 1)
# Reorder orbits to original order:
undo_idx = np.argsort(idx)

return orbits[..., undo_idx]
44 changes: 35 additions & 9 deletions gala/dynamics/nbody/tests/test_nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
import numpy as np
import pytest

# Custom
from gala.potential import (
NullPotential,
NFWPotential,
HernquistPotential,
ConstantRotatingFrame,
StaticFrame,
)
from gala.dynamics import PhaseSpacePosition, combine
from gala.units import UnitSystem, galactic
from gala.integrate import (
DOPRI853Integrator,
LeapfrogIntegrator,
Ruth4Integrator,
)

# Custom
from gala.potential import (
ConstantRotatingFrame,
HernquistPotential,
NFWPotential,
NullPotential,
StaticFrame,
)
from gala.units import UnitSystem, galactic

# Project
from ..core import DirectNBody

Expand Down Expand Up @@ -220,3 +221,28 @@ def test_directnbody_integrate_rotframe(self, Integrator):

assert u.allclose(orbits_static.xyz, orbits_static.xyz)
assert u.allclose(orbits2.v_xyz, orbits2.v_xyz)

@pytest.mark.parametrize("Integrator", [DOPRI853Integrator])
def test_nbody_reorder(self, Integrator):
N = 16
rng = np.random.default_rng(seed=42)
w0 = PhaseSpacePosition(
pos=rng.normal(0, 5, size=(3, N)) * u.kpc,
vel=rng.normal(0, 50, size=(3, N)) * u.km / u.s,
)
pots = [
(
HernquistPotential(1e9 * u.Msun, 1.0 * u.pc, units=galactic)
if rng.uniform() > 0.5
else None
)
for _ in range(N)
]
sim = DirectNBody(
w0,
pots,
external_potential=HernquistPotential(1e12, 10, units=galactic),
units=galactic,
)
orbits = sim.integrate_orbit(dt=1.0 * u.Myr, t1=0, t2=100 * u.Myr)
assert np.allclose(orbits.pos[0].xyz, w0.pos.xyz)

0 comments on commit 2067009

Please sign in to comment.