From 2165eda22bb17c663332d0260b67a74951b36672 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 1 Apr 2024 16:23:20 -0700 Subject: [PATCH 1/2] make garbage collection interval configurable --- olmo/config.py | 5 +++++ olmo/train.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/olmo/config.py b/olmo/config.py index e3dad46d7..96bc78b4d 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -987,6 +987,11 @@ class TrainConfig(BaseConfig): How often to log to the console. """ + gen1_gc_interval: int = 1 + """ + How often (in steps) to run generation 1 garbage collection. + """ + compile: Optional[CompilerConfig] = None """ Settings for compiling the model with ``torch.compile()``. diff --git a/olmo/train.py b/olmo/train.py index 4454786e3..baba207a9 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1155,7 +1155,8 @@ def on_trace_ready(p): break # Run generation 1 garbage collection. - gc.collect(1) + if self.global_step % self.cfg.gen1_gc_interval == 0: + gc.collect(1) # Python Profiler stuff # We do this now, at the bottom of this loop, so we capture the work of getting the next batch. From 989f799235a4cba9922a120254afed7e51036933 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 1 Apr 2024 16:50:12 -0700 Subject: [PATCH 2/2] Allow disabling manual gc --- olmo/config.py | 3 ++- olmo/train.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 96bc78b4d..6244994a3 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -987,9 +987,10 @@ class TrainConfig(BaseConfig): How often to log to the console. """ - gen1_gc_interval: int = 1 + gen1_gc_interval: Optional[int] = 1 """ How often (in steps) to run generation 1 garbage collection. + Set to ``None`` to use automatic garbage collection (i.e. we don't mess with it). """ compile: Optional[CompilerConfig] = None diff --git a/olmo/train.py b/olmo/train.py index baba207a9..71a45312e 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -948,7 +948,8 @@ def fit(self): self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close. # Disable automatic garbage collection, FSDP doesn't work well with it. - gc.disable() + if self.cfg.gen1_gc_interval is not None: + gc.disable() if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load: eval_metrics = self.eval() @@ -1155,7 +1156,7 @@ def on_trace_ready(p): break # Run generation 1 garbage collection. - if self.global_step % self.cfg.gen1_gc_interval == 0: + if self.cfg.gen1_gc_interval is not None and self.global_step % self.cfg.gen1_gc_interval == 0: gc.collect(1) # Python Profiler stuff