-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PyTorch 2.4.0 to CI #1063
Add PyTorch 2.4.0 to CI #1063
Conversation
Also: - Remove 2.0.1 - Updgrade 2.3.0 to 2.3.1 - Use index https://download.pytorch.org/whl/torch, as torch_stable does not have 2.4.0 (yet?)
Okay, so I investigated the failing tests with PyTorch 2.4 a bit further. Right now, we get a warning:
The reason why this warning leads to an error is pure coincidence: In the given tests, we have a filter to catch warnings for different reasons, which is what is triggered by this new skorch/skorch/tests/test_hf.py Lines 715 to 719 in 346d705
Anyway, it's good that we got an early indicator that this will break in the future. However, fixing the problem is not trivial. Here is why: Since PyTorch 1.13, there is an option in Right now, the default is to set 1. Defaulting to
|
In the long term, I'll want a way to allow
Concretely, whenever we call default_load_kwargs = {"weights_only": True}
torch_load_kwargs = {**default_load_kwargs, **self.torch_load_kwargs}
torch.load(..., **torch_load_kwargs) |
Thanks for the input, this sounds reasonable. It's not pretty, but since we cannot directly pass arguments to As to the default: WDYT about using "auto" and then switching to whatever the default is for the given PyTorch version? I found that there is also a context manager Edit: Planned release is v2.6.0. |
I like this. It would mean that we expose a way of handling model loading security to the user while keeping pytorch's defaults. Since this is a long-standing security issue I'd say we should at least follow the pytorch default as soon as they deem the ecosystem to be ready for it. We could simply use the pytorch release version as a default indicator (might be better than using inspect?) I assume that the need for a class variable for the
I was going to say that it might be beneficial to have the tests look as close to user code where possible so that we have approximately the same issues (in terms of functionality but also in terms of 'design') as our users do. The context manager + whitelisting generic classes is probably a good middle-ground. |
I'm happy with an "auto" option. |
See discussion in #1063 Starting from PyTorch 2.4, there is a warning when torch.load is called without setting the weights_only argument. This is because in the future, the default will switch from False to True, which can result in a lot of errors when trying to load torch files (which are pickle files and thus insecure). In this PR, we add a possibility for the user to influence the kwargs passed to torch.load so that they can control that behavior. If not further indicated by the user, we will use the same defaults as the installed torch version. Therefore, users will only encounter this issue via skorch if they would have encountered it via torch anyway. Since it's not 100% certain if the default will switch in torch 2.6.0, we may have to adjust the version check in the future. Besides directly testing the kwargs being passed on, a test was also added that net.load_params does not give any warnings. This is already indirectly tested through some accelerate tests that are currently failing with torch 2.4, but it's better to have an explicit test. After this is merged, the CI should pass when using torch 2.4.0.
* Fix warning from torch.load starting in torch 2.4 See discussion in #1063 Starting from PyTorch 2.4, there is a warning when torch.load is called without setting the weights_only argument. This is because in the future, the default will switch from False to True, which can result in a lot of errors when trying to load torch files (which are pickle files and thus insecure). In this PR, we add a possibility for the user to influence the kwargs passed to torch.load so that they can control that behavior. If not further indicated by the user, we will use the same defaults as the installed torch version. Therefore, users will only encounter this issue via skorch if they would have encountered it via torch anyway. Since it's not 100% certain if the default will switch in torch 2.6.0, we may have to adjust the version check in the future. Besides directly testing the kwargs being passed on, a test was also added that net.load_params does not give any warnings. This is already indirectly tested through some accelerate tests that are currently failing with torch 2.4, but it's better to have an explicit test. After this is merged, the CI should pass when using torch 2.4.0. * Reviewer feedback: return kwargs directly * Reviewer feedback: One more test w/o monkeypatch Instead, rely on the installed torch version and skip if it doesn't fit. * Reviewer feedback: rename function, fix typo
Also: