Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi committed Dec 18, 2024
1 parent b44ebea commit 045e3ed
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/otx/algo/common/losses/cross_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
import torch.nn.functional
from otx.utils.device import get_available_device
from torch import Tensor, nn
from torch.cuda.amp import custom_fwd

Expand Down Expand Up @@ -80,7 +79,7 @@ def __init__(

self.cls_criterion = cross_sigmoid_focal_loss

@custom_fwd(device_type=get_available_device(), cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32)
def forward(
self,
pred: Tensor,
Expand Down

0 comments on commit 045e3ed

Please sign in to comment.