Skip to content

Commit

Permalink
Merge branch 'keras-team:main' into Increase-test-coverage-keras_core…
Browse files Browse the repository at this point in the history
…-saving
  • Loading branch information
Faisal-Alsrheed authored Sep 16, 2023
2 parents a9a0a6b + 74a4e7f commit 22a8123
Show file tree
Hide file tree
Showing 7 changed files with 395 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/keras_io/vision/cct.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def call(self, x, training=None):
if training:
keep_prob = 1 - self.drop_prob
shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
random_tensor = keep_prob + keras.random.random.uniform(shape, 0, 1)
random_tensor = keep_prob + keras.random.uniform(shape, 0, 1)
random_tensor = keras.ops.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
Expand Down
62 changes: 46 additions & 16 deletions keras_core/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras_core import utils
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.legacy.saving import legacy_h5_format
from keras_core.models.variable_mapping import map_trackable_variables
from keras_core.saving import saving_api
from keras_core.saving import saving_lib
Expand Down Expand Up @@ -263,7 +264,7 @@ def summary(
)

@traceback_utils.filter_traceback
def save(self, filepath, overwrite=True, save_format="keras"):
def save(self, filepath, overwrite=True, **kwargs):
"""Saves a model as a `.keras` file.
Args:
Expand Down Expand Up @@ -301,22 +302,30 @@ def save(self, filepath, overwrite=True, save_format="keras"):
Thus models can be reinstantiated in the exact same state.
"""
if save_format in ["h5", "tf"]:
include_optimizer = kwargs.pop("include_optimizer", True)
save_format = kwargs.pop("save_format", None)
if kwargs:
raise ValueError(
"`'h5'` and `'t5'` formats are no longer supported via the "
"`save_format` option. Please use the new `'keras'` format. "
f"Received: save_format={save_format}"
)
if save_format not in ["keras", "keras_v3"]:
raise ValueError(
"Unknown `save_format` value. Only the `'keras'` format is "
f"currently supported. Received: save_format={save_format}"
)
if not str(filepath).endswith(".keras"):
raise ValueError(
"The filename must end in `.keras`. "
f"Received: filepath={filepath}"
"The following argument(s) are not supported: "
f"{list(kwargs.keys())}"
)
if save_format:
if str(filepath).endswith((".h5", ".hdf5")) or str(
filepath
).endswith(".keras"):
warnings.warn(
"The `save_format` argument is deprecated in Keras Core. "
"We recommend removing this argument as it can be inferred "
"from the file path. "
f"Received: save_format={save_format}"
)
else:
raise ValueError(
"The `save_format` argument is deprecated in Keras Core. "
"Please remove this argument and pass a file path with "
"either `.keras` or `.h5` extension."
f"Received: save_format={save_format}"
)
try:
exists = os.path.exists(filepath)
except TypeError:
Expand All @@ -325,7 +334,28 @@ def save(self, filepath, overwrite=True, save_format="keras"):
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
saving_lib.save_model(self, filepath)
if str(filepath).endswith(".keras"):
saving_lib.save_model(self, filepath)
elif str(filepath).endswith((".h5", ".hdf5")):
# Deprecation warnings
warnings.warn(
"You are saving your model as an HDF5 file via `model.save()`. "
"This file format is considered legacy. "
"We recommend using instead the native Keras format, "
"e.g. `model.save('my_model.keras')`."
)
legacy_h5_format.save_model_to_hdf5(
self, filepath, overwrite, include_optimizer
)
else:
raise ValueError(
"Invalid filepath extension for saving. "
"Please add either a `.keras` extension for the native Keras "
f"format (recommended) or a `.h5` extension. "
"Use `tf.saved_model.save()` if you want to export a "
"SavedModel for use with TFLite/TFServing/etc. "
f"Received: filepath={filepath}."
)

@traceback_utils.filter_traceback
def save_weights(self, filepath, overwrite=True):
Expand Down
48 changes: 38 additions & 10 deletions keras_core/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,6 @@ def test_metadata(self):
# self.assertIn(str(temp_filepath), mock_re_match.call_args.args)
# self.assertIn(str(temp_filepath), mock_copy.call_args.args)

def test_load_model_api_endpoint(self):
temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras"))
model = _get_basic_functional_model()
ref_input = np.random.random((2, 4))
ref_output = model.predict(ref_input)
model.save(temp_filepath)
model = keras_core.saving.load_model(temp_filepath)
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)

def test_save_load_weights_only(self):
temp_filepath = Path(
os.path.join(self.get_temp_dir(), "mymodel.weights.h5")
Expand Down Expand Up @@ -530,7 +521,10 @@ def test_partial_load(self):
np.array(new_model.layers[2].kernel), new_layer_kernel_value
)

def test_api_errors(self):

@pytest.mark.requires_trainable_backend
class SavingAPITest(testing.TestCase):
def test_saving_api_errors(self):
from keras_core.saving import saving_api

model = _get_basic_functional_model()
Expand Down Expand Up @@ -559,6 +553,40 @@ def test_api_errors(self):
with self.assertRaisesRegex(ValueError, "File format not supported"):
_ = saving_api.load_model(temp_filepath)

def test_model_api_endpoint(self):
temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras"))
model = _get_basic_functional_model()
ref_input = np.random.random((2, 4))
ref_output = model.predict(ref_input)
model.save(temp_filepath)
model = keras_core.saving.load_model(temp_filepath)
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)

def test_model_api_endpoint_h5(self):
temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.h5"))
model = _get_basic_functional_model()
ref_input = np.random.random((2, 4))
ref_output = model.predict(ref_input)
model.save(temp_filepath)
model = keras_core.saving.load_model(temp_filepath)
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)

def test_model_api_errors(self):
model = _get_basic_functional_model()

# Saving API errors
temp_filepath = os.path.join(self.get_temp_dir(), "mymodel")
with self.assertRaisesRegex(ValueError, "argument is deprecated"):
model.save(temp_filepath, save_format="keras")

temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.notkeras")
with self.assertRaisesRegex(ValueError, "Invalid filepath extension"):
model.save(temp_filepath)

temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras")
with self.assertRaisesRegex(ValueError, "are not supported"):
model.save(temp_filepath, invalid_arg="hello")


# def test_safe_mode(self):
# temp_filepath = os.path.join(self.get_temp_dir(), "unsafe_model.keras")
Expand Down
142 changes: 142 additions & 0 deletions keras_core/utils/code_stats_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
import sys
from io import StringIO

from keras_core.testing import test_case
from keras_core.utils.code_stats import count_loc


class TestCountLoc(test_case.TestCase):
def setUp(self):
self.test_dir = "test_directory"
os.makedirs(self.test_dir, exist_ok=True)

def tearDown(self):
for root, dirs, files in os.walk(self.test_dir, topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))

def create_file(self, filename, content):
with open(
os.path.join(self.test_dir, filename), "w", encoding="utf-8"
) as f:
f.write(content)

def test_count_loc_valid_python(self):
self.create_file(
"sample.py", "# This is a test file\n\nprint('Hello')\n"
)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 1)

def test_exclude_test_files(self):
self.create_file("sample_test.py", "print('Hello')\n")
loc = count_loc(self.test_dir, exclude=("_test",))
self.assertEqual(loc, 0)

def test_other_extensions(self):
self.create_file("sample.txt", "Hello\n")
loc = count_loc(self.test_dir, extensions=(".py",))
self.assertEqual(loc, 0)

def test_comment_lines(self):
self.create_file(
"sample.py", "# Comment\nprint('Hello')\n# Another comment\n"
)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 1)

def test_empty_file(self):
self.create_file("empty.py", "")
loc = count_loc(self.test_dir)
self.assertEqual(loc, 0)

def test_whitespace_only(self):
self.create_file("whitespace.py", " \n\t\n")
loc = count_loc(self.test_dir)
self.assertEqual(loc, 0)

