From 0b2dae0053605c6e76ac83b17cc1f36211248146 Mon Sep 17 00:00:00 2001 From: dogukankeceli <73390037+dogukankeceli@users.noreply.github.com> Date: Sun, 25 Oct 2020 03:39:46 +0300 Subject: [PATCH] =?UTF-8?q?Create=20Derin=20=C3=B6=C4=9Frenme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- "Derin \303\266\304\237renme" | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 "Derin \303\266\304\237renme" diff --git "a/Derin \303\266\304\237renme" "b/Derin \303\266\304\237renme" new file mode 100644 index 0000000..268f426 --- /dev/null +++ "b/Derin \303\266\304\237renme" @@ -0,0 +1,24 @@ +import my_function +@pytest.mark.unit +@pytest.mark.parametrize('use_tf_function', [True, False]) +def test_my_function(use_tf_function): + + def test(arg1, arg2): + # REQUEST: Check to make sure this is not being retraced many times. + return my_function(arg1, arg2) + + # Sometimes test in Eager mode for debugging, sometimes test in graph mode. + test_func = test + if use_tf_function: + test_func = tf.function(test_func) + + #####Test 1 + arg1, arg2 = #some setup stuff + test_func(arg1, arg2) # create the graph (tracing happens). + results = test_func(arg1, arg2) # hopefully tracing does not happen a second time. + assert results + + #####Test 2 + arg4, arg3 = #some setup stuff + results = test_func(arg3, arg4) # hopefully tracing does not happen again (if the inputs are tensors and the shapes do not change) + assert results