Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Mar 22, 2024
1 parent 42f43d9 commit 5beff57
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 39 deletions.
7 changes: 0 additions & 7 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.utils.preset_utils import check_preset_class
from keras_nlp.utils.preset_utils import load_from_preset
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring

Expand Down Expand Up @@ -254,12 +253,6 @@ def from_preset(
config_overrides=kwargs,
)

def save_to_preset(
self,
preset,
):
save_to_preset(self, preset)

def __init_subclass__(cls, **kwargs):
# Use __init_subclass__ to setup a correct docstring for from_preset.
super().__init_subclass__(**kwargs)
Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,8 @@ def token_to_id(self, token: str) -> int:
def save_to_preset(
self,
preset,
config_filename="tokenizer.json",
):
save_to_preset(self, preset, config_filename=config_filename)
save_to_preset(self, preset, config_filename="tokenizer.json")

def call(self, inputs, *args, training=None, **kwargs):
return self.tokenize(inputs, *args, **kwargs)
1 change: 0 additions & 1 deletion keras_nlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.utils.preset_utils import upload_preset
79 changes: 50 additions & 29 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
KAGGLE_PREFIX = "kaggle://"
GS_PREFIX = "gs://"
TOKENIZER_ASSET_DIR = "assets/tokenizer"
CONFIG_FILE = "config.json"
TOKENIZER_CONFIG_FILE = "tokenizer.json"


def get_file(preset, path):
Expand Down Expand Up @@ -158,10 +160,8 @@ def save_to_preset(
metadata_file.write(json.dumps(metadata, indent=4))


def _validate_tokenizer(
preset, config_file="tokenizer.json", allow_incomplete=False
):
config_path = get_file(preset, config_file)
def _validate_tokenizer(preset, allow_incomplete=False):
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
if not os.path.exists(config_path):
if allow_incomplete:
logging.warning(
Expand All @@ -177,23 +177,28 @@ def _validate_tokenizer(
config = json.load(config_file)
layer = keras.saving.deserialize_keras_object(config)

if not config["assets"]:
raise ValueError(
f"Tokenizer config file {config_path} is missing `asset`."
)

for asset in config["assets"]:
asset_path = os.path.join(preset, asset)
if not os.path.exists(asset_path):
raise FileNotFoundError(
f"Asset {asset} doesn't exist in the preset direcotry {preset}."
)
config_dir = os.path.dirname(config_path)
asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR)

tokenizer = get_tokenizer(layer)
if tokenizer and config["assets"]:
for asset in config["assets"]:
asset_path = os.path.join(preset, asset)
if not os.path.exists(asset_path):
raise FileNotFoundError(
f"Asset {asset} doesn't exist in the preset direcotry {preset}."
)
config_dir = os.path.dirname(config_path)
asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR)
tokenizer.load_assets(asset_dir)
else:
raise ValueError("Tokenizer or its asset are missing or invalid.")
if not tokenizer:
raise ValueError(f"Model or layer {layer} is missing tokenizer.")
tokenizer.load_assets(asset_dir)


def _validate_backbone(preset, config_file="config.json"):
config_path = get_file(preset, config_file)
def _validate_backbone(preset):
config_path = get_file(preset, CONFIG_FILE)
with open(config_path) as config_file:
config = json.load(config_file)
# Check if backbone is deserializable.
Expand All @@ -203,21 +208,25 @@ def _validate_backbone(preset, config_file="config.json"):
weights_path = os.path.join(preset, config["weights"])
if not os.path.exists(weights_path):
raise FileNotFoundError(
f"The weights file doesn't exist in preset directory {preset} ."
f"The weights file doesn't exist in preset directory `{preset}`."
)
else:
raise ValueError("there is no wieghts config!")
raise ValueError(
"No weights listed in `config.json`. Make sure to use "
"`save_to_preset()` which adds additional data to a serialized "
"Keras object."
)


def _validate_files(preset, backbone_config_file, tokenizer_config_file):
def _validate_files(preset):
# TODO: check if file sizes are reasonable.
# TODO: validate asset files.
backbone_config_path = get_file(preset, backbone_config_file)
backbone_config_path = get_file(preset, CONFIG_FILE)
with open(backbone_config_path) as config_file:
backbone_config = json.load(config_file)
valid_files = [
backbone_config_file,
tokenizer_config_file,
CONFIG_FILE,
TOKENIZER_CONFIG_FILE,
"metadata.json",
"assets",
backbone_config["weights"],
Expand All @@ -232,19 +241,31 @@ def _validate_files(preset, backbone_config_file, tokenizer_config_file):
def upload_preset(
uri,
preset,
config_file="config.json",
tokenizer_config_file="tokenizer.json",
allow_incomplete=False,
):
"""Upload a preset directory to a model hub.
Args:
uri: The URI identifying model to upload to.
URIs with format
`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`
will be uploaded to Kaggle Hub.
preset: The path to the local model preset directory.
allow_incomplete: If True, allows the upload of presets without
a tokenizer configuration. Otherwise, a tokenizer
is required.
"""

# Check if preset directory exists.
if not os.path.exists(preset):
raise FileNotFoundError(f"The preset directory {preset} doesn't exist.")

_validate_backbone(preset)
_validate_tokenizer(preset, allow_incomplete)
_validate_files(preset)

if uri.startswith(KAGGLE_PREFIX):
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
_validate_backbone(preset, config_file)
_validate_tokenizer(preset, tokenizer_config_file, allow_incomplete)
_validate_files(preset, config_file, tokenizer_config_file)
kagglehub.model_upload(kaggle_handle, preset)
else:
raise ValueError(
Expand Down

0 comments on commit 5beff57

Please sign in to comment.