Skip to content

Commit

Permalink
Merge pull request #218 from EmmaRenauld/remove_model_uses_streamlines
Browse files Browse the repository at this point in the history
Removing unnecessary params model_uses_forward and loss_uses_forward
  • Loading branch information
EmmaRenauld authored Nov 22, 2023
2 parents f470467 + 9b72d67 commit 6a4aa6d
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 74 deletions.
17 changes: 3 additions & 14 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ def __init__(self, experiment_name: str,

self.device = None

# To tell our trainer what to send to the forward / loss methods.
self.forward_uses_streamlines = False
self.loss_uses_streamlines = False

# To tell our batch loader how to resample streamlines during training
# (should also be the step size during tractography).
if step_size and compress_lines:
Expand Down Expand Up @@ -205,10 +201,10 @@ def _load_state(cls, model_dir):

return model_state

def forward(self, *inputs, **kw):
def forward(self, inputs, streamlines):
raise NotImplementedError

def compute_loss(self, *model_outputs, **kw):
def compute_loss(self, model_outputs, target_streamlines):
raise NotImplementedError


Expand Down Expand Up @@ -369,10 +365,6 @@ def __init__(self, nb_previous_dirs: int = 0,
self.prev_dirs_embedded_size = None
self.prev_dirs_embedding = None

# To tell our trainer what to send to the forward / loss methods.
if nb_previous_dirs > 0:
self.forward_uses_streamlines = True

@staticmethod
def add_args_model_with_pd(p):
# CNN embedding makes no sense for previous dir
Expand Down Expand Up @@ -411,7 +403,7 @@ def params_for_checkpoint(self):
})
return p

def forward(self, inputs, target_streamlines: List[torch.tensor], **kw):
def forward(self, inputs, target_streamlines: List[torch.tensor]):
"""
Params
------
Expand Down Expand Up @@ -710,9 +702,6 @@ def __init__(self, dg_key: str = 'cosine-regression',
raise ValueError("Direction getter choice not understood: {}"
.format(self.positional_encoding_key))

# To tell our trainer what to send to the forward / loss methods.
self.loss_uses_streamlines = True

def set_context(self, context):
assert context in ['training', 'tracking', 'visu']
self._context = context
Expand Down
9 changes: 0 additions & 9 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,6 @@ def __init__(self, experiment_name,
# 4. Direction getter:
self.instantiate_direction_getter(self.rnn_model.output_size)

# If multiple inheritance goes well, these params should be set
# correctly
if nb_previous_dirs > 0:
assert self.forward_uses_streamlines
assert self.loss_uses_streamlines

if self.start_from_copy_prev:
self.forward_uses_streamlines = True

def set_context(self, context):
assert context in ['training', 'validation', 'tracking', 'visu',
'preparing_backward']
Expand Down
12 changes: 5 additions & 7 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,6 @@ def __init__(self,
# the nb of features.
self.instantiate_direction_getter(self.d_model)

assert self.loss_uses_streamlines

@property
def d_model(self):
raise NotImplementedError
Expand Down Expand Up @@ -367,9 +365,11 @@ def forward(self, inputs: List[torch.tensor],
raise ValueError("Please set context before usage.")

# ----------- Checks
if self.forward_uses_streamlines:
# Reminder. In all cases, len(each input) == len(each streamline).
# Correct interpolation and management of points should be done before.
if input_streamlines is not None:
# If streamlines are necessary (depending on child class):
# In all cases, len(each input) == len(each streamline).
# Correct interpolation and management of points should be done
# before.
assert np.all([len(i) == len(s) for i, s in
zip(inputs, input_streamlines)])

Expand Down Expand Up @@ -618,8 +618,6 @@ def __init__(self,
self.embedding_layer_t = cls_t(self.target_features,
self.target_embedded_size)

self.forward_uses_streamlines = True

@property
def params_for_checkpoint(self):
"""
Expand Down
5 changes: 1 addition & 4 deletions dwi_ml/tracking/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,7 @@ def _prepare_inputs_at_pos(self, last_pos):

def _call_model_forward(self, inputs, lines):
with self.grad_context:
if self.model.forward_uses_streamlines:
model_outputs = self.model(inputs, lines)
else:
model_outputs = self.model(inputs)
model_outputs = self.model(inputs, lines)
return model_outputs

