The following changes have been made:
-
The semi-orthogonal loss is now computed as the Frobenius norm of P (
P = torch.mm(M, M.T)
), instead of the Frobenius norm of (P - \alpha^2 I). This makes it consistent with the loss reporting in Kaldi. -
The
forward()
function in theTDNNF
class now takessemi_ortho_step
as argument instead oftraining
. This allows the calling function to make the decision about whether or not to take the step towards semi-orthogonality. -
The initialization of the
TDNN
layer now takes abias
argument, which specifies whether or not to use bias in theConv1D
layer. When using the TDNN for theSemiOrthogonalConv
class forTDNNF
, we setbias = False
, so that the matrix factorization checks out correctly.