From e994c64568585adc40928e91437c18becb903f37 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 16 Dec 2024 10:36:28 +0800 Subject: [PATCH] [checkpointio] fix async io (#6155) --- colossalai/checkpoint_io/general_checkpoint_io.py | 3 +-- colossalai/checkpoint_io/utils.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 3bb805131276..54da168e54d0 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,8 +8,6 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.utils.safetensors import move_and_save - from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -54,6 +52,7 @@ def save_unsharded_model( pass if use_async: + from colossalai.utils.safetensors import move_and_save if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 5ef0bd354e93..ab599b556937 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -19,7 +19,6 @@ to_global, to_global_for_customized_distributed_tensor, ) -from colossalai.utils.safetensors import move_and_save SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -289,6 +288,7 @@ def async_save_state_dict_shards( Returns: int: the total size of shards """ + from colossalai.utils.safetensors import move_and_save total_size = 0 shard_filenames = []