Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh committed Jan 7, 2025
1 parent 3e8c98a commit 3ea5f90
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 54 deletions.
2 changes: 1 addition & 1 deletion torch_xla/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
skip_input_data_check = False

# Whether to transform the FX graph into an XLA computation
# and creating a call node for that computation. This allows XLA to trace
# and creating a call node for that computation. This allows XLA to trace
# a more extensive computation graph, potentially leading to greater
# optimization opportunities.
use_call_computation = False
Expand Down
40 changes: 22 additions & 18 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _maybe_move_tensors_to_device(tensors: tuple,
# If the input cuda tensor requires gradient, we need to call detach. Otherwise, we'd get the error "RuntimeError: Can't export tensors that require gradient, use tensor.detach()"
moved_tensor = torch_xla_dlpack.from_dlpack(tensor.detach())
elif zero_copy_enabled and tensor.device.type == 'xla' and target_device.type == 'cuda':
# mark_step is need to make sure the pjrt buffer is valid.
# mark_step is need to make sure the pjrt buffer is valid.
xm.mark_step()
moved_tensor = torch_xla_dlpack.from_xla_cuda_to_cuda(tensor)
else:
Expand Down Expand Up @@ -260,9 +260,7 @@ def __init__(self, trace_inputs, trace_outputs,
self.deduped_trace_outputs = self.deduper.dedup(trace_outputs)

# record the output that is also a input
trace_inputs_id2pos = {
id(x): pos for pos, x in enumerate(trace_inputs)
}
trace_inputs_id2pos = {id(x): pos for pos, x in enumerate(trace_inputs)}
self.trace_outputs_pos_to_inputs_pos = []
for out_pos, out in enumerate(self.deduped_trace_outputs):
in_pos = trace_inputs_id2pos.get(id(out), None)
Expand Down Expand Up @@ -470,7 +468,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
graph_hash = None
if config.use_call_computation:
graph_hash = None
xla_computation = torch_xla._XLAC._xla_create_computation("dynamo_call", args_and_out_tensor_only)
xla_computation = torch_xla._XLAC._xla_create_computation(
"dynamo_call", args_and_out_tensor_only)
elif len(args_and_out_tensor_only) > 0:
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only)
# compiles and cache graph rooted at tensors in 'args_and_out_tensor_only'
Expand Down Expand Up @@ -506,7 +505,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
vars_to_return = (xla_args_sharding_spec, len(args_and_out), graph_hash,
arg_index_to_need_update_index, none_remover,
graph_input_matcher, special_return_handler,
xla_args_need_update, xla_args_dtype, xla_computation, arg_index_to_update_output_index)
xla_args_need_update, xla_args_dtype, xla_computation,
arg_index_to_update_output_index)
# populate the cache
sym_constants_to_graph_vars[sym_constants] = vars_to_return

Expand Down Expand Up @@ -538,10 +538,9 @@ def extract_internal(xla_model: torch.fx.GraphModule):

(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler,
xla_args_need_update, xla_args_dtype, xla_computation,
arg_index_to_update_output_index) = extract_graph_helper(xla_model,
sym_constants_to_graph_vars)
special_return_handler, xla_args_need_update, xla_args_dtype,
xla_computation, arg_index_to_update_output_index) = extract_graph_helper(
xla_model, sym_constants_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)

Expand All @@ -566,14 +565,15 @@ def optimized_mod(*args: tuple):
if sym_constants in sym_constants_to_graph_vars:
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler,
xla_args_need_update, xla_args_dtype, xla_computation,
arg_index_to_update_output_index) = sym_constants_to_graph_vars[sym_constants]
special_return_handler, xla_args_need_update, xla_args_dtype,
xla_computation, arg_index_to_update_output_index
) = sym_constants_to_graph_vars[sym_constants]
else:
xla_model.xla_args = args
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler, xla_args_need_update, xla_args_dtype, xla_computation,
special_return_handler, xla_args_need_update, xla_args_dtype,
xla_computation,
arg_index_to_update_output_index) = extract_graph_helper(
xla_model, sym_constants_to_graph_vars)
if hasattr(xla_model, 'xla_args'):
Expand Down Expand Up @@ -638,8 +638,9 @@ def optimized_mod(*args: tuple):
if config.use_call_computation:
assert not is_cuda_args
assert xla_computation is not None
res = torch_xla._XLAC._xla_call_computation("xla::_call_computation", graph_input,
xla_computation, xla_args_tensor_only, arg_index_to_update_output_index)
res = torch_xla._XLAC._xla_call_computation(
"xla::_call_computation", graph_input, xla_computation,
xla_args_tensor_only, arg_index_to_update_output_index)
else:
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
xm.wait_device_ops()
Expand Down Expand Up @@ -794,7 +795,6 @@ def move_xla_to_cuda(self, graph: torch.fx.Graph):
kwargs["device"] = "cuda"
node.kwargs = kwargs


def __call__(self, graph: torch.fx.Graph, move_xla_to_cuda=False) -> None:
if move_xla_to_cuda:
self.move_xla_to_cuda(graph)
Expand Down Expand Up @@ -870,7 +870,11 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:

