Skip to content

Commit

Permalink
Fix re-compilation bugs (keras-team#1541)
Browse files Browse the repository at this point in the history
The primary bug this is trying to fix is to get rid of weird
re-compilation behaviors with our default compilation. E.g. create a
`GemmaCausalLM`, generate some text without specifying a sampler,
`compile()` again without a sampler, generation will have switched from
`"greedy"` -> `"top_k"`. Create a `BertClassifier`, `fit()`, `compile()`
again without specifying an optimizer, optimizer will have switch from
`"adam"` to `"rmsprop"`.

The way I am trying to fix this is by leaning a little more heavily on
the `"auto"` style option we introduced for `jit_compile`. KerasNLP
tasks will by default use `loss="auto"` and `optimizer="auto"`, which
resolve to a default for a given task.

Since we override compile with these in the signature, recompilation
will not silently change behavior.
  • Loading branch information
mattdangerw committed Apr 3, 2024
1 parent 9ac3335 commit 7d3c77c
Show file tree
Hide file tree
Showing 28 changed files with 211 additions and 316 deletions.
11 changes: 0 additions & 11 deletions keras_nlp/models/albert/albert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,6 @@ def __init__(
self.activation = keras.activations.get(activation)
self.dropout = dropout

# === Default compilation ===
logit_output = self.activation == keras.activations.linear
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(
from_logits=logit_output
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
8 changes: 0 additions & 8 deletions keras_nlp/models/albert/albert_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,3 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
outputs=outputs,
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)
11 changes: 1 addition & 10 deletions keras_nlp/models/bart/bart_seq_2_seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.models.bart.bart_backbone import BartBackbone
from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import (
Expand Down Expand Up @@ -200,14 +199,6 @@ def __init__(
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

def call_decoder_with_cache(
self,
encoder_hidden_states,
Expand Down Expand Up @@ -460,7 +451,7 @@ def repeat_tensor(x):
cache,
)

decoder_token_ids = self._sampler(
decoder_token_ids = self.sampler(
next=next,
prompt=decoder_token_ids,
cache=self_attention_cache,
Expand Down
11 changes: 0 additions & 11 deletions keras_nlp/models/bert/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,6 @@ def __init__(
self.activation = keras.activations.get(activation)
self.dropout = dropout

# === Default compilation ===
logit_output = self.activation == keras.activations.linear
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(
from_logits=logit_output
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
10 changes: 0 additions & 10 deletions keras_nlp/models/bert/bert_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,3 @@ def __init__(
outputs=outputs,
**kwargs,
)

# === Default compilation ===
self.backbone = backbone
self.preprocessor = preprocessor
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)
12 changes: 1 addition & 11 deletions keras_nlp/models/bloom/bloom_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.models.bloom.bloom_backbone import BloomBackbone
from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import (
Expand Down Expand Up @@ -167,15 +166,6 @@ def __init__(
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
sampler="greedy",
jit_compile=True,
)

def call_with_cache(
self,
token_ids,
Expand Down Expand Up @@ -273,7 +263,7 @@ def next(prompt, cache, index):
cache,
)

token_ids = self._sampler(
token_ids = self.sampler(
next=next,
prompt=token_ids,
cache=cache,
Expand Down
77 changes: 64 additions & 13 deletions keras_nlp/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,74 @@ class CausalLM(Task):
```
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
*args,
run_eagerly=False,
jit_compile=True,
optimizer="auto",
loss="auto",
*,
weighted_metrics="auto",
sampler="top_k",
**kwargs,
):
xla_compatible = True
"""Configures the `CausalLM` task for training and generation.
The `CausalLM` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`weighted_metrics`. To override these defaults, pass any value
to these arguments during compilation.
The `CausalLM` task adds a new `sampler` to `compile`, which can be used
to control the sampling strategy used with the `generate` function.
Note that because training inputs include padded tokens which are
excluded from the loss, it is almost always a good idea to compile with
`weighted_metrics` and not `metrics`.
Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"', a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the token classification `CausalLM` task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
weighted_metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model at guessing masked
token values. See `keras.Model.compile` and `keras.metrics` for
more info on possible `weighted_metrics` values.
sampler: A sampler name, or a `keras_nlp.samplers.Sampler` instance.
Configures the sampling method used during `generate()` calls.
See `keras_nlp.samplers` for a full list of built-in sampling
strategies.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(2e-5)
if loss == "auto":
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
if weighted_metrics == "auto":
weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()]
# Keras 2 does not jit_compile by default.
if not config.keras_3():
kwargs["jit_compile"] = True
super().compile(
*args,
run_eagerly=run_eagerly,
# Only `jit_compile` if not eager and in a compatible environment.
jit_compile=jit_compile and xla_compatible and not run_eagerly,
optimizer=optimizer,
loss=loss,
weighted_metrics=weighted_metrics,
**kwargs,
)
self._sampler = get_sampler(sampler)
self.sampler = get_sampler(sampler)
# Clear the compiled generate function.
self.generate_function = None

Expand Down Expand Up @@ -127,7 +178,7 @@ def compiled_generate_function(inputs, stop_token_ids, state):
non_trainable_variables,
) = state
mapping = itertools.chain(
zip(self._sampler.variables, sampler_variables),
zip(self.sampler.variables, sampler_variables),
zip(self.trainable_variables, trainable_variables),
zip(self.non_trainable_variables, non_trainable_variables),
)
Expand All @@ -137,7 +188,7 @@ def compiled_generate_function(inputs, stop_token_ids, state):

# Get updated sampler variables from the stateless scope.
sampler_variables = []
for v in self._sampler.variables:
for v in self.sampler.variables:
new_v = scope.get_current_value(v)
sampler_variables.append(new_v if new_v is not None else v)
return outputs, sampler_variables
Expand All @@ -151,7 +202,7 @@ def wrapped_generate_function(

# Create an explicit tuple of all variable state.
state = (
self._sampler.variables,
self.sampler.variables,
# Use the explicit variable.value to preserve the
# sharding spec of distribution.
[v.value for v in self.trainable_variables],
Expand All @@ -165,7 +216,7 @@ def wrapped_generate_function(
)
# Only assign the sampler variables (random seeds), as other
# model variables should never be updated in generation.
for ref_v, v in zip(self._sampler.variables, sampler_variables):
for ref_v, v in zip(self.sampler.variables, sampler_variables):
ref_v.assign(v)
return outputs

Expand Down
61 changes: 61 additions & 0 deletions keras_nlp/models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.models.task import Task


Expand Down Expand Up @@ -49,3 +51,62 @@ class Classifier(Task):
classifier.predict(["What an amazing movie!", "A total waste of my time."])
```
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
"""Configures the `Classifier` task for training.
The `Classifier` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.
Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"', a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the classification task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model during training.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(5e-5)
if loss == "auto":
activation = getattr(self, "activation", None)
activation = keras.activations.get(activation)
from_logits = activation != keras.activations.softmax
loss = keras.losses.SparseCategoricalCrossentropy(from_logits)
if metrics == "auto":
metrics = [keras.metrics.SparseCategoricalAccuracy()]
# Keras 2 does not jit_compile by default.
if not config.keras_3():
kwargs["jit_compile"] = True
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
11 changes: 0 additions & 11 deletions keras_nlp/models/deberta_v3/deberta_v3_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,6 @@ def __init__(
self.hidden_dim = hidden_dim
self.dropout = dropout

# === Default compilation ===
logit_output = self.activation == keras.activations.linear
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(
from_logits=logit_output
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
8 changes: 0 additions & 8 deletions keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,3 @@ def __init__(
outputs=outputs,
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)
11 changes: 0 additions & 11 deletions keras_nlp/models/distil_bert/distil_bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,6 @@ def __init__(
self.hidden_dim = hidden_dim
self.dropout = dropout

# === Default compilation ===
logit_output = self.activation == keras.activations.linear
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(
from_logits=logit_output
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
8 changes: 0 additions & 8 deletions keras_nlp/models/distil_bert/distil_bert_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,3 @@ def __init__(
outputs=outputs,
**kwargs,
)

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)
11 changes: 0 additions & 11 deletions keras_nlp/models/f_net/f_net_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,6 @@ def __init__(
self.activation = keras.activations.get(activation)
self.dropout = dropout

# === Default compilation ===
logit_output = self.activation == keras.activations.linear
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(
from_logits=logit_output
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=True,
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
Loading

0 comments on commit 7d3c77c

Please sign in to comment.