Skip to content

Commit

Permalink
[torchlib] Unregister stft, var, var_mean, std, std_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Sep 16, 2024
1 parent 377869a commit 2c2437d
Showing 1 changed file with 19 additions and 27 deletions.
46 changes: 19 additions & 27 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,8 +3974,6 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType:


# Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918


def aten_hstack(tensors: Sequence[TTensor]) -> TTensor:
"""hstack(Tensor[] tensors) -> Tensor"""

Expand Down Expand Up @@ -7887,14 +7885,14 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr
return op.ConcatFromSequence(tensors, axis=dim, new_axis=1)


@torch_op("aten::std", trace_only=True)
# std is decomposed by PyTroch
def aten_std(self: TReal, unbiased: bool = True) -> TReal:
"""std(Tensor self, bool unbiased=True) -> Tensor"""
var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
return op.Sqrt(var)


@torch_op("aten::std.dim", trace_only=True)
# std_dim is decomposed by PyTroch
def aten_std_dim(
self: TReal,
dim: Sequence[int],
Expand All @@ -7907,7 +7905,7 @@ def aten_std_dim(
return op.Sqrt(var)


@torch_op("aten::var.correction", trace_only=True)
# std is decomposed by PyTroch
def aten_std_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -7927,7 +7925,7 @@ def aten_std_correction(
return op.Sqrt(var)


@torch_op("aten::std_mean", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""

Expand All @@ -7937,7 +7935,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
return op.Sqrt(var), mean


@torch_op("aten::std_mean.dim", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -7951,7 +7949,7 @@ def aten_std_mean_dim(
return op.Sqrt(var), mean


@torch_op("aten::std_mean.correction", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -7973,7 +7971,6 @@ def aten_std_mean_correction(
return op.Sqrt(var), mean


@torch_op("aten::stft", private=True)
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
signal_rank = Rank(self)
if signal_rank == 1:
Expand All @@ -7982,7 +7979,6 @@ def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT6
return op.Identity(self), signal_rank


@torch_op("aten::stft", private=True)
def _center_window_around_zeros_if_needed(
window: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
Expand All @@ -8004,7 +8000,6 @@ def _center_window_around_zeros_if_needed(
return window


@torch_op("aten::stft", private=True)
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
left = (n_fft - win_length) / 2

Expand All @@ -8019,14 +8014,12 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
return op.Concat(left_win, window_list, right_win, axis=0)


@torch_op("aten::stft", private=True)
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
return window


@torch_op("aten::stft", private=True)
def _normalize_fft_result(
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
Expand All @@ -8036,7 +8029,6 @@ def _normalize_fft_result(
return result


@torch_op("aten::stft", private=True)
def _aten_stft_onnx(
signal: TFloatOrBFloat16,
frame_step_const: INT64,
Expand All @@ -8054,7 +8046,8 @@ def _aten_stft_onnx(
return result


@torch_op("aten::stft", trace_only=True)
# aten::stft is decomposed. The implementation is not complete and is left for
# reference only.
def aten_stft(
self: TFloatOrBFloat16,
n_fft: int,
Expand Down Expand Up @@ -8738,7 +8731,7 @@ def aten_vander(
raise NotImplementedError()


@torch_op("aten::var", trace_only=True)
# var is decomposed by PyTroch
def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
"""var(Tensor self, bool unbiased=True) -> Tensor"""

Expand All @@ -8747,7 +8740,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
return _aten_var_onnx(self, correction=float(unbiased), keepdim=False)


@torch_op("aten::var.dim", trace_only=True)
# var is decomposed by PyTroch
def aten_var_dim(
self: TReal,
dim: Sequence[int],
Expand All @@ -8759,7 +8752,7 @@ def aten_var_dim(
return _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)


@torch_op("aten::var.correction", trace_only=True)
# var is decomposed by PyTroch
def aten_var_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -8779,7 +8772,7 @@ def aten_var_correction(
return var


@torch_op("aten::var", private=True, traceable=True)
# var is decomposed by PyTroch
def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal:
mean = op.ReduceMean(self, keepdims=keepdim)
sub_mean = op.Sub(self, mean)
Expand All @@ -8796,7 +8789,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe
return var


@torch_op("aten::var.dim", private=True, traceable=True)
# var is decomposed by PyTroch
def _aten_var_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> TReal:
Expand All @@ -8817,7 +8810,7 @@ def _aten_var_dim_onnx(
return var


@torch_op("aten::var_mean", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""

Expand All @@ -8826,7 +8819,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
return _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False)


@torch_op("aten::var_mean.dim", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -8837,7 +8830,7 @@ def aten_var_mean_dim(
return _aten_var_mean_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)


@torch_op("aten::var_mean.correction", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -8859,7 +8852,7 @@ def aten_var_mean_correction(
return var, mean


@torch_op("aten::var_mean", private=True)
# var_mean is decomposed by PyTroch
def _aten_var_mean_onnx(
self: TReal, correction: float = 1.0, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -8879,7 +8872,7 @@ def _aten_var_mean_onnx(
return var, mean


@torch_op("aten::var_mean.dim", private=True)
# var_mean is decomposed by PyTroch
def _aten_var_mean_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand Down Expand Up @@ -8977,8 +8970,6 @@ def aten_view_copy(self: TTensor, size: IntType) -> TTensor:


# Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918


def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
"""vstack(Tensor[] tensors) -> Tensor"""

Expand All @@ -8998,6 +8989,7 @@ def reshape_to_2d(tensor):

@torch_op(
(
"aten::where",
"aten::where.Scalar",
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
Expand Down

0 comments on commit 2c2437d

Please sign in to comment.