diff --git a/pina/plotter.py b/pina/plotter.py index cd3a0b7f..7eebf63c 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -197,7 +197,7 @@ def plot(self, truth_solution = getattr(solver.problem, 'truth_solution', None) if len(v) == 1: - self._1d_plot(pts, predicted_output, method, truth_solution, + self._1d_plot(pts.extract(v), predicted_output, method, truth_solution, **kwargs) elif len(v) == 2: self._2d_plot(pts, predicted_output, v, res, method, truth_solution, @@ -208,7 +208,6 @@ def plot(self, plt.savefig(filename) else: plt.show() - plt.close() def plot_loss(self, trainer,