Skip to content

Commit

Permalink
Remove legacy constructor calls from pytorch codebase. (pytorch#54142)
Browse files Browse the repository at this point in the history
Summary:
Follow up from pytorch#53889
Related to pytorch#47112

Removing every occurrence of the legacy constructor call present in PyTorch at:
- _docs_
- _benchmarks_
- _test_
- _caffe2_
- _CONTRIBUTING.md_

Pull Request resolved: pytorch#54142

Reviewed By: ngimel

Differential Revision: D27699450

Pulled By: mruberry

fbshipit-source-id: 530aa3f5746cc8bc1407d5d51b2bbd8075e30546
  • Loading branch information
ysiraichi authored and facebook-github-bot committed Apr 11, 2021
1 parent fa29a64 commit 93bf0ae
Show file tree
Hide file tree
Showing 40 changed files with 350 additions and 351 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ If you are working on the CUDA code, here are some useful CUDA debugging tips:
nbytes_read_write = 4 # this is number of bytes read + written by a kernel. Change this to fit your kernel.

for i in range(10):
a=torch.Tensor(size).cuda().uniform_()
a=torch.empty(size).cuda().uniform_()
torch.cuda.synchronize()
start = time.time()
# dry run to alloc
Expand Down
10 changes: 5 additions & 5 deletions benchmarks/overrides_benchmark/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def bench(t1, t2):
torch.add(t1, t2)
bench_times.append(time.time() - time_start)

bench_time = float(torch.min(torch.Tensor(bench_times))) / 1000
bench_std = float(torch.std(torch.Tensor(bench_times))) / 1000
bench_time = float(torch.min(torch.tensor(bench_times))) / 1000
bench_std = float(torch.std(torch.tensor(bench_times))) / 1000

return bench_time, bench_std

Expand Down Expand Up @@ -48,11 +48,11 @@ def main():
NUM_REPEATS = args.nreps
NUM_REPEAT_OF_REPEATS = args.nrepreps

types = torch.Tensor, SubTensor, WithTorchFunction, SubWithTorchFunction
types = torch.tensor, SubTensor, WithTorchFunction, SubWithTorchFunction

for t in types:
tensor_1 = t(1)
tensor_2 = t(2)
tensor_1 = t([1.])
tensor_2 = t([2.])

bench_min, bench_std = bench(tensor_1, tensor_2)
print(
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/overrides_benchmark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, data, requires_grad=False):
self._tensor = data
return

self._tensor = torch.Tensor(data, requires_grad)
self._tensor = torch.tensor(data, requires_grad=requires_grad)

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/overrides_benchmark/pyspybench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
from common import SubTensor, WithTorchFunction, SubWithTorchFunction # noqa: F401

Tensor = torch.Tensor
Tensor = torch.tensor

NUM_REPEATS = 1000000

Expand All @@ -21,8 +21,8 @@
TensorClass = globals()[args.tensor_class]
NUM_REPEATS = args.nreps

t1 = TensorClass(1)
t2 = TensorClass(2)
t1 = TensorClass([1.])
t2 = TensorClass([2.])

for _ in range(NUM_REPEATS):
torch.add(t1, t2)
4 changes: 2 additions & 2 deletions caffe2/python/operator_test/heatmap_max_keypoint_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def heatmap_approx_keypoint_ref(maps, rois):

def c10_op_ref(maps, rois):
keypoints = torch.ops._caffe2.HeatmapMaxKeypoint(
torch.Tensor(maps),
torch.Tensor(rois),
torch.tensor(maps),
torch.tensor(rois),
should_output_softmax=True,
)
return [keypoints.numpy()]
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/operator_test/layer_norm_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def jit_layer_norm(
expected_norm, expected_mean, expected_std = \
_layer_norm_with_affine_ref(axis, eps, X, gamma, beta)
actual_norm, actual_mean, actual_std = jit_layer_norm(
torch.Tensor(X), torch.tensor(gamma), torch.tensor(beta),
torch.tensor(X), torch.tensor(gamma), torch.tensor(beta),
axis, eps, elementwise_affine)
else:
expected_norm, expected_mean, expected_std = _layer_norm_ref(
Expand Down
24 changes: 12 additions & 12 deletions caffe2/python/operator_test/torch_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ def roi_align_ref(_feature, _rois):

roi_feature_ref = roi_align_ref(feature, rois)
roi_feature = torch.ops._caffe2.RoIAlign(
torch.Tensor(feature).to(device),
torch.Tensor(rois).to(device),
torch.tensor(feature).to(device),
torch.tensor(rois).to(device),
order="NCHW",
spatial_scale=1.0,
pooled_h=3,
Expand Down Expand Up @@ -615,8 +615,8 @@ def roi_align_ref(_feature, _rois):

roi_feature_ref = roi_align_ref(feature, rois)
roi_feature = torch.ops._caffe2.RoIAlignRotated(
torch.Tensor(feature).to(device),
torch.Tensor(rois).to(device),
torch.tensor(feature).to(device),
torch.tensor(rois).to(device),
order="NCHW",
spatial_scale=1.0,
pooled_h=3,
Expand All @@ -639,7 +639,7 @@ def test_collect_and_distribute_fpn_rpn_proposals_op(self, roi_counts):
im_dims = np.random.randint(100, 600, batch_size)
rpn_rois_and_scores = []
for i in range(5):
rpn_rois_and_scores.append(torch.Tensor(generate_rois(roi_counts, im_dims)))
rpn_rois_and_scores.append(torch.tensor(generate_rois(roi_counts, im_dims)))
for i in range(5):
rpn_rois_and_scores.append(torch.rand(sum(roi_counts)))

Expand Down Expand Up @@ -842,16 +842,16 @@ def _piecewise_linear_ref(X):

def test_alias_with_name_is_in_place(self):
device = "cuda" if workspace.has_cuda_support else "cpu"
x = torch.Tensor([3, 42]).to(device)
x = torch.tensor([3., 42.]).to(device=device)
y = torch.ops._caffe2.AliasWithName(x, "new_name")
x[1] = 6
torch.testing.assert_allclose(x, torch.Tensor([3, 6]).to(device))
torch.testing.assert_allclose(x, torch.tensor([3., 6.]).to(device=device))
# y should also change because y is alias of x
torch.testing.assert_allclose(y, torch.Tensor([3, 6]).to(device))
torch.testing.assert_allclose(y, torch.tensor([3., 6.]).to(device=device))

@unittest.skipIf(not workspace.has_cuda_support, "No cuda support")
def test_copy_between_cpu_and_gpu(self):
x_cpu_ref = torch.Tensor([1, 2, 3])
x_cpu_ref = torch.tensor([1., 2., 3.])
x_gpu_ref = x_cpu_ref.to("cuda")

x_gpu = torch.ops._caffe2.CopyCPUToGPU(x_cpu_ref)
Expand Down Expand Up @@ -923,8 +923,8 @@ def _percentile_ref(original_values, value_to_pct, lengths):
expected_output = _percentile_ref(original_values, value_to_pct, lengths)
actual_output = torch.ops._caffe2.Percentile(
torch.tensor(original_values),
torch.Tensor(value_to_pct),
torch.Tensor(lengths).int(),
torch.tensor(value_to_pct),
torch.tensor(lengths),
)
torch.testing.assert_allclose(expected_output, actual_output.cpu())

Expand All @@ -945,7 +945,7 @@ def _batch_bucket_one_hot_ref(data, lengths, boundaries):

expected_output = _batch_bucket_one_hot_ref(data, lengths, boundaries)
actual_output = torch.ops._caffe2.BatchBucketOneHot(
torch.tensor(data), torch.Tensor(lengths).int(), torch.Tensor(boundaries)
torch.tensor(data), torch.tensor(lengths), torch.tensor(boundaries)
)
torch.testing.assert_allclose(expected_output, actual_output.cpu())

Expand Down
2 changes: 1 addition & 1 deletion docs/source/jit_language_reference_v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,7 @@ The above code results in the below RuntimeError
a : torch.jit.final[Bool] = True

if a:
return torch.Tensor(2,3)
return torch.empty(2,3)
else:
return []

Expand Down
8 changes: 4 additions & 4 deletions docs/source/notes/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ This is how a ``Linear`` module can be implemented::
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
self.bias = nn.Parameter(torch.empty(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
Expand Down Expand Up @@ -494,7 +494,7 @@ will return subclass instances instead of ``torch.Tensor`` instances::
... pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.Tensor([1]))).__name__
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'

If multiple subclasses exist, the lowest one in the hierarchy will be chosen by
Expand All @@ -503,7 +503,7 @@ default. If there is no unique way to determine such a case, then a

>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.Tensor([1]))).__name__
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tensor_attributes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ to perform many tensor operations efficiently.

Example::

>>> x = torch.Tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
>>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
>>> x.stride()
(5, 1)

Expand Down
2 changes: 1 addition & 1 deletion test/benchmark_utils/test_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def test_fuzzer(self):
for i, (tensors, _, _) in enumerate(fuzzer.take(2)):
x = tensors["x"]
self.assertEqual(
x, torch.Tensor(expected_results[i]), rtol=1e-3, atol=1e-3)
x, torch.tensor(expected_results[i]), rtol=1e-3, atol=1e-3)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/api/optim_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run(optimizer_name, iterations, sample_every):
loss.backward()

def closure():
return torch.Tensor([10])
return torch.tensor([10.])
optimizer.step(closure)

if i % sample_every == 0:
Expand Down
2 changes: 1 addition & 1 deletion test/jit/test_isinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def list_tensor_test(x: Any):
assert torch.jit.isinstance(x, List[torch.Tensor])
assert not torch.jit.isinstance(x, Tuple[int])

x = [torch.Tensor([1]), torch.Tensor([2]), torch.Tensor([3])]
x = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])]
self.checkScript(list_tensor_test, (x,))

def test_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion test/jit/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,7 @@ def test_trace_parameter(self):
class Param(nn.Module):
def __init__(self):
super(Param, self).__init__()
self.register_parameter("bias", nn.Parameter(torch.Tensor(4, 4)))
self.register_parameter("bias", nn.Parameter(torch.empty(4, 4)))

def forward(self, x):
return x
Expand Down
4 changes: 2 additions & 2 deletions test/onnx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ def test_densenet(self):
def test_dcgan_netD(self):
netD = _netD(1)
netD.apply(weights_init)
input = Variable(torch.Tensor(bsz, 3, imgsz, imgsz).normal_(0, 1))
input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
self.exportTest(toC(netD), toC(input))

@disableScriptTest()
def test_dcgan_netG(self):
netG = _netG(1)
netG.apply(weights_init)
input = Variable(torch.Tensor(bsz, nz, 1, 1).normal_(0, 1))
input = Variable(torch.empty(bsz, nz, 1, 1).normal_(0, 1))
self.exportTest(toC(netG), toC(input))

@skipIfUnsupportedMinOpsetVersion(10)
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,7 @@ def forward(self, feature, im_info, anchors):
)
return output

feature = torch.Tensor(img_count, A, H, W)
feature = torch.empty(img_count, A, H, W)
im_info = torch.ones(img_count, 3, dtype=torch.float32)
anchors = torch.ones(A, 4, dtype=torch.float32)
inputs = (feature, im_info, anchors)
Expand Down
14 changes: 7 additions & 7 deletions test/quantization/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
class LinearReluFunctional(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.Tensor(4, 4))
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))

Expand Down Expand Up @@ -82,7 +82,7 @@ def test_simple_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.Tensor(1, 4))
self.w = nn.Parameter(torch.empty(1, 4))
self.b = nn.Parameter(torch.zeros(1))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))

Expand Down Expand Up @@ -449,7 +449,7 @@ def test_extract_weights_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.Tensor(4, 4))
self.w = nn.Parameter(torch.empty(4, 4))
self.b = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))

Expand Down Expand Up @@ -482,9 +482,9 @@ def test_match_activations_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.Tensor(4, 4))
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
self.w2 = nn.Parameter(torch.Tensor(4, 4))
self.w2 = nn.Parameter(torch.empty(4, 4))
self.b2 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
Expand Down Expand Up @@ -518,9 +518,9 @@ def test_add_shadow_loggers_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.Tensor(4, 4))
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
self.w2 = nn.Parameter(torch.Tensor(4, 4))
self.w2 = nn.Parameter(torch.empty(4, 4))
self.b2 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
Expand Down
6 changes: 3 additions & 3 deletions test/quantization/test_qat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(self,
self.momentum = momentum
self.freeze_bn = freeze_bn if self.training else True
self.num_features = out_channels
self.gamma = nn.Parameter(torch.Tensor(out_channels))
self.beta = nn.Parameter(torch.Tensor(out_channels))
self.gamma = nn.Parameter(torch.empty(out_channels))
self.beta = nn.Parameter(torch.empty(out_channels))
self.affine = True
self.track_running_stats = True
self.register_buffer('running_mean', torch.zeros(out_channels))
Expand All @@ -53,7 +53,7 @@ def __init__(self,
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_workflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def test_zero_numel(self):
obs = obs_cls(0.1, 0)
else:
obs = obs_cls()
x = torch.Tensor()
x = torch.tensor([])
# verify no crash
x = obs(x)

Expand Down
16 changes: 8 additions & 8 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ def compare(x, y, idx, indexed_tensor, indexed_var):
self.assertEqual(indexed_tensor, indexed_var_t)

indexed_var.sum().backward()
expected_grad = torch.Tensor(x.size()).fill_(0)
expected_grad = torch.empty(x.size()).fill_(0)
expected_grad[idx] = 1
self.assertEqual(y.grad, expected_grad)

Expand Down Expand Up @@ -1595,18 +1595,18 @@ def test_indexing_duplicates(self):
y = Variable(x, requires_grad=True)
idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
y[idx].sum().backward()
expected_grad = torch.Tensor([[0, 2, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0]])
expected_grad = torch.tensor([[0., 2., 0., 0.],
[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 0.]])
self.assertEqual(y.grad, expected_grad)

x = torch.arange(1., 65).view(4, 4, 4)
y = Variable(x, requires_grad=True)

idx = [[1, 1, 1], slice(None), slice(None)]
y[idx].sum().backward()
expected_grad = torch.Tensor(4, 4, 4).zero_()
expected_grad = torch.empty(4, 4, 4).zero_()
expected_grad[1].fill_(3)
self.assertEqual(y.grad, expected_grad)

Expand Down Expand Up @@ -1849,7 +1849,7 @@ def test_inplace(self):
r.backward(torch.ones(5, 5), retain_graph=True)
self.assertEqual(x.grad, torch.ones(5, 5) / 2)
w.backward(torch.ones(5, 5), retain_graph=True)
self.assertEqual(x.grad, torch.Tensor(5, 5).fill_((1 + math.e) / 2))
self.assertEqual(x.grad, torch.empty(5, 5).fill_((1 + math.e) / 2))
self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))

leaf = torch.ones(5, 5, requires_grad=True)
Expand Down Expand Up @@ -4311,7 +4311,7 @@ def test_checkpointing(self):

feat_combined = []
for r in range(num_inp):
data_r = torch.Tensor(1, nz_inp)
data_r = torch.empty(1, nz_inp)
data_r.uniform_()
data_r.requires_grad = True
feat_r = checkpoint(module, data_r)
Expand Down
Loading

0 comments on commit 93bf0ae

Please sign in to comment.