From 3b7b26043d415aff6cab9e6866ac408bf1142c51 Mon Sep 17 00:00:00 2001 From: Dark Knight Date: Fri, 16 May 2025 12:58:18 -0700 Subject: [PATCH] Revert D74293458 Summary: This diff reverts D74293458 T224582561 broke tests Reviewed By: PoojaAg18 Differential Revision: D74902138 --- torchrec/metrics/metric_module.py | 1 - torchrec/metrics/metrics_config.py | 3 --- torchrec/metrics/rec_metric.py | 42 ------------------------------ 3 files changed, 46 deletions(-) diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index e46563c69..8ca849152 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -370,7 +370,6 @@ def _generate_rec_metrics( kwargs = metric_def.arguments kwargs["enable_pt2_compile"] = metrics_config.enable_pt2_compile - kwargs["should_clone_update_inputs"] = metrics_config.should_clone_update_inputs rec_tasks: List[RecTaskInfo] = [] if metric_def.rec_tasks and metric_def.rec_task_indices: diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index 91039ead3..0428e2412 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -170,8 +170,6 @@ class MetricsConfig: update if the inputs are invalid. Invalid inputs include the case where all examples have 0 weights for a batch. enable_pt2_compile (bool): whether to enable PT2 compilation for metrics. - should_clone_update_inputs (bool): whether to clone the inputs of update(). This - prevents CUDAGraph error on overwritting tensor outputs by subsequent runs. """ rec_tasks: List[RecTaskInfo] = field(default_factory=list) @@ -186,7 +184,6 @@ class MetricsConfig: compute_on_all_ranks: bool = False should_validate_update: bool = False enable_pt2_compile: bool = False - should_clone_update_inputs: bool = False DefaultTaskInfo = RecTaskInfo( diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index 8b9bf0e5c..53fbfa3b5 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -384,14 +384,6 @@ def __init__( if "enable_pt2_compile" in kwargs: del kwargs["enable_pt2_compile"] - # pyre-fixme[8]: Attribute has type `bool`; used as `Union[bool, - # Dict[str, Any]]`. - self._should_clone_update_inputs: bool = kwargs.get( - "should_clone_update_inputs", False - ) - if "should_clone_update_inputs" in kwargs: - del kwargs["should_clone_update_inputs"] - if self._window_size < self._batch_size: raise ValueError( f"Local window size must be larger than batch size. Got local window size {self._window_size} and batch size {self._batch_size}." @@ -549,35 +541,6 @@ def _create_default_weights(self, predictions: torch.Tensor) -> torch.Tensor: def _check_nonempty_weights(self, weights: torch.Tensor) -> torch.Tensor: return torch.gt(torch.count_nonzero(weights, dim=-1), 0) - def clone_update_inputs( - self, - predictions: RecModelOutput, - labels: RecModelOutput, - weights: Optional[RecModelOutput], - **kwargs: Dict[str, Any], - ) -> tuple[ - RecModelOutput, RecModelOutput, Optional[RecModelOutput], Dict[str, Any] - ]: - def clone_rec_model_output( - rec_model_output: RecModelOutput, - ) -> RecModelOutput: - if isinstance(rec_model_output, torch.Tensor): - return rec_model_output.clone() - else: - return {k: v.clone() for k, v in rec_model_output.items()} - - predictions = clone_rec_model_output(predictions) - labels = clone_rec_model_output(labels) - if weights is not None: - weights = clone_rec_model_output(weights) - - if "required_inputs" in kwargs: - kwargs["required_inputs"] = { - k: v.clone() for k, v in kwargs["required_inputs"].items() - } - - return predictions, labels, weights, kwargs - def _update( self, *, @@ -587,11 +550,6 @@ def _update( **kwargs: Dict[str, Any], ) -> None: with torch.no_grad(): - if self._should_clone_update_inputs: - predictions, labels, weights, kwargs = self.clone_update_inputs( - predictions, labels, weights, **kwargs - ) - if self._compute_mode in [ RecComputeMode.FUSED_TASKS_COMPUTATION, RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,