Skip to content

Commit

Permalink
Add Python 3.12 support and add better error msg for missing tf_keras
Browse files Browse the repository at this point in the history
With TF 2.16, users of TFP-on-TF must install `tf-keras` in addition to `tensorflow` -- so this change adds a custom error message if `tf-keras` (or `tf-keras-nightly`) is not installed.

PiperOrigin-RevId: 614280621
  • Loading branch information
jburnim authored and tensorflower-gardener committed Mar 9, 2024
1 parent 9a14b9b commit 988f023
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.12]
steps:
- name: Checkout
uses: actions/checkout@v1
Expand All @@ -38,7 +38,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.12]
shard: [0, 1, 2, 3, 4]
env:
SHARD: ${{ matrix.shard }}
Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
else:
TFDS_PACKAGE = 'tfds-nightly'

if release:
TF_PACKAGE = 'tensorflow >= 2.15'
KERAS_PACKAGE = 'tf-keras >= 2.15'
else:
TF_PACKAGE = 'tf-nightly'
KERAS_PACKAGE = 'tf-keras-nightly'


class BinaryDistribution(Distribution):
"""This class is needed in order to create OS specific wheels."""
Expand Down Expand Up @@ -91,6 +98,7 @@ def has_ext_modules(self):
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
Expand All @@ -101,6 +109,7 @@ def has_ext_modules(self):
keywords='tensorflow probability statistics bayesian machine learning',
extras_require={ # e.g. `pip install tfp-nightly[jax]`
'jax': ['jax', 'jaxlib'],
'tf': [TF_PACKAGE, KERAS_PACKAGE],
'tfds': [TFDS_PACKAGE],
}
)
16 changes: 14 additions & 2 deletions tensorflow_probability/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _validate_tf_environment(package):
inadequate.
"""
try:
import tensorflow.compat.v1 as tf
import tensorflow as tf
except (ImportError, ModuleNotFoundError):
# Print more informative error message, then reraise.
print('\n\nFailed to import TensorFlow. Please note that TensorFlow is not '
Expand All @@ -51,7 +51,7 @@ def _validate_tf_environment(package):
#
# Update this whenever we need to depend on a newer TensorFlow release.
#
required_tensorflow_version = '2.14'
required_tensorflow_version = '2.15'
# required_tensorflow_version = '1.15' # Needed internally -- DisableOnExport

if (distutils.version.LooseVersion(tf.__version__) <
Expand All @@ -74,6 +74,18 @@ def _validate_tf_environment(package):
'For more detail, see https://github.com/tensorflow/community/pull/287.'
)

if required_tensorflow_version[0] == '2':
try:
import tf_keras # pylint: disable=unused-import
except (ImportError, ModuleNotFoundError):
# Print more informative error message, then reraise.
print('\n\nFailed to import TF-Keras. Please note that TF-Keras is not '
'installed by default when you install TensorFlow Probability. '
'This is so that JAX-only users do not have to install TensorFlow '
'or TF-Keras. To use TensorFlow Probability with TensorFlow, '
'please install the tf-keras or tf-keras-nightly package.\n\n')
raise


# Declare these explicitly to appease pytype, which otherwise misses them,
# presumably due to lazy loading.
Expand Down

0 comments on commit 988f023

Please sign in to comment.