From 318af3e3eb4053416db5b2a13f2fb16acf3910aa Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:09:38 -0700 Subject: [PATCH 01/20] Don't show optimizer params if not optimizer or not built (#885) Most people will call `summary()` before `fit()`. Doing so will usually mean the optimizer would show up as 0 parameters, which is misleading. --- keras_core/utils/summary_utils.py | 13 ++++++++----- keras_core/utils/summary_utils_test.py | 1 + 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/keras_core/utils/summary_utils.py b/keras_core/utils/summary_utils.py index e7e5ddf94..06259fb4c 100644 --- a/keras_core/utils/summary_utils.py +++ b/keras_core/utils/summary_utils.py @@ -311,9 +311,11 @@ def print_layer(layer, nested_level=0): if model.compiled and model.optimizer and model.optimizer.built: optimizer_weight_count = count_params(model.optimizer.variables) optimizer_memory_size = weight_memory_size(model.optimizer.variables) + optimizer_built = True else: optimizer_weight_count = 0 optimizer_memory_size = 0 + optimizer_built = False total_count = trainable_count + non_trainable_count + optimizer_weight_count total_memory_size = ( @@ -349,11 +351,12 @@ def print_layer(layer, nested_level=0): + highlight_number(f"{non_trainable_count:,}") + f" ({readable_memory_size(non_trainable_memory_size)})" ) - console.print( - bold_text(" Optimizer params: ") - + highlight_number(f"{optimizer_weight_count:,}") - + f" ({readable_memory_size(optimizer_memory_size)})" - ) + if optimizer_built: + console.print( + bold_text(" Optimizer params: ") + + highlight_number(f"{optimizer_weight_count:,}") + + f" ({readable_memory_size(optimizer_memory_size)})" + ) # Output captured summary for non-interactive logging. if print_fn: diff --git a/keras_core/utils/summary_utils_test.py b/keras_core/utils/summary_utils_test.py index db2890ce7..4570fceca 100644 --- a/keras_core/utils/summary_utils_test.py +++ b/keras_core/utils/summary_utils_test.py @@ -37,5 +37,6 @@ def print_to_variable(text, line_break=False): self.assertIn("Total params: 9", summary_content) self.assertIn("Trainable params: 9", summary_content) self.assertIn("Non-trainable params: 0", summary_content) + self.assertNotIn("Optimizer params", summary_content) except ImportError: pass From fcdf2a456d4554fa3487c01bf16af9afc1a93a34 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:11:05 -0700 Subject: [PATCH 02/20] Export dtype utils in `keras_core.backend` (#886) Making multi backend dtype comparisons for code based on keras-core is tricky. Torch has dtype classes such that `torch_tensor.dtype is string_dtype` will always be false. We use `standardize_dtype` extensively in our own code and should expose it to clients. `is_float_dtype` and `is_int_dtype` we can include as useful helpers. --- keras_core/backend/common/variables.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras_core/backend/common/variables.py b/keras_core/backend/common/variables.py index 9050a769f..0f0c1e23d 100644 --- a/keras_core/backend/common/variables.py +++ b/keras_core/backend/common/variables.py @@ -1,5 +1,6 @@ import numpy as np +from keras_core.api_export import keras_core_export from keras_core.backend import config from keras_core.backend.common import global_state from keras_core.backend.common.name_scope import current_path @@ -398,6 +399,7 @@ def initialize_all_variables(): } +@keras_core_export("keras_core.backend.standardize_dtype") def standardize_dtype(dtype): if dtype is None: return config.floatx() @@ -454,11 +456,13 @@ def shape_equal(a, b): return True +@keras_core_export("keras_core.backend.is_float_dtype") def is_float_dtype(dtype): dtype = standardize_dtype(dtype) return dtype.startswith("float") or dtype.startswith("bfloat") +@keras_core_export("keras_core.backend.is_int_dtype") def is_int_dtype(dtype): dtype = standardize_dtype(dtype) return dtype.startswith("int") or dtype.startswith("uint") From 94b5361f37351e060e7ec4eb16e268500d057c7e Mon Sep 17 00:00:00 2001 From: Pedro Kaj Kjellerup Nacht Date: Thu, 14 Sep 2023 19:53:14 -0300 Subject: [PATCH 03/20] Set actions.yml to run with read-only permissions (#882) Signed-off-by: Pedro Kaj Kjellerup Nacht --- .github/workflows/actions.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index ddf53f2ec..e08885550 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -6,6 +6,10 @@ on: pull_request: release: types: [created] + +permissions: + contents: read + jobs: build: strategy: From a94aea04eb5e10dd0ea1097c1b45defdaff004df Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:55:53 -0700 Subject: [PATCH 04/20] Fix `model.save` API for legacy H5 format (#891) * Add more H5 backwards compatibility tests * Fix API for legacy H5 format * Fix formatting * Reorganize, add tests, fix nits --- keras_core/models/model.py | 62 +++++++++++++++++++++------- keras_core/saving/saving_lib_test.py | 48 ++++++++++++++++----- 2 files changed, 84 insertions(+), 26 deletions(-) diff --git a/keras_core/models/model.py b/keras_core/models/model.py index c467bbc0a..f2f18868a 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -7,6 +7,7 @@ from keras_core import utils from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer +from keras_core.legacy.saving import legacy_h5_format from keras_core.models.variable_mapping import map_trackable_variables from keras_core.saving import saving_api from keras_core.saving import saving_lib @@ -263,7 +264,7 @@ def summary( ) @traceback_utils.filter_traceback - def save(self, filepath, overwrite=True, save_format="keras"): + def save(self, filepath, overwrite=True, **kwargs): """Saves a model as a `.keras` file. Args: @@ -301,22 +302,30 @@ def save(self, filepath, overwrite=True, save_format="keras"): Thus models can be reinstantiated in the exact same state. """ - if save_format in ["h5", "tf"]: + include_optimizer = kwargs.pop("include_optimizer", True) + save_format = kwargs.pop("save_format", None) + if kwargs: raise ValueError( - "`'h5'` and `'t5'` formats are no longer supported via the " - "`save_format` option. Please use the new `'keras'` format. " - f"Received: save_format={save_format}" - ) - if save_format not in ["keras", "keras_v3"]: - raise ValueError( - "Unknown `save_format` value. Only the `'keras'` format is " - f"currently supported. Received: save_format={save_format}" - ) - if not str(filepath).endswith(".keras"): - raise ValueError( - "The filename must end in `.keras`. " - f"Received: filepath={filepath}" + "The following argument(s) are not supported: " + f"{list(kwargs.keys())}" ) + if save_format: + if str(filepath).endswith((".h5", ".hdf5")) or str( + filepath + ).endswith(".keras"): + warnings.warn( + "The `save_format` argument is deprecated in Keras Core. " + "We recommend removing this argument as it can be inferred " + "from the file path. " + f"Received: save_format={save_format}" + ) + else: + raise ValueError( + "The `save_format` argument is deprecated in Keras Core. " + "Please remove this argument and pass a file path with " + "either `.keras` or `.h5` extension." + f"Received: save_format={save_format}" + ) try: exists = os.path.exists(filepath) except TypeError: @@ -325,7 +334,28 @@ def save(self, filepath, overwrite=True, save_format="keras"): proceed = io_utils.ask_to_proceed_with_overwrite(filepath) if not proceed: return - saving_lib.save_model(self, filepath) + if str(filepath).endswith(".keras"): + saving_lib.save_model(self, filepath) + elif str(filepath).endswith((".h5", ".hdf5")): + # Deprecation warnings + warnings.warn( + "You are saving your model as an HDF5 file via `model.save()`. " + "This file format is considered legacy. " + "We recommend using instead the native Keras format, " + "e.g. `model.save('my_model.keras')`." + ) + legacy_h5_format.save_model_to_hdf5( + self, filepath, overwrite, include_optimizer + ) + else: + raise ValueError( + "Invalid filepath extension for saving. " + "Please add either a `.keras` extension for the native Keras " + f"format (recommended) or a `.h5` extension. " + "Use `tf.saved_model.save()` if you want to export a " + "SavedModel for use with TFLite/TFServing/etc. " + f"Received: filepath={filepath}." + ) @traceback_utils.filter_traceback def save_weights(self, filepath, overwrite=True): diff --git a/keras_core/saving/saving_lib_test.py b/keras_core/saving/saving_lib_test.py index f790ccb2d..06e0572eb 100644 --- a/keras_core/saving/saving_lib_test.py +++ b/keras_core/saving/saving_lib_test.py @@ -402,15 +402,6 @@ def test_metadata(self): # self.assertIn(str(temp_filepath), mock_re_match.call_args.args) # self.assertIn(str(temp_filepath), mock_copy.call_args.args) - def test_load_model_api_endpoint(self): - temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras")) - model = _get_basic_functional_model() - ref_input = np.random.random((2, 4)) - ref_output = model.predict(ref_input) - model.save(temp_filepath) - model = keras_core.saving.load_model(temp_filepath) - self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) - def test_save_load_weights_only(self): temp_filepath = Path( os.path.join(self.get_temp_dir(), "mymodel.weights.h5") @@ -530,7 +521,10 @@ def test_partial_load(self): np.array(new_model.layers[2].kernel), new_layer_kernel_value ) - def test_api_errors(self): + +@pytest.mark.requires_trainable_backend +class SavingAPITest(testing.TestCase): + def test_saving_api_errors(self): from keras_core.saving import saving_api model = _get_basic_functional_model() @@ -559,6 +553,40 @@ def test_api_errors(self): with self.assertRaisesRegex(ValueError, "File format not supported"): _ = saving_api.load_model(temp_filepath) + def test_model_api_endpoint(self): + temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras")) + model = _get_basic_functional_model() + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + model.save(temp_filepath) + model = keras_core.saving.load_model(temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + def test_model_api_endpoint_h5(self): + temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.h5")) + model = _get_basic_functional_model() + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + model.save(temp_filepath) + model = keras_core.saving.load_model(temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + def test_model_api_errors(self): + model = _get_basic_functional_model() + + # Saving API errors + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel") + with self.assertRaisesRegex(ValueError, "argument is deprecated"): + model.save(temp_filepath, save_format="keras") + + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.notkeras") + with self.assertRaisesRegex(ValueError, "Invalid filepath extension"): + model.save(temp_filepath) + + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + with self.assertRaisesRegex(ValueError, "are not supported"): + model.save(temp_filepath, invalid_arg="hello") + # def test_safe_mode(self): # temp_filepath = os.path.join(self.get_temp_dir(), "unsafe_model.keras") From 4cd003b987fcac04e7e1871309484b219b5aca25 Mon Sep 17 00:00:00 2001 From: gleize Date: Fri, 15 Sep 2023 15:26:14 -0700 Subject: [PATCH 05/20] Add missing dtypes. (#887) * Add missing dtypes. My model contains some `uint` non trainable variables that triggered an exception when calling the `model.summary()` method. I added the missing dtypes. * Fix linting issue. --- keras_core/utils/dtype_utils.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/keras_core/utils/dtype_utils.py b/keras_core/utils/dtype_utils.py index adba74015..9c2992f44 100644 --- a/keras_core/utils/dtype_utils.py +++ b/keras_core/utils/dtype_utils.py @@ -1,19 +1,20 @@ from keras_core import backend from keras_core import ops +DTYPE_TO_SIZE = { + **{f"float{i}": i for i in (16, 32, 64)}, + **{f"int{i}": i for i in (8, 16, 32, 64)}, + **{f"uint{i}": i for i in (8, 16, 32, 64)}, + "bfloat16": 16, + "bool": 1, +} + def dtype_size(dtype): - if dtype in ("bfloat16", "float16"): - return 16 - if dtype in ("float32", "int32"): - return 32 - if dtype in ("float64", "int64"): - return 64 - if dtype == "uint8": - return 8 - if dtype == "bool": - return 1 - raise ValueError(f"Invalid dtype: {dtype}") + size = DTYPE_TO_SIZE.get(dtype, None) + if size is None: + raise ValueError(f"Invalid dtype: {dtype}") + return size def is_float(dtype): From 7486af92e4d52a8c88aaaedeeebc8fa55143e8d0 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 01:26:59 +0300 Subject: [PATCH 06/20] Increase tests in `utils` (#880) * Increase tests in utils + Fix Bug in `count_loc` * Increase tests in utils + Fix Bug in `count_loc` * Increase tests in utils + Fix Bug in `count_loc` * Increase tests in utils + Fix Bug in `count_loc` * Increase tests in utils + Fix Bug in count_loc * Increase tests in utils * Increase tests in utils * Increase tests in `utils` * Increase tests in utils * Increase tests in utils * Increase tests in `utils` * Increase tests in utils * Increase tests in utils * Increase tests in utils * Increase tests in `utils`+ Improve `to_snake_case` * Increase tests in `utils`+ Improve `to_snake_case` * Increase tests in `utils`+ Improve `to_snake_case` * Increase tests in `utils` * Increase tests in `utils` * Increase tests in `utils` --- keras_core/utils/code_stats_test.py | 142 ++++++++++++++++++++++++++++ keras_core/utils/io_utils_test.py | 57 +++++++++++ keras_core/utils/naming_test.py | 99 +++++++++++++++++++ 3 files changed, 298 insertions(+) create mode 100644 keras_core/utils/code_stats_test.py create mode 100644 keras_core/utils/io_utils_test.py diff --git a/keras_core/utils/code_stats_test.py b/keras_core/utils/code_stats_test.py new file mode 100644 index 000000000..512303abc --- /dev/null +++ b/keras_core/utils/code_stats_test.py @@ -0,0 +1,142 @@ +import os +import sys +from io import StringIO + +from keras_core.testing import test_case +from keras_core.utils.code_stats import count_loc + + +class TestCountLoc(test_case.TestCase): + def setUp(self): + self.test_dir = "test_directory" + os.makedirs(self.test_dir, exist_ok=True) + + def tearDown(self): + for root, dirs, files in os.walk(self.test_dir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + + def create_file(self, filename, content): + with open( + os.path.join(self.test_dir, filename), "w", encoding="utf-8" + ) as f: + f.write(content) + + def test_count_loc_valid_python(self): + self.create_file( + "sample.py", "# This is a test file\n\nprint('Hello')\n" + ) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 1) + + def test_exclude_test_files(self): + self.create_file("sample_test.py", "print('Hello')\n") + loc = count_loc(self.test_dir, exclude=("_test",)) + self.assertEqual(loc, 0) + + def test_other_extensions(self): + self.create_file("sample.txt", "Hello\n") + loc = count_loc(self.test_dir, extensions=(".py",)) + self.assertEqual(loc, 0) + + def test_comment_lines(self): + self.create_file( + "sample.py", "# Comment\nprint('Hello')\n# Another comment\n" + ) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 1) + + def test_empty_file(self): + self.create_file("empty.py", "") + loc = count_loc(self.test_dir) + self.assertEqual(loc, 0) + + def test_whitespace_only(self): + self.create_file("whitespace.py", " \n\t\n") + loc = count_loc(self.test_dir) + self.assertEqual(loc, 0) + + def test_inline_comments_after_code(self): + content = 'print("Hello") # This is an inline comment' + self.create_file("inline_comment_sample.py", content) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 1) # The comment shouldn't affect the count + + def test_directory_structure(self): + content1 = 'print("Hello from file1")' + content2 = 'print("Hello from file2")' + os.mkdir(os.path.join(self.test_dir, "subdir")) + self.create_file("sample1.py", content1) + self.create_file(os.path.join("subdir", "sample2.py"), content2) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 2) # Both files should be counted + + def test_normal_directory_name(self): + content = 'print("Hello from a regular directory")' + os.makedirs(os.path.join(self.test_dir, "some_test_dir")) + self.create_file(os.path.join("some_test_dir", "sample.py"), content) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 1) # Should count normally + + def test_exclude_directory_name(self): + content = 'print("Hello from an excluded directory")' + os.makedirs(os.path.join(self.test_dir, "dir_test")) + self.create_file(os.path.join("dir_test", "sample.py"), content) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 0) + # Shouldn't count the file in dir_test due to the exclusion pattern + + def test_verbose_output(self): + content = 'print("Hello")' + self.create_file("sample.py", content) + original_stdout = sys.stdout + sys.stdout = StringIO() + count_loc(self.test_dir, verbose=1) + output = sys.stdout.getvalue() + sys.stdout = original_stdout + self.assertIn("Count LoCs in", output) + + def test_multiline_string_same_line(self): + content = '''"""This is a multiline string ending on the same line""" + print("Outside string")''' + self.create_file("same_line_multiline.py", content) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 1) # Only the print statement should count + + def test_multiline_string_ends_on_same_line(self): + content = '"""a multiline string end on same line"""\nprint("Outstr")' + self.create_file("same_line_multiline.py", content) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 1) # Only the print statement should count + + def test_multiline_string_ends_in_middle_of_line(self): + content = '''print("Start") + """This is a multiline string ending in the middle of a line""" + """This is another multiline string.""" + print("End")''' + self.create_file("multiline_in_middle.py", content) + loc = count_loc(self.test_dir) + self.assertEqual(loc, 2) # Both print statements should count + + def test_line_starting_with_triple_quotes_not_ending(self): + content = '"""\nThis is a multiline string\n' + self.create_file("test_file_2.py", content) + path = os.path.join(self.test_dir, "test_file_2.py") + self.assertEqual(count_loc(path), 0) + # Because it's part of a multiline string + + def test_line_starting_and_ending_with_triple_quotes(self): + content = '"""This is a one-liner docstring."""\n' + self.create_file("test_file_3.py", content) + path = os.path.join(self.test_dir, "test_file_3.py") + self.assertEqual(count_loc(path), 0) + # This is still considered a comment/docstring + + def test_string_open_true_line_starting_with_triple_quotes(self): + content = '"""\nEnd of the multiline string."""\n' + self.create_file("test_file_4.py", content) + path = os.path.join(self.test_dir, "test_file_4.py") + self.assertEqual(count_loc(path), 0) + # Entire content is a multiline string/comment diff --git a/keras_core/utils/io_utils_test.py b/keras_core/utils/io_utils_test.py new file mode 100644 index 000000000..bf57f5ca2 --- /dev/null +++ b/keras_core/utils/io_utils_test.py @@ -0,0 +1,57 @@ +from unittest.mock import patch + +from keras_core.testing import test_case +from keras_core.utils import io_utils + + +class TestIoUtils(test_case.TestCase): + def test_enable_interactive_logging(self): + io_utils.enable_interactive_logging() + self.assertTrue(io_utils.is_interactive_logging_enabled()) + + def test_disable_interactive_logging(self): + io_utils.disable_interactive_logging() + self.assertFalse(io_utils.is_interactive_logging_enabled()) + + def test_set_logging_verbosity_valid(self): + valid_levels = ["FATAL", "ERROR", "WARNING", "INFO", "DEBUG"] + for level in valid_levels: + io_utils.set_logging_verbosity(level) + + def test_set_logging_verbosity_invalid(self): + with self.assertRaises(ValueError): + io_utils.set_logging_verbosity("INVALID") + + @patch("builtins.input", side_effect=["y"]) + def test_ask_to_proceed_with_overwrite_yes(self, _): + self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + + @patch("builtins.input", side_effect=["n"]) + def test_ask_to_proceed_with_overwrite_no(self, _): + self.assertFalse(io_utils.ask_to_proceed_with_overwrite("test_path")) + + @patch("sys.stdout.write") + def test_print_msg_interactive_with_line_break(self, mock_write): + io_utils.enable_interactive_logging() + io_utils.print_msg("Hello", line_break=True) + mock_write.assert_called_once_with("Hello\n") + + @patch("sys.stdout.write") + def test_print_msg_interactive_without_line_break(self, mock_write): + io_utils.enable_interactive_logging() + io_utils.print_msg("Hello", line_break=False) + mock_write.assert_called_once_with("Hello") + + @patch("absl.logging.info") + def test_print_msg_non_interactive(self, mock_logging): + io_utils.disable_interactive_logging() + io_utils.print_msg("Hello") + mock_logging.assert_called_once_with("Hello") + + @patch("builtins.input", side_effect=["invalid", "invalid", "y"]) + def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _): + self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + + @patch("builtins.input", side_effect=["invalid", "n"]) + def test_ask_to_proceed_with_overwrite_invalid_then_no(self, _): + self.assertFalse(io_utils.ask_to_proceed_with_overwrite("test_path")) diff --git a/keras_core/utils/naming_test.py b/keras_core/utils/naming_test.py index b6e9552a9..40e9fa8fe 100644 --- a/keras_core/utils/naming_test.py +++ b/keras_core/utils/naming_test.py @@ -3,6 +3,11 @@ class NamingUtilsTest(test_case.TestCase): + def test_uniquify_unique_name(self): + name = "the_unique_name" + unique_name = naming.uniquify(name) + self.assertEqual(unique_name, name) + def test_auto_name(self): self.assertEqual(naming.auto_name("unique_name"), "unique_name") self.assertEqual(naming.auto_name("unique_name"), "unique_name_1") @@ -12,3 +17,97 @@ def test_get_uid(self): self.assertEqual(naming.get_uid("very_unique_name"), 1) self.assertEqual(naming.get_uid("very_unique_name"), 2) self.assertEqual(naming.get_uid("very_unique_name"), 3) + + def test_uniquify_non_unique_name(self): + name = "non_unique_name" + naming.uniquify(name) + unique_name = naming.uniquify(name) + self.assertEqual(unique_name, name + "_1") + + def test_to_snake_case_snake_case_name(self): + name = "snake_case_name" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, name) + + def test_get_uid_existing_prefix(self): + prefix = "existing_prefix" + naming.get_uid(prefix) + uid = naming.get_uid(prefix) + self.assertEqual(uid, 2) + + def test_reset_uids(self): + naming.get_uid("unique_name") + naming.reset_uids() + uid = naming.get_uid("unique_name") + self.assertEqual(uid, 1) + + def test_get_object_name_no_name_attribute(self): + class ObjectWithoutName: + __name__ = "ObjectWithoutName" + + obj = ObjectWithoutName() + object_name = naming.get_object_name(obj) + self.assertEqual(object_name, "object_without_name") + + def test_get_object_name_no_name_or_class_attribute(self): + class ObjectWithoutNameOrClass: + pass + + obj = ObjectWithoutNameOrClass() + object_name = naming.get_object_name(obj) + self.assertEqual(object_name, "object_without_name_or_class") + + def test_uniquify_already_uniquified_name(self): + name = "unique_name" + unique_name = naming.uniquify(name) + new_unique_name = naming.uniquify(unique_name) + self.assertEqual(new_unique_name, unique_name) + + def test_to_snake_case_capital_after_any_character(self): + name = "myVariableNameHere" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, "my_variable_name_here") + + def test_to_snake_case_lower_before_upper(self): + name = "convertTHIS" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, "convert_this") + + def test_to_snake_case_already_snake_cased(self): + name = "already_snake_cased" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, name) + + def test_to_snake_case_no_changes(self): + name = "lowercase" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, name) + + def test_to_snake_case_single_uppercase_word(self): + name = "UPPERCASE" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, "uppercase") + + def test_get_object_name_for_keras_objects(self): + class MockKerasObject: + name = "mock_object" + + obj = MockKerasObject() + result = naming.get_object_name(obj) + self.assertEqual( + result, "mock_object", f"Expected 'mock_object' but got {result}" + ) + + # Test for function objects that have a `__name__` attribute. + def test_get_object_name_for_functions(self): + def mock_function(): + pass + + result = naming.get_object_name(mock_function) + # Assumes to_snake_case works correctly. + expected_name = naming.to_snake_case(mock_function.__name__) + self.assertEqual( + result, + expected_name, + f"Expected '{expected_name}' but got {result}", + ) From 74a4e7fb23ee1553fd6808907a8f59f1386ef3a6 Mon Sep 17 00:00:00 2001 From: Muhammad Anas Raza <63569834+anas-rz@users.noreply.github.com> Date: Sat, 16 Sep 2023 01:14:58 -0400 Subject: [PATCH 07/20] syntax fix (#898) --- examples/keras_io/vision/cct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/keras_io/vision/cct.py b/examples/keras_io/vision/cct.py index c1b192e1f..59a32200b 100644 --- a/examples/keras_io/vision/cct.py +++ b/examples/keras_io/vision/cct.py @@ -256,7 +256,7 @@ def call(self, x, training=None): if training: keep_prob = 1 - self.drop_prob shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1) - random_tensor = keep_prob + keras.random.random.uniform(shape, 0, 1) + random_tensor = keep_prob + keras.random.uniform(shape, 0, 1) random_tensor = keras.ops.floor(random_tensor) return (x / keep_prob) * random_tensor return x From c663efd6560a6cdc0a44ada5e588a7b4da0cb801 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Sun, 17 Sep 2023 00:26:32 +0800 Subject: [PATCH 08/20] Use `ops.rsqrt`, improve normalization layers and enable ops fusion in tflite (#892) * Add `rsqrt` to numpy backend * Improve normalization * Fix order bug * Update LayerNormalization * Improve unit test coverage * Use np native --- keras_core/backend/numpy/math.py | 4 ++ .../normalization/batch_normalization.py | 31 ++++++++----- .../normalization/group_normalization.py | 33 +++++--------- .../normalization/group_normalization_test.py | 45 +++++++++++++++++++ .../normalization/layer_normalization.py | 38 +++++++--------- .../normalization/layer_normalization_test.py | 10 +++++ .../spectral_normalization_test.py | 26 +++++++++++ .../normalization/unit_normalization.py | 2 +- .../normalization/unit_normalization_test.py | 10 +++++ keras_core/ops/math_test.py | 4 -- 10 files changed, 143 insertions(+), 60 deletions(-) diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py index fb51f44e2..2e1c2cfcd 100644 --- a/keras_core/backend/numpy/math.py +++ b/keras_core/backend/numpy/math.py @@ -298,3 +298,7 @@ def istft( else: end = expected_output_len return x[..., start:end] + + +def rsqrt(x): + return 1.0 / np.sqrt(x) diff --git a/keras_core/layers/normalization/batch_normalization.py b/keras_core/layers/normalization/batch_normalization.py index 0812f6843..5bb3a4eb7 100644 --- a/keras_core/layers/normalization/batch_normalization.py +++ b/keras_core/layers/normalization/batch_normalization.py @@ -201,21 +201,21 @@ def call(self, inputs, training=None, mask=None): mean, variance = ops.moments( inputs, axes=self._reduction_axes, keepdims=True ) - outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon) - mean = ops.squeeze(mean, self._reduction_axes) - variance = ops.squeeze(variance, self._reduction_axes) moving_mean = ops.cast(self.moving_mean, inputs.dtype) moving_variance = ops.cast(self.moving_variance, inputs.dtype) self.moving_mean.assign( ops.cast( - moving_mean * self.momentum + mean * (1.0 - self.momentum), + moving_mean * self.momentum + + ops.squeeze(mean, self._reduction_axes) + * (1.0 - self.momentum), inputs.dtype, ) ) self.moving_variance.assign( ops.cast( moving_variance * self.momentum - + variance * (1.0 - self.momentum), + + ops.squeeze(variance, self._reduction_axes) + * (1.0 - self.momentum), inputs.dtype, ) ) @@ -224,17 +224,24 @@ def call(self, inputs, training=None, mask=None): moving_variance = ops.cast(self.moving_variance, inputs.dtype) moving_mean = ops.reshape(moving_mean, broadcast_shape) moving_variance = ops.reshape(moving_variance, broadcast_shape) - outputs = (inputs - moving_mean) / ops.sqrt( - moving_variance + self.epsilon - ) + mean = moving_mean + variance = moving_variance + + inv = ops.rsqrt(variance + self.epsilon) if self.scale: gamma = ops.reshape(self.gamma, broadcast_shape) - gamma = ops.cast(gamma, outputs.dtype) - outputs = outputs * gamma + gamma = ops.cast(gamma, inputs.dtype) + inv = inv * gamma + + res = -mean * inv if self.center: beta = ops.reshape(self.beta, broadcast_shape) - beta = ops.cast(beta, outputs.dtype) - outputs = outputs + beta + beta = ops.cast(beta, inputs.dtype) + res = res + beta + + # Note: Folding BatchNormalization depends on the precise order of ops + # that are generated by the expression below + outputs = inputs * inv + res return ops.cast(outputs, input_dtype) def get_config(self): diff --git a/keras_core/layers/normalization/group_normalization.py b/keras_core/layers/normalization/group_normalization.py index 94b56b05f..7931f3fb9 100644 --- a/keras_core/layers/normalization/group_normalization.py +++ b/keras_core/layers/normalization/group_normalization.py @@ -171,37 +171,26 @@ def _apply_normalization(self, reshaped_inputs, input_shape): axis = -2 if self.axis == -1 else self.axis - 1 group_reduction_axes.pop(axis) + broadcast_shape = self._create_broadcast_shape(input_shape) mean, variance = ops.moments( reshaped_inputs, axes=group_reduction_axes, keepdims=True ) - gamma, beta = self._get_reshaped_weights(input_shape) # Compute the batch normalization. - inv = 1 / ops.sqrt(variance + self.epsilon) - - if gamma is not None: - inv = ops.multiply(inv, gamma) - - if beta is not None: - x = beta - ops.multiply(mean, inv) - else: - x = -ops.multiply(mean, inv) - - normalized_inputs = reshaped_inputs * ops.cast( - inv, reshaped_inputs.dtype - ) + ops.cast(x, reshaped_inputs.dtype) - normalized_inputs = ops.cast(normalized_inputs, reshaped_inputs.dtype) - return normalized_inputs - - def _get_reshaped_weights(self, input_shape): - broadcast_shape = self._create_broadcast_shape(input_shape) - gamma = None - beta = None + inv = ops.rsqrt(variance + self.epsilon) if self.scale: gamma = ops.reshape(self.gamma, broadcast_shape) + gamma = ops.cast(gamma, reshaped_inputs.dtype) + inv = inv * gamma + + res = -mean * inv if self.center: beta = ops.reshape(self.beta, broadcast_shape) - return gamma, beta + beta = ops.cast(beta, reshaped_inputs.dtype) + res = res + beta + + normalized_inputs = reshaped_inputs * inv + res + return normalized_inputs def _create_broadcast_shape(self, input_shape): broadcast_shape = [1] * len(input_shape) diff --git a/keras_core/layers/normalization/group_normalization_test.py b/keras_core/layers/normalization/group_normalization_test.py index 59a65edd9..1780f62a7 100644 --- a/keras_core/layers/normalization/group_normalization_test.py +++ b/keras_core/layers/normalization/group_normalization_test.py @@ -41,6 +41,51 @@ def test_groupnorm(self): supports_masking=True, ) + def test_undefined_dim_error(self): + inputs = layers.Input(shape=(2, 2, 2, None)) + layer = layers.GroupNormalization() + with self.assertRaisesRegex( + ValueError, + ( + "input tensor should have a defined dimension but the layer " + "received an input with shape" + ), + ): + _ = layer(inputs) + + def test_groups_bigger_than_dim_error(self): + inputs = np.ones(shape=(2, 2, 2, 4)) + layer = layers.GroupNormalization(groups=5) + with self.assertRaisesRegex( + ValueError, + "cannot be more than the number of channels", + ): + _ = layer(inputs) + + def test_groups_not_a_multiple_of_dim_error(self): + inputs = np.ones(shape=(2, 2, 2, 4)) + layer = layers.GroupNormalization(groups=3) + with self.assertRaisesRegex( + ValueError, + "must be a multiple of the number of channels", + ): + _ = layer(inputs) + + def test_groups_instance_norm(self): + # GroupNormalization with groups=-1 will become InstanceNormalization + instance_norm_layer_1 = layers.GroupNormalization( + groups=-1, axis=-1, scale=False, center=False + ) + instance_norm_layer_2 = layers.GroupNormalization( + groups=4, axis=-1, scale=False, center=False + ) + inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]]) + + outputs_1 = instance_norm_layer_1(inputs) + outputs_2 = instance_norm_layer_2(inputs) + + self.assertAllClose(outputs_1, outputs_2) + def test_correctness_instance_norm(self): instance_norm_layer = layers.GroupNormalization( groups=4, axis=-1, scale=False, center=False diff --git a/keras_core/layers/normalization/layer_normalization.py b/keras_core/layers/normalization/layer_normalization.py index 7d8381c27..f4cb231d4 100644 --- a/keras_core/layers/normalization/layer_normalization.py +++ b/keras_core/layers/normalization/layer_normalization.py @@ -206,33 +206,29 @@ def _broadcast(v): if self.rms_scaling: # Calculate outputs with only variance and gamma if rms scaling # is enabled - # Calculate the variance along last axis (layer activations). + # Calculate the variance along self.axis (layer activations). variance = ops.var(inputs, axis=self.axis, keepdims=True) - inv = 1 / ops.sqrt(variance + self.epsilon) - outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma + inv = ops.rsqrt(variance + self.epsilon) + + outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype) else: - # Calculate the mean & variance along last axis (layer activations). + # Calculate the mean & variance along self.axis (layer activations). mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True) - inv = 1 / ops.sqrt(variance + self.epsilon) - scale, offset = _broadcast(self.gamma), _broadcast(self.beta) - if scale is not None: - scale = ops.cast(scale, inputs.dtype) - inv = inv * scale - x = -mean * inv - if offset is not None: - offset = ops.cast(offset, inputs.dtype) - x = offset + x - - outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast( - x, inputs.dtype - ) + gamma, beta = _broadcast(self.gamma), _broadcast(self.beta) + + inv = ops.rsqrt(variance + self.epsilon) + if gamma is not None: + gamma = ops.cast(gamma, inputs.dtype) + inv = inv * gamma - outputs = ops.cast(outputs, input_dtype) + res = -mean * inv + if beta is not None: + beta = ops.cast(beta, inputs.dtype) + res = res + beta - # If some components of the shape got lost due to adjustments, fix that. - outputs = ops.reshape(outputs, ops.shape(inputs)) + outputs = inputs * inv + res - return outputs + return ops.cast(outputs, input_dtype) def compute_output_shape(self, input_shape): return input_shape diff --git a/keras_core/layers/normalization/layer_normalization_test.py b/keras_core/layers/normalization/layer_normalization_test.py index c5d66f89f..94b039db0 100644 --- a/keras_core/layers/normalization/layer_normalization_test.py +++ b/keras_core/layers/normalization/layer_normalization_test.py @@ -83,6 +83,16 @@ def test_ln_basics(self): supports_masking=True, ) + def test_invalid_axis(self): + with self.assertRaisesRegex( + TypeError, + ( + "Expected an int or a list/tuple of ints for the argument " + "'axis'" + ), + ): + layers.LayerNormalization(axis={"axis": -1}) + def test_correctness(self): layer = layers.LayerNormalization(dtype="float32") layer.build(input_shape=(2, 2, 2)) diff --git a/keras_core/layers/normalization/spectral_normalization_test.py b/keras_core/layers/normalization/spectral_normalization_test.py index ae923e379..488beae24 100644 --- a/keras_core/layers/normalization/spectral_normalization_test.py +++ b/keras_core/layers/normalization/spectral_normalization_test.py @@ -20,6 +20,32 @@ def test_basic_spectralnorm(self): expected_num_losses=0, supports_masking=False, ) + self.run_layer_test( + layers.SpectralNormalization, + init_kwargs={"layer": layers.Embedding(10, 4)}, + input_data=np.random.randint(10, size=(10,)), + expected_output_shape=(10, 4), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + ) + + def test_invalid_power_iterations(self): + with self.assertRaisesRegex( + ValueError, "`power_iterations` should be greater than zero." + ): + layers.SpectralNormalization(layers.Dense(2), power_iterations=0) + + def test_invalid_layer(self): + layer = layers.SpectralNormalization(layers.ReLU()) + inputs = np.ones(shape=(4, 2)) + with self.assertRaisesRegex( + ValueError, "object has no attribute 'kernel' nor 'embeddings'" + ): + layer(inputs) def test_apply_layer(self): images = np.ones((1, 2, 2, 1)) diff --git a/keras_core/layers/normalization/unit_normalization.py b/keras_core/layers/normalization/unit_normalization.py index 7b44290fc..33553ada0 100644 --- a/keras_core/layers/normalization/unit_normalization.py +++ b/keras_core/layers/normalization/unit_normalization.py @@ -45,7 +45,7 @@ def call(self, inputs): x = ops.cast(inputs, self.compute_dtype) square_sum = ops.sum(ops.square(x), axis=self.axis, keepdims=True) - x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12)) + x_inv_norm = ops.rsqrt(ops.maximum(square_sum, 1e-12)) return ops.multiply(x, x_inv_norm) def compute_output_shape(self, input_shape): diff --git a/keras_core/layers/normalization/unit_normalization_test.py b/keras_core/layers/normalization/unit_normalization_test.py index 8a3e6b027..94235e855 100644 --- a/keras_core/layers/normalization/unit_normalization_test.py +++ b/keras_core/layers/normalization/unit_normalization_test.py @@ -29,6 +29,16 @@ def test_un_basics(self): supports_masking=True, ) + def test_invalid_axis(self): + with self.assertRaisesRegex( + TypeError, + ( + "Invalid value for `axis` argument: expected an int or a " + "list/tuple of ints." + ), + ): + layers.UnitNormalization(axis={"axis": -1}) + def test_correctness(self): layer = layers.UnitNormalization(axis=-1) inputs = np.random.normal(size=(2, 3)) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index dddb67103..9fdac5106 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -831,10 +831,6 @@ def test_istft( ref = ref[..., truncated_len:-truncated_len] self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5) - @pytest.mark.skipif( - backend.backend() == "numpy", - reason="Numpy does not support rsqrt.", - ) def test_rsqrt(self): x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x)) From 6ad426a2a7c2ced67680a736ddb86812c54abaa1 Mon Sep 17 00:00:00 2001 From: "Douglas K. G. Araujo" Date: Sat, 16 Sep 2023 18:27:48 +0200 Subject: [PATCH 09/20] fixes 550 (#853) --- keras_core/trainers/data_adapters/torch_data_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/trainers/data_adapters/torch_data_adapter.py b/keras_core/trainers/data_adapters/torch_data_adapter.py index 65e99fbb5..85fa6835b 100644 --- a/keras_core/trainers/data_adapters/torch_data_adapter.py +++ b/keras_core/trainers/data_adapters/torch_data_adapter.py @@ -22,7 +22,7 @@ def __init__(self, dataloader): def get_numpy_iterator(self): for batch in self._dataloader: - yield tuple(tree.map_structure(lambda x: x.numpy(), batch)) + yield tuple(tree.map_structure(lambda x: x.cpu().numpy(), batch)) def get_torch_dataloader(self): return self._dataloader From 22354534a0adfdf336201c69b59caddeb076e3af Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 16 Sep 2023 09:42:46 -0700 Subject: [PATCH 10/20] Ignore legacy files in code coverage report --- codecov.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codecov.yml b/codecov.yml index 63ad8099c..09e96b806 100644 --- a/codecov.yml +++ b/codecov.yml @@ -29,6 +29,8 @@ flag_management: - name: keras_core paths: - keras_core + ignore: + - keras_core/legacy - name: keras_core.applications paths: - keras_core/applications From 33dc7e43d2c2baa96aa166a76de3b7bdfa32d903 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 16 Sep 2023 10:19:36 -0700 Subject: [PATCH 11/20] Fix coverage file --- codecov.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codecov.yml b/codecov.yml index 09e96b806..566830771 100644 --- a/codecov.yml +++ b/codecov.yml @@ -9,6 +9,9 @@ coverage: default: target:auto +ignore: + - keras_core/legacy + comment: layout: "header, reach, diff, flags, files" behavior: default @@ -29,8 +32,6 @@ flag_management: - name: keras_core paths: - keras_core - ignore: - - keras_core/legacy - name: keras_core.applications paths: - keras_core/applications From 4e272441464cbdb623f6ff305039ae847ef4162e Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 16 Sep 2023 10:42:18 -0700 Subject: [PATCH 12/20] Codecov fiddling --- codecov.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codecov.yml b/codecov.yml index 566830771..f5b103593 100644 --- a/codecov.yml +++ b/codecov.yml @@ -10,7 +10,7 @@ coverage: target:auto ignore: - - keras_core/legacy + - keras-core/keras_core/legacy comment: layout: "header, reach, diff, flags, files" From aa394eed3f4a92d41cb5a0a3bd76c3dabf3ace1a Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 16 Sep 2023 10:48:51 -0700 Subject: [PATCH 13/20] Remove ignored path in codecov --- codecov.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/codecov.yml b/codecov.yml index f5b103593..63ad8099c 100644 --- a/codecov.yml +++ b/codecov.yml @@ -9,9 +9,6 @@ coverage: default: target:auto -ignore: - - keras-core/keras_core/legacy - comment: layout: "header, reach, diff, flags, files" behavior: default From 233e22194bfc2ab5f235209e5a9933c4273b2853 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 16 Sep 2023 10:58:09 -0700 Subject: [PATCH 14/20] Remove codecov badge --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 1dc5a65d7..7d8bb0e5c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@ [![](https://github.com/keras-team/keras-core/workflows/Tests/badge.svg?branch=main)](https://github.com/keras-team/keras-core/actions?query=workflow%3ATests+branch%3Amain) -[![](https://codecov.io/gh/keras-team/keras-core/branch/main/graph/badge.svg)](https://codecov.io/gh/keras-team/keras-core) [![](https://badge.fury.io/py/keras-core.svg)](https://badge.fury.io/py/keras-core) # Keras Core: A new multi-backend Keras From f25230ce5bf3b00c782e6b4d199cf47541637d3a Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 16 Sep 2023 18:16:23 -0700 Subject: [PATCH 15/20] Improve export/loading error messages, export TFSMLayer and move TorchModuleWrapper --- keras_core/export/export_lib.py | 19 +++++++++++++------ keras_core/saving/saving_api.py | 10 ++++++++-- keras_core/utils/torch_utils.py | 4 ++-- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/keras_core/export/export_lib.py b/keras_core/export/export_lib.py index e7d163210..4fa251214 100644 --- a/keras_core/export/export_lib.py +++ b/keras_core/export/export_lib.py @@ -491,6 +491,7 @@ def _get_save_spec(model): return specs +@keras_core_export("keras_core.layers.TFSMLayer") class TFSMLayer(Layer): """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. @@ -557,9 +558,13 @@ def __init__( self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint] else: raise ValueError( - f"The endpoint '{call_endpoint}' is neither an " - "attribute of the reloaded SavedModel, nor an entry " - "in the `signatures` field of the reloaded SavedModel. " + f"The endpoint '{call_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. Select another endpoint via " + "the `call_endpoint` argument. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" ) # Resolving the training function. @@ -574,10 +579,12 @@ def __init__( ] else: raise ValueError( - f"The endpoint '{call_training_endpoint}' is " - "neither an attribute of the reloaded SavedModel, " + f"The endpoint '{call_training_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " "nor an entry in the `signatures` field of " - "the reloaded SavedModel. " + "the reloaded SavedModel. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" ) # Add trainable and non-trainable weights from the call_endpoint_fn. diff --git a/keras_core/saving/saving_api.py b/keras_core/saving/saving_api.py index 6a1e3569e..a4c910590 100644 --- a/keras_core/saving/saving_api.py +++ b/keras_core/saving/saving_api.py @@ -194,7 +194,13 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): "Keras Core only supports V3 `.keras` files and " "legacy H5 format files (`.h5` extension). " "Note that the legacy SavedModel format is not " - "supported in Keras Core." + "supported by `load_model()` in Keras Core. In " + "order to reload a TensorFlow SavedModel as an " + "inference-only layer in Keras Core, use " + "`keras_core.layers.TFSMLayer(" + f"{filepath}, call_endpoint='serving_default')` " + "(note that your `call_endpoint` " + "might have a different name)." ) @@ -232,5 +238,5 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs): raise ValueError( f"File format not supported: filepath={filepath}. " "Keras Core only supports V3 `.keras` and `.weights.h5` " - "files." + "files, or legacy V1/V2 `.h5` files." ) diff --git a/keras_core/utils/torch_utils.py b/keras_core/utils/torch_utils.py index cb92fa9c8..3adcf7314 100644 --- a/keras_core/utils/torch_utils.py +++ b/keras_core/utils/torch_utils.py @@ -2,7 +2,7 @@ from keras_core.layers import Layer -@keras_core_export("keras_core.utils.TorchModuleWrapper") +@keras_core_export("keras_core.layers.TorchModuleWrapper") class TorchModuleWrapper(Layer): """Torch module wrapper layer. @@ -27,7 +27,7 @@ class TorchModuleWrapper(Layer): import torch.nn.functional as F import keras_core - from keras_core.backend.torch import TorchModuleWrapper + from keras_core.layers import TorchModuleWrapper class Classifier(keras_core.Model): def __init__(self, *args, **kwargs): From a465816da34c9ceebe3baa87f0e973b4bb751f86 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 16 Sep 2023 18:37:30 -0700 Subject: [PATCH 16/20] Convert deit example to KC+TF --- examples/keras_io/tensorflow/vision/deit.py | 625 ++++++++++++++++++++ 1 file changed, 625 insertions(+) create mode 100644 examples/keras_io/tensorflow/vision/deit.py diff --git a/examples/keras_io/tensorflow/vision/deit.py b/examples/keras_io/tensorflow/vision/deit.py new file mode 100644 index 000000000..09b988c6e --- /dev/null +++ b/examples/keras_io/tensorflow/vision/deit.py @@ -0,0 +1,625 @@ +""" +Title: Distilling Vision Transformers +Author: [Sayak Paul](https://twitter.com/RisingSayak) +Date created: 2022/04/05 +Last modified: 2023/09/16 +Description: Distillation of Vision Transformers through attention. +Accelerator: GPU +""" +""" +## Introduction + +In the original *Vision Transformers* (ViT) paper +([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)), +the authors concluded that to perform on par with Convolutional Neural Networks (CNNs), +ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly +due to the lack of inductive biases in the ViT architecture -- unlike CNNs, +they don't have layers that exploit locality. In a follow-up paper +([Steiner et al.](https://arxiv.org/abs/2106.10270)), +the authors show that it is possible to substantially improve the performance of ViTs +with stronger regularization and longer training. + +Many groups have proposed different ways to deal with the problem +of data-intensiveness of ViT training. +One such way was shown in the *Data-efficient image Transformers*, +(DeiT) paper ([Touvron et al.](https://arxiv.org/abs/2012.12877)). The +authors introduced a distillation technique that is specific to transformer-based vision +models. DeiT is among the first works to show that it's possible to train ViTs well +without using larger datasets. + +In this example, we implement the distillation recipe proposed in DeiT. This +requires us to slightly tweak the original ViT architecture and write a custom training +loop to implement the distillation recipe. + +To comfortably navigate through this example, you'll be expected to know how a ViT and +knowledge distillation work. The following are good resources in case you needed a +refresher: + +* [ViT on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer) +* [Knowledge distillation on keras.io](https://keras.io/examples/vision/knowledge_distillation/) +""" + +""" +## Imports +""" + +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" + +import tensorflow as tf +import tensorflow_datasets as tfds +import keras_core as keras +from keras_core import layers + +tfds.disable_progress_bar() +keras.utils.set_random_seed(42) + +""" +## Constants +""" + +# Model +MODEL_TYPE = "deit_distilled_tiny_patch16_224" +RESOLUTION = 224 +PATCH_SIZE = 16 +NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2 +LAYER_NORM_EPS = 1e-6 +PROJECTION_DIM = 192 +NUM_HEADS = 3 +NUM_LAYERS = 12 +MLP_UNITS = [ + PROJECTION_DIM * 4, + PROJECTION_DIM, +] +DROPOUT_RATE = 0.0 +DROP_PATH_RATE = 0.1 + +# Training +NUM_EPOCHS = 20 +BASE_LR = 0.0005 +WEIGHT_DECAY = 0.0001 + +# Data +BATCH_SIZE = 256 +AUTO = tf.data.AUTOTUNE +NUM_CLASSES = 5 + +""" +You probably noticed that `DROPOUT_RATE` has been set 0.0. Dropout has been used +in the implementation to keep it complete. For smaller models (like the one used in +this example), you don't need it, but for bigger models, using dropout helps. +""" + +""" +## Load the `tf_flowers` dataset and prepare preprocessing utilities + +The authors use an array of different augmentation techniques, including MixUp +([Zhang et al.](https://arxiv.org/abs/1710.09412)), +RandAugment ([Cubuk et al.](https://arxiv.org/abs/1909.13719)), +and so on. However, to keep the example simple to work through, we'll discard them. +""" + + +def preprocess_dataset(is_training=True): + def fn(image, label): + if is_training: + # Resize to a bigger spatial resolution and take the random + # crops. + image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20)) + image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3)) + image = tf.image.random_flip_left_right(image) + else: + image = tf.image.resize(image, (RESOLUTION, RESOLUTION)) + label = tf.one_hot(label, depth=NUM_CLASSES) + return image, label + + return fn + + +def prepare_dataset(dataset, is_training=True): + if is_training: + dataset = dataset.shuffle(BATCH_SIZE * 10) + dataset = dataset.map( + preprocess_dataset(is_training), num_parallel_calls=AUTO + ) + return dataset.batch(BATCH_SIZE).prefetch(AUTO) + + +train_dataset, val_dataset = tfds.load( + "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True +) +num_train = train_dataset.cardinality() +num_val = val_dataset.cardinality() +print(f"Number of training examples: {num_train}") +print(f"Number of validation examples: {num_val}") + +train_dataset = prepare_dataset(train_dataset, is_training=True) +val_dataset = prepare_dataset(val_dataset, is_training=False) + +""" +## Implementing the DeiT variants of ViT + +Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend +it to support DeiT's components. + +First, we'll implement a layer for Stochastic Depth +([Huang et al.](https://arxiv.org/abs/1603.09382)) +which is used in DeiT for regularization. +""" + + +# Referred from: github.com:rwightman/pytorch-image-models. +class StochasticDepth(layers.Layer): + def __init__(self, drop_prop, **kwargs): + super().__init__(**kwargs) + self.drop_prob = drop_prop + + def call(self, x, training=True): + if training: + keep_prob = 1 - self.drop_prob + shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +""" +Now, we'll implement the MLP and Transformer blocks. +""" + + +def mlp(x, dropout_rate: float, hidden_units): + """FFN for a Transformer block.""" + # Iterate over the hidden units and + # add Dense => Dropout. + for idx, units in enumerate(hidden_units): + x = layers.Dense( + units, + activation=tf.nn.gelu if idx == 0 else None, + )(x) + x = layers.Dropout(dropout_rate)(x) + return x + + +def transformer(drop_prob: float, name: str) -> keras.Model: + """Transformer block with pre-norm.""" + num_patches = ( + NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1 + ) + encoded_patches = layers.Input((num_patches, PROJECTION_DIM)) + + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches) + + # Multi Head Self Attention layer 1. + attention_output = layers.MultiHeadAttention( + num_heads=NUM_HEADS, + key_dim=PROJECTION_DIM, + dropout=DROPOUT_RATE, + )(x1, x1) + attention_output = ( + StochasticDepth(drop_prob)(attention_output) + if drop_prob + else attention_output + ) + + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2) + + # MLP layer 1. + x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE) + x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4 + + # Skip connection 2. + outputs = layers.Add()([x2, x4]) + + return keras.Model(encoded_patches, outputs, name=name) + + +""" +We'll now implement a `ViTClassifier` class building on top of the components we just +developed. Here we'll be following the original pooling strategy used in the ViT paper -- +use a class token and use the feature representations corresponding to it for +classification. +""" + + +class ViTClassifier(keras.Model): + """Vision Transformer base class.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Patchify + linear projection + reshaping. + self.projection = keras.Sequential( + [ + layers.Conv2D( + filters=PROJECTION_DIM, + kernel_size=(PATCH_SIZE, PATCH_SIZE), + strides=(PATCH_SIZE, PATCH_SIZE), + padding="VALID", + name="conv_projection", + ), + layers.Reshape( + target_shape=(NUM_PATCHES, PROJECTION_DIM), + name="flatten_projection", + ), + ], + name="projection", + ) + + # Positional embedding. + init_shape = ( + 1, + NUM_PATCHES + 1, + PROJECTION_DIM, + ) + self.positional_embedding = tf.Variable( + tf.zeros(init_shape), name="pos_embedding" + ) + + # Transformer blocks. + dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)] + self.transformer_blocks = [ + transformer(drop_prob=dpr[i], name=f"transformer_block_{i}") + for i in range(NUM_LAYERS) + ] + + # CLS token. + initial_value = tf.zeros((1, 1, PROJECTION_DIM)) + self.cls_token = tf.Variable( + initial_value=initial_value, trainable=True, name="cls" + ) + + # Other layers. + self.dropout = layers.Dropout(DROPOUT_RATE) + self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS) + self.head = layers.Dense( + NUM_CLASSES, + name="classification_head", + ) + + def call(self, inputs): + n = tf.shape(inputs)[0] + + # Create patches and project the patches. + projected_patches = self.projection(inputs) + + # Append class token if needed. + cls_token = tf.tile(self.cls_token, (n, 1, 1)) + cls_token = tf.cast(cls_token, projected_patches.dtype) + projected_patches = tf.concat([cls_token, projected_patches], axis=1) + + # Add positional embeddings to the projected patches. + encoded_patches = ( + self.positional_embedding + projected_patches + ) # (B, number_patches, projection_dim) + encoded_patches = self.dropout(encoded_patches) + + # Iterate over the number of layers and stack up blocks of + # Transformer. + for transformer_module in self.transformer_blocks: + # Add a Transformer block. + encoded_patches = transformer_module(encoded_patches) + + # Final layer normalization. + representation = self.layer_norm(encoded_patches) + + # Pool representation. + encoded_patches = representation[:, 0] + + # Classification head. + output = self.head(encoded_patches) + return output + + +""" +This class can be used standalone as ViT and is end-to-end trainable. Just remove the +`distilled` phrase in `MODEL_TYPE` and it should work with `vit_tiny = ViTClassifier()`. +Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken +from the DeiT paper): + +![](https://i.imgur.com/5lmg2Xs.png) + +Apart from the class token, DeiT has another token for distillation. During distillation, +the logits corresponding to the class token are compared to the true labels, and the +logits corresponding to the distillation token are compared to the teacher's predictions. +""" + + +class ViTDistilled(ViTClassifier): + def __init__(self, regular_training=False, **kwargs): + super().__init__(**kwargs) + self.num_tokens = 2 + self.regular_training = regular_training + + # CLS and distillation tokens, positional embedding. + init_value = tf.zeros((1, 1, PROJECTION_DIM)) + self.dist_token = tf.Variable(init_value, name="dist_token") + self.positional_embedding = tf.Variable( + tf.zeros( + ( + 1, + NUM_PATCHES + self.num_tokens, + PROJECTION_DIM, + ) + ), + name="pos_embedding", + ) + + # Head layers. + self.head = layers.Dense( + NUM_CLASSES, + name="classification_head", + ) + self.head_dist = layers.Dense( + NUM_CLASSES, + name="distillation_head", + ) + + def call(self, inputs, training=False): + n = tf.shape(inputs)[0] + + # Create patches and project the patches. + projected_patches = self.projection(inputs) + + # Append the tokens. + cls_token = tf.tile(self.cls_token, (n, 1, 1)) + dist_token = tf.tile(self.dist_token, (n, 1, 1)) + cls_token = tf.cast(cls_token, projected_patches.dtype) + dist_token = tf.cast(dist_token, projected_patches.dtype) + projected_patches = tf.concat( + [cls_token, dist_token, projected_patches], axis=1 + ) + + # Add positional embeddings to the projected patches. + encoded_patches = ( + self.positional_embedding + projected_patches + ) # (B, number_patches, projection_dim) + encoded_patches = self.dropout(encoded_patches) + + # Iterate over the number of layers and stack up blocks of + # Transformer. + for transformer_module in self.transformer_blocks: + # Add a Transformer block. + encoded_patches = transformer_module(encoded_patches) + + # Final layer normalization. + representation = self.layer_norm(encoded_patches) + + # Classification heads. + x, x_dist = ( + self.head(representation[:, 0]), + self.head_dist(representation[:, 1]), + ) + + if not training or self.regular_training: + # During standard train / finetune, inference average the classifier + # predictions. + return (x + x_dist) / 2 + + elif training: + # Only return separate classification predictions when training in distilled + # mode. + return x, x_dist + + +""" +Let's verify if the `ViTDistilled` class can be initialized and called as expected. +""" + +deit_tiny_distilled = ViTDistilled() + +dummy_inputs = tf.ones((2, 224, 224, 3)) +outputs = deit_tiny_distilled(dummy_inputs, training=False) +print(f"output_shape: {outputs.shape}") + +""" +## Implementing the trainer + +Unlike what happens in standard knowledge distillation +([Hinton et al.](https://arxiv.org/abs/1503.02531)), +where a temperature-scaled softmax is used as well as KL divergence, +DeiT authors use the following loss function: + +![](https://i.imgur.com/bXdxsBq.png) + + +Here, + +* CE is cross-entropy +* `psi` is the softmax function +* Z_s denotes student predictions +* y denotes true labels +* y_t denotes teacher predictions +""" + + +class DeiT(keras.Model): + # Reference: + # https://keras.io/examples/vision/knowledge_distillation/ + def __init__(self, student, teacher, **kwargs): + super().__init__(**kwargs) + self.student = student + self.teacher = teacher + self.student_loss_tracker = keras.metrics.Mean(name="student_loss") + self.distillation_loss_tracker = keras.metrics.Mean( + name="distillation_loss" + ) + self.accuracy = keras.metrics.CategoricalAccuracy(name="accuracy") + + @property + def metrics(self): + return [ + self.accuracy, + self.student_loss_tracker, + self.distillation_loss_tracker, + ] + + def compile( + self, + optimizer, + student_loss_fn, + distillation_loss_fn, + run_eagerly=False, + jit_compile=False, + ): + super().compile( + optimizer=optimizer, + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + self.student_loss_fn = student_loss_fn + self.distillation_loss_fn = distillation_loss_fn + + def train_step(self, data): + # Unpack data. + x, y = data + + # Forward pass of teacher + teacher_predictions = self.teacher(x)["dense"] + teacher_predictions = tf.nn.softmax(teacher_predictions, axis=-1) + + with tf.GradientTape() as tape: + # Forward pass of student. + cls_predictions, dist_predictions = self.student( + x / 255.0, training=True + ) + + # Compute losses. + student_loss = self.student_loss_fn(y, cls_predictions) + distillation_loss = self.distillation_loss_fn( + teacher_predictions, dist_predictions + ) + loss = (student_loss + distillation_loss) / 2 + + # Compute gradients. + trainable_vars = self.student.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + + # Update weights. + self.optimizer.apply_gradients(zip(gradients, trainable_vars)) + + # Update the metrics configured in `compile()`. + student_predictions = (cls_predictions + dist_predictions) / 2 + self.student_loss_tracker.update_state(student_loss) + self.distillation_loss_tracker.update_state(distillation_loss) + self.accuracy.update_state(y, student_predictions) + + # Return a dict of performance. + return {m.name: m.result() for m in self.metrics} + + def test_step(self, data): + # Unpack the data. + x, y = data + + # Compute predictions. + y_prediction = self.student(x / 255.0) + + # Calculate the loss. + student_loss = self.student_loss_fn(y, y_prediction) + + # Update the metrics. + self.student_loss_tracker.update_state(student_loss) + self.accuracy.update_state(y, y_prediction) + + # Return a dict of performance. + results = {m.name: m.result() for m in self.metrics} + return results + + def call(self, inputs): + return self.student(inputs / 255.0) + + +""" +## Load the teacher model + +This model is based on the BiT family of ResNets +([Kolesnikov et al.](https://arxiv.org/abs/1912.11370)) +fine-tuned on the `tf_flowers` dataset. You can refer to +[this notebook](https://github.com/sayakpaul/deit-tf/blob/main/notebooks/bit-teacher.ipynb) +to know how the training was performed. The teacher model has about 212 Million parameters +which is about **40x more** than the student. +""" + +"""shell +wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip +unzip -q bit_teacher_flowers.zip +""" + +bit_teacher_flowers = keras.layers.TFSMLayer( + filepath="bit_teacher_flowers", + call_endpoint="serving_default", +) + + +""" +## Training through distillation +""" + +deit_tiny = ViTDistilled() +deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers) + +lr_scaled = (BASE_LR / 512) * BATCH_SIZE +deit_distiller.compile( + optimizer=keras.optimizers.AdamW( + weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled + ), + student_loss_fn=keras.losses.CategoricalCrossentropy( + from_logits=True, label_smoothing=0.1 + ), + distillation_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True), +) +_ = deit_distiller.fit( + train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS +) + +""" +If we had trained the same model (the `ViTClassifier`) from scratch with the exact same +hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code +to reproduce this result: + +``` +vit_tiny = ViTClassifier() + +inputs = keras.Input((RESOLUTION, RESOLUTION, 3)) +x = keras.layers.Rescaling(scale=1./255)(inputs) +outputs = deit_tiny(x) +model = keras.Model(inputs, outputs) + +model.compile(...) +model.fit(...) +``` +""" + +""" +## Notes + +* Through the use of distillation, we're effectively transferring the inductive biases of +a CNN-based teacher model. +* Interestingly enough, this distillation strategy works better with a CNN as the teacher +model rather than a Transformer as shown in the paper. +* The use of regularization to train DeiT models is very important. +* ViT models are initialized with a combination of different initializers including +truncated normal, random normal, Glorot uniform, etc. If you're looking for +end-to-end reproduction of the original results, don't forget to initialize the ViTs well. +* If you want to explore the pre-trained DeiT models in TensorFlow and Keras with code +for fine-tuning, [check out these models on TF-Hub](https://tfhub.dev/sayakpaul/collections/deit/1). + +## Acknowledgements + +* Ross Wightman for keeping +[`timm`](https://github.com/rwightman/pytorch-image-models) +updated with readable implementations. I referred to the implementations of ViT and DeiT +a lot during implementing them in TensorFlow. +* [Aritra Roy Gosthipaty](https://github.com/ariG23498) +who implemented some portions of the `ViTClassifier` in another project. +* [Google Developers Experts](https://developers.google.com/programs/experts/) +program for supporting me with GCP credits which were used to run experiments for this +example. +""" From 176d12060f4031584c7072b328f4759ba779af93 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Mon, 18 Sep 2023 14:40:01 -0700 Subject: [PATCH 17/20] Make jax2tf import conditional (#911) --- keras_core/export/export_lib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_core/export/export_lib.py b/keras_core/export/export_lib.py index 4fa251214..5dae2602c 100644 --- a/keras_core/export/export_lib.py +++ b/keras_core/export/export_lib.py @@ -1,7 +1,5 @@ """Library for exporting inference-only Keras models/layers.""" -from jax.experimental import jax2tf - from keras_core import backend from keras_core.api_export import keras_core_export from keras_core.layers import Layer @@ -441,6 +439,8 @@ def _filter_and_track_resources(self): self._tf_trackable._misc_assets.append(trackable) def _convert_jax2tf_function(self, fn, input_signature): + from jax.experimental import jax2tf + shapes = [] for spec in input_signature: shapes.append(self._spec_to_poly_shape(spec)) From 9d39e9a9999a7ae6ad4c5469656c0db4961c792e Mon Sep 17 00:00:00 2001 From: Qianli Scott Zhu Date: Mon, 18 Sep 2023 14:40:25 -0700 Subject: [PATCH 18/20] Separate the metrics variables from non-trainable variables. (#910) * Seprate the metrics variables from non-trainable variables. * Fix unit test for jax trainer. --- keras_core/layers/layer.py | 20 ++++++++++++--- keras_core/layers/layer_test.py | 40 +++++++++++++++++++++++++++++ keras_core/trainers/trainer_test.py | 13 +++++----- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index e556abedf..6d3c487b1 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -514,10 +514,13 @@ def trainable(self, value): @property def variables(self): - """List of all layer state, including metric variables and random seeds. + """List of all layer state, including random seeds. This extends `layer.weights` to include all state used by the layer - including state for metrics and `SeedGenerator`s. + including `SeedGenerator`s. + + Note that metrics variables are not included here, use + `metrics_variables` to visit all the metric variables. """ # Return all `Variables` associate with the layer including metrics # and random seeds. Also deduplicate them. @@ -527,8 +530,6 @@ def variables(self): if id(v) not in seen_ids: variables.append(v) seen_ids.add(id(v)) - for m in self._metrics: - variables.extend(m.variables) for sg in self._seed_generators: variables.append(sg.state) for layer in self._layers: @@ -602,6 +603,17 @@ def non_trainable_weights(self): return self.weights return [v for v in self.weights if not v.trainable] + @property + def metrics_variables(self): + """List of all metric variables.""" + vars = [] + for metric in self._metrics: + vars.extend(metric.variables) + for layer in self._layers: + for metric in layer._metrics: + vars.extend(metric.variables) + return vars + def get_weights(self): """Return the values of `layer.weights` as a list of NumPy arrays.""" return [v.numpy() for v in self.weights] diff --git a/keras_core/layers/layer_test.py b/keras_core/layers/layer_test.py index 1243a9f75..72429b1ef 100644 --- a/keras_core/layers/layer_test.py +++ b/keras_core/layers/layer_test.py @@ -3,6 +3,7 @@ from keras_core import backend from keras_core import layers +from keras_core import metrics from keras_core import models from keras_core import ops from keras_core import testing @@ -214,6 +215,45 @@ def call(self, x): self.assertLen(layer.inner_layer.weights, 8) self.assertLen(layer.weights, 8) + def test_metric_tracking(self): + class LayerWithMetric(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense = layers.Dense(units) + self.metric = metrics.MeanSquaredError(name="my_metric") + + def build(self, input_shape): + self.dense.build(input_shape) + + def call(self, x): + return self.dense(x) + + class NestedLayerWithMetric(layers.Layer): + def __init__(self, units): + super().__init__() + self.layer_with_metric = LayerWithMetric(units) + self.metric = metrics.MeanSquaredError(name="my_metric") + + def build(self, input_shape): + self.layer_with_metric.build(input_shape) + + def call(self, x): + return self.layer_with_metric(x) + + layer = LayerWithMetric(3) + layer.build((1, 3)) + + self.assertLen(layer.metrics_variables, 2) + self.assertLen(layer.trainable_variables, 2) + self.assertLen(layer.non_trainable_variables, 0) + + layer = NestedLayerWithMetric(3) + layer.build((1, 3)) + + self.assertLen(layer.metrics_variables, 4) + self.assertLen(layer.trainable_variables, 2) + self.assertLen(layer.non_trainable_variables, 0) + def test_build_on_call(self): class LayerWithUnbuiltState(layers.Layer): def __init__(self, units): diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 5028b7b28..6d2a2de43 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -29,7 +29,7 @@ # A model is just a layer mixed in with a Trainer. -class ExampleModel(layers.Dense, Trainer): +class ExampleModel(Trainer, layers.Dense): def __init__(self, units): layers.Dense.__init__( self, @@ -40,7 +40,7 @@ def __init__(self, units): Trainer.__init__(self) -class StructModel(layers.Layer, Trainer): +class StructModel(Trainer, layers.Layer): def __init__(self, units): layers.Layer.__init__(self) Trainer.__init__(self) @@ -62,7 +62,7 @@ def call(self, x): } -class ListModel(layers.Layer, Trainer): +class ListModel(Trainer, layers.Layer): def __init__(self, units): layers.Layer.__init__(self) Trainer.__init__(self) @@ -82,7 +82,7 @@ def call(self, x): return self.dense_1(x[0]) + self.dense_2(x[1]) -class TrainingTestingLayer(layers.Layer, Trainer): +class TrainingTestingLayer(Trainer, layers.Layer): def __init__(self, **kwargs): layers.Layer.__init__(self, **kwargs) Trainer.__init__(self) @@ -96,7 +96,7 @@ def call(self, x, training=False): class TestTrainer(testing.TestCase, parameterized.TestCase): @pytest.mark.requires_trainable_backend def test_metric_tracking(self): - class ModelWithMetric(layers.Dense, Trainer): + class ModelWithMetric(Trainer, layers.Dense): def __init__(self, units): layers.Dense.__init__( self, @@ -132,6 +132,7 @@ def __init__(self, units): # And those weights are tracked at the model level self.assertEqual(len(model.metrics_variables), 6) + self.assertLen(model.non_trainable_variables, 0) # Models with only weighted_metrics should have the same 3 metrics model_weighted = ModelWithMetric(units=3) @@ -361,7 +362,7 @@ def test_adds_loss_scaling_optimizer(self): reason="half precision unsupported on torch CPU.", ) def test_loss_scaling_prevents_underflow(self): - class DeepModel(layers.Layer, Trainer): + class DeepModel(Trainer, layers.Layer): def __init__(self): layers.Layer.__init__(self, dtype="mixed_float16") Trainer.__init__(self) From 2dfe47537ed23086c8d017e60800acf6bcc1ab07 Mon Sep 17 00:00:00 2001 From: Qianli Scott Zhu Date: Mon, 18 Sep 2023 15:06:48 -0700 Subject: [PATCH 19/20] Update jax trainer function to save memory buffer. (#897) * Update jax trainer function to save memory buffer. * Address format issu. --- keras_core/backend/jax/trainer.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index 77c59eb36..7b605f40d 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -1,3 +1,5 @@ +from functools import partial + import jax import numpy as np import tree @@ -237,8 +239,11 @@ def multi_train_steps(state, data): train_step = one_train_step if not self.run_eagerly and self.jit_compile: - - @jax.jit + # Note that we mark the state and data to be donated to jax, + # so that jax will reuse the memory buffer for outputs. + # This will reduce the memory usage of the training function by + # half. + @partial(jax.jit, donate_argnames="state") def compiled_train_step(state, data): return train_step(state, data) @@ -266,8 +271,11 @@ def multi_test_steps(state, data): test_step = one_test_step if not self.run_eagerly and self.jit_compile: - - @jax.jit + # Note that we mark the state and data to be donated to jax, + # so that jax will reuse the memory buffer for outputs. + # This will reduce the memory usage of the training function by + # half. + @partial(jax.jit, donate_argnames="state") def compiled_test_step(state, data): return test_step(state, data) @@ -578,15 +586,18 @@ def evaluate( ) data = self._distribute_data(data) logs, state = self.test_function(state, data) - # Note that trainable variables are not returned since they're - # immutable here. - _, non_trainable_variables, metrics_variables = state + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state # Setting _jax_state enables callbacks to force a state sync # if they need to. self._jax_state = { # I wouldn't recommend modifying non-trainable model state # during evaluate(), but it's allowed. + "trainable_variables": trainable_variables, "non_trainable_variables": non_trainable_variables, "metrics_variables": metrics_variables, } @@ -764,8 +775,9 @@ def test_on_batch( logs, state = self.test_function(state, [data]) # State sync - _, non_trainable_variables, metrics_variables = state + trainable_variables, non_trainable_variables, metrics_variables = state self._jax_state = { + "trainable_variables": trainable_variables, "non_trainable_variables": non_trainable_variables, "metrics_variables": metrics_variables, } From f38e2e6014aba7bae8511de3865fea216ec112ed Mon Sep 17 00:00:00 2001 From: kiukchung <43595115+kiukchung@users.noreply.github.com> Date: Mon, 18 Sep 2023 22:25:56 +0000 Subject: [PATCH 20/20] [GHA] fix several codecov issues (#912) * [ci/cd] exclude keras_core/legacy from coverage.run|report, add backend to codecov flags, move project config to pyproject.toml * GHA codecov upload, fail ci if error --- .github/workflows/actions.yml | 6 ++++-- pyproject.toml | 36 +++++++++++++++++++++++++++++++++++ setup.cfg | 31 ------------------------------ 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index e08885550..185c0aaee 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -59,8 +59,9 @@ jobs: uses: codecov/codecov-action@v3 with: env_vars: PYTHON,KERAS_BACKEND - flags: keras_core.applications + flags: keras_core.applications,keras_core.applications-${{ matrix.backend }} files: apps-coverage.xml + fail_ci_if_error: true - name: Test with pytest run: | pytest keras_core --ignore keras_core/applications --cov=keras_core @@ -69,8 +70,9 @@ jobs: uses: codecov/codecov-action@v3 with: env_vars: PYTHON,KERAS_BACKEND - flags: keras_core + flags: keras_core,keras_core-${{ matrix.backend }} files: core-coverage.xml + fail_ci_if_error: true format: name: Check the code format diff --git a/pyproject.toml b/pyproject.toml index e1350f59b..e016bb363 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,3 +17,39 @@ known_first_party = ["keras_core", "tests"] default_section = "THIRDPARTY" line_length = 80 extend_skip_glob=["examples/*", "guides/*"] + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + "ignore::DeprecationWarning", + "ignore::ImportWarning", + "ignore::RuntimeWarning", + "ignore::PendingDeprecationWarning", + "ignore::FutureWarning", + "ignore::UserWarning", + # Ignore a spurious warning on tf-nightly related to save model changes. + "ignore:Custom mask layers require a config", +] +addopts = "-vv" + +# Do not run tests in the `build` folders +norecursedirs = ["build"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "@abstract", + "raise NotImplementedError", +] +omit = [ + "*/*_test.py", + "keras_core/legacy/*", +] + +[tool.coverage.run] +branch = true +omit = [ + "*/*_test.py", + "keras_core/legacy/*", +] + diff --git a/setup.cfg b/setup.cfg index 566293607..df2bc8fe3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,34 +1,3 @@ -[tool:pytest] -filterwarnings = - error - ignore::DeprecationWarning - ignore::ImportWarning - ignore::RuntimeWarning - ignore::PendingDeprecationWarning - ignore::FutureWarning - ignore::UserWarning - # Ignore a spurious warning on tf-nightly related to save model changes. - ignore:Custom mask layers require a config - -addopts=-vv - -# Do not run tests in the `build` folders -norecursedirs = build - -[coverage:report] -exclude_lines = - pragma: no cover - @abstract - raise NotImplementedError -omit = - */*_test.py - -[coverage:run] -omit = - */*_test.py - -branch = True - [flake8] ignore = # Conflicts with black