diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index b6ea09beb..b417832f3 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -171,6 +171,10 @@ local_checkpoint_period: 0 # Jax cache directory jax_cache_dir: "~/jax_cache" +# Jax persistent cache thresholds +jax_persistent_cache_min_entry_size_bytes: 0 +jax_persistent_cache_min_compile_time_secs: 1 + # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 7ff4442c1..fe10f0eb1 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -278,6 +278,12 @@ def __init__(self, argv: list[str], **kwargs): if raw_keys["jax_cache_dir"]: compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"])) + if raw_keys["jax_persistent_cache_min_entry_size_bytes"]: + jax.config.update("jax_persistent_cache_min_entry_size_bytes", os.path.expanduser(raw_keys["jax_persistent_cache_min_entry_size_bytes"])) + + if raw_keys["jax_persistent_cache_min_compile_time_secs"]: + jax.config.update("jax_persistent_cache_min_compile_time_secs", os.path.expanduser(raw_keys["jax_persistent_cache_min_compile_time_secs"])) + if raw_keys["model_name"] == "gpt3-175b": _HyperParameters.configure_gpt3_task(raw_keys) _HyperParameters.user_init(raw_keys)