From 4572b5ac21a1942067c4b0cec290fea04480f2f7 Mon Sep 17 00:00:00 2001 From: sunyaqiang Date: Fri, 22 Aug 2025 03:01:26 +0000 Subject: [PATCH 1/2] fix FutureWarning: torch.cuda.amp.autocast(args...) is deprecated --- detectron2/engine/train_loop.py | 37 ++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py index 738a69de94..ceaa61b468 100644 --- a/detectron2/engine/train_loop.py +++ b/detectron2/engine/train_loop.py @@ -469,9 +469,14 @@ def __init__( ) if grad_scaler is None: - from torch.cuda.amp import GradScaler + if torch.__version__ >= "2.4.0": + from torch.amp import GradScaler - grad_scaler = GradScaler() + grad_scaler = GradScaler('cuda') + else: + from torch.cuda.amp import GradScaler + + grad_scaler = GradScaler() self.grad_scaler = grad_scaler self.precision = precision self.log_grad_scaler = log_grad_scaler @@ -482,7 +487,10 @@ def run_step(self): """ assert self.model.training, "[AMPTrainer] model was changed to eval mode!" assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" - from torch.cuda.amp import autocast + if torch.__version__ >= "2.4.0": + from torch.amp import autocast + else: + from torch.cuda.amp import autocast start = time.perf_counter() data = next(self._data_loader_iter) @@ -490,13 +498,22 @@ def run_step(self): if self.zero_grad_before_forward: self.optimizer.zero_grad() - with autocast(dtype=self.precision): - loss_dict = self.model(data) - if isinstance(loss_dict, torch.Tensor): - losses = loss_dict - loss_dict = {"total_loss": loss_dict} - else: - losses = sum(loss_dict.values()) + if torch.__version__ >= "2.4.0": + with autocast('cuda', dtype=self.precision): + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + else: + with autocast(dtype=self.precision): + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) if not self.zero_grad_before_forward: self.optimizer.zero_grad() From 19755232e71d9b7cc1076cdc72df3a7112ed58bf Mon Sep 17 00:00:00 2001 From: sunyaqiang Date: Fri, 22 Aug 2025 03:05:01 +0000 Subject: [PATCH 2/2] fit for linter --- detectron2/engine/train_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py index ceaa61b468..06731b7565 100644 --- a/detectron2/engine/train_loop.py +++ b/detectron2/engine/train_loop.py @@ -472,7 +472,7 @@ def __init__( if torch.__version__ >= "2.4.0": from torch.amp import GradScaler - grad_scaler = GradScaler('cuda') + grad_scaler = GradScaler("cuda") else: from torch.cuda.amp import GradScaler @@ -499,13 +499,13 @@ def run_step(self): if self.zero_grad_before_forward: self.optimizer.zero_grad() if torch.__version__ >= "2.4.0": - with autocast('cuda', dtype=self.precision): + with autocast("cuda", dtype=self.precision): loss_dict = self.model(data) if isinstance(loss_dict, torch.Tensor): losses = loss_dict loss_dict = {"total_loss": loss_dict} else: - losses = sum(loss_dict.values()) + losses = sum(loss_dict.values()) else: with autocast(dtype=self.precision): loss_dict = self.model(data)