Skip to content

Commit

Permalink
Update TF parallel_for direct imports.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565418720
  • Loading branch information
jburnim authored and tensorflower-gardener committed Sep 14, 2023
1 parent f4836cc commit 6bba52b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.math import linalg

from tensorflow.python.ops import parallel_for # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import

tfl = tf.linalg

Expand Down Expand Up @@ -694,8 +694,8 @@ def _build_model_spec_kwargs_for_parallel_fns(self,
sample_shape=(),
pass_covariance=False):
"""Builds a dict of model parameters across all timesteps."""
kwargs = parallel_for.pfor(self._get_time_varying_kwargs,
self.num_timesteps)
kwargs = control_flow_ops.pfor(self._get_time_varying_kwargs,
self.num_timesteps)

# If given a sample shape, encode it as additional batch dimension(s).
# It is sufficient to do this for one parameter (we use initial_mean),
Expand Down Expand Up @@ -1371,7 +1371,7 @@ def pfor_body(t):
t=self.initial_step + t,
latent_mean=tf.gather(latent_means, t),
latent_cov=tf.gather(latent_covs, t))
observation_means, observation_covs = parallel_for.pfor(
observation_means, observation_covs = control_flow_ops.pfor(
pfor_body, self._num_timesteps)

observation_means = distribution_util.move_dimension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.internal.backend import numpy as nptf
from tensorflow_probability.python.internal.backend.numpy import functional_ops as np_pfor
from tensorflow.python.ops import parallel_for as tf_pfor # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops.parallel_for import control_flow_ops as tf_pfor_control_flow_ops # pylint: disable=g-direct-tensorflow-import


# Allows us to test low-level TF:XLA match.
Expand Down Expand Up @@ -1832,7 +1832,7 @@ def test_foldl_struct_in_alt_out(self):

def test_pfor(self):
self.assertAllEqual(
self.evaluate(tf_pfor.pfor(lambda x: tf.ones([]), 7)),
self.evaluate(tf_pfor_control_flow_ops.pfor(lambda x: tf.ones([]), 7)),
np_pfor.pfor(lambda x: nptf.ones([]), 7))

def test_pfor_with_closure(self):
Expand All @@ -1843,7 +1843,7 @@ def tf_fn(x):
def np_fn(x):
return nptf.gather(val, x)**2
self.assertAllEqual(
self.evaluate(tf_pfor.pfor(tf_fn, 7)),
self.evaluate(tf_pfor_control_flow_ops.pfor(tf_fn, 7)),
np_pfor.pfor(np_fn, 7))

def test_pfor_with_closure_multi_out(self):
Expand All @@ -1854,7 +1854,7 @@ def tf_fn(x):
def np_fn(x):
return nptf.gather(val, x)**2, nptf.gather(val, x)
self.assertAllEqual(
self.evaluate(tf_pfor.pfor(tf_fn, 7)),
self.evaluate(tf_pfor_control_flow_ops.pfor(tf_fn, 7)),
np_pfor.pfor(np_fn, 7))

def test_convert_variable_to_tensor(self):
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/internal/vectorization_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.util import SeedStream
from tensorflow.python.ops import parallel_for # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import

__all__ = [
Expand Down Expand Up @@ -99,7 +99,7 @@ def pfor_loop_body(i):
if static_n == 1:
draws = pfor_loop_body(0)
else:
draws = parallel_for.pfor(pfor_loop_body, n)
draws = control_flow_ops.pfor(pfor_loop_body, n)
return tf.nest.map_structure(unflatten, draws, expand_composites=True)

return iid_sample_fn
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_probability/substrates/meta/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@
'from tensorflow_probability.python.internal.backend.numpy.private',
'from tensorflow.python.ops.linalg':
'from tensorflow_probability.python.internal.backend.numpy.gen',
'from tensorflow.python.ops import parallel_for':
('from tensorflow.python.ops.parallel_for '
'import control_flow_ops'):
'from tensorflow_probability.python.internal.backend.numpy '
'import functional_ops as parallel_for',
'import functional_ops as control_flow_ops',
'from tensorflow.python.ops import control_flow_case':
'from tensorflow_probability.python.internal.backend.numpy '
'import control_flow as control_flow_case',
Expand Down

0 comments on commit 6bba52b

Please sign in to comment.