diff --git a/extra_keras_datasets/iris.py b/extra_keras_datasets/iris.py index de98f65..ec6e13a 100644 --- a/extra_keras_datasets/iris.py +++ b/extra_keras_datasets/iris.py @@ -80,10 +80,10 @@ def load_data(path="iris.npz", test_split=0.2): testing_data = samples[:num_test_samples] # Split into inputs and targets - input_train = [i[0:4] for i in training_data] - input_test = [i[0:4] for i in testing_data] - target_train = [i[4] for i in training_data] - target_test = [i[4] for i in testing_data] + input_train = np.array([i[0:4] for i in training_data]) + input_test = np.array([i[0:4] for i in testing_data]) + target_train = np.array([i[4] for i in training_data]) + target_test = np.array([i[4] for i in testing_data]) # Warn about citation warn_citation()