Skip to content

Commit

Permalink
Handle enable CUDA graph param in SD example (microsoft#246)
Browse files Browse the repository at this point in the history
This PR updates how the enable_cuda_graph param is set depending on the world_size i.e. CUDA graphs should only be enabled when world_size==1.
  • Loading branch information
lekurile authored Mar 31, 2023
1 parent 127c7a1 commit 4c6bfc5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 2 additions & 0 deletions inference/huggingface/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pip install -r requirements.txt
Examples can be run as follows:
<pre>deepspeed --num_gpus [number of GPUs] test-[model].py</pre>

NOTE: Local CUDA graphs for replaced SD modules will only be enabled when `mp_size==1`.

# Example Output
Command:
<pre>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
model = "prompthero/midjourney-v4-diffusion"
local_rank = int(os.getenv("LOCAL_RANK", "0"))
device = torch.device(f"cuda:{local_rank}")
world_size = int(os.getenv('WORLD_SIZE', '4'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
generator = torch.Generator(device=torch.cuda.current_device())

pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half)
Expand All @@ -19,12 +19,14 @@
baseline_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0]
baseline_image.save(f"baseline.png")

# NOTE: DeepSpeed inference supports local CUDA graphs for replaced SD modules
# NOTE: DeepSpeed inference supports local CUDA graphs for replaced SD modules.
# Local CUDA graphs for replaced SD modules will only be enabled when `mp_size==1`
pipe = deepspeed.init_inference(
pipe,
mp_size=world_size,
dtype=torch.half,
replace_with_kernel_inject=True,
enable_cuda_graph=True,
enable_cuda_graph=True if world_size==1 else False,
)

generator.manual_seed(0xABEDABE7)
Expand Down

0 comments on commit 4c6bfc5

Please sign in to comment.