From 8358e0344a02e4e6b309b4daf4aea5de1021d254 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 29 Nov 2024 11:29:06 +0100 Subject: [PATCH 1/3] Reintroduced old validation loss calculation, let's see if this fixes something --- mala/network/trainer.py | 267 ++++++++++++++++++++++++++++++++++------ 1 file changed, 231 insertions(+), 36 deletions(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index b5eb0892..76cf5b55 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -675,46 +675,241 @@ def _validate_network(self, data_set_fractions, metrics): ) loader_id += 1 else: - with torch.no_grad(): - for snapshot_number in trange( - offset_snapshots, - number_of_snapshots + offset_snapshots, - desc="Validation", - disable=self.parameters_full.verbosity < 2, - ): - # Get optimal batch size and number of batches per snapshotss - grid_size = ( - self.data.parameters.snapshot_directories_list[ - snapshot_number - ].grid_size - ) + # If only the LDOS is in the validation metrics (as is the + # case for, e.g., distributed network trainings), we can + # use a faster (or at least better parallelizing) code + if ( + len(self.parameters.validation_metrics) == 1 + and self.parameters.validation_metrics[0] == "ldos" + ): + validation_loss_sum = torch.zeros( + 1, device=self.parameters._configuration["device"] + ) + with torch.no_grad(): + if self.parameters._configuration["gpu"]: + report_freq = self.parameters.training_log_interval + torch.cuda.synchronize( + self.parameters._configuration["device"] + ) + tsample = time.time() + batchid = 0 + for loader in data_loaders: + for x, y in loader: + x = x.to( + self.parameters._configuration[ + "device" + ], + non_blocking=True, + ) + y = y.to( + self.parameters._configuration[ + "device" + ], + non_blocking=True, + ) + + if ( + self.parameters.use_graphs + and self.validation_graph is None + ): + printout( + "Capturing CUDA graph for validation.", + min_verbosity=2, + ) + s = torch.cuda.Stream( + self.parameters._configuration[ + "device" + ] + ) + s.wait_stream( + torch.cuda.current_stream( + self.parameters._configuration[ + "device" + ] + ) + ) + # Warmup for graphs + with torch.cuda.stream(s): + for _ in range(20): + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + prediction = self.network( + x + ) + if ( + self.parameters_full.use_ddp + ): + loss = self.network.module.calculate_loss( + prediction, y + ) + else: + loss = self.network.calculate_loss( + prediction, y + ) + torch.cuda.current_stream( + self.parameters._configuration[ + "device" + ] + ).wait_stream(s) + + # Create static entry point tensors to graph + self.static_input_validation = ( + torch.empty_like(x) + ) + self.static_target_validation = ( + torch.empty_like(y) + ) + + # Capture graph + self.validation_graph = ( + torch.cuda.CUDAGraph() + ) + with torch.cuda.graph( + self.validation_graph + ): + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + self.static_prediction_validation = self.network( + self.static_input_validation + ) + if ( + self.parameters_full.use_ddp + ): + self.static_loss_validation = self.network.module.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) + else: + self.static_loss_validation = self.network.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) + + if self.validation_graph: + self.static_input_validation.copy_(x) + self.static_target_validation.copy_(y) + self.validation_graph.replay() + validation_loss_sum += ( + self.static_loss_validation + ) + else: + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + prediction = self.network(x) + if self.parameters_full.use_ddp: + loss = self.network.module.calculate_loss( + prediction, y + ) + else: + loss = self.network.calculate_loss( + prediction, y + ) + validation_loss_sum += loss + if ( + batchid != 0 + and (batchid + 1) % report_freq == 0 + ): + torch.cuda.synchronize( + self.parameters._configuration[ + "device" + ] + ) + sample_time = time.time() - tsample + avg_sample_time = ( + sample_time / report_freq + ) + avg_sample_tput = ( + report_freq + * x.shape[0] + / sample_time + ) + printout( + f"batch {batchid + 1}, " # /{total_samples}, " + f"validation avg time: {avg_sample_time} " + f"validation avg throughput: {avg_sample_tput}", + min_verbosity=2, + ) + tsample = time.time() + batchid += 1 + torch.cuda.synchronize( + self.parameters._configuration["device"] + ) + else: + batchid = 0 + for loader in data_loaders: + for x, y in loader: + x = x.to( + self.parameters._configuration[ + "device" + ] + ) + y = y.to( + self.parameters._configuration[ + "device" + ] + ) + prediction = self.network(x) + if self.parameters_full.use_ddp: + validation_loss_sum += ( + self.network.module.calculate_loss( + prediction, y + ).item() + ) + else: + validation_loss_sum += ( + self.network.calculate_loss( + prediction, y + ).item() + ) + batchid += 1 + + validation_loss = validation_loss_sum.item() / batchid + errors[data_set_type]["ldos"] = validation_loss - optimal_batch_size = self._correct_batch_size( - grid_size, self.parameters.mini_batch_size - ) - number_of_batches_per_snapshot = int( - grid_size / optimal_batch_size - ) + else: + with torch.no_grad(): + for snapshot_number in trange( + offset_snapshots, + number_of_snapshots + offset_snapshots, + desc="Validation", + disable=self.parameters_full.verbosity < 2, + ): + # Get optimal batch size and number of batches per snapshotss + grid_size = ( + self.data.parameters.snapshot_directories_list[ + snapshot_number + ].grid_size + ) - actual_outputs, predicted_outputs = ( - self._forward_entire_snapshot( - snapshot_number, - data_sets[0], - data_set_type[0:2], - number_of_batches_per_snapshot, - optimal_batch_size, + optimal_batch_size = self._correct_batch_size( + grid_size, self.parameters.mini_batch_size ) - ) - calculated_errors = self._calculate_errors( - actual_outputs, - predicted_outputs, - metrics, - snapshot_number, - ) - for metric in metrics: - errors[data_set_type][metric].append( - calculated_errors[metric] + number_of_batches_per_snapshot = int( + grid_size / optimal_batch_size + ) + + actual_outputs, predicted_outputs = ( + self._forward_entire_snapshot( + snapshot_number, + data_sets[0], + data_set_type[0:2], + number_of_batches_per_snapshot, + optimal_batch_size, + ) ) + calculated_errors = self._calculate_errors( + actual_outputs, + predicted_outputs, + metrics, + snapshot_number, + ) + for metric in metrics: + errors[data_set_type][metric].append( + calculated_errors[metric] + ) return errors def __prepare_to_train(self, optimizer_dict): From d3043e60cf5ee1aab5c8fa8e7ef133936aac55ee Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 29 Nov 2024 11:36:59 +0100 Subject: [PATCH 2/3] Forgot a renaming --- mala/network/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 76cf5b55..5407fdd7 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -710,7 +710,7 @@ def _validate_network(self, data_set_fractions, metrics): if ( self.parameters.use_graphs - and self.validation_graph is None + and self._validation_graph is None ): printout( "Capturing CUDA graph for validation.", @@ -762,11 +762,11 @@ def _validate_network(self, data_set_fractions, metrics): ) # Capture graph - self.validation_graph = ( + self._validation_graph = ( torch.cuda.CUDAGraph() ) with torch.cuda.graph( - self.validation_graph + self._validation_graph ): with torch.cuda.amp.autocast( enabled=self.parameters.use_mixed_precision @@ -787,10 +787,10 @@ def _validate_network(self, data_set_fractions, metrics): self.static_target_validation, ) - if self.validation_graph: + if self._validation_graph: self.static_input_validation.copy_(x) self.static_target_validation.copy_(y) - self.validation_graph.replay() + self._validation_graph.replay() validation_loss_sum += ( self.static_loss_validation ) From 20a06f2e0bc32348a24ea8e3bbd831028931d455 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 29 Nov 2024 11:46:07 +0100 Subject: [PATCH 3/3] Refactored code internally --- mala/network/trainer.py | 335 ++++++++++++++++++---------------------- 1 file changed, 150 insertions(+), 185 deletions(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 5407fdd7..ccd0ab70 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -678,196 +678,17 @@ def _validate_network(self, data_set_fractions, metrics): # If only the LDOS is in the validation metrics (as is the # case for, e.g., distributed network trainings), we can # use a faster (or at least better parallelizing) code + if ( len(self.parameters.validation_metrics) == 1 and self.parameters.validation_metrics[0] == "ldos" ): - validation_loss_sum = torch.zeros( - 1, device=self.parameters._configuration["device"] - ) - with torch.no_grad(): - if self.parameters._configuration["gpu"]: - report_freq = self.parameters.training_log_interval - torch.cuda.synchronize( - self.parameters._configuration["device"] - ) - tsample = time.time() - batchid = 0 - for loader in data_loaders: - for x, y in loader: - x = x.to( - self.parameters._configuration[ - "device" - ], - non_blocking=True, - ) - y = y.to( - self.parameters._configuration[ - "device" - ], - non_blocking=True, - ) - if ( - self.parameters.use_graphs - and self._validation_graph is None - ): - printout( - "Capturing CUDA graph for validation.", - min_verbosity=2, - ) - s = torch.cuda.Stream( - self.parameters._configuration[ - "device" - ] - ) - s.wait_stream( - torch.cuda.current_stream( - self.parameters._configuration[ - "device" - ] - ) - ) - # Warmup for graphs - with torch.cuda.stream(s): - for _ in range(20): - with torch.cuda.amp.autocast( - enabled=self.parameters.use_mixed_precision - ): - prediction = self.network( - x - ) - if ( - self.parameters_full.use_ddp - ): - loss = self.network.module.calculate_loss( - prediction, y - ) - else: - loss = self.network.calculate_loss( - prediction, y - ) - torch.cuda.current_stream( - self.parameters._configuration[ - "device" - ] - ).wait_stream(s) - - # Create static entry point tensors to graph - self.static_input_validation = ( - torch.empty_like(x) - ) - self.static_target_validation = ( - torch.empty_like(y) - ) - - # Capture graph - self._validation_graph = ( - torch.cuda.CUDAGraph() - ) - with torch.cuda.graph( - self._validation_graph - ): - with torch.cuda.amp.autocast( - enabled=self.parameters.use_mixed_precision - ): - self.static_prediction_validation = self.network( - self.static_input_validation - ) - if ( - self.parameters_full.use_ddp - ): - self.static_loss_validation = self.network.module.calculate_loss( - self.static_prediction_validation, - self.static_target_validation, - ) - else: - self.static_loss_validation = self.network.calculate_loss( - self.static_prediction_validation, - self.static_target_validation, - ) - - if self._validation_graph: - self.static_input_validation.copy_(x) - self.static_target_validation.copy_(y) - self._validation_graph.replay() - validation_loss_sum += ( - self.static_loss_validation - ) - else: - with torch.cuda.amp.autocast( - enabled=self.parameters.use_mixed_precision - ): - prediction = self.network(x) - if self.parameters_full.use_ddp: - loss = self.network.module.calculate_loss( - prediction, y - ) - else: - loss = self.network.calculate_loss( - prediction, y - ) - validation_loss_sum += loss - if ( - batchid != 0 - and (batchid + 1) % report_freq == 0 - ): - torch.cuda.synchronize( - self.parameters._configuration[ - "device" - ] - ) - sample_time = time.time() - tsample - avg_sample_time = ( - sample_time / report_freq - ) - avg_sample_tput = ( - report_freq - * x.shape[0] - / sample_time - ) - printout( - f"batch {batchid + 1}, " # /{total_samples}, " - f"validation avg time: {avg_sample_time} " - f"validation avg throughput: {avg_sample_tput}", - min_verbosity=2, - ) - tsample = time.time() - batchid += 1 - torch.cuda.synchronize( - self.parameters._configuration["device"] - ) - else: - batchid = 0 - for loader in data_loaders: - for x, y in loader: - x = x.to( - self.parameters._configuration[ - "device" - ] - ) - y = y.to( - self.parameters._configuration[ - "device" - ] - ) - prediction = self.network(x) - if self.parameters_full.use_ddp: - validation_loss_sum += ( - self.network.module.calculate_loss( - prediction, y - ).item() - ) - else: - validation_loss_sum += ( - self.network.calculate_loss( - prediction, y - ).item() - ) - batchid += 1 - - validation_loss = validation_loss_sum.item() / batchid - errors[data_set_type]["ldos"] = validation_loss + errors[data_set_type]["ldos"] = ( + self.__calculate_validation_error_ldos_only( + data_loaders + ) + ) else: with torch.no_grad(): @@ -912,6 +733,150 @@ def _validate_network(self, data_set_fractions, metrics): ) return errors + def __calculate_validation_error_ldos_only(self, data_loaders): + validation_loss_sum = torch.zeros( + 1, device=self.parameters._configuration["device"] + ) + with torch.no_grad(): + if self.parameters._configuration["gpu"]: + report_freq = self.parameters.training_log_interval + torch.cuda.synchronize( + self.parameters._configuration["device"] + ) + tsample = time.time() + batchid = 0 + for loader in data_loaders: + for x, y in loader: + x = x.to( + self.parameters._configuration["device"], + non_blocking=True, + ) + y = y.to( + self.parameters._configuration["device"], + non_blocking=True, + ) + + if ( + self.parameters.use_graphs + and self._validation_graph is None + ): + printout( + "Capturing CUDA graph for validation.", + min_verbosity=2, + ) + s = torch.cuda.Stream( + self.parameters._configuration["device"] + ) + s.wait_stream( + torch.cuda.current_stream( + self.parameters._configuration["device"] + ) + ) + # Warmup for graphs + with torch.cuda.stream(s): + for _ in range(20): + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + prediction = self.network(x) + if self.parameters_full.use_ddp: + loss = self.network.module.calculate_loss( + prediction, y + ) + else: + loss = self.network.calculate_loss( + prediction, y + ) + torch.cuda.current_stream( + self.parameters._configuration["device"] + ).wait_stream(s) + + # Create static entry point tensors to graph + self.static_input_validation = torch.empty_like(x) + self.static_target_validation = torch.empty_like(y) + + # Capture graph + self._validation_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._validation_graph): + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + self.static_prediction_validation = ( + self.network( + self.static_input_validation + ) + ) + if self.parameters_full.use_ddp: + self.static_loss_validation = self.network.module.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) + else: + self.static_loss_validation = self.network.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) + + if self._validation_graph: + self.static_input_validation.copy_(x) + self.static_target_validation.copy_(y) + self._validation_graph.replay() + validation_loss_sum += self.static_loss_validation + else: + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + prediction = self.network(x) + if self.parameters_full.use_ddp: + loss = self.network.module.calculate_loss( + prediction, y + ) + else: + loss = self.network.calculate_loss( + prediction, y + ) + validation_loss_sum += loss + if batchid != 0 and (batchid + 1) % report_freq == 0: + torch.cuda.synchronize( + self.parameters._configuration["device"] + ) + sample_time = time.time() - tsample + avg_sample_time = sample_time / report_freq + avg_sample_tput = ( + report_freq * x.shape[0] / sample_time + ) + printout( + f"batch {batchid + 1}, " # /{total_samples}, " + f"validation avg time: {avg_sample_time} " + f"validation avg throughput: {avg_sample_tput}", + min_verbosity=2, + ) + tsample = time.time() + batchid += 1 + torch.cuda.synchronize( + self.parameters._configuration["device"] + ) + else: + batchid = 0 + for loader in data_loaders: + for x, y in loader: + x = x.to(self.parameters._configuration["device"]) + y = y.to(self.parameters._configuration["device"]) + prediction = self.network(x) + if self.parameters_full.use_ddp: + validation_loss_sum += ( + self.network.module.calculate_loss( + prediction, y + ).item() + ) + else: + validation_loss_sum += self.network.calculate_loss( + prediction, y + ).item() + batchid += 1 + + return validation_loss_sum.item() / batchid + def __prepare_to_train(self, optimizer_dict): """Prepare everything for training.""" # Configure keyword arguments for DataSampler.