Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling input_size with multi-input #166

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ahmedhshahin
Copy link

@ahmedhshahin ahmedhshahin commented Mar 11, 2021

While dealing with multi-input, the current implementation calculates the number of elements in input as the product of all dimensions in all inputs, I believe this is not accurate.
For example, if we have input1 with dimensions [1,5,5] and input2 with dimensions [1,10,10]:
Current implementation: number of elements = 1 * 5 * 5 * 1 * 10 * 10 = 2500 elements
Where it should be: number of elements = (1 * 5 * 5) + (1 * 10 * 10) = 125 elements
As they are two separate inputs.

@scratchmex
Copy link

scratchmex commented Sep 4, 2021

When the inputs are different length, for example: [(1, 28, 28), (1,)]; the nowadays implementation throws TypeError: can't multiply sequence by non-int of type 'tuple' and a warning

numpy\core\fromnumeric.py:87: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray

I think this PR solves this.

@sksq96

@@ -98,7 +98,9 @@ def hook(module, input, output):
summary_str += line_new + "\n"

# assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(sum(input_size, ()))
# to handle the case of multi-input: prod(input1) + prod(input2) + ...
n_input_size = np.array([np.prod(i) for i in input_size]).sum() if isinstance(input_size, list) else np.prod(input_size)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to check if input_size is a list because that is normalized here:

# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]

Copy link
Author

@ahmedhshahin ahmedhshahin Sep 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this a long time ago, but I think I added this check so that the code works properly for the single-input cases too. If its a single-input, return the product of input dims, if its multi-input (list of inputs), return prod(input1) + prod(input2) + ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants