From e1aa2b9a8ea8cbe42c80a89363b2212511ef6a93 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Tue, 20 Feb 2024 14:47:31 -0800 Subject: [PATCH] =?UTF-8?q?Revert=20"Fix=20UserWarning:=20The=20torch.cuda?= =?UTF-8?q?.*DtypeTensor=20constructors=20are=20no=20long=E2=80=A6=20(#501?= =?UTF-8?q?8)"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 177dc14331a64e61f6dcce2c4b8071576bcb22db. --- accelerator/cuda_accelerator.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 3d5e9c168c16..2030f36631e9 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import functools import os import pkgutil import importlib @@ -261,31 +260,31 @@ def replay_graph(self, graph): @property def BFloat16Tensor(self): - return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda') + return torch.cuda.BFloat16Tensor @property def ByteTensor(self): - return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda') + return torch.cuda.ByteTensor @property def DoubleTensor(self): - return functools.partial(torch.tensor, dtype=torch.double, device='cuda') + return torch.cuda.DoubleTensor @property def FloatTensor(self): - return functools.partial(torch.tensor, dtype=torch.float, device='cuda') + return torch.cuda.FloatTensor @property def HalfTensor(self): - return functools.partial(torch.tensor, dtype=torch.half, device='cuda') + return torch.cuda.HalfTensor @property def IntTensor(self): - return functools.partial(torch.tensor, dtype=torch.int, device='cuda') + return torch.cuda.IntTensor @property def LongTensor(self): - return functools.partial(torch.tensor, dtype=torch.long, device='cuda') + return torch.cuda.LongTensor def pin_memory(self, tensor, align_bytes=1): return tensor.pin_memory()