Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights shape not validated against kernel, channels #2506

Open
BioTurboNick opened this issue Oct 25, 2024 · 4 comments
Open

Weights shape not validated against kernel, channels #2506

BioTurboNick opened this issue Oct 25, 2024 · 4 comments

Comments

@BioTurboNick
Copy link
Contributor

weights = Flux.kaiming_normal()(3, 3, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3,), 3 => 1, pad=1)  # 10 parameters

weights = Flux.kaiming_normal()(3, 3, 1, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3, 3), 1 => 1, pad=1)  # 10 parameters

I wanted to strictly specify the weight init for testing, but encountered this odd result. I think there should be validation to ensure that the weight shape matches the kernel size and input channels, and error if there is a mismatch.

@CarloLucibello
Copy link
Member

yes, definitely those sizes should be validated.

@mcabbott
Copy link
Member

You might be looking for Conv(weights; pad=(1,1))? I.e. there's a method which accepts weights::Array, for exactly this purpose. It does not take (3, 3), 1 => 1 since, as you note, these are implied by the array size.

The methods which accept an init function certainly assume size(init(s...)) == s. Maybe they can all be made to check somehow but it does seem a somewhat strange path.

The initialisation of the weight matrix is W = init(out, in), calling the function given to keyword init

@BioTurboNick
Copy link
Contributor Author

Ah, that's fair, thanks. Didn't think to look for a different method signature for this.

@mcabbott
Copy link
Member

The check if we do want one could look like this -- maybe not as messy as I pictured it being, at first:

function _sizecheck(f, sz::Integer...)
  W = f(sz...)
  size(W) == sz || error("bad size! (except more friendly)")
  W
end

function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
               init = glorot_uniform, bias = true)
  # Dense(init(out, in), bias, σ)  # current code
  Dense(_sizecheck(init, out, in), bias, σ)  # with new check
end

(When you pass an array bias = [1,2,3.] to layer constructors, I think it is always checked for size.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants