-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert outputs to dict #757
Conversation
batchflow/models/torch/base.py
Outdated
Each string defines a tensor to get and should be one of: | ||
- pre-defined tensors, which are `predictions`, `loss`, and `predictions_{i}` for multi-output models. | ||
- pre-defined operations, which are `softplus`, `sigmoid`, `sigmoid_uint8`, `sigmoid_int16`, | ||
`proba`, `labels`. Work only with len(predictions) == 1 | ||
- layer id, which describes how to access the layer through a series of `getattr` and `getitem` calls. | ||
Allows to get intermediate activations of a neural network. | ||
Each callable defines a function that should be applied to predictions. | ||
If outputs are dict, then keys are strings and they are considered as output_names. The values should be | ||
either callables, pre-defined tensors, pre-defined operations or layer id. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A tensor to get and should be one of:
- a string, which can be ...
- - ...
- - ...
- a callable, ...
- a sequence, where each item is one of the previous types. Result of this method is guaranteed to have the same order of elements;
- a dict, where each value is one of the previous types. Result of this method is a dictionary with the same keys and requested tensors as values.
batchflow/models/torch/base.py
Outdated
if 'predictions' in outputs_dict.values(): | ||
# in case there are multiple output_names with the same operation == `predictions`. Same for `loss` | ||
predictions_names_list = [output_name for output_name, operation in outputs_dict.items() \ | ||
if operation == 'predictions'] | ||
output_container.update({output_name: predictions for output_name in predictions_names_list}) | ||
if 'loss' in outputs_dict.values(): | ||
losses_names_list = [output_name for output_name, operation in outputs_dict.items() \ | ||
if operation == 'loss'] | ||
output_container.update({output_name: loss for output_name in losses_names_list}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for output_name, requested in output_dict.items():
if requested == 'prediction':
output_container[output_name] = predictions
elif requested == 'loss':
output_container[output_name] = loss
IMO, looks much simpler. Can be further reduced to 4 lines.
batchflow/models/torch/base.py
Outdated
elif isinstance(outputs, list): | ||
result = list(result.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about tuple
/ set
?
batchflow/models/torch/base.py
Outdated
def compute_outputs(self, predictions): | ||
""" Produce additional outputs, defined in the config, from `predictions`. | ||
Also adds a number of aliases to predicted tensors. | ||
def compute_outputs(self, predictions, operations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably, can move code for adding predictions
/ loss
here
batchflow/models/torch/base.py
Outdated
elif isinstance(operation, LayerHook): | ||
operation.close() | ||
result = operation.activation | ||
elif isinstance(operation, str) and re.match(r"predictions_[0-9]+", operation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either don't use regexp here or compile it once at the module / class level
batchflow/models/torch/base.py
Outdated
else: | ||
if isinstance(predictions, (tuple, list)) and not len(predictions) == 1: | ||
raise ValueError('Default operations can`t be applicable to multi output predictions.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't be applied
, maybe?
batchflow/models/torch/base.py
Outdated
""" Add the hooks to all outputs that look like a layer id. """ | ||
result = [] | ||
for output_name in outputs: | ||
""" Add the hooks to all outputs that look like a layer id. Also convert outputs to dict""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting
batchflow/models/torch/base.py
Outdated
elif callable(output): | ||
processed_outputs[output.__name__] = output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is that needed? callables are hashable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the naming, this PR looks great. If possible, make a better distinction between internal/external variables, and it is ready to merge.
Good job:)
batchflow/models/torch/base.py
Outdated
def compute_outputs(self, predictions): | ||
""" Produce additional outputs, defined in the config, from `predictions`. | ||
Also adds a number of aliases to predicted tensors. | ||
def compute_outputs(self, predictions, operations, targets=None, loss=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that renaming operations
to requested_outputs
would make things much easier to understand. Or maybe to outputs_dict
, to keep in line with the rest of the code
batchflow/models/torch/base.py
Outdated
elif targets is not None: | ||
targets = self.transfer_to_device(targets) | ||
loss = self.loss(predictions, targets) | ||
result[name] = loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about keeping this part in the _predict
method?
batchflow/models/torch/base.py
Outdated
else: | ||
raise ValueError(f'Unknown type of operation `{operation}`!') | ||
name = operation | ||
return result, name | ||
return result | ||
|
||
|
||
def prepare_outputs(self, outputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prepare_outputs
sounds very much alike to compute_outputs
, while the intent here is to `prepare user-passed argument "outputs" to a form that we internally use". Do you see any better names?
This PR makes it possible to work with outputs as dicts. The main points are: