-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TTTensor
class
#23089
TTTensor
class
#23089
Conversation
Thanks for contributing to Ivy! 😊👏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mobley-trent ! Thanks for working on this, it's a great PR. Would you please add some tests from TensorLy for the TTTensor class in ivy_tests/test_ivy/test_misc
like we have added for other Factorized Classes?
I have no idea what caused the changes to the |
Hi, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mobley-trent, thanks for making the changes. Just left a few minor comments - let's also make sure that all the relevant tests are passing on the CI. Happy to review again once you have made the changes. Thanks :)
rank = draw(helpers.ints(min_value=1, max_value=len(shape))) | ||
dtype = draw( | ||
helpers.get_dtypes("float", full=False).filter( | ||
lambda x: x not in ["bfloat16", "float16"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these data-types causing the tests to fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is an internal error related to these dtypes. I remember Ved mentioning this in a sync sometime back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of the internal error related to these dtypes but we shouldn't be adding this condition in the tests. It will automatically be fixed once the internal error is fixed. We shouldn't worry about this if this is the only reason why the testa are failing :)
if full: | ||
reconstructed_tensor = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) | ||
reconstructed_tensor_gt = helpers.flatten_and_to_np( | ||
ret=ret_from_gt_np, backend=test_flags.ground_truth_backend | ||
) | ||
for x, x_gt in zip(reconstructed_tensor, reconstructed_tensor_gt): | ||
assert np.prod(shape) == np.prod(x.shape) | ||
assert np.prod(shape) == np.prod(x_gt.shape) | ||
|
||
else: | ||
weights = helpers.flatten_and_to_np(ret=ret_np[0], backend=backend_fw) | ||
factors = helpers.flatten_and_to_np(ret=ret_np[1], backend=backend_fw) | ||
weights_gt = helpers.flatten_and_to_np( | ||
ret=ret_from_gt_np[0], backend=test_flags.ground_truth_backend | ||
) | ||
factors_gt = helpers.flatten_and_to_np( | ||
ret=ret_from_gt_np[1], backend=test_flags.ground_truth_backend | ||
) | ||
|
||
for w, w_gt in zip(weights, weights_gt): | ||
assert w.shape[-1] == rank | ||
assert w_gt.shape[-1] == rank | ||
|
||
for f, f_gt in zip(factors, factors_gt): | ||
assert np.prod(f.shape) == np.prod(f_gt.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be right. TTTensor
only contains factors
matrices and and doesn't contain weights
:)
# TODO: This test fails even for the native implementation | ||
# rank = ivy.TTTensor.validate_tt_rank(tensor_shape, coef, rounding="floor") | ||
# n_param = ivy.TTTensor._tt_n_param(tensor_shape, rank) | ||
# np.testing.assert_(n_param >= n_param_tensor * coef) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if I follow this, If our implementation matches exactly with TensorLy, this should ideally pass as well. What's the error you're getting on running this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting an assertion error. This test fails for TensorLy as well, e.g:
import ivy
import numpy as np
import tensorly as tl
coef = 0.2
tensor_shape = tuple(np.random.randint(5, 10, size=4))
n_param_tensor = np.prod(tensor_shape)
# ivy
rank = ivy.TTTensor.validate_tt_rank(tensor_shape, coef, rounding="floor")
n_param = ivy.TTTensor._tt_n_param(tensor_shape, rank)
np.testing.assert_(n_param >= n_param_tensor * coef)
# tensorly
rank = tl.tt_tensor.validate_tt_rank(tensor_shape, coef, rounding="floor")
n_param = tl.tt_tensor._tt_n_param(tensor_shape, rank)
np.testing.assert_(n_param >= n_param_tensor * coef)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only ceil
mode is passing
rank = draw(helpers.ints(min_value=1, max_value=len(shape))) | ||
dtype = draw( | ||
helpers.get_dtypes("float", full=False).filter( | ||
lambda x: x not in ["bfloat16", "float16"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of the internal error related to these dtypes but we shouldn't be adding this condition in the tests. It will automatically be fixed once the internal error is fixed. We shouldn't worry about this if this is the only reason why the testa are failing :)
tensor_shape = tuple(ivy.random.randint(5, 10, shape=(4,))) | ||
n_param_tensor = ivy.prod(tensor_shape) | ||
|
||
# TODO: This test fails even for the native implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the tensorly issue link here so that it's easy to keep a track :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Compliance Checks
Thank you for your Pull Request! We have run several checks on this pull request in order to make sure it's suitable for merging into this project. The results are listed in the following section.
Conventional Commit PR Title
In order to be considered for merging, the pull request title must match the specification in conventional commits. You can edit the title in order for this check to pass.
Most often, our PR titles are something like one of these:
- docs: correct typo in README
- feat: implement dark mode"
- fix: correct remove button behavior
Linting Errors
- Found type "null", must be one of "feat","fix","docs","style","refactor","perf","test","build","ci","chore","revert"
- No subject found
Closes #22188