diff --git a/ssm/emissions.py b/ssm/emissions.py index c5f8b6c5..5b8d2de3 100644 --- a/ssm/emissions.py +++ b/ssm/emissions.py @@ -758,8 +758,9 @@ def log_likelihoods(self, data, input, mask, tag, x): return np.sum(lls * mask[:, None, :], axis=2) def invert(self, data, input=None, mask=None, tag=None): - pad = np.zeros((1, 1, self.N)) if self.single_subspace else np.zeros((1, self.K, self.N)) - resid = data - np.concatenate((pad, self.As[None, :, :] * data[:-1, None, :])) + assert self.single_subspace, "Can only invert with a single emission model" + pad = np.zeros((1, self.N)) + resid = data - np.concatenate((pad, self.As * data[:-1])) return self._invert(resid, input=input, mask=mask, tag=tag) def sample(self, z, x, input=None, tag=None):