From d1167ca9740780169ca60f18121ff561e7d9cacb Mon Sep 17 00:00:00 2001 From: Johannes Gasteiger Date: Fri, 21 Jun 2024 04:52:24 -0700 Subject: [PATCH] Add option to use tf_data_service_config for validation PiperOrigin-RevId: 645350562 --- tensorflow_gnn/runner/orchestration.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tensorflow_gnn/runner/orchestration.py b/tensorflow_gnn/runner/orchestration.py index 8489e735..cc5bfad0 100644 --- a/tensorflow_gnn/runner/orchestration.py +++ b/tensorflow_gnn/runner/orchestration.py @@ -379,6 +379,7 @@ def run(*, train_padding: Optional[GraphTensorPadding] = None, valid_padding: Optional[GraphTensorPadding] = None, tf_data_service_config: Optional[TFDataServiceConfig] = None, + use_data_service_for_validation: bool = False, steps_per_execution: Optional[int] = None, run_eagerly: bool = False): """Runs training (and validation) of a model on task(s) with the given data. @@ -474,6 +475,11 @@ def run(*, runtime reducing input bottlenecks for model training. Particularly for training on accelerators consider enabling it. For more info please see: https://www.tensorflow.org/api_docs/python/tf/data/experimental/service. + use_data_service_for_validation: Whether to use tf.data service for + validation, in addition to training. Use with caution! Many ShardingPolicy + values do not visit every sample exactly once, which is critical for + validation. Increasing `validation_freq` of the trainer is another way to + reduce the fraction of time spent on validation. steps_per_execution: The number of batches to run during each training iteration. If not set, for TPU strategy default to 100 and to `None` otherwise. @@ -561,11 +567,19 @@ def apply_fn( tf_data_service_config) if validate: + # TFTrainer doesn't support using a different tf_data_service_config for + # validation than for training (see b/346691297#comment5). We therefore use + # the same config for both or None. + valid_tf_data_service_config = ( + tf_data_service_config if use_data_service_for_validation else None + ) valid_ds_provider = _WrappedDatasetProvider( valid_apply_fn, valid_ds_provider, drop_remainder, - global_batch_size) + global_batch_size, + valid_tf_data_service_config, + ) def adapted_model_fn(): xs, *_ = preprocess_model.output