diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index 8675b825d..7f442a68a 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -1,4 +1,5 @@ from typing import List +from unittest import mock import pytest import torch @@ -8,6 +9,7 @@ from lorax_server.utils.sgmv import MIN_RANK_CUSTOM +@mock.patch("lorax_server.utils.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))) @pytest.mark.parametrize("lora_ranks", [ [8, 16], [32, 64],