-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed SaveLoad implementation for semi-auto strategy #59659
Merged
+1,162
−0
Merged
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
fc3b3c0
exclude xpu
pangengzheng e291552
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 7a13c0b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng d81f305
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng cd6e4fb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 9d27f27
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 5037694
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng ef695ee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 23aa6ff
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng f7615b7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 6605dff
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 767835d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng f756bc6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 2ffd709
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 738f5d5
demo of running dygraph distributed save load
pangengzheng f3d4bb2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 7134583
support save cross mesh state_dict
pangengzheng 9e2094a
polish
pangengzheng 786a318
fix compute overlap bug
pangengzheng ef4f374
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 058d5fe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 8f64e81
test save load in dp_mp unittest
pangengzheng 250b1b7
fix get local file bug and test
pangengzheng bd9348f
delete useless files, and rename var
pangengzheng ecee68b
polish
pangengzheng a8491b9
format codes
pangengzheng 867726d
merge develop
pangengzheng 2bf30c5
test use_dist
pangengzheng b46042c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 160552c
fix test
pangengzheng c5394c5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng baf2b74
info to debug
pangengzheng 968d611
fix test
pangengzheng 170fd81
fix
pangengzheng e0d0690
fix coverage ci
pangengzheng 18298b9
fix docstring codes
pangengzheng 13b1d07
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng 1dcd0a7
rename and codestyle
pangengzheng 00df8ba
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangengzheng c728400
get rid of use_dist argument
pangengzheng a3125c0
fix copyright
pangengzheng 0543d1f
polish doc
pangengzheng e4c72cd
polish
pangengzheng 0561180
polish
pangengzheng 4df7f76
use tmp file path
pangengzheng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
info to debug
commit baf2b745f6e740cf796735cbd3e01ad3f1a39b62
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import copy | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
|
@@ -69,7 +70,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): | |
for files in global_data_files: | ||
tmp += files | ||
global_data_files_set = set(tmp) | ||
logger.info( | ||
logger.debug( | ||
f"necessary_data_files_set:{necessary_data_files_set}, global_data_files_set:{global_data_files_set}" | ||
) | ||
# check neccesary files in global_data_files | ||
|
@@ -78,7 +79,10 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): | |
== necessary_data_files_set | ||
), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" | ||
missing_keys = set(state_dict.keys()) - set(tensor_id_list) | ||
logger.info(f"missing_keys:{missing_keys}") | ||
if len(missing_keys) > 0: | ||
logger.warning( | ||
f"Missing keys:{missing_keys}, check whether the checkpoint is complete." | ||
) | ||
|
||
rank_to_files = {} | ||
for rank, local_files in enumerate(global_data_files): | ||
|
@@ -87,7 +91,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): | |
f for f in local_files if f in necessary_data_files_set | ||
] | ||
rank_to_files[rank] = local_files | ||
logger.info(f"mapping rank_to_files:{rank_to_files}") | ||
logger.debug(f"mapping rank_to_files:{rank_to_files}") | ||
return rank_to_files | ||
|
||
|
||
|
@@ -111,17 +115,18 @@ def get_local_load_files(rank_to_files): | |
if file not in file_to_ranks: | ||
file_to_ranks[file] = [] | ||
file_to_ranks[file].append(rank) | ||
rank_to_read_files = {rank: [] for rank in rank_to_files.keys()} | ||
rank_to_not_read_files = copy.copy(rank_to_files) | ||
rank_to_read_files = {rank: [] for rank in rank_to_not_read_files.keys()} | ||
for file, ranks in file_to_ranks.items(): | ||
if len(ranks) == 1: | ||
rank = ranks[0] | ||
rank_to_read_files[rank].append(file) | ||
rank_to_files[rank].remove(file) | ||
if len(rank_to_files[rank]) == 0: | ||
rank_to_files.pop(rank) | ||
rank_to_not_read_files[rank].remove(file) | ||
if len(rank_to_not_read_files[rank]) == 0: | ||
rank_to_not_read_files.pop(rank) | ||
|
||
logger.info( | ||
f"start rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}" | ||
logger.debug( | ||
f"rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}" | ||
) | ||
|
||
def get_least_read_files_ranks(rank_to_read_files): | ||
|
@@ -132,55 +137,51 @@ def get_least_read_files_ranks(rank_to_read_files): | |
ranks = [rank for rank, num in nums if num == nums[0][1]] | ||
return ranks | ||
|
||
def get_read_rank_file(rank_to_files, ranks): | ||
if len(rank_to_files) == 0: | ||
def get_read_rank_file(rank_to_not_read_files, ranks): | ||
if len(rank_to_not_read_files) == 0: | ||
return (None, None) | ||
nums = [ | ||
(rank, len(files)) | ||
for rank, files in rank_to_files.items() | ||
for rank, files in rank_to_not_read_files.items() | ||
if rank in ranks | ||
] | ||
nums = sorted(nums, key=lambda x: x[1]) | ||
rank = nums[0][0] | ||
return (rank, rank_to_files[rank][0]) | ||
return (rank, rank_to_not_read_files[rank][0]) | ||
|
||
def update(rank_to_read_files, rank_to_files, rank_file): | ||
def update(rank_to_read_files, rank_to_not_read_files, rank_file): | ||
rank, file = rank_file | ||
if rank is None and file is None: | ||
return | ||
if rank not in rank_to_read_files: | ||
rank_to_read_files[rank] = [] | ||
rank_to_read_files[rank].append(file) | ||
# update rank_to_files | ||
# update rank_to_not_read_files | ||
file_to_ranks = {} | ||
for r, files in rank_to_files.items(): | ||
for r, files in rank_to_not_read_files.items(): | ||
for f in files: | ||
if f not in file_to_ranks: | ||
file_to_ranks[f] = [] | ||
file_to_ranks[f].append(r) | ||
logger.info(f"file_to_ranks:{file_to_ranks}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logger系列调试信息后续会清理吗?如果不清理建议规范化一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 会打算在最后合入前统一清理,如果规范化的话,是有指定格式吗 |
||
if file in file_to_ranks: | ||
for r in file_to_ranks[file]: | ||
rank_to_files[r].remove(file) | ||
if len(rank_to_files[r]) == 0: | ||
rank_to_files.pop(r) | ||
rank_to_not_read_files[r].remove(file) | ||
if len(rank_to_not_read_files[r]) == 0: | ||
rank_to_not_read_files.pop(r) | ||
|
||
while len(rank_to_files) > 0: | ||
while len(rank_to_not_read_files) > 0: | ||
ranks = get_least_read_files_ranks(rank_to_read_files) | ||
rank_file = get_read_rank_file(rank_to_files, ranks) | ||
update(rank_to_read_files, rank_to_files, rank_file) | ||
logger.info( | ||
f"update rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}, ranks:{ranks}, rank_file:{rank_file}" | ||
rank_file = get_read_rank_file(rank_to_not_read_files, ranks) | ||
update(rank_to_read_files, rank_to_not_read_files, rank_file) | ||
logger.debug( | ||
f"update rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}, ranks:{ranks}, rank_file:{rank_file}" | ||
) | ||
logger.info(f"rank_to_read_files:{rank_to_read_files}") | ||
cur_rank = paddle.distributed.get_rank() | ||
if cur_rank in rank_to_read_files: | ||
logger.info( | ||
f"cur_rank:{cur_rank}, rank_to_read_files[cur_rank]:{rank_to_read_files[cur_rank]}" | ||
) | ||
return rank_to_read_files[cur_rank] | ||
else: | ||
logger.info(f"rank:{cur_rank} does not need to load checkpoint") | ||
logger.warning(f"rank:{cur_rank} does not need to load checkpoint") | ||
return [] | ||
|
||
|
||
|
@@ -285,7 +286,7 @@ def get_read_items(path, state_dict, process_group, use_dist): | |
storage_state_dict_metadata[tensor_id] = [] | ||
storage_state_dict_metadata[tensor_id] += local_tensor_metadata | ||
read_items = [] | ||
logger.info(f"storage_state_dict_metadata:{storage_state_dict_metadata}") | ||
logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}") | ||
for tensor_id, val in state_dict.items(): | ||
if isinstance(val, paddle.Tensor): | ||
if val.is_dist(): | ||
|
@@ -389,7 +390,7 @@ def load_state_dict( | |
# slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. | ||
read_items = get_read_items(path, state_dict, process_group, use_dist) | ||
storage_file_to_state_dict = {} | ||
logger.info( | ||
logger.debug( | ||
f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" | ||
) | ||
for item in read_items: | ||
|
@@ -474,6 +475,6 @@ def load_state_dict( | |
if use_dist | ||
else state_dict | ||
) | ||
logger.info( | ||
logger.debug( | ||
f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does local_files differ from necessary_data_files_set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
necessary_data_files_set是指当前state_dict的key命中的所有需要的文件,这些文件可能分布在其他rank上,local_files这里是个list,确实包含了所有rank可以读到的文件总和,但是不排除这些可以读到的文件总和是大于state_dict所需要读到的数据文件的,所以这里做了一个过滤的逻辑,只处理需要用到的文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果大于,是需要报warning吗?还是本来就合理。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
大于的话没有关系,不需要warning,因为不影响当前参数的加载