-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
weight initialization issue #4
Comments
I am not sure what you mean in your comment about the bias on line 628 of import torch
import torch.nn.functional as F
from cplxmodule import cplx
def randn(shape):
return torch.randn(shape, dtype=torch.double)
sx = 32, 12, 31, 47
sw = 7, 12, 7, 11
x = cplx.Cplx(randn(sx), randn(sx))
w = cplx.Cplx(randn(sw), randn(sw))
b = cplx.Cplx(randn(sw[0]), randn(sw[0]))
# do the 2d convo manually
re = F.conv2d(x.real, w.real, bias=b.real) \
- F.conv2d(x.imag, w.imag, bias=None)
im = F.conv2d(x.real, w.imag, bias=b.imag) \
+ F.conv2d(x.imag, w.real, bias=None)
# use the function from cplx
cc = cplx.conv2d(x, w, bias=b)
assert torch.allclose(cc.real, re) and torch.allclose(cc.imag, im) |
Yes, by default I decided to go with import torch
import numpy as np
from cplxmodule import Cplx
from cplxmodule.nn import init, CplxLinear, CplxConv2d
def cplx_trabelsi_independent_(mod):
if not hasattr(mod, 'weight'):
return
# Trabelsi orthogonal weight initializer
if isinstance(mod.weight, Cplx):
init.cplx_trabelsi_independent_(mod.weight)
# a model with some structure
module = torch.nn.ModuleDict({
'linear': CplxLinear(11, 17),
'conv': CplxConv2d(13, 19, 5),
}).double()
# standard torch functionality `module.apply`
module.apply(cplx_trabelsi_independent_)
# according to Trabelsi et al. (2018) the reshaped weight bust be an almost unitary matrix
w = module['conv'].weight
m = w.reshape(w.shape[:2].numel(), w.shape[2:].numel()).detach().numpy()
mHm = m.conjugate().T @ m
assert np.allclose(mHm, np.diag(mHm.diagonal())) The initializer |
Thanks a lot for replying so quickly, I like your project very much. |
I did not use You rightfully point out that according to the original research by Trabelsi et al. (2018) there is indeed no reason for NOT using the initialization proposed therein. Unfortunately, I am currently fully engaged with writing and experimenting for my current research unrelated to complex-valued networks (this is why i took so long to reply). Could you, please, conduct some experiments and post the results here on some small toy dataset showing the benefits of the |
In cplxmodule.nn.modules conv.py, line 43, reset_parameters function, the initialization function is "init.cplx_kaiming_uniform_", rather than "cplx_trabelsi_independent_", while "init.cplx_uniform_independent_" is used for the "bias".
Furthermore, in cplxmodule, cplx.py, line 628, convnd function, the bias is used to be added to the output, rather than the weight.
Is that right? And how to initialize the conv1d weight with "init.cplx_uniform_independent_" as the paper?
The text was updated successfully, but these errors were encountered: