Skip to content

Commit

Permalink
[checkpointio] fix async io (#6155)
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 authored Dec 16, 2024
1 parent de3d371 commit e994c64
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 1 addition & 2 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.utils.safetensors import move_and_save

from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
Expand Down Expand Up @@ -54,6 +52,7 @@ def save_unsharded_model(
pass

if use_async:
from colossalai.utils.safetensors import move_and_save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import move_and_save

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -289,6 +288,7 @@ def async_save_state_dict_shards(
Returns:
int: the total size of shards
"""
from colossalai.utils.safetensors import move_and_save

total_size = 0
shard_filenames = []
Expand Down

0 comments on commit e994c64

Please sign in to comment.