Skip to content

Commit

Permalink
adjusted decay rate for InverseTimeDecay LR scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
kyle-woodward committed May 9, 2024
1 parent 1645cc4 commit 15b4786
Show file tree
Hide file tree
Showing 15 changed files with 77 additions and 202 deletions.
4 changes: 2 additions & 2 deletions fao_models/model_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def main():
model_name = config_data["model_name"]
total_examples = config_data["total_examples"]
data_dir = config_data["data_dir"]
# val_data_dir = config_data["val_data_dir"]
test_split = config_data["test_split"]
val_split = config_data["val_split"]
seed = config_data["seed"]
epochs = config_data["epochs"]
learning_rate = config_data["learning_rate"]
decay_rate = config_data["decay_rate"]
batch_size = config_data["batch_size"]
buffer_size = config_data["buffer_size"]
optimizer = config_data["optimizer"]
Expand All @@ -80,7 +80,7 @@ def main():
lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(
initial_learning_rate=learning_rate,
decay_steps=steps_per_epoch * epochs,
decay_rate=1,
decay_rate=decay_rate,
staircase=False,
)
logger.info(
Expand Down
73 changes: 55 additions & 18 deletions fao_models/plotting/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,39 +59,76 @@ def __init__(self, lr, lr_decay, lr_decay_step, step=0, decay_fn='inverse_time_d
test_split = 0.2
batch_size = 64
steps_per_epoch = total_examples * test_split // batch_size
print('Steps per epoch:', steps_per_epoch)

#%%
# Inverse Time Decay
# for ITD decay rate needs to be more aggressive the more epochs we have
lr_decay = 1
# orginal decay rate constant - decay steps proportional to epochs
lr = 0.001
for epochs in [5,50,100]:
decay_steps = int(steps_per_epoch * epochs)
# lr_decay = base_decay_rate * (1/5*epochs)
lr_decay = 1.5
for epochs in [5,10,15,30,50,100]:
decay_steps = (steps_per_epoch * epochs)
lr_decay ** 1/5 if lr_decay > 1 else lr_decay
lr_vals = []
for i in range(1,decay_steps):
for i in range(1,int(decay_steps)):
lr_val = InverseTimeDecayScheduler(lr, lr_decay, decay_steps, step = i).lr_step
lr_vals.append(lr_val)
print(lr_vals[:10])
plt.plot(lr_vals, label=f'{epochs} epochs')
plt.plot(lr_vals, label=f'{epochs} epochs - decay_steps={decay_steps} - lr_decay={lr_decay}')
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.legend()
plt.title(f'Inverse Time Decay - lr={lr}, decay_rate={lr_decay}')
# %%
# Exponential Decay
lr_decay = 0.9
plt.title(f'Inverse Time Decay - lr={lr} - original')
#%%
# Inverse Time Decay
# for ITD decay rate needs to be more aggressive the more epochs we have
# ((steps_per_epoch * epochs) ** 1/5) - modify lr_Decay as well
lr = 0.001
for epochs in [5,50,100]:
decay_steps = int(steps_per_epoch * epochs)
lr_decay = 1.5
for epochs in [5,10,15,30,50,100]:
decay_steps = ((steps_per_epoch * epochs) ** 1/5)
lr_decay = lr_decay ** 1/5 if lr_decay > 1 else lr_decay
lr_vals = []
for i in range(1,decay_steps):
lr_val = ExponentialDecayScheduler(lr, lr_decay, decay_steps, step = i).lr_step
for i in range(1,int(decay_steps)):
lr_val = InverseTimeDecayScheduler(lr, lr_decay, decay_steps, step = i).lr_step
lr_vals.append(lr_val)
plt.plot(lr_vals, label=f'{epochs} epochs - decay_steps={decay_steps} - lr_decay={lr_decay}')
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.legend()
plt.title(f'Inverse Time Decay - lr={lr} - decay steps ** 1/5')
#%%
# Inverse Time Decay
# for ITD decay rate needs to be more aggressive the more epochs we have
# ((steps_per_epoch * epochs) ** 1/5) * 2
lr = 0.001
lr_decay = 1.5
for epochs in [5,10,15,30,50,100]:
decay_steps = ((steps_per_epoch * epochs) ** 1/5)*2
lr_decay = lr_decay ** 1/5 if lr_decay > 1 else lr_decay
lr_vals = []
for i in range(1,int(decay_steps)):
lr_val = InverseTimeDecayScheduler(lr, lr_decay, decay_steps, step = i).lr_step
lr_vals.append(lr_val)
print(lr_vals[:10])
plt.plot(lr_vals, label=f'{epochs} epochs')
plt.plot(lr_vals, label=f'{epochs} epochs - decay_steps={decay_steps} - lr_decay={lr_decay}')
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.legend()
plt.title(f'Exponential Decay - lr={lr}, decay_rate={lr_decay}')
plt.title(f'Inverse Time Decay - lr={lr} - (decay steps ** 1/5)*2')
# %%
# # Exponential Decay
# lr_decay = 0.9
# lr = 0.001
# for epochs in [5,50,100]:
# decay_steps = int(steps_per_epoch * epochs)
# lr_vals = []
# for i in range(1,int(decay_steps)):
# lr_val = ExponentialDecayScheduler(lr, lr_decay, decay_steps, step = i).lr_step
# lr_vals.append(lr_val)
# print(lr_vals[:10])
# plt.plot(lr_vals, label=f'{epochs} epochs')
# plt.xlabel('Step')
# plt.ylabel('Learning Rate')
# plt.legend()
# plt.title(f'Exponential Decay - lr={lr}, decay_rate={lr_decay}')
# # %%
16 changes: 0 additions & 16 deletions runc-mobilenetv3small-kdw_v1.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc-mobilenetv3small-kdw_v2.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc-mobilenetv3small-kdw_v3.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc-mobilenetv3small-kdw_v4.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc-resnet-diceloss-jjd.yml

This file was deleted.

20 changes: 20 additions & 0 deletions runc-resnet-epochs15-batch64-lr001-seed5-lrdecay5.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
experiment_name: resnet-epochs5-batch64-lr001-seed5-lrdecay5
model_name: resnet
data_dir: tfrecords/all
checkpoint: saved_models/resnet-epochs5-batch64-lr001-seed5-lrdecay5/best_model.h5
total_examples: 76992 # number of geotiffs not tfrecords
test_split: 0.2 # float or null
val_split: 0.1
seed: 5
decay_rate: 5

optimizer: adam
optimizer_use_lr_schedular: true
loss_function: binary_crossentropy

epochs: 15
learning_rate: 0.001
batch_size: 64
buffer_size: 76992

early_stopping_patience: 5 # null or int
16 changes: 0 additions & 16 deletions runc-resnet-jjd.yml

This file was deleted.

11 changes: 0 additions & 11 deletions runc1_kdw.yml

This file was deleted.

11 changes: 0 additions & 11 deletions runc2_kdw.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc3_kdw.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc4_kdw.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc5_kdw.yml

This file was deleted.

16 changes: 0 additions & 16 deletions runc6_kdw.yml

This file was deleted.

0 comments on commit 15b4786

Please sign in to comment.