Skip to content

Commit

Permalink
Move preset-related logic of task saving and loading to preset_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 6, 2024
1 parent 88cc471 commit 616dabe
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 115 deletions.
172 changes: 59 additions & 113 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

from rich import console as rich_console
from rich import markup
from rich import table as rich_table
Expand All @@ -25,16 +22,15 @@
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import print_msg
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_from_preset
from keras_nlp.utils.preset_utils import recursive_pop
from keras_nlp.utils.preset_utils import load_task_from_preset
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -277,115 +273,75 @@ def from_preset(
f"Received: backbone={kwargs['backbone']}."
)

task_config_path = os.path.join(preset, TASK_CONFIG_FILE)
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
backbone_preset_cls = check_config_class(preset)
backbone_config_path = get_file(preset, CONFIG_FILE)
with open(backbone_config_path) as config_file:
backbone_config = json.load(config_file)

# Load preprocessor from preset.
preprocessor_config_path = os.path.join(
preprocessor_preset_cls = check_config_class(
preset, PREPROCESSOR_CONFIG_FILE
)
if os.path.exists(preprocessor_config_path):
preprocessor_preset_cls = check_config_class(
preset, PREPROCESSOR_CONFIG_FILE
)
if not issubclass(preprocessor_preset_cls, Preprocessor):
raise ValueError(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
)
preprocessor = preprocessor_preset_cls.from_preset(preset)
elif "preprocessor" in kwargs:
preprocessor = kwargs.pop("preprocessor")
else:
tokenizer = load_from_preset(
preset,
config_file="tokenizer.json",
if not issubclass(preprocessor_preset_cls, Preprocessor):
raise ValueError(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
)
preprocessor = cls.preprocessor_cls(tokenizer=tokenizer)
preprocessor = preprocessor_preset_cls.from_preset(preset)

# Backbone case.
if not os.path.exists(task_config_path) or not issubclass(
task_preset_cls, cls
):
if backbone_preset_cls is not cls.backbone_cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
filter(
lambda x: x.backbone_cls == backbone_preset_cls,
subclasses,
backbone_preset_cls = check_config_class(preset)
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
try:
get_file(preset, TASK_CONFIG_FILE)
except FileNotFoundError:
if not issubclass(task_preset_cls, cls):
if backbone_preset_cls is not cls.backbone_cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
filter(
lambda x: x.backbone_cls == backbone_preset_cls,
subclasses,
)
)
if len(subclasses) == 0:
raise ValueError(
f"No registered subclass of `{cls.__name__}` can load "
f"a `{backbone_preset_cls.__name__}`."
)
if len(subclasses) > 1:
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
raise ValueError(
f"Ambiguous call to `{cls.__name__}.from_preset()`. "
f"Found multiple possible subclasses {names}. "
"Please call `from_preset` on a subclass directly."
)
cls = subclasses[0]
# Forward dtype to the backbone.
config_overrides = {}
if "dtype" in kwargs:
config_overrides["dtype"] = kwargs.pop("dtype")
backbone = backbone_preset_cls.from_preset(
preset,
load_weights=load_weights,
config_overrides=config_overrides,
)
return cls(
backbone=backbone, preprocessor=preprocessor, **kwargs
)
if len(subclasses) == 0:
raise ValueError(
f"No registered subclass of `{cls.__name__}` can load "
f"a `{backbone_preset_cls.__name__}`."
)
if len(subclasses) > 1:
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
raise ValueError(
f"Ambiguous call to `{cls.__name__}.from_preset()`. "
f"Found multiple possible subclasses {names}. "
"Please call `from_preset` on a subclass directly."
)
cls = subclasses[0]
# Forward dtype to the backbone.
config_overrides = {}
if "dtype" in kwargs:
config_overrides["dtype"] = kwargs.pop("dtype")
backbone = load_from_preset(
preset,
load_weights=load_weights,
config_overrides=config_overrides,
)
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

# Load task from preset if it exists.
# TODO: I should probably move task loading logic to preset_utils.py?
if not issubclass(cls, Task):
raise ValueError(
"`{cls.__name__}` should be subclass of Task!"
) # TODO: update error message
task_config_class = check_config_class(
preset, config_file=TASK_CONFIG_FILE
)
if not issubclass(task_config_class, cls):
if not issubclass(task_preset_cls, cls):
raise ValueError(
f"`{TASK_CONFIG_FILE}` has type `{task_config_class.__name__}` "
f"`{TASK_CONFIG_FILE}` has type `{task_preset_cls.__name__}` "
f"which is not a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{task_config_class.__name__}` instead."
f"`from_preset` directly on `{task_preset_cls.__name__}` instead."
)

with open(task_config_path, "r") as config_file:
task_config = json.load(config_file)
# TODO: add back backbone and preprocessor config when save_to_preset removes them (rn, save_to_preset, doesn't remove them!).
# task_config.update(backbone_config)
# task_config.update(preprocessor_config)
task = keras.saving.deserialize_keras_object(task_config)
if load_weights:
if not task_config["weights"]:
raise ValueError(
f"`weights` config is missing from `{TASK_CONFIG_FILE}` in "
f"preset directory `{preset}`."
)
if not backbone_config["weights"]:
raise ValueError(
f"`weights` config is missing from `{CONFIG_FILE}` in "
f"preset directory `{preset}`."
)
task_weights_path = os.path.join(preset, task_config["weights"])
task.load_task_weights(task_weights_path)
backbone_weights_path = os.path.join(
preset, backbone_config["weights"]
)
task.backbone.load_weights(backbone_weights_path)
# TODO: is this assignment okay?
task.preprocessor = preprocessor
return task
task = load_task_from_preset(preset, TASK_CONFIG_FILE)
# TODO: should I avoid duplicating preprocessor memory too?
task.preprocessor = preprocessor
return task

def load_task_weights(self, filepath, skip_mismatch=False):
def load_weights(self, filepath, skip_mismatch=False):
"""Load only the tasks specific weights not in the backbone."""
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
Expand All @@ -408,7 +364,7 @@ def load_task_weights(self, filepath, skip_mismatch=False):
)
weights_store.close()

def save_task_weights(self, filepath):
def save_weights(self, filepath):
"""Save only the tasks specific weights not in the backbone."""
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
Expand All @@ -435,7 +391,6 @@ def save_task_weights(self, filepath):
)
weights_store.close()

