Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot export BERT #51

Open
toranb opened this issue May 24, 2023 · 0 comments
Open

Cannot export BERT #51

toranb opened this issue May 24, 2023 · 0 comments

Comments

@toranb
Copy link

toranb commented May 24, 2023

When I pull down the BERT or RoBERTa models from hugging face I can't export them today with 0.4 because the op in the Axon.Node doesn't match any of the available function heads %Axon.Node{op: :container}

Screenshot 2023-05-24 at 10 31 33 AM

Here is the Axon.Node that is a pattern match miss

%Axon.Node{
  id: 999,
  name: #Function<66.122028880/2 in Axon.name/2>,
  mode: :both,
  parent: [%{attentions: 996, hidden_states: 997, logits: 998}],
  parameters: [],
  args: [:layer],
  op: :container,
  policy: p=f32 c=f32 o=f32,
  hooks: [],
  opts: [],
  op_name: :container,
  stacktrace: [
    {Axon, :layer, 3, [file: 'lib/axon.ex', line: 338]},
    {Bumblebee, :load_model, 2, [file: 'lib/bumblebee.ex', line: 460]},
    {MyCode.MyModule, :export_bert, 0, [file: 'lib/my_module/export.ex', line: 7]},
    ...
  ]
}

And here is the code I'm using to pull in BERT and export it

  def export() do
    {:ok, spec} = Bumblebee.load_spec({:hf, "bert-base-cased"}, architecture: :for_sequence_classification)
    spec = Bumblebee.configure(spec, num_labels: 8)
    {:ok, bert} = Bumblebee.load_model({:hf, "bert-base-cased"}, spec: spec)
    %{model: the_model, params: params} = bert
    sequence_length = 50
    batch_size = 16

    input_template = %{
      "input_ids" => Nx.template({batch_size, sequence_length}, :f32),
      "attention_mask" => Nx.template({batch_size, sequence_length}, :f32),
      "token_type_ids" => Nx.template({batch_size, sequence_length}, :f32)
    }

    AxonOnnx.export(the_model, input_template, params, path: "tuned.onnx")
  end

Further detail: I've cut this down to something easy to reproduce but my full example/motivation here is as follows

  • I pull down a RoBERTa or BERT model from huggingface
  • I then fine tune this model with domain specific data
  • I evaluate the tuned model and it's working as expected with high accuracy
  • I then want to export it so I can examine it further with tools like weight watcher
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant