diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 30f8a53b16..7ddcefd52f 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -26,5 +26,6 @@ from keras_nlp import samplers from keras_nlp import tokenizers from keras_nlp import utils +from keras_nlp.utils import upload_preset from keras_nlp.version_utils import __version__ from keras_nlp.version_utils import version diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index bfdc8207ad..1ad9177aed 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -17,6 +17,7 @@ from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset +from keras_nlp.utils.preset_utils import save_to_preset from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -141,6 +142,12 @@ def from_preset( config_overrides=kwargs, ) + def save_to_preset( + self, + preset, + ): + save_to_preset(self, preset) + def __init_subclass__(cls, **kwargs): # Use __init_subclass__ to setup a correct docstring for from_preset. super().__init_subclass__(**kwargs) diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index 16a65e57c2..c825767ac8 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -18,6 +18,7 @@ ) from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset +from keras_nlp.utils.preset_utils import save_to_preset from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -96,6 +97,13 @@ def from_preset( ) return cls(tokenizer=tokenizer, **kwargs) + def save_to_preset( + self, + preset, + config_filename="tokenizer.json", + ): + save_to_preset(self, preset, config_filename=config_filename) + def __init_subclass__(cls, **kwargs): # Use __init_subclass__ to setup a correct docstring for from_preset. super().__init_subclass__(**kwargs) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 9957f6546f..96d94cac28 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -23,6 +23,7 @@ from keras_nlp.utils.pipeline_model import PipelineModel from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset +from keras_nlp.utils.preset_utils import save_to_preset from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -253,6 +254,12 @@ def from_preset( config_overrides=kwargs, ) + def save_to_preset( + self, + preset, + ): + save_to_preset(self, preset) + def __init_subclass__(cls, **kwargs): # Use __init_subclass__ to setup a correct docstring for from_preset. super().__init_subclass__(**kwargs) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 7da1e9d7b1..4c7d1d387b 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -18,6 +18,7 @@ from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) +from keras_nlp.utils.preset_utils import save_to_preset @keras_nlp_export("keras_nlp.tokenizers.Tokenizer") @@ -121,5 +122,12 @@ def token_to_id(self, token: str) -> int: f"{self.__class__.__name__}." ) + def save_to_preset( + self, + preset, + config_filename="tokenizer.json", + ): + save_to_preset(self, preset, config_filename=config_filename) + def call(self, inputs, *args, training=None, **kwargs): return self.tokenize(inputs, *args, **kwargs) diff --git a/keras_nlp/utils/__init__.py b/keras_nlp/utils/__init__.py index ba0c2545e4..be5af11363 100644 --- a/keras_nlp/utils/__init__.py +++ b/keras_nlp/utils/__init__.py @@ -11,3 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.utils.preset_utils import upload_preset diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 01c11a3db1..3d62d46d6d 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -155,6 +155,20 @@ def save_to_preset( metadata_file.write(json.dumps(metadata, indent=4)) +def upload_preset( + preset, + uri, +): + if uri.startswith(KAGGLE_PREFIX): + kaggle_handle = uri.removeprefix(KAGGLE_PREFIX) + kagglehub.model_upload(kaggle_handle, preset) + else: + raise ValueError( + f"Unexpected URI `'{uri}'`. " + f"URI prefix should be one of `'{','.join([KAGGLE_PREFIX])}'`." + ) + + def load_from_preset( preset, load_weights=True,