Skip to content

Commit

Permalink
Update deprecated PyTorch functions in fbcode/deeplearning (#473)
Browse files Browse the repository at this point in the history
Summary:
X-link: fairinternal/CrypTen#250

Pull Request resolved: #473

Update some deprecated PyTorch APIs:
ger -> outer
range -> arange
functorch -> torch.func

Reviewed By: malfet

Differential Revision: D46615036

fbshipit-source-id: 81d5dd58239dd23b49b24a082b618d7d108422c2
  • Loading branch information
kit1980 authored and facebook-github-bot committed Jun 13, 2023
1 parent 6ef1511 commit f579c2f
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion crypten/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def chebyshev_series(func, width, terms):
n_range = torch.arange(start=0, end=terms).float()
x = width * torch.cos((n_range + 0.5) * np.pi / terms)
y = func(x)
cos_term = torch.cos(torch.ger(n_range, n_range + 0.5) * np.pi / terms)
cos_term = torch.cos(torch.outer(n_range, n_range + 0.5) * np.pi / terms)
coeffs = (2 / terms) * torch.sum(y * cos_term, axis=1)
return coeffs

Expand Down
2 changes: 1 addition & 1 deletion examples/bandits/plain_contextual_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def online_learner(

# update linear least squares accumulators (using Sherman–Morrison formula):
A_inv_context = A_inv[selected_arm, :, :].mv(context)
numerator = torch.ger(A_inv_context, A_inv_context)
numerator = torch.outer(A_inv_context, A_inv_context)
denominator = A_inv_context.dot(context).add(1.0)
A_inv[selected_arm, :, :].sub_(numerator.div_(denominator))
b[selected_arm, :].add_(context.mul(reward))
Expand Down
2 changes: 1 addition & 1 deletion test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_dot_ger(self):
tensor1 = get_random_test_tensor(is_float=True).squeeze()
tensor2 = get_random_test_tensor(is_float=True).squeeze()
dot_reference = tensor1.dot(tensor2)
ger_reference = torch.ger(tensor1, tensor2)
ger_reference = torch.outer(tensor1, tensor2)

tensor2 = tensor_type(tensor2)

Expand Down
2 changes: 1 addition & 1 deletion test/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test_dot_ger(self):
tensor1 = self._get_random_test_tensor(is_float=True).squeeze()
tensor2 = self._get_random_test_tensor(is_float=True).squeeze()
dot_reference = tensor1.dot(tensor2)
ger_reference = torch.ger(tensor1, tensor2)
ger_reference = torch.outer(tensor1, tensor2)

tensor2 = tensor_type(tensor2)

Expand Down

0 comments on commit f579c2f

Please sign in to comment.