From 346d7050117e0abd19f12831aa26e214b710716f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 30 May 2024 16:31:49 +0200 Subject: [PATCH] FIX NeuralNetBinaryClassifier with torch.compile (#1058) * FIX NeuralNetBinaryClassifier with torch.compile Fixes #1057 NeuralNetBinaryClassifier was not working with torch.compile because the non-linearity was not correctly inferred. This inference depends on the instance type of the criterion. However, when using torch.compile, the criterion is wrapped, resulting in the isinstance check to miss. Now, we unwrap the criterion before checking the instance type. * Add entry to CHANGES.md --- CHANGES.md | 2 ++ skorch/tests/test_net.py | 33 +++++++++++++++++++++++++++++++++ skorch/utils.py | 2 ++ 3 files changed, 37 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 9302faa40..cd493f25f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed ### Fixed +- Fix an issue with using `NeuralNetBinaryClassifier` with `torch.compile` (#1058) + ## [1.0.0] - 2024-05-27 The 1.0.0 release of skorch is here. We think that skorch is at a very stable point, which is why a 1.0.0 release is appropriate. There are no plans to add any breaking changes or major revisions in the future. Instead, our focus now is to keep skorch up-to-date with the latest versions of PyTorch and scikit-learn, and to fix any bugs that may arise. diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 7eec94e5a..17704e1bd 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -4159,3 +4159,36 @@ def test_fit_and_predict_with_compile(self, net_cls, module_cls, data): # compiled, we rely here on torch keeping this public attribute assert hasattr(net.module_, 'dynamo_ctx') assert hasattr(net.criterion_, 'dynamo_ctx') + + def test_binary_classifier_with_compile(self, data): + # issue 1057 the problem was that compile would wrap the optimizer, + # resulting in _infer_predict_nonlinearity to return the wrong result + # because of a failing isinstance check + from skorch import NeuralNetBinaryClassifier + + X, y = data[0], data[1].astype(np.float32) + + class MyNet(nn.Module): + def __init__(self): + super(MyNet, self).__init__() + self.linear = nn.Linear(20, 10) + self.output = nn.Linear(10, 1) + + def forward(self, input): + out = self.linear(input) + out = nn.functional.relu(out) + out = self.output(out) + return out.squeeze(-1) + + net = NeuralNetBinaryClassifier( + MyNet, + max_epochs=3, + compile=True, + ) + # check that no error is raised + net.fit(X, y) + + y_proba = net.predict_proba(X) + y_pred = net.predict(X) + assert y_proba.shape == (X.shape[0], 2) + assert y_pred.shape == (X.shape[0],) diff --git a/skorch/utils.py b/skorch/utils.py index 57ceaaa6f..de679ec35 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -660,6 +660,8 @@ def _infer_predict_nonlinearity(net): return _identity criterion = getattr(net, net._criteria[0] + '_') + # unwrap optimizer in case of torch.compile being used + criterion = getattr(criterion, '_orig_mod', criterion) if isinstance(criterion, CrossEntropyLoss): return partial(torch.softmax, dim=-1)