diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 8706117f6a8..6cb20d920c4 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -508,3 +508,17 @@ def atexit_hook(hook): if oneflow._oneflow_internal.flags.with_mlir(): oneflow_internal_path = oneflow._oneflow_internal.__file__ oneflow._oneflow_internal.ir.load_jit_shared_lib(oneflow_internal_path) + + +def flow_ones(self, *args, **kwargs): + return ones(*args, **kwargs, device=self.device, dtype=self.dtype) + + +Tensor.new_ones = flow_ones + + +def flow_zeros(self, *args, **kwargs): + return zeros(*args, **kwargs, device=self.device, dtype=self.dtype) + + +Tensor.new = flow_zeros