Skip to content

Commit 3dd4ec1

Browse files
committed
Remove main_process_first calls
1 parent b736ca9 commit 3dd4ec1

File tree

1 file changed

+29
-30
lines changed

1 file changed

+29
-30
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -506,22 +506,22 @@ def tokenize_function(examples):
506506
)
507507
return output
508508

509-
with training_args.main_process_first(desc="dataset map tokenization"):
510-
if not data_args.streaming:
511-
tokenized_datasets = raw_datasets.map(
512-
tokenize_function,
513-
batched=True,
514-
num_proc=data_args.preprocessing_num_workers,
515-
remove_columns=column_names,
516-
load_from_cache_file=not data_args.overwrite_cache,
517-
desc="Running tokenizer on dataset",
518-
)
519-
else:
520-
tokenized_datasets = raw_datasets.map(
521-
tokenize_function,
522-
batched=True,
523-
remove_columns=column_names,
524-
)
509+
510+
if not data_args.streaming:
511+
tokenized_datasets = raw_datasets.map(
512+
tokenize_function,
513+
batched=True,
514+
num_proc=data_args.preprocessing_num_workers,
515+
remove_columns=column_names,
516+
load_from_cache_file=not data_args.overwrite_cache,
517+
desc="Running tokenizer on dataset",
518+
)
519+
else:
520+
tokenized_datasets = raw_datasets.map(
521+
tokenize_function,
522+
batched=True,
523+
remove_columns=column_names,
524+
)
525525
if hasattr(config, "max_position_embeddings"):
526526
max_pos_embeddings = config.max_position_embeddings
527527
else:
@@ -570,21 +570,20 @@ def group_texts(examples):
570570
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
571571
# https://huggingface.co/docs/datasets/process#map
572572

573-
with training_args.main_process_first(desc="grouping texts together"):
574-
if not data_args.streaming:
575-
lm_datasets = tokenized_datasets.map(
576-
group_texts,
577-
batched=True,
578-
num_proc=data_args.preprocessing_num_workers,
579-
load_from_cache_file=not data_args.overwrite_cache,
580-
desc=f"Grouping texts in chunks of {block_size}",
581-
)
582-
else:
583-
lm_datasets = tokenized_datasets.map(
584-
group_texts,
585-
batched=True,
586-
)
587573

574+
if not data_args.streaming:
575+
lm_datasets = tokenized_datasets.map(
576+
group_texts,
577+
batched=True,
578+
num_proc=data_args.preprocessing_num_workers,
579+
load_from_cache_file=not data_args.overwrite_cache,
580+
desc=f"Grouping texts in chunks of {block_size}",
581+
)
582+
else:
583+
lm_datasets = tokenized_datasets.map(
584+
group_texts,
585+
batched=True,
586+
)
588587
if training_args.do_train:
589588
if "train" not in tokenized_datasets:
590589
raise ValueError("--do_train requires a train dataset")

0 commit comments

Comments
 (0)