From 9b72d679e64d0abf93665a1c197a44ca03888d20 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 22 Nov 2023 11:49:52 -0500 Subject: [PATCH] Removing unnecessary params model_uses_forward and loss_uses_forward --- dwi_ml/models/main_models.py | 17 ++------- dwi_ml/models/projects/learn2track_model.py | 9 ----- dwi_ml/models/projects/transformer_models.py | 12 +++--- dwi_ml/tracking/tracker.py | 5 +-- dwi_ml/training/trainers.py | 37 +++++++------------ dwi_ml/training/with_generation/trainer.py | 5 +-- .../utils/data_and_models_for_tests.py | 18 +++------ 7 files changed, 29 insertions(+), 74 deletions(-) diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 14576222..f6f06914 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -69,10 +69,6 @@ def __init__(self, experiment_name: str, self.device = None - # To tell our trainer what to send to the forward / loss methods. - self.forward_uses_streamlines = False - self.loss_uses_streamlines = False - # To tell our batch loader how to resample streamlines during training # (should also be the step size during tractography). if step_size and compress_lines: @@ -205,10 +201,10 @@ def _load_state(cls, model_dir): return model_state - def forward(self, *inputs, **kw): + def forward(self, inputs, streamlines): raise NotImplementedError - def compute_loss(self, *model_outputs, **kw): + def compute_loss(self, model_outputs, target_streamlines): raise NotImplementedError @@ -369,10 +365,6 @@ def __init__(self, nb_previous_dirs: int = 0, self.prev_dirs_embedded_size = None self.prev_dirs_embedding = None - # To tell our trainer what to send to the forward / loss methods. - if nb_previous_dirs > 0: - self.forward_uses_streamlines = True - @staticmethod def add_args_model_with_pd(p): # CNN embedding makes no sense for previous dir @@ -411,7 +403,7 @@ def params_for_checkpoint(self): }) return p - def forward(self, inputs, target_streamlines: List[torch.tensor], **kw): + def forward(self, inputs, target_streamlines: List[torch.tensor]): """ Params ------ @@ -710,9 +702,6 @@ def __init__(self, dg_key: str = 'cosine-regression', raise ValueError("Direction getter choice not understood: {}" .format(self.positional_encoding_key)) - # To tell our trainer what to send to the forward / loss methods. - self.loss_uses_streamlines = True - def set_context(self, context): assert context in ['training', 'tracking', 'visu'] self._context = context diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index 68dba721..00d57332 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -188,15 +188,6 @@ def __init__(self, experiment_name, # 4. Direction getter: self.instantiate_direction_getter(self.rnn_model.output_size) - # If multiple inheritance goes well, these params should be set - # correctly - if nb_previous_dirs > 0: - assert self.forward_uses_streamlines - assert self.loss_uses_streamlines - - if self.start_from_copy_prev: - self.forward_uses_streamlines = True - def set_context(self, context): assert context in ['training', 'validation', 'tracking', 'visu', 'preparing_backward'] diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index eb1a418a..7986743f 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -224,8 +224,6 @@ def __init__(self, # the nb of features. self.instantiate_direction_getter(self.d_model) - assert self.loss_uses_streamlines - @property def d_model(self): raise NotImplementedError @@ -367,9 +365,11 @@ def forward(self, inputs: List[torch.tensor], raise ValueError("Please set context before usage.") # ----------- Checks - if self.forward_uses_streamlines: - # Reminder. In all cases, len(each input) == len(each streamline). - # Correct interpolation and management of points should be done before. + if input_streamlines is not None: + # If streamlines are necessary (depending on child class): + # In all cases, len(each input) == len(each streamline). + # Correct interpolation and management of points should be done + # before. assert np.all([len(i) == len(s) for i, s in zip(inputs, input_streamlines)]) @@ -618,8 +618,6 @@ def __init__(self, self.embedding_layer_t = cls_t(self.target_features, self.target_embedded_size) - self.forward_uses_streamlines = True - @property def params_for_checkpoint(self): """ diff --git a/dwi_ml/tracking/tracker.py b/dwi_ml/tracking/tracker.py index 9f0bd582..eb6894a3 100644 --- a/dwi_ml/tracking/tracker.py +++ b/dwi_ml/tracking/tracker.py @@ -494,10 +494,7 @@ def _prepare_inputs_at_pos(self, last_pos): def _call_model_forward(self, inputs, lines): with self.grad_context: - if self.model.forward_uses_streamlines: - model_outputs = self.model(inputs, lines) - else: - model_outputs = self.model(inputs) + model_outputs = self.model(inputs, lines) return model_outputs def update_memory_after_removing_lines(self, can_continue: np.ndarray, diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 2d989a7a..7d8b45a5 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1092,33 +1092,22 @@ def run_one_batch(self, data): batch_inputs = self.batch_loader.load_batch_inputs( streamlines_f, ids_per_subj) - # Possibly add noise to inputs here. logger.debug('*** Computing forward propagation') - if self.model.forward_uses_streamlines: - # Now possibly add noise to streamlines (training / valid) - streamlines_f = self.batch_loader.add_noise_streamlines_forward( - streamlines_f, self.device) - - # Possibly computing directions twice (during forward and loss) - # but ok, shouldn't be too heavy. Easier to deal with multiple - # projects' requirements by sending whole streamlines rather - # than only directions. - model_outputs = self.model(batch_inputs, streamlines_f) - del streamlines_f - else: - del streamlines_f - model_outputs = self.model(batch_inputs) + # todo Possibly add noise to inputs here. Not ready + # Now add noise to streamlines for the forward pass + # (batch loader will do it depending on training / valid) + streamlines_f = self.batch_loader.add_noise_streamlines_forward( + streamlines_f, self.device) + model_outputs = self.model(batch_inputs, streamlines_f) + del streamlines_f logger.debug('*** Computing loss') - if self.model.loss_uses_streamlines: - targets = self.batch_loader.add_noise_streamlines_loss( - targets, self.device) - - results = self.model.compute_loss(model_outputs, targets, - average_results=True) - else: - results = self.model.compute_loss(model_outputs, - average_results=True) + # Add noise to targets. + # (batch loader will do it depending on training / valid) + targets = self.batch_loader.add_noise_streamlines_loss(targets, + self.device) + results = self.model.compute_loss(model_outputs, targets, + average_results=True) if self.use_gpu: log_gpu_memory_usage(logger) diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/with_generation/trainer.py index ef7f76f1..d516b16e 100644 --- a/dwi_ml/training/with_generation/trainer.py +++ b/dwi_ml/training/with_generation/trainer.py @@ -364,10 +364,7 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): batch_inputs = self.batch_loader.load_batch_inputs( n_last_pos, ids_per_subj) - if self.model.forward_uses_streamlines: - model_outputs = self.model(batch_inputs, n_last_pos) - else: - model_outputs = self.model(batch_inputs) + model_outputs = self.model(batch_inputs, n_last_pos) next_dirs = self.model.get_tracking_directions( model_outputs, algo='det', eos_stopping_thresh=0.5) diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index 0467bf60..74c27fc4 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -85,11 +85,12 @@ def compute_loss(self, model_outputs, target_streamlines=None, else: return torch.zeros(n, device=self.device) - def forward(self, x: list): + def forward(self, inputs: list, streamlines): + # Not using streamlines. Pretending to use inputs. _ = self.fake_parameter regressed_dir = torch.as_tensor([1., 1., 1.]) - return [regressed_dir for _ in x] + return [regressed_dir for _ in inputs] class TrackingModelForTestWithPD(ModelWithPreviousDirections, @@ -119,12 +120,6 @@ def __init__(self, experiment_name: str = 'test', # For super MainModelForTracking: dg_key=dg_key, dg_args=dg_args) - # If multiple inheritance goes well, these params should be set - # correctly - if nb_previous_dirs > 0: - assert self.forward_uses_streamlines - assert self.loss_uses_streamlines - self.instantiate_direction_getter(dg_input_size) def compute_loss(self, model_outputs: List[torch.Tensor], @@ -137,7 +132,8 @@ def compute_loss(self, model_outputs: List[torch.Tensor], return self.direction_getter.compute_loss( model_outputs, target_streamlines, average_results) - def get_tracking_directions(self, regressed_dirs, algo): + def get_tracking_directions(self, regressed_dirs, algo, + eos_stopping_thresh): if algo == 'det': return regressed_dirs elif algo == 'prob': @@ -148,9 +144,7 @@ def get_tracking_directions(self, regressed_dirs, algo): raise ValueError("'algo' should be 'det' or 'prob'.") def forward(self, inputs: List[torch.tensor], - target_streamlines: List[torch.tensor] = None, - hidden_reccurent_states: tuple = None, - return_state: bool = False) -> List[torch.tensor]: + target_streamlines: List[torch.tensor]): # Previous dirs if self.nb_previous_dirs > 0: target_dirs = compute_directions(target_streamlines)