diff --git a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py index baf22913..a86b705d 100644 --- a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py +++ b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py @@ -51,7 +51,7 @@ def __init__(self, model_name_or_path, **kwargs): self.model, logger, disable=self._target_device.type == "mps" ) - if self.device.type == "cuda" and not os.environ.get( + if self._target_device.type == "cuda" and not os.environ.get( "INFINITY_DISABLE_HALF", "" ): logger.info(