diff --git a/tensorflow_asr/utils/env_util.py b/tensorflow_asr/utils/env_util.py index 55bca0c57..3854ce80a 100644 --- a/tensorflow_asr/utils/env_util.py +++ b/tensorflow_asr/utils/env_util.py @@ -112,7 +112,7 @@ def setup_strategy( available_devices = setup_devices(devices) if len(available_devices) == 1: return tf.distribute.get_strategy() - return tf.distribute.MirroredStrategy(devices=[d.name for d in available_devices]) + return tf.distribute.MultiWorkerMirroredStrategy() def has_devices(