From 045e3ed6d05e245c21776a92093c13b872a2b7d2 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Thu, 19 Dec 2024 04:10:32 +0900 Subject: [PATCH] fix unit test --- src/otx/algo/common/losses/cross_focal_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/otx/algo/common/losses/cross_focal_loss.py b/src/otx/algo/common/losses/cross_focal_loss.py index 7744a9e5117..bfec15c0c84 100644 --- a/src/otx/algo/common/losses/cross_focal_loss.py +++ b/src/otx/algo/common/losses/cross_focal_loss.py @@ -7,7 +7,6 @@ import torch import torch.nn.functional -from otx.utils.device import get_available_device from torch import Tensor, nn from torch.cuda.amp import custom_fwd @@ -80,7 +79,7 @@ def __init__( self.cls_criterion = cross_sigmoid_focal_loss - @custom_fwd(device_type=get_available_device(), cast_inputs=torch.float32) + @custom_fwd(cast_inputs=torch.float32) def forward( self, pred: Tensor,