diff --git a/src/setfit/exporters/onnx.py b/src/setfit/exporters/onnx.py index cd05c464..1156cb10 100644 --- a/src/setfit/exporters/onnx.py +++ b/src/setfit/exporters/onnx.py @@ -47,11 +47,21 @@ def __init__( self.model_head = model_head def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor): - hidden_states = self.model_body(input_ids, attention_mask, token_type_ids) + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + hidden_states = self.model_body(**inputs) + hidden_states = {"token_embeddings": hidden_states[0], "attention_mask": attention_mask} embeddings = self.pooler(hidden_states) + # Just to enforce that the token_type_ids will be included in the ONNX graph. + embeddings = embeddings + 0 * token_type_ids.sum() + # If the model_head is none we are using a sklearn head and only output # the embeddings from the setfit model if self.model_head is None: @@ -60,6 +70,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_t # If head is set then we have a fully torch based model and make the final predictions # with the head. out = self.model_head(embeddings) + return out diff --git a/tests/exporters/test_onnx.py b/tests/exporters/test_onnx.py index 6c132d43..27f843ab 100644 --- a/tests/exporters/test_onnx.py +++ b/tests/exporters/test_onnx.py @@ -61,9 +61,8 @@ def test_export_onnx_sklearn_head(model_path, input_text): os.remove(output_path) -@pytest.mark.skip("ONNX exporting of SetFit model with Torch head not yet supported.") @pytest.mark.parametrize("out_features", [1, 2, 3]) -def test_export_onnx_torch_head(out_features): +def test_export_onnx_torch_head_model_accepts_token_type_ids(out_features): """Test that the exported `ONNX` model returns the same predictions as the original model.""" dataset = get_templated_dataset(reference_dataset="SetFit/SentEval-CR") model_path = "sentence-transformers/paraphrase-albert-small-v2" @@ -71,6 +70,74 @@ def test_export_onnx_torch_head(out_features): model_path, use_differentiable_head=True, head_params={"out_features": out_features} ) + trainer = SetFitTrainer( + model=model, + train_dataset=dataset, + eval_dataset=dataset, + num_iterations=15, + column_mapping={"text": "text", "label": "label"}, + ) + # Train and evaluate + trainer.freeze() # Freeze the head + trainer.train() # Train only the body + # Unfreeze the head and unfreeze the body -> end-to-end training + trainer.unfreeze(keep_body_frozen=False) + trainer.train( + num_epochs=20, + batch_size=16, + body_learning_rate=1e-5, + learning_rate=1e-2, + l2_weight=0.0, + ) + + # Export the sklearn based model + output_path = "model.onnx" + try: + export_onnx(model.model_body, model.model_head, opset=12, output_path=output_path) + + # Check that the model was saved. + assert output_path in os.listdir(), "Model not saved to output_path" + + # Run inference using the original model. + input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"] + pytorch_preds = model(input_text) + + # Run inference using the exported onnx model. + tokenizer = AutoTokenizer.from_pretrained(model_path) + inputs = tokenizer( + input_text, + padding=True, + truncation=True, + return_attention_mask=True, + return_token_type_ids=True, + return_tensors="np", + ) + # Map inputs to int64 from int32 + inputs = {key: value.astype("int64") for key, value in inputs.items()} + + session = onnxruntime.InferenceSession(output_path) + + onnx_preds = session.run(None, dict(inputs))[0] + onnx_preds = onnx_preds / (1 + 1e-5) + onnx_preds_soft = np.exp(onnx_preds) / sum(np.exp(onnx_preds)) + onnx_preds_argmax = np.argmax(onnx_preds_soft, axis=1) + # Compare the results and ensure that we get the same predictions. + assert np.array_equal(onnx_preds_argmax, pytorch_preds) + + finally: + # Cleanup the model. + os.remove(output_path) + + +@pytest.mark.parametrize("out_features", [3]) +def test_export_onnx_torch_head_model_not_accepts_token_type_ids(out_features): + """Test that the exported `ONNX` model returns the same predictions as the original model.""" + dataset = get_templated_dataset(reference_dataset="SetFit/SentEval-CR") + model_path = "sentence-transformers/paraphrase-mpnet-base-v2" + model = SetFitModel.from_pretrained( + model_path, use_differentiable_head=True, head_params={"out_features": out_features} + ) + trainer = SetFitTrainer( model=model, train_dataset=dataset, @@ -119,9 +186,12 @@ def test_export_onnx_torch_head(out_features): session = onnxruntime.InferenceSession(output_path) onnx_preds = session.run(None, dict(inputs))[0] + onnx_preds = onnx_preds / (1 + 1e-5) + onnx_preds_soft = np.exp(onnx_preds) / sum(np.exp(onnx_preds)) + onnx_preds_argmax = np.argmax(onnx_preds_soft, axis=1) # Compare the results and ensure that we get the same predictions. - assert np.array_equal(onnx_preds, pytorch_preds) + assert np.array_equal(onnx_preds_argmax, pytorch_preds) finally: # Cleanup the model.