Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor size mismatch error #93

Open
AkankshaP0102 opened this issue Apr 1, 2024 · 2 comments
Open

tensor size mismatch error #93

AkankshaP0102 opened this issue Apr 1, 2024 · 2 comments

Comments

@AkankshaP0102
Copy link

I am trying to train STEGO on a custom dataset but during the training process if I provide labels for the corresponding images I get the following error:
Traceback (most recent call last):
File "train_segmentation.py", line 598, in my_app
trainer.fit(model, train_loader, val_loader)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
self._dispatch()
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
self.training_type_plugin.start_training(self)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
return self._run_train()
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1311, in _run_train
self._run_sanity_check(self.lightning_module)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1375, in _run_sanity_check
self._evaluation_loop.run()
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 239, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step
return self.model.validation_step(*args, **kwargs)
File "train_segmentation.py", line 354, in validation_step
self.linear_metrics.update(linear_preds, label)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/torchmetrics/metric.py", line 405, in wrapped_func
raise err
File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/torchmetrics/metric.py", line 395, in wrapped_func
update(*args, **kwargs)
File "/media/2d46715b-293d-4478-acd4-5f000d443896/stego-studies/src/utils.py", line 240, in update
mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes)
RuntimeError: The size of tensor a (3276800) must match the size of tensor b (10240) at non-singleton dimension 0

Please help me with the following. Thank you

@Ruhrozz
Copy link

Ruhrozz commented May 16, 2024

I have found the solution

Most likely you have the same issue as mine. In DirectoryDataset after Image.open("mask") you have mask shape [H, W, 3]. But after that in validation step it interpolates to labels[-2:], so now predict shape is something like [B, C, H, 3].

I have solved the problem by converting mask to grayscale so that mask shape is just [H, W]:

label = Image.open(join(self.label_dir, label_fn)).convert('L')

See this and this for additional information.

@AkankshaP0102
Copy link
Author

Thank you @Ruhrozz .I tried the above solution suggested by you. It has solved the tensor error but I'm facing issue for the lable image shape. Was this type of error faced by you?? Your help will be appreciated. I have pasted the error below:

Traceback (most recent call last):
File "train_segmentation.py", line 503, in my_app
trainer.fit(model, train_loader, val_loader)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
self._dispatch()
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
self.training_type_plugin.start_training(self)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
return self._run_train()
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1311, in _run_train
self._run_sanity_check(self.lightning_module)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1375, in _run_sanity_check
self._evaluation_loop.run()
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 151, in run
output = self.on_run_end()
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 131, in on_run_end
self._evaluation_epoch_end(outputs)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 236, in _evaluation_epoch_end
model.validation_epoch_end(outputs)
File "train_segmentation.py", line 297, in validation_epoch_end
ax[1, i].imshow(self.label_cmap[output["label"][i]])
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/_api/deprecation.py", line 459, in wrapper
return func(*args, **kwargs)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/init.py", line 1414, in inner
return func(ax, *map(sanitize_sequence, args), **kwargs)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/axes/_axes.py", line 5487, in imshow
im.set_data(X)
File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/image.py", line 716, in set_data
.format(self._A.shape))
TypeError: Invalid shape (1, 320, 320, 3) for image data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants