Skip to content

Commit

Permalink
MantaRay
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Jun 19, 2024
1 parent d9db136 commit 5823a14
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def unpermute(self, intermediate, sorted_selected_experts, weights):

def call_gmm(self, inputs, gate_logits, config, w0_kernel, w1_kernel, wo_kernel):
# TODO(ranran): update the static default tile_size
tile_size = (512, 512, 512)
# tile_size = None
# tile_size = (512, 512, 512)
tile_size = None
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))

def gmm(inputs, kernel, group_sizes):
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ transformers
git+https://github.com/mlperf/logging.git
google-jetstream
jsonlines
google-cloud-aiplatform==1.50.0
shapely==1.8.5.post1

0 comments on commit 5823a14

Please sign in to comment.