Skip to content

Commit

Permalink
fix kwargs for torch.amp.autocast
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Oct 21, 2024
1 parent 0987833 commit a2212c0
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,27 @@ def in_fp8_activation_recompute_phase() -> bool:

def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast
state at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()

gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)

cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)

return gpu_autocast_ctx, cpu_autocast_ctx
Expand All @@ -274,8 +280,8 @@ def _get_active_autocast_contexts():

def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast
state at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()

Expand Down

0 comments on commit a2212c0

Please sign in to comment.