Skip to content

Commit

Permalink
fix axis in shift_data
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Dec 18, 2023
1 parent 01d092a commit 06d079e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def preprocessing_pipeline_pygrain(

# Shift inputs for teacher-forced training
if shift:
operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=0,segmented=pack_examples)))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=1,segmented=pack_examples)))

index_sampler = pygrain.IndexSampler(
num_records=len(dataset),
Expand Down

0 comments on commit 06d079e

Please sign in to comment.