def test_inline_comments_after_code(self):
content = 'print("Hello") # This is an inline comment'
self.create_file("inline_comment_sample.py", content)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 1) # The comment shouldn't affect the count

def test_directory_structure(self):
content1 = 'print("Hello from file1")'
content2 = 'print("Hello from file2")'
os.mkdir(os.path.join(self.test_dir, "subdir"))
self.create_file("sample1.py", content1)
self.create_file(os.path.join("subdir", "sample2.py"), content2)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 2) # Both files should be counted

def test_normal_directory_name(self):
content = 'print("Hello from a regular directory")'
os.makedirs(os.path.join(self.test_dir, "some_test_dir"))
self.create_file(os.path.join("some_test_dir", "sample.py"), content)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 1) # Should count normally

def test_exclude_directory_name(self):
content = 'print("Hello from an excluded directory")'
os.makedirs(os.path.join(self.test_dir, "dir_test"))
self.create_file(os.path.join("dir_test", "sample.py"), content)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 0)
# Shouldn't count the file in dir_test due to the exclusion pattern

def test_verbose_output(self):
content = 'print("Hello")'
self.create_file("sample.py", content)
original_stdout = sys.stdout
sys.stdout = StringIO()
count_loc(self.test_dir, verbose=1)
output = sys.stdout.getvalue()
sys.stdout = original_stdout
self.assertIn("Count LoCs in", output)

def test_multiline_string_same_line(self):
content = '''"""This is a multiline string ending on the same line"""
print("Outside string")'''
self.create_file("same_line_multiline.py", content)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 1) # Only the print statement should count

def test_multiline_string_ends_on_same_line(self):
content = '"""a multiline string end on same line"""\nprint("Outstr")'
self.create_file("same_line_multiline.py", content)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 1) # Only the print statement should count

def test_multiline_string_ends_in_middle_of_line(self):
content = '''print("Start")
"""This is a multiline string ending in the middle of a line"""
"""This is another multiline string."""
print("End")'''
self.create_file("multiline_in_middle.py", content)
loc = count_loc(self.test_dir)
self.assertEqual(loc, 2) # Both print statements should count

def test_line_starting_with_triple_quotes_not_ending(self):
content = '"""\nThis is a multiline string\n'
self.create_file("test_file_2.py", content)
path = os.path.join(self.test_dir, "test_file_2.py")
self.assertEqual(count_loc(path), 0)
# Because it's part of a multiline string

def test_line_starting_and_ending_with_triple_quotes(self):
content = '"""This is a one-liner docstring."""\n'
self.create_file("test_file_3.py", content)
path = os.path.join(self.test_dir, "test_file_3.py")
self.assertEqual(count_loc(path), 0)
# This is still considered a comment/docstring

def test_string_open_true_line_starting_with_triple_quotes(self):
content = '"""\nEnd of the multiline string."""\n'
self.create_file("test_file_4.py", content)
path = os.path.join(self.test_dir, "test_file_4.py")
self.assertEqual(count_loc(path), 0)
# Entire content is a multiline string/comment
23 changes: 12 additions & 11 deletions keras_core/utils/dtype_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from keras_core import backend
from keras_core import ops

DTYPE_TO_SIZE = {
**{f"float{i}": i for i in (16, 32, 64)},
**{f"int{i}": i for i in (8, 16, 32, 64)},
**{f"uint{i}": i for i in (8, 16, 32, 64)},
"bfloat16": 16,
"bool": 1,
}


def dtype_size(dtype):
if dtype in ("bfloat16", "float16"):
return 16
if dtype in ("float32", "int32"):
return 32
if dtype in ("float64", "int64"):
return 64
if dtype == "uint8":
return 8
if dtype == "bool":
return 1
raise ValueError(f"Invalid dtype: {dtype}")
size = DTYPE_TO_SIZE.get(dtype, None)
if size is None:
raise ValueError(f"Invalid dtype: {dtype}")
return size


def is_float(dtype):
Expand Down
Loading

0 comments on commit 22a8123

Please sign in to comment.