Skip to content

Commit

Permalink
edit jax cache variables
Browse files Browse the repository at this point in the history
  • Loading branch information
vivianrwu committed Aug 20, 2024
1 parent a2e2893 commit 77391c5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,10 @@ def __init__(self, argv: list[str], **kwargs):
compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"]))

if raw_keys["jax_persistent_cache_min_entry_size_bytes"]:
jax.config.update("jax_persistent_cache_min_entry_size_bytes", os.path.expanduser(raw_keys["jax_persistent_cache_min_entry_size_bytes"]))
jax.config.update("jax_persistent_cache_min_entry_size_bytes", raw_keys["jax_persistent_cache_min_entry_size_bytes"])

if raw_keys["jax_persistent_cache_min_compile_time_secs"]:
jax.config.update("jax_persistent_cache_min_compile_time_secs", os.path.expanduser(raw_keys["jax_persistent_cache_min_compile_time_secs"]))
jax.config.update("jax_persistent_cache_min_compile_time_secs", raw_keys["jax_persistent_cache_min_compile_time_secs"])

if raw_keys["model_name"] == "gpt3-175b":
_HyperParameters.configure_gpt3_task(raw_keys)
Expand Down

0 comments on commit 77391c5

Please sign in to comment.