Skip to content

Commit

Permalink
fix: lost duplication of indices
Browse files Browse the repository at this point in the history
  • Loading branch information
arabian9ts committed Jan 25, 2018
1 parent 6814085 commit 1c66253
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
BATCH_SIZE = 10
EPOCH = 100
EPOCH_LOSSES = []
SHUFFLED_INDECES = []
# ============================== END ============================== #

if __name__ == '__main__':
Expand All @@ -43,12 +44,17 @@
BATCH = int(len(keys) / BATCH_SIZE)

def next_batch():
global buff
global buff, BATCH_SIZE ,SHUFFLED_INDECES
mini_batch = []
actual_data = []
indicies = np.random.choice(len(keys), BATCH_SIZE)

for idx in indicies:
if 0 == len(SHUFFLED_INDECES):
SHUFFLED_INDECES = list(np.random.permutation(len(keys)))

indices = SHUFFLED_INDECES[:min(BATCH_SIZE, len(SHUFFLED_INDECES))]
del SHUFFLED_INDECES[:min(BATCH_SIZE, len(SHUFFLED_INDECES))]

for idx in indices:
# make images mini batch

img = load_image('voc2007/'+keys[idx])
Expand Down Expand Up @@ -103,12 +109,14 @@ def draw_marker(image_name, save):

# saver.restore(sess, './checkpoints/params.ckpt')

SHUFFLED_INDECES = list(np.random.permutation(len(keys)))

print('\nSTART LEARNING')
print('==================== '+str(datetime.datetime.now())+' ====================')

for _ in range(5):
next_batch()

for ep in range(EPOCH):
BATCH_LOSSES = []
for ba in trange(BATCH):
Expand Down

0 comments on commit 1c66253

Please sign in to comment.