From 7d3c77c641fe07c1d0456aea8214c241c71b359a Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 3 Apr 2024 13:01:11 -0700 Subject: [PATCH] Fix re-compilation bugs (#1541) The primary bug this is trying to fix is to get rid of weird re-compilation behaviors with our default compilation. E.g. create a `GemmaCausalLM`, generate some text without specifying a sampler, `compile()` again without a sampler, generation will have switched from `"greedy"` -> `"top_k"`. Create a `BertClassifier`, `fit()`, `compile()` again without specifying an optimizer, optimizer will have switch from `"adam"` to `"rmsprop"`. The way I am trying to fix this is by leaning a little more heavily on the `"auto"` style option we introduced for `jit_compile`. KerasNLP tasks will by default use `loss="auto"` and `optimizer="auto"`, which resolve to a default for a given task. Since we override compile with these in the signature, recompilation will not silently change behavior. --- keras_nlp/models/albert/albert_classifier.py | 11 --- keras_nlp/models/albert/albert_masked_lm.py | 8 -- keras_nlp/models/bart/bart_seq_2_seq_lm.py | 11 +-- keras_nlp/models/bert/bert_classifier.py | 11 --- keras_nlp/models/bert/bert_masked_lm.py | 10 --- keras_nlp/models/bloom/bloom_causal_lm.py | 12 +-- keras_nlp/models/causal_lm.py | 77 +++++++++++++++---- keras_nlp/models/classifier.py | 61 +++++++++++++++ .../deberta_v3/deberta_v3_classifier.py | 11 --- .../models/deberta_v3/deberta_v3_masked_lm.py | 8 -- .../distil_bert/distil_bert_classifier.py | 11 --- .../distil_bert/distil_bert_masked_lm.py | 8 -- keras_nlp/models/f_net/f_net_classifier.py | 11 --- keras_nlp/models/f_net/f_net_masked_lm.py | 8 -- keras_nlp/models/gemma/gemma_causal_lm.py | 24 ++++-- keras_nlp/models/gpt2/gpt2_causal_lm.py | 10 +-- .../models/gpt_neo_x/gpt_neo_x_causal_lm.py | 11 +-- keras_nlp/models/llama/llama_causal_lm.py | 10 +-- keras_nlp/models/masked_lm.py | 62 +++++++++++++++ keras_nlp/models/mistral/mistral_causal_lm.py | 10 +-- keras_nlp/models/opt/opt_causal_lm.py | 11 +-- .../models/roberta/roberta_classifier.py | 11 --- keras_nlp/models/roberta/roberta_masked_lm.py | 8 -- .../models/roberta/roberta_masked_lm_test.py | 2 +- keras_nlp/models/task.py | 50 ------------ keras_nlp/models/task_test.py | 41 ---------- .../xlm_roberta/xlm_roberta_classifier.py | 11 --- .../xlm_roberta/xlm_roberta_masked_lm.py | 8 -- 28 files changed, 211 insertions(+), 316 deletions(-) diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py index 7471393cc7..d732bc1973 100644 --- a/keras_nlp/models/albert/albert_classifier.py +++ b/keras_nlp/models/albert/albert_classifier.py @@ -186,17 +186,6 @@ def __init__( self.activation = keras.activations.get(activation) self.dropout = dropout - # === Default compilation === - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index 01892bdcab..752f20c2ff 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -124,11 +124,3 @@ def __init__(self, backbone, preprocessor=None, **kwargs): outputs=outputs, **kwargs, ) - - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index e13e74b769..23f65522ec 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -14,7 +14,6 @@ from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( @@ -200,14 +199,6 @@ def __init__( **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def call_decoder_with_cache( self, encoder_hidden_states, @@ -460,7 +451,7 @@ def repeat_tensor(x): cache, ) - decoder_token_ids = self._sampler( + decoder_token_ids = self.sampler( next=next, prompt=decoder_token_ids, cache=self_attention_cache, diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index 27bb076ea9..8b4a927af9 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -170,17 +170,6 @@ def __init__( self.activation = keras.activations.get(activation) self.dropout = dropout - # === Default compilation === - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py index 1166963625..5966c91cdb 100644 --- a/keras_nlp/models/bert/bert_masked_lm.py +++ b/keras_nlp/models/bert/bert_masked_lm.py @@ -128,13 +128,3 @@ def __init__( outputs=outputs, **kwargs, ) - - # === Default compilation === - self.backbone = backbone - self.preprocessor = preprocessor - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) diff --git a/keras_nlp/models/bloom/bloom_causal_lm.py b/keras_nlp/models/bloom/bloom_causal_lm.py index 914107f101..823b93fc32 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm.py +++ b/keras_nlp/models/bloom/bloom_causal_lm.py @@ -14,7 +14,6 @@ from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.bloom.bloom_backbone import BloomBackbone from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( @@ -167,15 +166,6 @@ def __init__( **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - sampler="greedy", - jit_compile=True, - ) - def call_with_cache( self, token_ids, @@ -273,7 +263,7 @@ def next(prompt, cache, index): cache, ) - token_ids = self._sampler( + token_ids = self.sampler( next=next, prompt=token_ids, cache=cache, diff --git a/keras_nlp/models/causal_lm.py b/keras_nlp/models/causal_lm.py index 98867e9ad2..660e17bf59 100644 --- a/keras_nlp/models/causal_lm.py +++ b/keras_nlp/models/causal_lm.py @@ -68,23 +68,74 @@ class CausalLM(Task): ``` """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + def compile( self, - *args, - run_eagerly=False, - jit_compile=True, + optimizer="auto", + loss="auto", + *, + weighted_metrics="auto", sampler="top_k", **kwargs, ): - xla_compatible = True + """Configures the `CausalLM` task for training and generation. + + The `CausalLM` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `weighted_metrics`. To override these defaults, pass any value + to these arguments during compilation. + + The `CausalLM` task adds a new `sampler` to `compile`, which can be used + to control the sampling strategy used with the `generate` function. + + Note that because training inputs include padded tokens which are + excluded from the loss, it is almost always a good idea to compile with + `weighted_metrics` and not `metrics`. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"', a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the token classification `CausalLM` task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + weighted_metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model at guessing masked + token values. See `keras.Model.compile` and `keras.metrics` for + more info on possible `weighted_metrics` values. + sampler: A sampler name, or a `keras_nlp.samplers.Sampler` instance. + Configures the sampling method used during `generate()` calls. + See `keras_nlp.samplers` for a full list of built-in sampling + strategies. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(2e-5) + if loss == "auto": + loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + if weighted_metrics == "auto": + weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()] + # Keras 2 does not jit_compile by default. + if not config.keras_3(): + kwargs["jit_compile"] = True super().compile( - *args, - run_eagerly=run_eagerly, - # Only `jit_compile` if not eager and in a compatible environment. - jit_compile=jit_compile and xla_compatible and not run_eagerly, + optimizer=optimizer, + loss=loss, + weighted_metrics=weighted_metrics, **kwargs, ) - self._sampler = get_sampler(sampler) + self.sampler = get_sampler(sampler) # Clear the compiled generate function. self.generate_function = None @@ -127,7 +178,7 @@ def compiled_generate_function(inputs, stop_token_ids, state): non_trainable_variables, ) = state mapping = itertools.chain( - zip(self._sampler.variables, sampler_variables), + zip(self.sampler.variables, sampler_variables), zip(self.trainable_variables, trainable_variables), zip(self.non_trainable_variables, non_trainable_variables), ) @@ -137,7 +188,7 @@ def compiled_generate_function(inputs, stop_token_ids, state): # Get updated sampler variables from the stateless scope. sampler_variables = [] - for v in self._sampler.variables: + for v in self.sampler.variables: new_v = scope.get_current_value(v) sampler_variables.append(new_v if new_v is not None else v) return outputs, sampler_variables @@ -151,7 +202,7 @@ def wrapped_generate_function( # Create an explicit tuple of all variable state. state = ( - self._sampler.variables, + self.sampler.variables, # Use the explicit variable.value to preserve the # sharding spec of distribution. [v.value for v in self.trainable_variables], @@ -165,7 +216,7 @@ def wrapped_generate_function( ) # Only assign the sampler variables (random seeds), as other # model variables should never be updated in generation. - for ref_v, v in zip(self._sampler.variables, sampler_variables): + for ref_v, v in zip(self.sampler.variables, sampler_variables): ref_v.assign(v) return outputs diff --git a/keras_nlp/models/classifier.py b/keras_nlp/models/classifier.py index f6c6a88720..c10c481172 100644 --- a/keras_nlp/models/classifier.py +++ b/keras_nlp/models/classifier.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config +from keras_nlp.backend import keras from keras_nlp.models.task import Task @@ -49,3 +51,62 @@ class Classifier(Task): classifier.predict(["What an amazing movie!", "A total waste of my time."]) ``` """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `Classifier` task for training. + + The `Classifier` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"', a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the classification task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.SparseCategoricalCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.SparseCategoricalAccuracy()] + # Keras 2 does not jit_compile by default. + if not config.keras_3(): + kwargs["jit_compile"] = True + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py index e8cb7a60ed..848e4c342c 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py @@ -212,17 +212,6 @@ def __init__( self.hidden_dim = hidden_dim self.dropout = dropout - # === Default compilation === - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py index 7bb613b96c..bad6f61d1e 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py @@ -130,11 +130,3 @@ def __init__( outputs=outputs, **kwargs, ) - - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py index f816e40a1b..cb6560eade 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py @@ -192,17 +192,6 @@ def __init__( self.hidden_dim = hidden_dim self.dropout = dropout - # === Default compilation === - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py index 80b4c17bb0..e78a5bb85d 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -132,11 +132,3 @@ def __init__( outputs=outputs, **kwargs, ) - - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index a5c0bf6525..d4c7f6260f 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -141,17 +141,6 @@ def __init__( self.activation = keras.activations.get(activation) self.dropout = dropout - # === Default compilation === - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py index 83c7e62719..8ef197df69 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm.py +++ b/keras_nlp/models/f_net/f_net_masked_lm.py @@ -129,11 +129,3 @@ def __init__( outputs=outputs, **kwargs, ) - - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py index 34b0a43126..26e9aad8a2 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm.py +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -168,13 +168,21 @@ def __init__( **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - sampler="greedy", - jit_compile=True, + def compile( + self, + optimizer="auto", + loss="auto", + *, + weighted_metrics="auto", + sampler="greedy", + **kwargs, + ): + super().compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=weighted_metrics, + sampler=sampler, + **kwargs, ) def call_with_cache( @@ -274,7 +282,7 @@ def next(prompt, cache, index): cache, ) - token_ids = self._sampler( + token_ids = self.sampler( next=next, prompt=token_ids, cache=cache, diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 40d4787119..bdaeabeb45 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -170,14 +170,6 @@ def __init__( **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def call_with_cache( self, token_ids, @@ -279,7 +271,7 @@ def next(prompt, cache, index): cache, ) - token_ids = self._sampler( + token_ids = self.sampler( next=next, prompt=token_ids, cache=cache, diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py index 119e51cc75..0c1a5a583e 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -13,7 +13,6 @@ # limitations under the License. from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone @@ -69,14 +68,6 @@ def __init__( **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def call_with_cache( self, token_ids, @@ -175,7 +166,7 @@ def next(prompt, cache, index): cache, ) - token_ids = self._sampler( + token_ids = self.sampler( next=next, prompt=token_ids, cache=cache, diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py index 48b5fdb4c2..b1e85d2925 100644 --- a/keras_nlp/models/llama/llama_causal_lm.py +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -61,14 +61,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs): **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - @classproperty def backbone_cls(cls): return LlamaBackbone @@ -180,7 +172,7 @@ def next(prompt, cache, index): cache, ) - token_ids = self._sampler( + token_ids = self.sampler( next=next, prompt=token_ids, cache=cache, diff --git a/keras_nlp/models/masked_lm.py b/keras_nlp/models/masked_lm.py index 136dbf0b8e..731899b938 100644 --- a/keras_nlp/models/masked_lm.py +++ b/keras_nlp/models/masked_lm.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config +from keras_nlp.backend import keras from keras_nlp.models.task import Task @@ -40,3 +42,63 @@ class MaskedLM(Task): masked_lm.fit(train_ds) ``` """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + weighted_metrics="auto", + **kwargs, + ): + """Configures the `MaskedLM` task for training. + + The `MaskedLM` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `weighted_metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Note that because training inputs include padded tokens which are + excluded from the loss, it is almost always a good idea to compile with + `weighted_metrics` and not `metrics`. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"', a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the token classification `MaskedLM` task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + weighted_metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model at guessing masked + token values. See `keras.Model.compile` and `keras.metrics` for + more info on possible `weighted_metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + if weighted_metrics == "auto": + weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()] + # Keras 2 does not jit_compile by default. + if not config.keras_3(): + kwargs["jit_compile"] = True + super().compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=weighted_metrics, + **kwargs, + ) diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py index 754c07d2a5..b56947ebfc 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm.py +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -64,14 +64,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs): **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def call_with_cache( self, token_ids, @@ -175,7 +167,7 @@ def next(prompt, cache, index): cache, ) - token_ids = self._sampler( + token_ids = self.sampler( next=next, prompt=token_ids, cache=cache, diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 1bb5bd1e87..5ec6bf416a 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -14,7 +14,6 @@ from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.opt.opt_backbone import OPTBackbone @@ -170,14 +169,6 @@ def __init__( **kwargs, ) - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def call_with_cache( self, token_ids, @@ -275,7 +266,7 @@ def next(prompt, cache, index): cache, ) - token_ids = self._sampler( + token_ids = self.sampler( next=next, prompt=token_ids, cache=cache, diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index 57f50f4e94..ba6d0e4902 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -191,17 +191,6 @@ def __init__( self.hidden_dim = hidden_dim self.dropout = dropout - # === Default compilation === - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(2e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index ef6660f777..1429b367bd 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -131,11 +131,3 @@ def __init__( outputs=outputs, **kwargs, ) - - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) diff --git a/keras_nlp/models/roberta/roberta_masked_lm_test.py b/keras_nlp/models/roberta/roberta_masked_lm_test.py index f4e410fa69..6d5ba7eb8d 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_test.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_test.py @@ -54,7 +54,7 @@ def setUp(self): "backbone": self.backbone, } self.train_data = ( - [" airplane at airport", " airplane_airport"], # Features. + [" airplane at airport", " airplane airport"], # Features. ) self.input_data = self.preprocessor(*self.train_data)[0] diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 7858b84709..41876f34cb 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -83,56 +83,6 @@ def filter_fn(attr): return filter(filter_fn, super().__dir__()) - def _check_for_loss_mismatch(self, loss): - """Check for a softmax/from_logits mismatch after compile. - - We cannot handle this in the general case, but we can handle this for - the extremely common case of a single `SparseCategoricalCrossentropy` - loss, and a `None` or `"softmax"` activation. - """ - # Only handle a single loss. - if isinstance(loss, (dict, list, tuple)): - return - # Only handle tasks with activation. - if not hasattr(self, "activation"): - return - - loss = keras.losses.get(loss) - activation = keras.activations.get(self.activation) - if isinstance(loss, keras.losses.SparseCategoricalCrossentropy): - from_logits = loss.get_config()["from_logits"] - elif loss == keras.losses.sparse_categorical_crossentropy: - from_logits = False - else: - # Only handle sparse categorical crossentropy. - return - - softmax_output = activation == keras.activations.softmax - logit_output = activation == keras.activations.linear - if softmax_output and from_logits: - raise ValueError( - "The `loss` passed to `compile()` expects logit output, but " - "the model is configured to output softmax probabilities " - "(`activation='softmax'`). This will not converge! Pass " - "`from_logits=False` to your loss, e.g. " - "`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False)`. " - ) - if logit_output and not from_logits: - raise ValueError( - "The `loss` passed to `compile()` expects softmax probability " - "output, but the model is configured to output logits " - "(`activation=None`). This will not converge! Pass " - "`from_logits=True` to your loss, e.g. " - "`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True)`. " - ) - - def compile(self, optimizer="rmsprop", loss=None, **kwargs): - # Temporarily disable jit compilation on torch. - if config.backend() == "torch": - kwargs["jit_compile"] = False - self._check_for_loss_mismatch(loss) - super().compile(optimizer=optimizer, loss=loss, **kwargs) - def preprocess_samples(self, x, y=None, sample_weight=None): if self.preprocessor is not None: return self.preprocessor(x, y=y, sample_weight=sample_weight) diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index 63fd189c12..d9af5c4902 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -72,44 +72,3 @@ def test_summary_without_preprocessor(self): summary = [] model.summary(print_fn=lambda x, line_break: summary.append(x)) self.assertNotRegex("\n".join(summary), "Preprocessor:") - - def test_mismatched_loss(self): - # Logit output. - model = SimpleTask(activation=None) - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True) - ) - # Non-standard losses should not throw. - model.compile(loss="mean_squared_error") - with self.assertRaises(ValueError): - model.compile(loss="sparse_categorical_crossentropy") - with self.assertRaises(ValueError): - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=False - ) - ) - - # Probability output. - model = SimpleTask(activation="softmax") - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False) - ) - model.compile(loss="sparse_categorical_crossentropy") - # Non-standard losses should not throw. - model.compile(loss="mean_squared_error") - with self.assertRaises(ValueError): - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=True - ) - ) - - # Non-standard activations should not throw. - model = SimpleTask(activation="tanh") - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True) - ) - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False) - ) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py index 14d41d233c..bf078862a9 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py @@ -207,17 +207,6 @@ def __init__( self.hidden_dim = hidden_dim self.dropout = dropout - # === Default compilation === - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py index e687bac525..9abbbe5dc8 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py @@ -134,11 +134,3 @@ def __init__( outputs=outputs, **kwargs, ) - - # === Default compilation === - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - )