Skip to content

Commit

Permalink
fix shape returned from keplerlib.propagate
Browse files Browse the repository at this point in the history
  • Loading branch information
Tontyna committed Jul 23, 2024
1 parent b270bcb commit ca97e12
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion skyfield/keplerlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy import (
abs, amax, amin, arange, arccos, arctan, array, atleast_1d,
clip, copy, copyto, cos, cosh, exp, full_like, log, ndarray, newaxis,
pi, power, repeat, sign, sin, sinh, squeeze, sqrt, sum,
pi, power, repeat, reshape, sign, sin, sinh, squeeze, sqrt, sum,
tan, tanh, zeros_like,
)
from skyfield.constants import AU_KM, DAY_S, DEG2RAD
Expand Down Expand Up @@ -459,6 +459,11 @@ def propagate(position, velocity, t0, t1, gm):
gm : float
Gravitational parameter in units that match the other arguments
"""
if getattr(t1, 'shape', 0) and t1.shape[0] == 1:
expected_shape= (3, t1.shape[0])
else:
expected_shape= None

gm = atleast_1d(gm)
if (gm <= 0).any():
raise ValueError("'gm' should be positive")
Expand Down Expand Up @@ -618,4 +623,7 @@ def kepler_1d(x, orb_inds):
position_prop = pc[newaxis, :, :]*position[:, :, newaxis] + vc[newaxis, :, :]*velocity[:, :, newaxis]
velocity_prop = pcdot[newaxis, :, :]*position[:, :, newaxis] + vcdot[newaxis, :, :]*velocity[:, :, newaxis]

if expected_shape:
return reshape(squeeze(position_prop), expected_shape ), reshape(squeeze(velocity_prop), expected_shape )

return squeeze(position_prop), squeeze(velocity_prop)

0 comments on commit ca97e12

Please sign in to comment.