diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 16ba415e..74ef71b6 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -130,7 +130,9 @@ def load_prompts( # only keep the datapoints relevant to the current process if world_size > 1: # This prints to stdout which is slightly annoying - split = split_dataset_by_node(split, world_size, rank) + split = split_dataset_by_node( + dataset=split, rank=rank, world_size=world_size + ) raw_datasets.append(split) train_datasets.append(train_ds)