-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed SaveLoad implementation for semi-auto strategy #59659
Conversation
… dist_save_load
… dist_save_load
… dist_save_load
… dist_save_load
你的PR提交成功,感谢你对开源项目的贡献! |
paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 | ||
): | ||
raise ValueError( | ||
f"use_dist is False, please set coordinator_rank to 0 and paddle.distributed.get_world_size() to 1, world_size:{paddle.distributed.get_world_size()}, coordinator_rank:{coordinator_rank}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not allow use_dist=false and world_size > 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_dist是针对单卡的情况的,但貌似不需要用户指定,在内部通过use_dist=True if world_size>1 else False来确定就行。save_state_dict的设计是导出当前训练时候分布式策略下的模型,如果当前是分布式的就导出分布式的,如果是单卡的就导出单卡的,不支持直接在分布式的情况下导出单卡模型,如果需要导出单卡模型,需要先定义单卡模型,用load_state_dict加载再用save_state_dict导出即可
return tuple(local_shape), tuple(global_offset) | ||
|
||
|
||
def flatten_state_dict(state_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WHY return directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是个TODO,为了支持state_dict={"model":model.state_dict(), "optimizer":optimizer.state_dict()}这种情况,但目前还未实现,先不对传入的state_dict进行操作
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
if coordinator_rank == paddle.distributed.get_rank(): | ||
logger.debug(f"metadata:{metadata}") | ||
paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not save meta on all ranks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meta是global的,每个rank上是一样的,只需要保存一份
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我明白,每个rank都save是不是方便调试,不必都找rank 0?meta 也不占很多空间。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可能不行,因为每个机器都有多个卡,多个卡同时写一个文件可能会出问题,导致写入的内容不符合预期
The identifier of a local tensor. | ||
""" | ||
|
||
tensor_id: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor_name or tensor_key ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor_name貌似不太合适,这个是个标识,在动半中是structure_name,在静半中是tensor的名字。叫tensor_key与tensor_id的意思类似,也是可以的,如果觉得tensor_key更合适,可更改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,在state_dict中就是key吧
local_tensor_index not in tensor_id_list | ||
), f"Duplicate tensor_id:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." | ||
tensor_id_list.append(local_tensor_index.tensor_id) | ||
if local_tensor_index.tensor_id in state_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state_dict is local_state_dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个state_dict是每个rank自己维护的那个,是local的
for rank, local_files in enumerate(global_data_files): | ||
if len(local_files) > 0: | ||
local_files = [ | ||
f for f in local_files if f in necessary_data_files_set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does local_files differ from necessary_data_files_set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
necessary_data_files_set是指当前state_dict的key命中的所有需要的文件,这些文件可能分布在其他rank上,local_files这里是个list,确实包含了所有rank可以读到的文件总和,但是不排除这些可以读到的文件总和是大于state_dict所需要读到的数据文件的,所以这里做了一个过滤的逻辑,只处理需要用到的文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果大于,是需要报warning吗?还是本来就合理。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
大于的话没有关系,不需要warning,因为不影响当前参数的加载
@@ -0,0 +1,21 @@ | |||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2019->2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -0,0 +1,497 @@ | |||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2022->2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if f not in file_to_ranks: | ||
file_to_ranks[f] = [] | ||
file_to_ranks[f].append(r) | ||
logger.info(f"file_to_ranks:{file_to_ranks}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger系列调试信息后续会清理吗?如果不清理建议规范化一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会打算在最后合入前统一清理,如果规范化的话,是有指定格式吗
@@ -0,0 +1,42 @@ | |||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体check一下吧,年份都不对
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
v._local_value().add_(paddle.ones_like(v._local_value())) | ||
paddle.distributed.load_state_dict(state_dict, ckpt_path()) | ||
for k, v in state_dict.items(): | ||
assert k in local_state_dict, k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the last k
used for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最后那个k是打印内容,assert用法是assert condition, error_message
assert k in local_state_dict, k | ||
if v._is_initialized(): | ||
self.check_tensor_eq(v._local_value(), local_state_dict[k]) | ||
os.system(f"rm -rf {ckpt_path()}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use tempfile.TemporaryDirectory()
, you can find examples in other ut.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
中文api文档PR: PaddlePaddle/docs#6355 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
单测超时时间设置
__all__ = [ | ||
"save_state_dict", | ||
"load_state_dict", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only add API in list of __ all__ at recommended user path, as we recommend using paddle.distributed.save_state_dict
and paddle.distributed.load_state_dict
, there is no need to add them to this list. import
above can be retained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
def load_state_dict( | ||
state_dict, | ||
path, | ||
process_group=None, | ||
coordinator_rank=0, | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw in the design document that there is parameter of use_dist
. Shall we need to implement use_dist
which is not implemented here? If not, please explain the reason and modify the design document.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API 文档请参考 英文模板,务必注意空行和缩进
coordinator_rank(int): The rank used to save non distributed values. Rank0 is used by default. | ||
|
||
Examples: | ||
.. code-block:: python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
Examples: | ||
.. code-block:: python | ||
>>> # doctest: +SKIP('Save state dict.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> # doctest: +SKIP('Save state dict.') | |
>>> # doctest: +SKIP('state dict not exist'') |
跳过检查的原因写清晰一点叭,保证可读性
) -> None: | ||
""" | ||
Load the state_dict inplace from a checkpoint path. | ||
Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Args: | |
Args: |
声明、参数..等各部分之间加空行,否则可能会导致官网渲染出错
Example: | ||
.. code-block:: python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example: | |
.. code-block:: python | |
Example: | |
.. code-block:: python | |
同理
coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default. | ||
Example: | ||
.. code-block:: python | ||
>>> # doctest: +SKIP('Load state dict.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> # doctest: +SKIP('Load state dict.') | |
>>> # doctest: +SKIP('state dict not exist') |
理由写清晰一点,保证可读性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,先合入,后续进行相关修改
PR types
Others
PR changes
Others
Description
card-78318
Design the save_state_dict and load_state_dict api to support save and load checkpoint of dynamic and static graph semi-auto distributed training.