# TODO: do we want to have a `save_weights` flag in this public save_to_preset? probably yes!
def save_to_preset(self, preset):
"""TODO: add docstring"""
if self.preprocessor is None:
Expand All @@ -445,21 +400,12 @@ def save_to_preset(self, preset):

self.preprocessor.save_to_preset(preset)
self.backbone.save_to_preset(preset)
weights_filename = TASK_WEIGHTS_FILE

# TODO: the serialization and saving logic should probably be moved to preset_utils.py
task_config_path = os.path.join(preset, TASK_CONFIG_FILE)
task_config = keras.saving.serialize_keras_object(self)
recursive_pop(task_config, "compile_config")
recursive_pop(task_config, "build_config")
# TODO: remove preprocessor and backbone from task.json to prevent redundancy in config files.
# recursive_pop(task_config, "preprocessor")
# recursive_pop(task_config, "backbone")
task_config["weights"] = weights_filename
with open(task_config_path, "w") as config_file:
config_file.write(json.dumps(task_config, indent=4))
task_weights_path = os.path.join(preset, weights_filename)
self.save_task_weights(task_weights_path)
save_to_preset(
self,
preset,
config_filename=TASK_CONFIG_FILE,
weights_filename=TASK_WEIGHTS_FILE,
)

@property
def layers(self):
Expand Down
46 changes: 44 additions & 2 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def list_subclasses(cls):

def get_file(preset, path):
"""Download a preset file in necessary and return the local path."""
# TODO: Through FileNotFoundError when the path doesn't exist.
if not isinstance(preset, str):
raise ValueError(
f"A preset identifier must be a string. Received: preset={preset}"
Expand Down Expand Up @@ -209,6 +210,10 @@ def save_to_preset(
config["weights"] = weights_filename if save_weights else None
recursive_pop(config, "compile_config")
recursive_pop(config, "build_config")
# Remove preprocessor and backbone from task.json to prevent redundancy in config files.
if config_filename == TASK_CONFIG_FILE:
recursive_pop(config, "preprocessor")
recursive_pop(config, "backbone")
with open(config_path, "w") as config_file:
config_file.write(json.dumps(config, indent=4))

Expand All @@ -217,7 +222,7 @@ def save_to_preset(
keras_version = keras.version() if hasattr(keras, "version") else None

# Save any associated metadata.
if config_filename == "config.json":
if config_filename == CONFIG_FILE:
metadata = {
"keras_version": keras_version,
"keras_nlp_version": keras_nlp_version,
Expand Down Expand Up @@ -402,9 +407,46 @@ def load_from_preset(
return layer


def load_task_from_preset(
preset,
load_weights=True,
task_config_file=TASK_CONFIG_FILE,
backbone_config_file=CONFIG_FILE,
config_overrides={},
):
# Load a serialized Keras object.
task_config = _get_config(preset, task_config_file)
task_config["config"] = {**task_config["config"], **config_overrides}
task = keras.saving.deserialize_keras_object(task_config)
backbone_config = _get_config(preset, backbone_config_file)
if load_weights:
if not task_config["weights"]:
raise ValueError(
f"`weights` config is missing from `{task_config_file}` in "
f"preset directory `{preset}`."
)
if not backbone_config["weights"]:
raise ValueError(
f"`weights` config is missing from `{backbone_config_file}` in "
f"preset directory `{preset}`."
)
task_weights_path = os.path.join(preset, task_config["weights"])
task.load_weights(task_weights_path)
backbone_weights_path = os.path.join(preset, backbone_config["weights"])
task.backbone.load_weights(backbone_weights_path)
return task


def _get_config(preset, config_file):
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
config = json.load(config_file)
return config


def check_config_class(
preset,
config_file="config.json",
config_file=CONFIG_FILE,
):
"""Validate a preset is being loaded on the correct class."""
config_path = get_file(preset, config_file)
Expand Down

0 comments on commit 616dabe

Please sign in to comment.