diff --git a/MaxText/input_pipeline.py b/MaxText/input_pipeline.py index 106e7d92b..9d9e1faa0 100644 --- a/MaxText/input_pipeline.py +++ b/MaxText/input_pipeline.py @@ -222,7 +222,7 @@ def preprocessing_pipeline_pygrain( # Shift inputs for teacher-forced training if shift: - operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=0,segmented=pack_examples))) + operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=1,segmented=pack_examples))) index_sampler = pygrain.IndexSampler( num_records=len(dataset),