Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit parameters for torch.load #6751

Merged
merged 11 commits into from
Nov 19, 2024
17 changes: 10 additions & 7 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files')

def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch this value to True for safety where needed.

self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)

Expand All @@ -137,14 +137,17 @@ def get_final_norm_layer_id(self):

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)

return self.global_state[ITERATION_KEY]

def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd_list = [
torch.load(fname, map_location=torch.device('cpu'), weights_only=False)
for fname in self.tp_to_embedding_map[tp_index]
]
sd = self._merge_state_dicts(sd_list)
return sd

Expand All @@ -154,7 +157,7 @@ def get_embedding_files(self, tp_index: int) -> list:

def _get_checkpoint_value(self, key):
if not key in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[key] = sd.get(key, None)

return self.global_state[key]
Expand All @@ -169,7 +172,7 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]

merged_sd = None
for sd in sd_list:
Expand All @@ -185,7 +188,7 @@ def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
Expand All @@ -196,7 +199,7 @@ def get_pp_transformer_map(self, pp_index: int) -> list:

def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False)
return sd

def get_final_norm_files(self, tp_index: int) -> list:
Expand Down
14 changes: 7 additions & 7 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):


def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)

flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
Expand Down Expand Up @@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
raise ValueError(f"Cannot parse dp_rank from {p}")

paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]
shards = [torch.load(p, weights_only=False) for p in paths]

if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
Expand Down Expand Up @@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size):


def _parse_model_states_stage3(files):
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES]


def _save_optimizer_state(args, ds_checkpoint):
Expand All @@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint):


def _save_optimizer_state_stage3(args, optim_files):
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
output_sd = sd[OPTIMIZER_STATE_DICT]
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero")
Expand All @@ -446,15 +446,15 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern):


def _get_zero_stage(optim_files):
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
zero_stage = optimizer_state.get(ZERO_STAGE, 1)
return zero_stage


def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
Expand Down Expand Up @@ -488,7 +488,7 @@ def main(args):

slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False)
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
step = None
for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
ckpt_dict = torch.load(ckpt_file, weights_only=False)

if key == "step":
step = ckpt_dict
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], st
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None
for state_file in state_file_list:
sd = torch.load(state_file, map_location=torch.device('cpu'))
sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False)
for key in keys_to_ignore:
sd.pop(key, None)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,15 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
checkpoint = sd_loader['checkpoints']

if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False)
self.key_list = list(self.sd.keys())

self.load_model_with_checkpoint(self.module)

for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False)
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool:
else:
model_param_json_fname = "pytorch_model.bin.index.json"
model_file_fname = "pytorch_model.bin"
self._checkpoint_load_fn = partial(torch.load, map_location="cpu")
self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False)

model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def populate_model_parameters(self) -> None:
buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)

buffer = torch.load(buffer_path)
buffer = torch.load(buffer_path, weights_only=False)
metadata = json.load(open(metadata_path, "r"))
metadata = ModelMetadata.parse_raw(metadata)

Expand Down
8 changes: 4 additions & 4 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")

for i in range(len(checkpoint)):
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module,
sd,
mp_replace,
Expand All @@ -437,7 +437,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
for j in range(sd_count)
]
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand All @@ -457,7 +457,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar.update(1)
ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand Down Expand Up @@ -624,7 +624,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
from safetensors.torch import load_file
sd = load_file(checkpoint)
else:
sd = torch.load(checkpoint, map_location='cpu')
sd = torch.load(checkpoint, map_location='cpu', weights_only=False)

policy = {}
if orig_class is not None:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=False)

self._load_global_state(optim_sd)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def load(self, path: str, map_location=None):
if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag
logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
return partition

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def save(self, state_dict, path: str):

def load(self, path: str, map_location=None):
logger.info(f"[Torch] Loading checkpoint from {path}...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Torch] Loaded checkpoint from {path}.")
return partition

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'

optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)

key_list = ["fp32", "exp_avg", "exp_avg_sq"]
Expand Down Expand Up @@ -2799,7 +2799,7 @@ def load_hp_checkpoint_state(self, folder, key):
local_rank = dist.get_local_rank()

# Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1)
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)

# Partition the loaded data according to the local rank
world_size = dist.get_world_size(group=self.dp_process_group)
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_model_state_files(checkpoint_dir):
def parse_model_states(files):
zero_model_states = []
for file in files:
state_dict = torch.load(file, map_location=device)
state_dict = torch.load(file, map_location=device, weights_only=False)

if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
Expand Down Expand Up @@ -147,7 +147,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
state_dict = torch.load(f, map_location=device, weights_only=False)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict,
for root, _, files in os.walk(save_folder):
for f in files:
if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f))
expert = torch.load(os.path.join(root, f), weights_only=True)
needed, storages = 0, {}
for name, tensor in expert.items():
needed += tensor.size().numel()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/test_universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
)

hidden_dim = 10
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt")
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=True)

ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim)
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l
model.load_checkpoint(tmpdir, load_optimizer_states=load_optim)

if load_optim:
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file))
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False)
curr_sd = model.optimizer.optimizer.state_dict()
compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys)

Expand Down Expand Up @@ -523,7 +523,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
all_ckpt_folder = os.path.join(tmpdir, 'all_params')
ds_engine.save_checkpoint(all_ckpt_folder)
all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00')
loaded_all_param_model = torch.load(all_params_ckpt_file)['module']
loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=True)['module']
all_param_names = set([n for n, p in model.named_parameters()])
assert set(loaded_all_param_model.keys()) == all_param_names

Expand All @@ -536,7 +536,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
# Excluding frozen parameters should reduce checkpoint size
assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file)

loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module']
loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=True)['module']
frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad])
loaded_trainable_param_names = set(loaded_trainable_param_model.keys())
overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names)
Expand Down Expand Up @@ -575,7 +575,7 @@ def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage):

custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank(
os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00')
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file)['module']
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=True)['module']
loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys())

custom_state_dict_param_names = set([k for k, v in model.state_dict().items()])
Expand Down Expand Up @@ -618,7 +618,8 @@ def test_save_tensor_clone(self, tmpdir, zero_stage, use_cpu_device):
clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt')
torch.save(clone_state_dict, clone_ckpt_file)

compare_state_dicts(torch.load(ref_ckpt_file), torch.load(clone_ckpt_file))
compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False), torch.load(clone_ckpt_file,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HPU tests failed when this was set to True:

FAILED unit/checkpoint/test_zero_optimizer.py::TestSaveTensorClone::test_save_tensor_clone[False-1] - _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
 Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 149

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
FAILED unit/checkpoint/test_zero_optimizer.py::TestSaveTensorClone::test_save_tensor_clone[True-2] - _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
 Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 149

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
FAILED unit/checkpoint/test_zero_optimizer.py::TestSaveTensorClone::test_save_tensor_clone[True-1] - _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
 Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 149

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
FAILED unit/checkpoint/test_zero_optimizer.py::TestSaveTensorClone::test_save_tensor_clone[False-2] - _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
 Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 149

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

weights_only=False))


class TestZeRONonDistributed(DistributedTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test(self, baseline_mp2, inputs, class_tmpdir):
test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
if dist.get_rank() == 0:
load_path = os.path.join(class_tmpdir, "output.pt")
baseline = torch.load(load_path)
baseline = torch.load(load_path, weights_only=True)
test = test.cpu()
assert torch.allclose(
baseline, test,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resiz
assert torch.is_tensor(test[0][0])
test = test[0][0].cpu()
load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
baseline = torch.load(load_path)
baseline = torch.load(load_path, weights_only=True)
assert torch.allclose(
baseline, test,
atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"
Expand Down
Loading