From 77391c5553ef171d54234287a88d8d8a45e790d5 Mon Sep 17 00:00:00 2001 From: Vivian Wu Date: Tue, 20 Aug 2024 23:17:32 +0000 Subject: [PATCH] edit jax cache variables --- MaxText/pyconfig.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index fe10f0eb1..d78961aba 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -279,10 +279,10 @@ def __init__(self, argv: list[str], **kwargs): compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"])) if raw_keys["jax_persistent_cache_min_entry_size_bytes"]: - jax.config.update("jax_persistent_cache_min_entry_size_bytes", os.path.expanduser(raw_keys["jax_persistent_cache_min_entry_size_bytes"])) + jax.config.update("jax_persistent_cache_min_entry_size_bytes", raw_keys["jax_persistent_cache_min_entry_size_bytes"]) if raw_keys["jax_persistent_cache_min_compile_time_secs"]: - jax.config.update("jax_persistent_cache_min_compile_time_secs", os.path.expanduser(raw_keys["jax_persistent_cache_min_compile_time_secs"])) + jax.config.update("jax_persistent_cache_min_compile_time_secs", raw_keys["jax_persistent_cache_min_compile_time_secs"]) if raw_keys["model_name"] == "gpt3-175b": _HyperParameters.configure_gpt3_task(raw_keys)