diff --git a/train.py b/train.py index f398155..f7516f5 100644 --- a/train.py +++ b/train.py @@ -30,6 +30,7 @@ BATCH_SIZE = 10 EPOCH = 100 EPOCH_LOSSES = [] +SHUFFLED_INDECES = [] # ============================== END ============================== # if __name__ == '__main__': @@ -43,12 +44,17 @@ BATCH = int(len(keys) / BATCH_SIZE) def next_batch(): - global buff + global buff, BATCH_SIZE ,SHUFFLED_INDECES mini_batch = [] actual_data = [] - indicies = np.random.choice(len(keys), BATCH_SIZE) - for idx in indicies: + if 0 == len(SHUFFLED_INDECES): + SHUFFLED_INDECES = list(np.random.permutation(len(keys))) + + indices = SHUFFLED_INDECES[:min(BATCH_SIZE, len(SHUFFLED_INDECES))] + del SHUFFLED_INDECES[:min(BATCH_SIZE, len(SHUFFLED_INDECES))] + + for idx in indices: # make images mini batch img = load_image('voc2007/'+keys[idx]) @@ -103,12 +109,14 @@ def draw_marker(image_name, save): # saver.restore(sess, './checkpoints/params.ckpt') + SHUFFLED_INDECES = list(np.random.permutation(len(keys))) + print('\nSTART LEARNING') print('==================== '+str(datetime.datetime.now())+' ====================') for _ in range(5): next_batch() - + for ep in range(EPOCH): BATCH_LOSSES = [] for ba in trange(BATCH):