Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Input, output and indices must be on the current device #5

Open
cpuyyp opened this issue Dec 18, 2020 · 9 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@cpuyyp
Copy link

cpuyyp commented Dec 18, 2020

Hi Robin,

I found an error when I test with the pre-trained sentiment model from Flair. Simply loading with

classifier = TextClassifier.load('sentiment')
flair_model_wrapper = ModelWrapper(classifier)

And the rest is the same as yours. I got this error when I call the function interpret_sentence.

I test and print out the device. It turns out that the variable input_ids in function interpret_sentence is on cpu. My clumsy solution is to add

input_ids = input_ids.to(device)

after line

input_ids = flair_model_wrapper.tokenizer.encode(...)

There might be other internal solutions.

BTW, this work helps a lot!

@robinvanschaik
Copy link
Owner

Hi @cpuyyp,

Thanks for using the repo and raising the issue.
I may have overlooked it because my Macbook does not have a GPU.
I will probably adjust it this weekend, and you could test it afterwards.

Kind regards

@robinvanschaik
Copy link
Owner

Hi @cpuyyp,

I tested that this works indeed on a GPU notebook on Google Cloud.

# Store the encoding on the GPU.
input_ids = input_ids.to(flair_model_wrapper.device)

image
Not sure how to tackle CUDA memory management, e.g. if you rerun the function multiple times you might get OOM errors.

Will test this on my laptop tomorrow just to see if I will not break anything on a CPU-only machine.

@robinvanschaik robinvanschaik added the bug Something isn't working label Dec 18, 2020
@robinvanschaik robinvanschaik self-assigned this Dec 18, 2020
@robinvanschaik robinvanschaik pinned this issue Dec 18, 2020
@krzysztoffiok
Copy link
Contributor

In my case your explanation helped (I had to edit the function manually as you proposed after cloning your repo, so I guess this solution is not in the repo yet?), but I still needed to state clearly before the type of flair.device (flair.device = 'cuda') before starting everything.

It's a great thing you achieved @robinvanschaik ! I was looking for a solution like this for some time now already. Thanks a lot! Question: if I am to use your code/repo, is there any research publication you would like me to cite?

@robinvanschaik
Copy link
Owner

In my case your explanation helped (I had to edit the function manually as you proposed after cloning your repo, so I guess this solution is not in the repo yet?), but I still needed to state clearly before the type of flair.device (flair.device = 'cuda') before starting everything.

It's a great thing you achieved @robinvanschaik ! I was looking for a solution like this for some time now already. Thanks a lot! Question: if I am to use your code/repo, is there any research publication you would like me to cite?

Hi @krzysztoffiok,

I am glad to hear that you find this repo useful! :)

You are right. For some reason I never got around to actually pushing this to the master.
I guess life got the best of me.

It is possible that I might have some free time soon. In the meantime I would definitely welcome a pull request that will solve this issue!

Regarding citing this repo; I am not affiliated with any academic institution, nor do I write any (academic) papers.

In that regard this was a hobby project.
I am definitely standing on the shoulders of the CAPTUM & Flair teams, but I appreciate the fact that you are checking in for citing this repo. :)

Is there a way I could facilitate you in making sure this work is properly cited?

Then I will add a snippet to the markdown file on the front page.

Cheers

@krzysztoffiok
Copy link
Contributor

Hi @robinvanschaik ,

Thank you for a very quick merge :) I have created another pull request to force the user to clearly state the device they will be using. For me it helped.

@robinvanschaik
Copy link
Owner

Hi @krzysztoffiok,

Thanks for contributing with your pull requests! Keep them coming.

Are you willing to reflect your changes in the tutorial in the readme.md as well?
You might have the snippet at hand.

I believe that we can close this issue after this has been updated.

Afterwards I can create a new release as soon as this has been updated.

The Generated DOI will reflect the new release as well, which should help with citing the code.

Cheers.

@krzysztoffiok
Copy link
Contributor

@robinvanschaik

OK I will do that.

BTW, I have also noticed that, presumably for some slightly older model versions (this is my guess of the reason), there is a new error (see below). It happened with other models that I've fine-tuned ~9 months ago and not only Albert but also BERT and RoBERTa fine-tuned at the same time.

If I'm correct that this is a package version issue, I guess interpret-flair should clearly state which version of huggingface transformers and flair and captum to use.

AttributeError Traceback (most recent call last)
in
6 n_steps=500,
7 estimation_method="gausslegendre",
----> 8 internal_batch_size=3)

~/env/flair06/respect/data/models/respect_5k_final_respect_values_0/interpretation_package/interpret_flair.py in interpret_sentence(flair_model_wrapper, lig, sentence, target_label, visualization_list, n_steps, estimation_method, internal_batch_size)
62 # Thus we calculate the softmax afterwards.
63 # For now, I take the first dimension and run this sentence, per sentence.
---> 64 model_outputs = flair_model_wrapper(input_ids)
65
66 softmax = torch.nn.functional.softmax(model_outputs[0], dim=0)

~/env/flair06/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),

~/env/flair06/respect/data/models/respect_5k_final_respect_values_0/interpretation_package/flair_model_wrapper.py in forward(self, input_ids)
44 # Run the input embeddings through all the layers.
45 # Return the hidden states of the model.
---> 46 hidden_states = self.model(input_ids=input_ids)[-1]
47
48 # BERT has an initial CLS token.

~/env/flair06/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),

~/env/flair06/lib/python3.7/site-packages/transformers/modeling_albert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
656 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657 )
--> 658 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
659
660 if input_ids is not None and inputs_embeds is not None:

~/env/flair06/lib/python3.7/site-packages/transformers/configuration_utils.py in use_return_dict(self)
221 """
222 # If torchscript is set, force return_dict=False to avoid jit errors
--> 223 return self.return_dict and not self.torchscript
224
225 @Property

AttributeError: 'AlbertConfig' object has no attribute 'return_dict'

@robinvanschaik
Copy link
Owner

@krzysztoffiok You are definitely right.

I should have added a requirements.txt to the repository with pinned versions.

This would make it more reproducible.

@krzysztoffiok
Copy link
Contributor

@robinvanschaik do you think we could meet online about the functionality of interpret-flair package? I'm not that much familiar with practical use of IG method and its various parameters, so if you had the time to clarify some aspects that would greatly help me to get proper results. I have tried for a while and it didn't work as straightforward as I would expect.

Please contact me at krzysztof.fiok at gmail.com if you agree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants