From ae1c5d76e7b31c658adfe52c739c278df2ce4c49 Mon Sep 17 00:00:00 2001 From: Christian Versloot Date: Mon, 30 Nov 2020 21:28:58 +0100 Subject: [PATCH] Fix #8 --- extra_keras_datasets/iris.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/extra_keras_datasets/iris.py b/extra_keras_datasets/iris.py index de98f65..ec6e13a 100644 --- a/extra_keras_datasets/iris.py +++ b/extra_keras_datasets/iris.py @@ -80,10 +80,10 @@ def load_data(path="iris.npz", test_split=0.2): testing_data = samples[:num_test_samples] # Split into inputs and targets - input_train = [i[0:4] for i in training_data] - input_test = [i[0:4] for i in testing_data] - target_train = [i[4] for i in training_data] - target_test = [i[4] for i in testing_data] + input_train = np.array([i[0:4] for i in training_data]) + input_test = np.array([i[0:4] for i in testing_data]) + target_train = np.array([i[4] for i in training_data]) + target_test = np.array([i[4] for i in testing_data]) # Warn about citation warn_citation()