Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defining arguments names to avoid issues with positional args #446

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,21 @@ def __init__(
self.model_head = model_head

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor):
hidden_states = self.model_body(input_ids, attention_mask, token_type_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

hidden_states = self.model_body(**inputs)

hidden_states = {"token_embeddings": hidden_states[0], "attention_mask": attention_mask}

embeddings = self.pooler(hidden_states)

# Just to enforce that the token_type_ids will be included in the ONNX graph.
embeddings = embeddings + 0 * token_type_ids.sum()

# If the model_head is none we are using a sklearn head and only output
# the embeddings from the setfit model
if self.model_head is None:
Expand All @@ -60,6 +70,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_t
# If head is set then we have a fully torch based model and make the final predictions
# with the head.
out = self.model_head(embeddings)

return out


Expand Down
76 changes: 73 additions & 3 deletions tests/exporters/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,83 @@ def test_export_onnx_sklearn_head(model_path, input_text):
os.remove(output_path)


@pytest.mark.skip("ONNX exporting of SetFit model with Torch head not yet supported.")
@pytest.mark.parametrize("out_features", [1, 2, 3])
def test_export_onnx_torch_head(out_features):
def test_export_onnx_torch_head_model_accepts_token_type_ids(out_features):
"""Test that the exported `ONNX` model returns the same predictions as the original model."""
dataset = get_templated_dataset(reference_dataset="SetFit/SentEval-CR")
model_path = "sentence-transformers/paraphrase-albert-small-v2"
model = SetFitModel.from_pretrained(
model_path, use_differentiable_head=True, head_params={"out_features": out_features}
)

trainer = SetFitTrainer(
model=model,
train_dataset=dataset,
eval_dataset=dataset,
num_iterations=15,
column_mapping={"text": "text", "label": "label"},
)
# Train and evaluate
trainer.freeze() # Freeze the head
trainer.train() # Train only the body
# Unfreeze the head and unfreeze the body -> end-to-end training
trainer.unfreeze(keep_body_frozen=False)
trainer.train(
num_epochs=20,
batch_size=16,
body_learning_rate=1e-5,
learning_rate=1e-2,
l2_weight=0.0,
)

# Export the sklearn based model
output_path = "model.onnx"
try:
export_onnx(model.model_body, model.model_head, opset=12, output_path=output_path)

# Check that the model was saved.
assert output_path in os.listdir(), "Model not saved to output_path"

# Run inference using the original model.
input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]
pytorch_preds = model(input_text)

# Run inference using the exported onnx model.
tokenizer = AutoTokenizer.from_pretrained(model_path)
inputs = tokenizer(
input_text,
padding=True,
truncation=True,
return_attention_mask=True,
return_token_type_ids=True,
return_tensors="np",
)
# Map inputs to int64 from int32
inputs = {key: value.astype("int64") for key, value in inputs.items()}

session = onnxruntime.InferenceSession(output_path)

onnx_preds = session.run(None, dict(inputs))[0]
onnx_preds = onnx_preds / (1 + 1e-5)
onnx_preds_soft = np.exp(onnx_preds) / sum(np.exp(onnx_preds))
onnx_preds_argmax = np.argmax(onnx_preds_soft, axis=1)
# Compare the results and ensure that we get the same predictions.
assert np.array_equal(onnx_preds_argmax, pytorch_preds)

finally:
# Cleanup the model.
os.remove(output_path)


@pytest.mark.parametrize("out_features", [3])
def test_export_onnx_torch_head_model_not_accepts_token_type_ids(out_features):
"""Test that the exported `ONNX` model returns the same predictions as the original model."""
dataset = get_templated_dataset(reference_dataset="SetFit/SentEval-CR")
model_path = "sentence-transformers/paraphrase-mpnet-base-v2"
model = SetFitModel.from_pretrained(
model_path, use_differentiable_head=True, head_params={"out_features": out_features}
)

trainer = SetFitTrainer(
model=model,
train_dataset=dataset,
Expand Down Expand Up @@ -119,9 +186,12 @@ def test_export_onnx_torch_head(out_features):
session = onnxruntime.InferenceSession(output_path)

onnx_preds = session.run(None, dict(inputs))[0]
onnx_preds = onnx_preds / (1 + 1e-5)
onnx_preds_soft = np.exp(onnx_preds) / sum(np.exp(onnx_preds))
onnx_preds_argmax = np.argmax(onnx_preds_soft, axis=1)

# Compare the results and ensure that we get the same predictions.
assert np.array_equal(onnx_preds, pytorch_preds)
assert np.array_equal(onnx_preds_argmax, pytorch_preds)

finally:
# Cleanup the model.
Expand Down
Loading