last_node = list(reversed(partitioned_graph.graph.nodes))[0]
with partitioned_graph.graph.inserting_after(last_node):
partitioned_graph.graph.create_node(op='output', target='output', args=return_node.args, type_expr=return_node.type)
partitioned_graph.graph.create_node(
op='output',
target='output',
args=return_node.args,
type_expr=return_node.type)
partitioned_graph.graph.erase_node(return_node)
InputCollector(partitioned_graph).run(*xla_args)

Expand Down
61 changes: 30 additions & 31 deletions torch_xla/distributed/fsdp/_exec_utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
class ExecState:
def __init__(self):
self._index_to_fsdp_module = {}
self._fsdp_module_to_index = {}
self._iter = 0
self._current_index = 0

@property
def is_first_iter(self) -> bool:
return self._iter == 0

def record_forward(self, fsdp_module):
if self.is_first_iter:
assert fsdp_module not in self._fsdp_module_to_index
self._index_to_fsdp_module[self._current_index] = fsdp_module
self._fsdp_module_to_index[fsdp_module] = self._current_index
self._current_index += 1
else:
assert fsdp_module in self._fsdp_module_to_index
assert self._fsdp_module_to_index[fsdp_module] == self._current_index, \
"FSDP module is not in the same execution order as first iteration."
self._current_index += 1


def get_prefetch_module(self):
if self.is_first_iter:
return None
if self._current_index >= len(self._index_to_fsdp_module):
return None
return self._index_to_fsdp_module[self._current_index]
def __init__(self):
self._index_to_fsdp_module = {}
self._fsdp_module_to_index = {}
self._iter = 0
self._current_index = 0

@property
def is_first_iter(self) -> bool:
return self._iter == 0

def next_iter(self):
self._iter += 1
self._current_index = 0
def record_forward(self, fsdp_module):
if self.is_first_iter:
assert fsdp_module not in self._fsdp_module_to_index
self._index_to_fsdp_module[self._current_index] = fsdp_module
self._fsdp_module_to_index[fsdp_module] = self._current_index
self._current_index += 1
else:
assert fsdp_module in self._fsdp_module_to_index
assert self._fsdp_module_to_index[fsdp_module] == self._current_index, \
"FSDP module is not in the same execution order as first iteration."
self._current_index += 1

def get_prefetch_module(self):
if self.is_first_iter:
return None
if self._current_index >= len(self._index_to_fsdp_module):
return None
return self._index_to_fsdp_module[self._current_index]

def next_iter(self):
self._iter += 1
self._current_index = 0
7 changes: 5 additions & 2 deletions torch_xla/distributed/fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

_prev_early_sync_counter = 0


def exists_early_sync():
import torch_xla.debug.metrics as metrics
global _prev_early_sync_counter
Expand Down Expand Up @@ -128,7 +129,8 @@ def backward(ctx, grad_output):
grad_input = grad_input_flat
if torch.compiler.is_dynamo_compiling() or ctx.needs_input_grad[1]:
grad_weight = grad_output_flat.t().mm(input_flat)
if bias is not None and (torch.compiler.is_dynamo_compiling() or ctx.needs_input_grad[2]):
if bias is not None and (torch.compiler.is_dynamo_compiling() or
ctx.needs_input_grad[2]):
grad_bias = grad_output_flat.sum(0)

return grad_input, grad_weight, grad_bias
Expand Down Expand Up @@ -226,7 +228,8 @@ def forward(ctx, run_function, *args):

ctx.save_for_backward(*(tensor_inputs + tensor_outputs))
outputs = _apply_to_tensors(lambda t: t.clone().detach(), outputs)
if dynamo_config.mark_step_after_layer_if_early_sync and exists_early_sync():
if dynamo_config.mark_step_after_layer_if_early_sync and exists_early_sync(
):
ctx.mark_step = True
xm.mark_step(reset_scope=False)
else:
Expand Down
6 changes: 4 additions & 2 deletions torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def __init__(
is_forward_defined = (
hasattr(module, "forward") and hasattr(module.forward, "__func__") and
module.forward.__func__ != torch.nn.Module.forward)
if not is_forward_defined and not isinstance(module, torch._dynamo.OptimizedModule):
if not is_forward_defined and not isinstance(module,
torch._dynamo.OptimizedModule):
raise RuntimeError(
"The module wrapped by FSDP *must define a `forward` method and call it "
"during the module's forward pass for FSDP to work correctly.* "
Expand Down Expand Up @@ -996,7 +997,8 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._exec_state.record_forward(self)
next_module = self._exec_state.get_prefetch_module()
if next_module:
next_module._rebuild_full_params(apply_opt_barrier=self.optimization_barrier_in_forward)
next_module._rebuild_full_params(
apply_opt_barrier=self.optimization_barrier_in_forward)

# Start of a forward pass.
self.training_state = TrainingState.FORWARD
Expand Down

0 comments on commit 3ea5f90

Please sign in to comment.