Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weight initialization issue #4

Open
bryent111 opened this issue Jul 27, 2020 · 4 comments
Open

weight initialization issue #4

bryent111 opened this issue Jul 27, 2020 · 4 comments
Assignees
Labels
enhancement New feature or request question Further information is requested

Comments

@bryent111
Copy link

In cplxmodule.nn.modules conv.py, line 43, reset_parameters function, the initialization function is "init.cplx_kaiming_uniform_", rather than "cplx_trabelsi_independent_", while "init.cplx_uniform_independent_" is used for the "bias".
Furthermore, in cplxmodule, cplx.py, line 628, convnd function, the bias is used to be added to the output, rather than the weight.
Is that right? And how to initialize the conv1d weight with "init.cplx_uniform_independent_" as the paper?

@ivannz
Copy link
Owner

ivannz commented Jul 27, 2020

I am not sure what you mean in your comment about the bias on line 628 of cplx.py. The following test seems to show that it is being added correctly:

import torch
import torch.nn.functional as F
from cplxmodule import cplx

def randn(shape):
    return torch.randn(shape, dtype=torch.double)

sx = 32, 12, 31, 47
sw = 7, 12, 7, 11

x = cplx.Cplx(randn(sx), randn(sx))
w = cplx.Cplx(randn(sw), randn(sw))
b = cplx.Cplx(randn(sw[0]), randn(sw[0]))

# do the 2d convo manually
re = F.conv2d(x.real, w.real, bias=b.real) \
    - F.conv2d(x.imag, w.imag, bias=None)
im = F.conv2d(x.real, w.imag, bias=b.imag) \
    + F.conv2d(x.imag, w.real, bias=None)

# use the function from cplx
cc = cplx.conv2d(x, w, bias=b)

assert torch.allclose(cc.real, re) and torch.allclose(cc.imag, im)

@ivannz
Copy link
Owner

ivannz commented Jul 27, 2020

Yes, by default I decided to go with init.cplx_kaiming_uniform_ for complex-valued weights. In order to initialize your model with what was proposed by Trabelsi et al.(2018) you can use the .apply method of torch.nn.Module. For example,

import torch
import numpy as np

from cplxmodule import Cplx
from cplxmodule.nn import init, CplxLinear, CplxConv2d

def cplx_trabelsi_independent_(mod):
    if not hasattr(mod, 'weight'):
        return

    # Trabelsi orthogonal weight initializer
    if isinstance(mod.weight, Cplx):
        init.cplx_trabelsi_independent_(mod.weight)

# a model with some structure
module = torch.nn.ModuleDict({
    'linear': CplxLinear(11, 17),
    'conv': CplxConv2d(13, 19, 5),
}).double()

# standard torch functionality `module.apply`
module.apply(cplx_trabelsi_independent_)

# according to Trabelsi et al. (2018) the reshaped weight bust be an almost unitary matrix
w = module['conv'].weight
m = w.reshape(w.shape[:2].numel(), w.shape[2:].numel()).detach().numpy()
mHm = m.conjugate().T @ m

assert np.allclose(mHm, np.diag(mHm.diagonal()))

The initializer init.cplx_uniform_independent_ in fact simply resets the real and imaginary components separately to a uniform random sample from [0,1]. I admit that it has a somewhat confusing name and lacks clarifying documentation. I shall fix this in the next update.

@bryent111
Copy link
Author

Thanks a lot for replying so quickly, I like your project very much.
My question is just that why the weight is not initialized with init.cplx_trabelsi_independent_, and how to do this.
Now I know the weight is initialized with init.cplx_kaiming_uniform_ and the bias with init.cplx_uniform_independent_ by default, and I can use 'apply' method to initialize my model weight.
And one more question: why not initialize the weight with cplx_trabelsi_independent_(or cplx_trabelsi_standard_) by default as the paper?

@ivannz
Copy link
Owner

ivannz commented Aug 5, 2020

I did not use cplx_trabelsi_independent_ or cplx_trabelsi_standard_ by default for legacy reasons, back when I was conducting experiments for the paper and the project this package was written for.

You rightfully point out that according to the original research by Trabelsi et al. (2018) there is indeed no reason for NOT using the initialization proposed therein.

Unfortunately, I am currently fully engaged with writing and experimenting for my current research unrelated to complex-valued networks (this is why i took so long to reply). Could you, please, conduct some experiments and post the results here on some small toy dataset showing the benefits of the cplx_trabelsi_independent_ over cplx_trabelsi_standard_ and over the current default?

@ivannz ivannz reopened this Aug 5, 2020
@ivannz ivannz self-assigned this Aug 14, 2020
@ivannz ivannz added enhancement New feature or request question Further information is requested labels Aug 14, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants