diff --git a/examples/cifar10_tensorflow/launcher.py b/examples/cifar10_tensorflow/launcher.py index f0a7922..f6e7249 100644 --- a/examples/cifar10_tensorflow/launcher.py +++ b/examples/cifar10_tensorflow/launcher.py @@ -30,6 +30,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.') +flags.DEFINE_integer('gpus_per_node', 2, 'Number of GPUs per node.') def main(_): @@ -76,7 +77,10 @@ def main(_): experiment.add( xm.Job( executable=executable, - executor=xm_local.Vertex(tensorboard=tensorboard_capability), + executor=xm_local.Vertex( + tensorboard=tensorboard_capability, + requirements=xm.JobRequirements(t4=FLAGS.gpus_per_node), + ), args=hyperparameters, ) )