From 5c3b7edb885693a11d6335d70c43dbc91e2f334a Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 15 Feb 2022 11:05:59 +0100 Subject: [PATCH] test_nn_transformer_search Also related to search flag: #18 --- tests/test_nn_transformer.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/test_nn_transformer.py diff --git a/tests/test_nn_transformer.py b/tests/test_nn_transformer.py new file mode 100644 index 00000000..09687f36 --- /dev/null +++ b/tests/test_nn_transformer.py @@ -0,0 +1,28 @@ +""" +Test nn.transformer. +""" + +from __future__ import annotations + +from . import _setup_test_env # noqa +from .returnn_helpers import dummy_run_net, config_net_dict_via_serialized +import typing + +if typing.TYPE_CHECKING: + from .. import nn +else: + from returnn_common import nn # noqa + + +def test_nn_transformer_search(): + with nn.NameCtx.new_root() as name_ctx: + time_dim = nn.SpatialDim("time") + input_dim = nn.FeatureDim("input", 4) + data = nn.get_extern_data(nn.Data("data", dim_tags=[nn.batch_dim, time_dim, input_dim])) + transformer = nn.Transformer() + out, _ = transformer(data, source_spatial_axis=time_dim, search=True, beam_size=3, eos_symbol=0, name=name_ctx) + out.mark_as_default_output() + + config_code = name_ctx.get_returnn_config_serialized() + config, net_dict = config_net_dict_via_serialized(config_code) + dummy_run_net(config)