Support offloading encode, for generate() with much less VRAM #269
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
generate() from Transformers can take encoder outputs as kwargs instead of running the encoder. This PR extends this to "super conditioning" sampling. It also enables providing only one "null sequence" per batch, as inputs or encoder state, since that prompt is normally constant.
How is this useful? We only need to run the encoder once per distinct prompt, which even on a household CPU takes 1-2 seconds for a single input (worst case, no batching, no reuse). Offloading this step, generate works without 2 or 4 gigabytes of encoder weights (mega-1 and mega-1-fp16, respectively) hogging VRAM.
That way, mega-1-fp16 can run on a 4GB GPU (1-batches, without VQGAN, which is fast enough on CPU) and full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).
Specifically, without VQGAN, 1-batches need 3728 MiB in float16, 6770 MiB in float32 this way. GPU-accelerating VQGAN adds 770 MiB, assuming we also
del vqgan_params["encoder"]
(we never need these for generating images) beforereplicate(vqgan_params)
or the like.On systems that have enough memory anyway, up to 10 (fp32) or 20 (fp16) more items fit in a batch. Given the CPU encode cost, that's a few percent slower or faster (especially combined with other tricks in #247) in my experience, depending on how much state is shared.