Skip to content

Commit

Permalink
life is hard
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Aug 14, 2024
1 parent 8a0a2d8 commit 9a1f389
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion MaxText/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import functools
from typing import Any

class Pipeline(nn.Module):
sdclass Pipeline(nn.Module):
"""Module that implements pipelining across stages.

This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights.
Expand Down Expand Up @@ -325,6 +325,13 @@ def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positi
)
return vmap_func


def rotate_right_shmap(self, arr):
# Use lax.slice to avoid generating a gather.
last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0)
except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0)
return jnp.concatenate([last, except_last], axis=0)

def rotate_right(self, arr):
# Use lax.slice to avoid generating a gather.
last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0)
Expand Down

0 comments on commit 9a1f389

Please sign in to comment.