def update_memory_after_removing_lines(self, can_continue: np.ndarray,
Expand Down
37 changes: 13 additions & 24 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,33 +1092,22 @@ def run_one_batch(self, data):
batch_inputs = self.batch_loader.load_batch_inputs(
streamlines_f, ids_per_subj)

# Possibly add noise to inputs here.
logger.debug('*** Computing forward propagation')
if self.model.forward_uses_streamlines:
# Now possibly add noise to streamlines (training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)

# Possibly computing directions twice (during forward and loss)
# but ok, shouldn't be too heavy. Easier to deal with multiple
# projects' requirements by sending whole streamlines rather
# than only directions.
model_outputs = self.model(batch_inputs, streamlines_f)
del streamlines_f
else:
del streamlines_f
model_outputs = self.model(batch_inputs)
# todo Possibly add noise to inputs here. Not ready
# Now add noise to streamlines for the forward pass
# (batch loader will do it depending on training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)
model_outputs = self.model(batch_inputs, streamlines_f)
del streamlines_f

logger.debug('*** Computing loss')
if self.model.loss_uses_streamlines:
targets = self.batch_loader.add_noise_streamlines_loss(
targets, self.device)

results = self.model.compute_loss(model_outputs, targets,
average_results=True)
else:
results = self.model.compute_loss(model_outputs,
average_results=True)
# Add noise to targets.
# (batch loader will do it depending on training / valid)
targets = self.batch_loader.add_noise_streamlines_loss(targets,
self.device)
results = self.model.compute_loss(model_outputs, targets,
average_results=True)

if self.use_gpu:
log_gpu_memory_usage(logger)
Expand Down
5 changes: 1 addition & 4 deletions dwi_ml/training/with_generation/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,7 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):
batch_inputs = self.batch_loader.load_batch_inputs(
n_last_pos, ids_per_subj)

if self.model.forward_uses_streamlines:
model_outputs = self.model(batch_inputs, n_last_pos)
else:
model_outputs = self.model(batch_inputs)
model_outputs = self.model(batch_inputs, n_last_pos)

next_dirs = self.model.get_tracking_directions(
model_outputs, algo='det', eos_stopping_thresh=0.5)
Expand Down
18 changes: 6 additions & 12 deletions dwi_ml/unit_tests/utils/data_and_models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ def compute_loss(self, model_outputs, target_streamlines=None,
else:
return torch.zeros(n, device=self.device)

def forward(self, x: list):
def forward(self, inputs: list, streamlines):
# Not using streamlines. Pretending to use inputs.
_ = self.fake_parameter
regressed_dir = torch.as_tensor([1., 1., 1.])

return [regressed_dir for _ in x]
return [regressed_dir for _ in inputs]


class TrackingModelForTestWithPD(ModelWithPreviousDirections,
Expand Down Expand Up @@ -119,12 +120,6 @@ def __init__(self, experiment_name: str = 'test',
# For super MainModelForTracking:
dg_key=dg_key, dg_args=dg_args)

# If multiple inheritance goes well, these params should be set
# correctly
if nb_previous_dirs > 0:
assert self.forward_uses_streamlines
assert self.loss_uses_streamlines

self.instantiate_direction_getter(dg_input_size)

def compute_loss(self, model_outputs: List[torch.Tensor],
Expand All @@ -137,7 +132,8 @@ def compute_loss(self, model_outputs: List[torch.Tensor],
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)

def get_tracking_directions(self, regressed_dirs, algo):
def get_tracking_directions(self, regressed_dirs, algo,
eos_stopping_thresh):
if algo == 'det':
return regressed_dirs
elif algo == 'prob':
Expand All @@ -148,9 +144,7 @@ def get_tracking_directions(self, regressed_dirs, algo):
raise ValueError("'algo' should be 'det' or 'prob'.")

def forward(self, inputs: List[torch.tensor],
target_streamlines: List[torch.tensor] = None,
hidden_reccurent_states: tuple = None,
return_state: bool = False) -> List[torch.tensor]:
target_streamlines: List[torch.tensor]):
# Previous dirs
if self.nb_previous_dirs > 0:
target_dirs = compute_directions(target_streamlines)
Expand Down

0 comments on commit 6a4aa6d

Please sign in to comment.