Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- Enable / Disable showing settings #51

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions ranger21/ranger21.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
warmup_type="linear",
warmup_pct_default=0.22,
logging_active=True,
verbose=True
):

# todo - checks on incoming params
Expand All @@ -153,6 +154,7 @@ def __init__(

# core
self.logging = logging_active
self.verbose = verbose

# engine
self.use_madgrad = use_madgrad
Expand Down Expand Up @@ -294,8 +296,8 @@ def __init__(
engine = "AdamW" if not self.use_madgrad else "MadGrad"

# print out initial settings to make usage easier

self.show_settings()
if self.verbose:
self.show_settings()

def __setstate__(self, state):
super().__setstate__(state)
Expand Down Expand Up @@ -360,8 +362,8 @@ def show_settings(self):
# lookahead functions
def clear_cache(self):
"""clears the lookahead cached params """

print(f"clearing lookahead cache...")
if self.verbose:
print(f"clearing lookahead cache...")
for group in self.param_groups:
for p in group["params"]:
param_state = self.state[p]
Expand All @@ -373,7 +375,8 @@ def clear_cache(self):

if len(la_params):
param_state["lookahead_params"] = torch.zeros_like(p.data)
print(f"lookahead cache cleared")
if self.verbose:
print(f"lookahead cache cleared")

def clear_and_load_backup(self):
for group in self.param_groups:
Expand Down Expand Up @@ -449,7 +452,8 @@ def warmup_dampening(self, lr, step):
)

self.warmup_complete = True
print(f"\n** Ranger21 update = Warmup complete - lr set to {lr}\n")
if self.verbose:
print(f"\n** Ranger21 update = Warmup complete - lr set to {lr}\n")
return lr

if style == "linear":
Expand All @@ -472,9 +476,10 @@ def get_warm_down(self, lr, iteration):
if iteration > self.start_warm_down - 1:
# print when starting
if not self.warmdown_displayed:
print(
if self.verbose:
print(
f"\n** Ranger21 update: Warmdown starting now. Current iteration = {iteration}....\n"
)
)
self.warmdown_displayed = True

warmdown_iteration = (
Expand Down Expand Up @@ -697,9 +702,10 @@ def step(self, closure=None):
# we will run this first epoch only and then memoize
if not self.param_size:
self.param_size = param_size
print(f"params size saved")
print(f"total param groups = {i+1}")
print(f"total params in groups = {j+1}")
if self.verbose:
print(f"params size saved")
print(f"total param groups = {i+1}")
print(f"total params in groups = {j+1}")

if not self.param_size:
raise ValueError("failed to set param size")
Expand Down