Skip to content

Commit

Permalink
[GLCC]Part-4: fix testcases for dy2st PyLayer (PaddlePaddle#57633)
Browse files Browse the repository at this point in the history
* add testcase load_inference_model

* polish code for last pr

* delete annots

* fix docs of static_pylayer
  • Loading branch information
MarioLulab authored Sep 25, 2023
1 parent 2998ba2 commit 8f5879f
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 17 deletions.
12 changes: 6 additions & 6 deletions python/paddle/static/nn/static_pylayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ def static_pylayer(forward_fn, inputs, backward_fn=None, name=None):
the logic of ``forward_fn`` and ``backward_fn``, with the operator ``pylayer``
holding information about the two blocks.
``forward_fn`` and ``backward_fn`` should return a nest structure of tensors.
A nest structure of tensors in PaddlePaddle is tensor(s), or tuple of tensors, or
list of tensors.
``forward_fn`` and ``backward_fn`` should return a nest structure of Variables.
A nest structure of Variables in PaddlePaddle is Variable(s), or tuple of Variables, or
list of Variables.
Note:
1. If ``backward_fn`` is not None, user needs to keep the number of inputs to ``forward_fn`` the same as the
number of outputs to ``backward_fn``, and the number of outputs to ``forward_fn``
the same as the number of inputs to ``backward_fn``.
1. If ``backward_fn`` is not None, user needs to keep the number of `Variable` inputs to ``forward_fn`` the same as the
number of `Variable` outputs to ``backward_fn``, and the number of `Variable` outputs to ``forward_fn``
the same as the number of `Variable` inputs to ``backward_fn``.
2. If ``backward_fn`` is None, ``stop_gradient`` attr of all Variable in ``inputs`` is expected to be True.
Otherwise it might get unexpected results in backward propagation.
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_pylayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(self, in_size, out_size):
def forward(self, x1, x2):
y1 = self.linear1(x1)
y2 = self.linear1(x2)
out = cus_tanh_2.apply(y1, func1=paddle.tanh)
out = cus_tanh_2.apply(y1, paddle.tanh)
out = out + y2
out = paddle.mean(out)
return out
Expand Down
91 changes: 85 additions & 6 deletions test/dygraph_to_static/test_save_inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import paddle
from paddle import base
from paddle.autograd import PyLayer
from paddle.jit.api import to_static
from paddle.jit.dy2static.partial_program import partial_program_from
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
Expand All @@ -45,6 +46,33 @@ def forward(self, x):
return out, y


class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y

@staticmethod
def backward(ctx, dy):
(y,) = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad


class SimplePyLayerNet(paddle.nn.Layer):
def __init__(self, fc_size):
super().__init__()
self._linear = paddle.nn.Linear(fc_size, fc_size)

@to_static
def forward(self, x):
y = self._linear(x)
out = cus_tanh.apply(y)
loss = paddle.mean(out)
return loss, out


class TestDyToStaticSaveInferenceModel(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -94,8 +122,52 @@ def test_save_inference_model(self):
layer, [x_data], dygraph_out.numpy(), feed=[x]
)

@ast_only_test
def test_save_pylayer_model(self):
fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32')
paddle.base.framework._set_expected_place(place)

base.default_startup_program().random_seed = SEED
base.default_main_program().random_seed = SEED
paddle.disable_static()
x = base.dygraph.to_variable(x_data)
layer = SimplePyLayerNet(fc_size)
adam = paddle.optimizer.SGD(
learning_rate=0.1, parameters=layer.parameters()
)

for i in range(5):
loss, pred = layer(x)
loss.backward()
adam.minimize(loss)
layer.clear_gradients()

infer_model_prefix = os.path.join(
self.temp_dir.name, "test_dy2stat_inference_in_guard/model_pylayer"
)
paddle.jit.save(
layer=layer,
path=infer_model_prefix,
input_spec=[x],
output_spec=[pred],
)
# Check the correctness of the inference
loss_out, _ = layer(x)

loss_out_numpy = float(loss_out)
self.check_save_inference_model(
layer, [x_data], loss_out_numpy, enable_new_ir=False
)
self.check_save_inference_model(
layer, [x_data], loss_out_numpy, fetch=[loss], enable_new_ir=False
)
self.check_save_inference_model(
layer, [x_data], loss_out_numpy, feed=[x], enable_new_ir=False
)

def check_save_inference_model(
self, model, inputs, gt_out, feed=None, fetch=None
self, model, inputs, gt_out, feed=None, fetch=None, enable_new_ir=True
):
expected_persistable_vars = {p.name for p in model.parameters()}

Expand All @@ -113,13 +185,20 @@ def check_save_inference_model(
input_spec=feed if feed else None,
output_spec=fetch if fetch else None,
)
# Check the correctness of the inference
infer_out = self.load_and_run_inference(
infer_model_dir, model_filename, params_filename, inputs
)
if enable_new_ir:
wrapped_load_and_run_inference = test_and_compare_with_new_ir(True)(
self.load_and_run_inference
)
infer_out = wrapped_load_and_run_inference(
infer_model_dir, model_filename, params_filename, inputs
)
else:
infer_out = self.load_and_run_inference(
infer_model_dir, model_filename, params_filename, inputs
)

np.testing.assert_allclose(gt_out, infer_out, rtol=1e-05)

@test_and_compare_with_new_ir(True)
def load_and_run_inference(
self, model_path, model_filename, params_filename, inputs
):
Expand Down
8 changes: 4 additions & 4 deletions test/legacy_test/test_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _mock_guard(mock):
base.Executor._prune_program = original


def net1():
def create_net1():
x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32')
x.desc.set_need_check_feed(False)
label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
Expand All @@ -238,7 +238,7 @@ def net1():
return x, y, label, loss1, loss2, w_param_attrs


def net2():
def create_net2():
x1 = paddle.static.data(name='x1', shape=[-1, 2], dtype='float32')
x1.desc.set_need_check_feed(False)
x2 = paddle.static.data(name='x2', shape=[-1, 2], dtype='float32')
Expand Down Expand Up @@ -286,8 +286,8 @@ def net2():

class TestExecutorRunAutoPrune(unittest.TestCase):
def setUp(self):
self.net1 = net1
self.net2 = net2
self.net1 = create_net1
self.net2 = create_net2

def test_not_prune(self):
"""
Expand Down

0 comments on commit 8f5879f

Please sign in to comment.