diff --git a/discussion/adaptive_malt/adaptive_malt.py b/discussion/adaptive_malt/adaptive_malt.py index 3952b04d09..7720e96470 100644 --- a/discussion/adaptive_malt/adaptive_malt.py +++ b/discussion/adaptive_malt/adaptive_malt.py @@ -1085,7 +1085,7 @@ def meads_step(meads_state: MeadsState, def refold(x, perm): return x.reshape((num_chains,) + x.shape[2:])[perm].reshape(x.shape) - phmc_state = jax.tree_map(functools.partial(refold, perm=perm), phmc_state) + phmc_state = jax.tree.map(functools.partial(refold, perm=perm), phmc_state) if vector_step_size is None: vector_step_size = phmc_state.state.std(1, keepdims=True) @@ -1135,7 +1135,7 @@ def rejoin_folds(updated, original): ], 0), fold_to_skip, 0) active_fold_state, phmc_extra = fun_mc.prefab.persistent_hamiltonian_monte_carlo_step( - jax.tree_map(select_folds, phmc_state), + jax.tree.map(select_folds, phmc_state), target_log_prob_fn=target_log_prob_fn, step_size=select_folds(scalar_step_size[:, jnp.newaxis, jnp.newaxis] * rolled_vector_step_size), @@ -1143,10 +1143,10 @@ def rejoin_folds(updated, original): noise_fraction=select_folds(noise_fraction)[:, jnp.newaxis, jnp.newaxis], mh_drift=select_folds(mh_drift)[:, jnp.newaxis], seed=phmc_seed) - phmc_state = jax.tree_map(rejoin_folds, active_fold_state, phmc_state) + phmc_state = jax.tree.map(rejoin_folds, active_fold_state, phmc_state) # Revert the ordering of the walkers. - phmc_state = jax.tree_map(functools.partial(refold, perm=unperm), phmc_state) + phmc_state = jax.tree.map(functools.partial(refold, perm=unperm), phmc_state) meads_state = MeadsState( phmc_state=phmc_state, @@ -1838,7 +1838,7 @@ def run_grid_element(mean_trajectory_length: jnp.ndarray, for i in range(num_replicas): with utils.delete_device_buffers(): res.append( - jax.tree_map( + jax.tree.map( np.array, _run_grid_element_impl( seed=jax.random.fold_in(seed, i), @@ -1853,7 +1853,7 @@ def run_grid_element(mean_trajectory_length: jnp.ndarray, jitter_style=jitter_style, target_accept_prob=target_accept_prob, ))) - res = jax.tree_map(lambda *x: np.stack(x, 0), *res) + res = jax.tree.map(lambda *x: np.stack(x, 0), *res) res['mean_trajectory_length'] = mean_trajectory_length res['damping'] = damping @@ -1988,7 +1988,7 @@ def run_trial( for i in range(num_replicas): with utils.delete_device_buffers(): res.append( - jax.tree_map( + jax.tree.map( np.array, _run_trial_impl( seed=jax.random.fold_in(seed, i), @@ -2006,5 +2006,5 @@ def run_trial( trajectory_length_adaptation_rate_decay=trajectory_length_adaptation_rate_decay, save_warmup=save_warmup, ))) - res = jax.tree_map(lambda *x: np.stack(x, 0), *res) + res = jax.tree.map(lambda *x: np.stack(x, 0), *res) return res diff --git a/discussion/meads/meads.ipynb b/discussion/meads/meads.ipynb index 8067fac5da..db5660883e 100644 --- a/discussion/meads/meads.ipynb +++ b/discussion/meads/meads.ipynb @@ -233,7 +233,7 @@ " unperm = jnp.eye(num_chains)[perm].argmax(0)\n", " def refold(x, perm):\n", " return x.reshape((num_chains,) + x.shape[2:])[perm].reshape(x.shape)\n", - " phmc_state = jax.tree_map(functools.partial(refold, perm=perm), phmc_state)\n", + " phmc_state = jax.tree.map(functools.partial(refold, perm=perm), phmc_state)\n", "\n", " if diagonal_preconditioning:\n", " scale_estimates = phmc_state.state.std(1, keepdims=True)\n", @@ -274,7 +274,7 @@ " fold_to_skip, 0)\n", "\n", " active_fold_state, phmc_extra = fun_mc.prefab.persistent_hamiltonian_monte_carlo_step(\n", - " jax.tree_map(select_folds, phmc_state),\n", + " jax.tree.map(select_folds, phmc_state),\n", " target_log_prob_fn=target_log_prob_fn,\n", " step_size=select_folds(step_size[:, jnp.newaxis, jnp.newaxis] *\n", " rolled_scale_estimates),\n", @@ -285,7 +285,7 @@ " phmc_state = jax.tree_multimap(rejoin_folds, active_fold_state, phmc_state)\n", "\n", " # Revert the ordering of the walkers.\n", - " phmc_state = jax.tree_map(functools.partial(refold, perm=unperm), phmc_state)\n", + " phmc_state = jax.tree.map(functools.partial(refold, perm=unperm), phmc_state)\n", "\n", " traced = {\n", " 'z_chain': phmc_state.state,\n", @@ -315,7 +315,7 @@ " @jit\n", " def update_step(x, adam_state):\n", " def g_fn(x):\n", - " return jax.tree_map(lambda x: -x, value_and_grad(target_log_prob_fn)(x))\n", + " return jax.tree.map(lambda x: -x, value_and_grad(target_log_prob_fn)(x))\n", " tlp, g = g_fn(x)\n", " updates, adam_state = optimizer.update(g, adam_state)\n", " return optax.apply_updates(x, updates), adam_state, tlp\n", diff --git a/spinoffs/autobnn/autobnn/kernels_test.py b/spinoffs/autobnn/autobnn/kernels_test.py index 550d51b742..b3b03e5986 100644 --- a/spinoffs/autobnn/autobnn/kernels_test.py +++ b/spinoffs/autobnn/autobnn/kernels_test.py @@ -45,7 +45,7 @@ def get_bnn_and_params(self): linear_bnn = kernels.OneLayerBNN(width=50) seed = jax.random.PRNGKey(0) init_params = linear_bnn.init(seed, x_train) - constant_params = jax.tree_map( + constant_params = jax.tree.map( lambda x: jnp.full(x.shape, 0.1), init_params) constant_params['params']['noise_scale'] = jnp.array([0.005 ** 0.5]) return linear_bnn, constant_params, x_train, y_train diff --git a/spinoffs/autobnn/autobnn/training_util.py b/spinoffs/autobnn/autobnn/training_util.py index 55134bad58..295cd7e659 100644 --- a/spinoffs/autobnn/autobnn/training_util.py +++ b/spinoffs/autobnn/autobnn/training_util.py @@ -117,7 +117,7 @@ def _init(rand_seed): initial_state = jax.vmap(_init)(jax.random.split(seed, num_particles)) # It is okay to reuse the initial_state[0] as the test point, as Bayeux # only uses it to figure out the treedef. - test_point = jax.tree_map(lambda t: t[0], initial_state) + test_point = jax.tree.map(lambda t: t[0], initial_state) if for_vi: @@ -127,7 +127,7 @@ def log_density(params, *, seed=None): # of size [1] to the start, so we undo all of that. del seed return net.log_prob( - {'params': jax.tree_map(lambda x: x[0, ...], params)}, + {'params': jax.tree.map(lambda x: x[0, ...], params)}, data=x_train, observations=y_train) @@ -189,9 +189,9 @@ def _filter_stuck_chains(params): halfway_to_zero = -0.5 * stds_mu / stds_scale unstuck = jnp.where(z_scores > halfway_to_zero)[0] if unstuck.shape[0] > 2: - return jax.tree_map(lambda x: x[unstuck], params) + return jax.tree.map(lambda x: x[unstuck], params) best_two = jnp.argsort(stds)[-2:] - return jax.tree_map(lambda x: x[best_two], params) + return jax.tree.map(lambda x: x[best_two], params) @jax.named_call @@ -214,7 +214,7 @@ def fit_bnn_vi( seed=vi_seed, **vi_kwargs) params = surrogate_dist.sample(seed=draw_seed, sample_shape=num_draws) - params = jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), params) + params = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), params) return params, {'loss': loss} @@ -241,7 +241,7 @@ def fit_bnn_mcmc( # is the easiest way to determine where "stuck chains" occur, and it is # nice to return parameters with a single batch dimension. params = _filter_stuck_chains(params) - params = jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), params) + params = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), params) return params, {'noise_scale': params['params'].get('noise_scale', None)} @@ -430,6 +430,6 @@ def debatchify_params(params: PyTree) -> List[Dict[str, Any]]: """Nested dict of rank n tensors -> a list of nested dicts of rank n-1's.""" n = get_params_batch_length(params) def get_item(i): - return jax.tree_map(lambda x: x[i, ...], params) + return jax.tree.map(lambda x: x[i, ...], params) return [get_item(i) for i in range(n)] diff --git a/spinoffs/autobnn/autobnn/util.py b/spinoffs/autobnn/autobnn/util.py index b244221071..3491b0ccda 100644 --- a/spinoffs/autobnn/autobnn/util.py +++ b/spinoffs/autobnn/autobnn/util.py @@ -27,24 +27,24 @@ def make_transforms( net: bnn.BNN, ) -> Tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]: """Returns unconstraining bijectors for all variables in the BNN.""" - jb = jax.tree_map( + jb = jax.tree.map( lambda x: x.experimental_default_event_space_bijector(), net.get_all_distributions(), is_leaf=lambda x: isinstance(x, distribution_lib.Distribution), ) def transform(params): - return {'params': jax.tree_map(lambda p, b: b(p), params['params'], jb)} + return {'params': jax.tree.map(lambda p, b: b(p), params['params'], jb)} def inverse_transform(params): return { - 'params': jax.tree_map(lambda p, b: b.inverse(p), params['params'], jb) + 'params': jax.tree.map(lambda p, b: b.inverse(p), params['params'], jb) } def inverse_log_det_jacobian(params): return jax.tree_util.tree_reduce( lambda a, b: a + b, - jax.tree_map( + jax.tree.map( lambda p, b: jnp.sum(b.inverse_log_det_jacobian(p)), params['params'], jb, diff --git a/spinoffs/inference_gym/notebooks/inference_gym_tutorial.ipynb b/spinoffs/inference_gym/notebooks/inference_gym_tutorial.ipynb index 578dd2bdc0..cf541af92e 100644 --- a/spinoffs/inference_gym/notebooks/inference_gym_tutorial.ipynb +++ b/spinoffs/inference_gym/notebooks/inference_gym_tutorial.ipynb @@ -184,7 +184,7 @@ " return targets\n", "\n", "def get_num_latents(target):\n", - " return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0]))))" + " return int(sum(map(np.prod, list(jax.tree.flatten(target.event_shape)[0]))))" ] }, { diff --git a/tensorflow_probability/examples/jupyter_notebooks/Distributed_Inference_with_JAX.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Distributed_Inference_with_JAX.ipynb index 8d93bad05e..a3464c20d1 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Distributed_Inference_with_JAX.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Distributed_Inference_with_JAX.ipynb @@ -1187,7 +1187,7 @@ " x = x.reshape((jax.device_count(), -1, *x.shape[1:]))\n", " return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices\n", "\n", - "shard = functools.partial(jax.tree_map, shard_value)" + "shard = functools.partial(jax.tree.map, shard_value)" ] }, { @@ -1322,7 +1322,7 @@ "source": [ "%%time\n", "output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))\n", - "jax.tree_map(lambda x: x.block_until_ready(), output)" + "jax.tree.map(lambda x: x.block_until_ready(), output)" ] }, { @@ -1357,7 +1357,7 @@ "source": [ "%%time\n", "states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))\n", - "jax.tree_map(lambda x: x.block_until_ready(), trace)" + "jax.tree.map(lambda x: x.block_until_ready(), trace)" ] }, { @@ -1879,7 +1879,7 @@ "%%time\n", "run = make_run(axis_name='data')\n", "output = run(random.PRNGKey(0), sharded_watch_matrix)\n", - "jax.tree_map(lambda x: x.block_until_ready(), output)" + "jax.tree.map(lambda x: x.block_until_ready(), output)" ] }, { @@ -1914,7 +1914,7 @@ "source": [ "%%time\n", "states, trace = run(random.PRNGKey(0), sharded_watch_matrix)\n", - "jax.tree_map(lambda x: x.block_until_ready(), trace)" + "jax.tree.map(lambda x: x.block_until_ready(), trace)" ] }, { @@ -2050,7 +2050,7 @@ " already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1])\n", " for i in range(500):\n", " for j in range(2):\n", - " sample = jax.tree_map(lambda x: x[i, j], samples)\n", + " sample = jax.tree.map(lambda x: x[i, j], samples)\n", " ranking = recommend(sample, user_id)\n", " for movie_id in ranking:\n", " if int(movie_id) not in already_watched:\n", diff --git a/tensorflow_probability/python/experimental/fastgp/mbcg.py b/tensorflow_probability/python/experimental/fastgp/mbcg.py index 8a52da62d3..28b43d9f3c 100644 --- a/tensorflow_probability/python/experimental/fastgp/mbcg.py +++ b/tensorflow_probability/python/experimental/fastgp/mbcg.py @@ -155,7 +155,7 @@ def loop_body(carry, _): new_off_diags = off_diags.at[:, j - 1].set(off_diag_update) # Only update if we are not within tolerance. - (preconditioned_errors, search_directions, alpha) = (jax.tree_map( + (preconditioned_errors, search_directions, alpha) = (jax.tree.map( lambda o, n: jnp.where(converged, o, n), (old_preconditioned_errors, old_search_directions, old_alpha), (preconditioned_errors, search_directions, safe_alpha)))