-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question / Not sure if it's an issue] Suggested choice of hyperparameters feat_dim (N_a) == output_dim (N_d) leads to ValueError #14
Comments
There is a cuda bug that occurs if Nd and Na are same. The code internally takes Nd-Na dim of information for self attention, and setting them same means a (0, X) dim vector which fails on gpu. To get true Nd = Na, make Na = 2Nd and that works. |
And yes, I should update that docstring. |
You should update your examples also. |
Interestingly for the example train_embedding, I get much better performance when setting Na = Nd+1, than Na = 2Nd. For example: feature_dim=5, output_dim=4 (90% Accuracy rate), feature_dim=8, output_dim=4 (60% accuracy rate) |
Is anyone currently working on making this clearer? I found it very confusing that the default values do not run. The actual meaning should probably be noted in the comments/docstrings. Also, I wonder if it makes sense to allow inputs for |
move this line 270 in tabnet.py: features_for_coef = transform_f4[:, self.output_dim:] |
Both in docstring of TabNet class and in the original article they suggest
N_a == N_d
for most datasets.(Dimensionalities of hidden representations and the outputs of each decision step)
But in the code (tabnet.py:129) there is a ValueError which is raised if
N_a <= N_d
.I'm not sure if it's an issue or it's my comprehension of the code which is not correct.
Could you please clarify this point ?
P.S.
I'd like to thank you for your implementation of a very interesting paper.
I'm trying to use tabnet module for a small POC with an imbalanced dataset containing ~20k samples, mostly categorical data.
The text was updated successfully, but these errors were encountered: