Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch input type and weight type when training with precision fp16 #260

Open
hungvo304ml opened this issue Sep 16, 2023 · 7 comments
Open
Labels
bug Something isn't working

Comments

@hungvo304ml
Copy link

hungvo304ml commented Sep 16, 2023

Hi, thanks for making this project public.

I am trying to run training with fp16 and get the following error:

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

I am able to run using fp32 successfully only with an OOM error.

Traceback for error when using fp16:

Traceback (most recent call last):                                                                                                                                                                            
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train.py", line 484, in <module>                                                                            
    main()                                                                                                                                                                                                    
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train.py", line 465, in main                                                                                
    train_one_epoch(                                                                                                                                                                                          
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train_utils.py", line 111, in train_one_epoch                                                               
    loss_laion = model(                                                                                                                                                                                       
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward                                                                         
    output = self._run_ddp_forward(*inputs, **kwargs)                                                                                                                                                         
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward                                                                
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 108, in forward                                                                            
    self._encode_vision_x(vision_x=vision_x)                                                                                                                                                                  
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 195, in _encode_vision_x                                                                   
    vision_x = self.vision_encoder(vision_x)[1]                                                                                                                                                               
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/open_clip/transformer.py", line 469, in forward                                                                                  
    x = self.conv1(x)  # shape = [*, width, grid, grid]                                                                                                                                                       
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

Environment

I am using python 3.9.17 with V100 GPUs.

open-clip-torch          2.16.0
torch                    2.0.1
torchvision              0.15.2
transformers             4.28.1
@hungvo304ml hungvo304ml added the bug Something isn't working label Sep 16, 2023
@anas-awadalla
Copy link
Collaborator

Thanks for bringing this up! I will take a closer look later today. I do want to point out that we haven't gotten good performance with pure fp16 training. It could be more better if you use fp32 but use fsdp to shard model state across your GPUs rather than reducing the precision.

@hungvo304ml
Copy link
Author

Thanks for clarifying. FSDP would be ideal. Still, I have problems training with FSDP. Namely, I am using MPT-1B and it does not have the get_output_embeddings and set_output_embeddings methods. I see there is a major refactor that is in progress. Looking forward to using it soon.

@anas-awadalla
Copy link
Collaborator

Got it. There is this version of mpt I use for testing if you want to give fsdp a shot before the new refactor is merged.

@hungvo304ml
Copy link
Author

Great, thanks for bringing up this. I will give it a try on this model with fsdp.

@hungvo304ml
Copy link
Author

hungvo304ml commented Sep 19, 2023

I tried fsdp with "mpt-1b-redpajama-200b-hf-style" and it could pass the above error.

However, I get another error where the shape of input embeddings (self.transformer.wte.weight) has been altered. I believe it should be a 2-D tensor of shape (:, 2048) instead of a 1-D tensor of shape (25743360) which causes the size mismatch when computing the logits. More details below:

File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward                                                                            
    output = self.lang_encoder(
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward                                                                         
    return super().forward(**kwargs)  # Call the other parent's forward method    
File "/home/hqvo2/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b-hf-style/f40a2c7f92621be8b12a01ac9214d3ed4ef50f60/mosaic_gpt.py", line 379, in forward                
    logits = F.linear(x, self.transformer.wte.weight, None)                                                                                                                                                   
RuntimeError: size mismatch, got 8, 8x2048,25743360 

@alyakin314
Copy link

I tried fsdp with "mpt-1b-redpajama-200b-hf-style" and it could pass the above error.

However, I get another error where the shape of input embeddings (self.transformer.wte.weight) has been altered. I believe it should be a 2-D tensor of shape (:, 2048) instead of a 1-D tensor of shape (25743360) which causes the size mismatch when computing the logits. More details below:

File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward                                                                            
    output = self.lang_encoder(
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward                                                                         
    return super().forward(**kwargs)  # Call the other parent's forward method    
File "/home/hqvo2/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b-hf-style/f40a2c7f92621be8b12a01ac9214d3ed4ef50f60/mosaic_gpt.py", line 379, in forward                
    logits = F.linear(x, self.transformer.wte.weight, None)                                                                                                                                                   
RuntimeError: size mismatch, got 8, 8x2048,25743360 

did you resolve this? i get a very similar error while trying to use fsdp w/ openflamingo 9B:

  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/open_flamingo/src/flamingo.py", line 111, in forward
    output = self.lang_encoder(
  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/open_flamingo/src/flamingo_lm.py", line 157, in forward
    return super().forward(**kwargs)  # Call the other parent's forward method
  File "/gpfs/data/oermannlab/users/alyaka01/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-7b/b772e556c8e8a17d087db6935e7cd019e5eefb0f/modeling_mpt.py", line 258, in forward
    logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
RuntimeError: size mismatch, got 8192, 8192x4096,51486720

@alyakin314
Copy link

related: #129 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants