Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing batch size and using multiple gpu makes Incompatible shapes issue. #9

Open
hyeonjinXZ opened this issue Oct 25, 2021 · 15 comments

Comments

@hyeonjinXZ
Copy link

hyeonjinXZ commented Oct 25, 2021

I have a memory issue.
So would it be better to change the batch size?

I changed only two things in the configuration.

batch_size = 56 -> 4
eval_batch_size = 7 -> 4.

But it makes a dimension error as below.

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [1,8,17,768] but got [1,32,17,768]. [Op:IteratorGetNext]

  1. What makes this error and what should I do now in this case?
  2. Or would you give me better tips to solve the out-of-memory?

My device is RTX 1080 TI (11GB) X 2

@woctezuma
Copy link

woctezuma commented Oct 25, 2021

Your error occurs because you feed [1,32,17,768] (notice the value 32) when [1,8,17,768] is expected (see the value 8).
It would be useful to know the full error message, with the exact line where the error occurs.

What you edited must have been:

config.batch_size = 56
config.eval_batch_size = 7

@hyeonjinXZ
Copy link
Author

Thank you for your reply.
Yes. I edited only two things you mentioned.

config.batch_size = 56
config.eval_batch_size = 7

And full error message is as below.
image

@woctezuma
Copy link

woctezuma commented Oct 25, 2021

Ok, so the error occurs at:

batch = jax.tree_map(np.asarray, next(train_iter))

after a call of:

train_utils.train(FLAGS.config, FLAGS.workdir)

@hyeonjinXZ
Copy link
Author

batch = jax.tree_map(np.asarray, next(train_iter))

Yes right. and when I print the 'next(train_iter)' before 421 line, it occurs error as below.
image

@hyeonjinXZ
Copy link
Author

hyeonjinXZ commented Oct 25, 2021

@woctezuma
Copy link

woctezuma commented Oct 25, 2021

If I were you, I would try to change fewer parameters actually. Try to get the code running without changing the eval batch size.

For instance:

 config.batch_size = 28
 config.eval_batch_size = 7 

or

 config.batch_size = 14
 config.eval_batch_size = 7 

Indeed:

  • going from a batch size of 56 to 4 might have been too drastic.
  • simultaneously changing the eval batch size a bit (from 7 to 4) makes it harder to know where the issue comes from.

Moreover, I see in the README that:

By default, the configs/coco_xmc.py config is used, which runs an experiment for 128px images. This is able to accommodate a batch size of 8 on each GPU, and achieves an FID of around 10.5 - 11.0 with the EMA weights.

I think you have 2 GPUs. Maybe try with a batch size of at least 16? Maybe 4 was too small.

Finally, I see stuff like these which could be where the error arises, not sure about that. I would like to know where the value 32 comes from in [1,32,17,768] in your error message.

if config.batch_size % jax.device_count() != 0:
raise ValueError(f"Batch size ({config.batch_size}) must be divisible by "
f"the number of devices ({jax.device_count()}).")
per_device_batch_size = config.batch_size // jax.device_count()
per_device_batch_size_train = per_device_batch_size * config.d_step_per_g_step

where:

config.d_step_per_g_step = 2

As a side-note, there is no enforcement that eval_batch_size has to be divisible by the number of GPUs, so you should be able to let it be equal to 7.

eval_num_batches = None
eval_batch_size_per_replica = config.eval_batch_size // jax.device_count()

@hyeonjinXZ
Copy link
Author

@woctezuma When I use one GPU and configuration is below makes it work:) Thank you for your kind reply.
config.batch_size = 8
config.d_step_per_g_step = 4

But when I use two GPUs and 16 batch sizes, it makes an error as below. It changed batch dimensions from 1 to 2.
image

  1. Why this error occurs? and how can I fix it?
  2. Can changing the batch size and d_step_per_g_step makes performance different like low or high?

@woctezuma
Copy link

woctezuma commented Oct 26, 2021

It is hard to say, but I believe the error that you see comes from a line like this one:

batch_dims=[jax.local_device_count(), per_device_batch_size_train],

where you have the first batch dim that is the number of GPUs.

In your error message, it seems the code is expecting to see data chunked for 2 GPUs but receives data for 1 GPU.

I wonder if there is an option to toggle ON the support for multiple GPUs.
That being said, It could be something else, I don't have the expertise here.

@kohjingyu
Copy link
Contributor

@Hyeonjin1989 I believe multiple GPUs should be supported natively by JAX? I did not have to do anything when running it on > 1 GPU. Can you print jax.local_device_count() to see what the value returned is?

As for your second question: We did find that performance is quite sensitive to batch size. I've never run the model on 2 GPUs, so I suggest that you do some quick hyperparameter sweeps if possible to find the best performance.

@hyeonjinXZ
Copy link
Author

@woctezuma @kohjingyu Thank you for your kind help :)

I printed jax.local_device_count() and per_device_batch_size_train before below line and it is 2 and 32.

train_ds = deterministic_data.create_dataset(

  • Details
    • command line: CUDA_VISIBLE_DEVICES=0,1 python -m xmcgan.main --config="xmcgan/configs/coco_xmc.py" --mode="train" --workdir=./exp/
    • error message: tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [2,32,17,768] but got [1,32,17,768]. [Op:IteratorGetNext]
    • modified configuration
      • config.batch_size = 16
      • config.d_step_per_g_step = 4

@kohjingyu
Copy link
Contributor

Can you paste your full error log?

@hyeonjinXZ
Copy link
Author

hyeonjinXZ commented Oct 27, 2021

@kohjingyu This is my full error log.
error message: tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [2,32,17,768] but got [1,32,17,768]. [Op:IteratorGetNext]
image

I also print jax.local_device_count() and per_device_batch_size_train. please find it in the log.
image

@hyeonjinXZ hyeonjinXZ changed the title To solve the out of memory, is it good to change batch size? Changing batch size and using multiple gpu makes Incompatible shapes issue. Nov 7, 2021
@hyeonjinXZ
Copy link
Author

hyeonjinXZ commented Nov 7, 2021

Multiple GPU makes a issue.
And I found the similar issue which said that the input data size must be a multiple of the number of GPUs.
How can I shave down the data size?

Error:
" tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [2,32,17,768] but got [1,32,17,768]."
image

  • details
    • config.batch_size = 16
    • config.d_step_per_g_step = 4

@woctezuma
Copy link

woctezuma commented Nov 7, 2021

To be clear, the StackOverflow answer comes from kuza55/keras-extras#7 (comment).
I am not sure it is relevant, as you can see that it has received as many thumbs up as thumbs down.

Are you forced to edit d_step_per_g_step? What was the error which prompted this change?

Also, it it normal that you have these lines in your log (reminiscent of #8)?

log

@hyeonjinXZ
Copy link
Author

hyeonjinXZ commented Nov 8, 2021

  1. Changing only batch size=14 without changing config.d_step_per_g_step=2 makes an 'Incompatible shapes' error. And below code may cause this error. If I set batch_size=14 and d_step_per_g_step=8, it works well without the 'Incompatible shapes' error using one GPU.

per_device_batch_size_train = per_device_batch_size * config.d_step_per_g_step

  1. I don't have TPU. So there are below lines normally.

I1107 21:05:37.584076 47723213264704 xla_bridge.py:212] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
I1107 21:05:37.818347 47723213264704 xla_bridge.py:212] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
/localscratch/xianzhen.18653261.0/env_xmc_gan_v1/lib/python3.8/site-packages/jax/lib/xla_bridge.py:368: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.

@hyeonjinXZ hyeonjinXZ reopened this Dec 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants