Skip to content

Commit

Permalink
Update data_gen.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wusaifei authored Sep 16, 2019
1 parent f73b00c commit 6ae4145
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,10 @@ def data_flow(train_data_dir, batch_size, num_classes, input_size): # need modi
# 标签平滑
labels = smooth_labels(labels)

merge_paths_label = np.squeeze(np.dstack([img_paths, labels]))
train, validation = train_test_split(merge_paths_label, test_size=0.1, random_state=0,
stratify=merge_paths_label[:, -1])
train_img_paths, train_labels = train[:, 0], train[:, -1]
validation_img_paths, validation_labels = validation[:, 0], validation[:, -1]

train_img_paths, validation_img_paths, train_labels, validation_labels = \
train_test_split(img_paths, labels, test_size=0.1, random_state=0)
print('total samples: %d, training samples: %d, validation samples: %d' % (
len(img_paths), len(train_img_paths), len(validation_img_paths)))

print('total samples: %d, training samples: %d, validation samples: %d' % (len(img_paths), len(train_img_paths), len(validation_img_paths)))

Expand Down

0 comments on commit 6ae4145

Please sign in to comment.