Skip to content

Commit

Permalink
Remove preset saving and loading tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 10, 2024
1 parent c238621 commit f89d795
Showing 1 changed file with 0 additions and 71 deletions.
71 changes: 0 additions & 71 deletions keras_nlp/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import pytest
Expand All @@ -21,83 +20,13 @@
from keras_nlp import upload_preset
from keras_nlp.models import AlbertClassifier
from keras_nlp.models import BertBackbone
from keras_nlp.models import BertClassifier
from keras_nlp.models import BertTokenizer
from keras_nlp.models import RobertaClassifier
from keras_nlp.tests.test_case import TestCase
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import load_from_preset
from keras_nlp.utils.preset_utils import save_to_preset


class PresetUtilsTest(TestCase):
@parameterized.parameters(
(AlbertClassifier, "albert_base_en_uncased", "sentencepiece"),
(RobertaClassifier, "roberta_base_en", "bytepair"),
(BertClassifier, "bert_tiny_en_uncased", "wordpiece"),
)
@pytest.mark.keras_3_only
@pytest.mark.large
def test_preset_saving(self, cls, preset_name, tokenizer_type):
save_dir = self.get_temp_dir()
model = cls.from_preset(preset_name, num_classes=2)
save_to_preset(model, save_dir)

if tokenizer_type == "bytepair":
vocab_filename = "assets/tokenizer/vocabulary.json"
expected_assets = [
"assets/tokenizer/vocabulary.json",
"assets/tokenizer/merges.txt",
]
elif tokenizer_type == "sentencepiece":
vocab_filename = "assets/tokenizer/vocabulary.spm"
expected_assets = ["assets/tokenizer/vocabulary.spm"]
else:
vocab_filename = "assets/tokenizer/vocabulary.txt"
expected_assets = ["assets/tokenizer/vocabulary.txt"]

# Check existence of files
self.assertTrue(os.path.exists(os.path.join(save_dir, vocab_filename)))
self.assertTrue(os.path.exists(os.path.join(save_dir, "config.json")))
self.assertTrue(
os.path.exists(os.path.join(save_dir, "model.weights.h5"))
)
self.assertTrue(os.path.exists(os.path.join(save_dir, "metadata.json")))

# Check the model config (`config.json`)
config_json = open(os.path.join(save_dir, "config.json"), "r").read()
self.assertTrue(
"build_config" not in config_json
) # Test on raw json to include nested keys
self.assertTrue(
"compile_config" not in config_json
) # Test on raw json to include nested keys
config = json.loads(config_json)
self.assertEqual(set(config["assets"]), set(expected_assets))
self.assertEqual(config["weights"], "model.weights.h5")

# Try loading the model from preset directory
self.assertEqual(cls, check_config_class(save_dir))

# Try loading the model from preset directory
restored_model = load_from_preset(save_dir)

train_data = (
["the quick brown fox.", "the slow brown fox."], # Features.
)
model_input_data = model.preprocessor(*train_data)
restored_model_input_data = restored_model.preprocessor(*train_data)

# Check that saved vocab is equal to the original preset vocab
self.assertAllClose(model_input_data, restored_model_input_data)

# Check model outputs
self.assertAllEqual(
model(model_input_data), restored_model(restored_model_input_data)
)

def test_preset_errors(self):
with self.assertRaisesRegex(ValueError, "must be a string"):
AlbertClassifier.from_preset(AlbertClassifier)
Expand Down

0 comments on commit f89d795

Please sign in to comment.