diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index 1ae3e3b0..3b1cccf6 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -505,7 +505,8 @@ def prepare_model( else: raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}") - self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=self.validation_data["predict_kwargs"]) # type: ignore[arg-type] + predict_kwargs = self.validation_data.get("predict_kwargs", {}) + self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=predict_kwargs) # type: ignore[arg-type] def train( self, diff --git a/src/cellflow/solvers/_genot.py b/src/cellflow/solvers/_genot.py index 7270ad7f..f2cc91d1 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/cellflow/solvers/_genot.py @@ -284,6 +284,12 @@ def predict( pred_targets = batched_predict(src_inputs, batched_conditions) return {k: pred_targets[i] for i, k in enumerate(keys)} + elif isinstance(x, dict): + return jax.tree.map( + functools.partial(self._predict_jit, rng=rng, **kwargs), + x, + condition, # type: ignore[attr-defined] + ) else: x_pred = self._predict_jit(x, condition, rng, rng_genot, **kwargs) return np.array(x_pred)