Skip to content

Commit

Permalink
add jax persistent cache thresholds
Browse files Browse the repository at this point in the history
  • Loading branch information
vivianrwu committed Aug 20, 2024
1 parent 15966fa commit a2e2893
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
4 changes: 4 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ local_checkpoint_period: 0
# Jax cache directory
jax_cache_dir: "~/jax_cache"

# Jax persistent cache thresholds
jax_persistent_cache_min_entry_size_bytes: 0
jax_persistent_cache_min_compile_time_secs: 1

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

Expand Down
6 changes: 6 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ def __init__(self, argv: list[str], **kwargs):
if raw_keys["jax_cache_dir"]:
compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"]))

if raw_keys["jax_persistent_cache_min_entry_size_bytes"]:
jax.config.update("jax_persistent_cache_min_entry_size_bytes", os.path.expanduser(raw_keys["jax_persistent_cache_min_entry_size_bytes"]))

if raw_keys["jax_persistent_cache_min_compile_time_secs"]:
jax.config.update("jax_persistent_cache_min_compile_time_secs", os.path.expanduser(raw_keys["jax_persistent_cache_min_compile_time_secs"]))

if raw_keys["model_name"] == "gpt3-175b":
_HyperParameters.configure_gpt3_task(raw_keys)
_HyperParameters.user_init(raw_keys)
Expand Down

0 comments on commit a2e2893

Please sign in to comment.