From 6ae414577f48b02bd9a2540b2d548051133ec697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=B5=E8=B5=9B=E9=A3=9E?= <1210063452@qq.com> Date: Mon, 16 Sep 2019 20:14:39 +0800 Subject: [PATCH] Update data_gen.py --- data_gen.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/data_gen.py b/data_gen.py index 4352979..761a844 100644 --- a/data_gen.py +++ b/data_gen.py @@ -155,12 +155,10 @@ def data_flow(train_data_dir, batch_size, num_classes, input_size): # need modi # 标签平滑 labels = smooth_labels(labels) - merge_paths_label = np.squeeze(np.dstack([img_paths, labels])) - train, validation = train_test_split(merge_paths_label, test_size=0.1, random_state=0, - stratify=merge_paths_label[:, -1]) - train_img_paths, train_labels = train[:, 0], train[:, -1] - validation_img_paths, validation_labels = validation[:, 0], validation[:, -1] - + train_img_paths, validation_img_paths, train_labels, validation_labels = \ + train_test_split(img_paths, labels, test_size=0.1, random_state=0) + print('total samples: %d, training samples: %d, validation samples: %d' % ( + len(img_paths), len(train_img_paths), len(validation_img_paths))) print('total samples: %d, training samples: %d, validation samples: %d' % (len(img_paths), len(train_img_paths), len(validation_img_paths)))