From 22ffccf03f3e207731a481e3e42bdb564ceebb69 Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Thu, 9 Nov 2023 10:18:42 +0100 Subject: [PATCH] fix: adapt to new jax API --- src/dalle_mini/model/modeling.py | 4 ++-- src/dalle_mini/model/partitions.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dalle_mini/model/modeling.py b/src/dalle_mini/model/modeling.py index 53dd7ce96..733ecc2ad 100644 --- a/src/dalle_mini/model/modeling.py +++ b/src/dalle_mini/model/modeling.py @@ -1599,8 +1599,8 @@ def prepare_inputs_for_generation( self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/dalle_mini/model/partitions.py b/src/dalle_mini/model/partitions.py index 0bcd89d10..6d08081ab 100644 --- a/src/dalle_mini/model/partitions.py +++ b/src/dalle_mini/model/partitions.py @@ -2,7 +2,7 @@ from flax.core.frozen_dict import freeze from flax.traverse_util import flatten_dict, unflatten_dict -from jax.experimental import PartitionSpec as P +from jax.sharding import PartitionSpec as P # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py # Sentinels