Skip to content

Commit

Permalink
fix: 🐛 fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ivaquero committed Sep 7, 2024
1 parent 45b2e88 commit 7813346
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 157 deletions.
12 changes: 6 additions & 6 deletions filters-bayes.ipynb

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions filters-ghk.ipynb

Large diffs are not rendered by default.

45 changes: 18 additions & 27 deletions filters-kf-basic.ipynb

Large diffs are not rendered by default.

50 changes: 25 additions & 25 deletions filters-kf-design.ipynb

Large diffs are not rendered by default.

42 changes: 21 additions & 21 deletions filters-kf-plus.ipynb

Large diffs are not rendered by default.

46 changes: 23 additions & 23 deletions filters-maneuver.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions filters-pf.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions filters-smoothers.ipynb

Large diffs are not rendered by default.

42 changes: 21 additions & 21 deletions filters-task-fusion.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions filters-task-tracking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"source": [
"import math\n",
"\n",
"from numpy import random\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
"import numpy as np\n",
"from numpy import random"
]
},
{
Expand Down
4 changes: 3 additions & 1 deletion filters/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def batch_filter(
Rs = [self.R] * n

# mean estimates from Kalman Filter
if self.x.ndim == 1:
if np.ndim(self.x) == 1:
means = np.zeros((n, self.dim_x))
means_p = np.zeros((n, self.dim_x))
else:
Expand All @@ -297,6 +297,8 @@ def batch_filter(
cov[i, :, :] = self.P

self.predict(u=u, G=G, F=F, Q=Q)
print()

means_p[i, :] = self.x
cov_p[i, :, :] = self.P

Expand Down
14 changes: 9 additions & 5 deletions filters/kalman_ukf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
self.sigmas_h = np.zeros((self._num_sigmas, self._dim_z))

self.K = np.zeros((dim_x, dim_z))
self.y = np.zeros((dim_z))
self.y = np.zeros((dim_z, dim_z))
self.z = np.array([[None] * dim_z]).T
self.S = np.zeros((dim_z, dim_z))
self.SI = np.zeros((dim_z, dim_z))
Expand Down Expand Up @@ -223,15 +223,18 @@ def batch_filter(self, zs, Rs=None, dts=None, UT=None, saver=None):
try:
z = zs[0]
except TypeError as e:
raise TypeError('zs must be list-like') from e
error_message = 'zs must be list-like'
raise TypeError(error_message) from e

if self._dim_z == 1:
if not np.isscalar(z) and (z.ndim != 1 or len(z) != 1):
raise TypeError('zs must be a list of scalars or 1D, 1 element arrays')
error_message = 'zs must be a list of scalars or 1D, 1 element arrays'
raise TypeError(error_message)
elif len(z) != self._dim_z:
raise TypeError(
error_message = (
f'each element in zs must be a 1D array of length {self._dim_z}'
)
raise TypeError(error_message)

z_n = len(zs)

Expand Down Expand Up @@ -268,7 +271,8 @@ def rts_smoother(self, Xs, Ps, Qs=None, dts=None, UT=None):
"""

if len(Xs) != len(Ps):
raise ValueError('Xs and Ps must have the same length')
error_message = 'Xs and Ps must have the same length'
raise ValueError(error_message)

n, dim_x = Xs.shape

Expand Down
5 changes: 3 additions & 2 deletions models/const_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def FxCV(x, dt):
return FCV(len(x), dt) @ x


def KFCV1d(P, R, Q=0, dt=1, x=(0)):
if type(x) == list:
def KFCV1d(P, R, Q=0, dt=1, x=(0,)):
if type(x) == list | tuple:
x = np.array(x)

dim_x = len(x)
kf_cv = KalmanFilter(dim_x=dim_x, dim_z=1)
kf_cv.x = np.zeros(dim_x)
Expand Down
2 changes: 1 addition & 1 deletion plots/plot_sigmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def plot_sigmas_selection(ax, kappas=None, alphas=None, betas=None, var=None):
sigmas = points.sigma_points(x, P)
_plot_sigmas(ax, sigmas, points.Wc, alpha=1.0, facecolor='k')
plot_cov_ellipse(
ax, x, P, stds=np.sqrt(var), facecolor='b', alpha=0.3, title=False
ax, x, P, stds=np.sqrt(var), facecolor='b', alpha=0.3, show_title=False
)

ax.axis('equal')
Expand Down

0 comments on commit 7813346

Please sign in to comment.