Skip to content

Commit

Permalink
test_nn_transformer_search
Browse files Browse the repository at this point in the history
Also related to search flag: #18
  • Loading branch information
albertz committed Feb 16, 2022
1 parent 95c201d commit 5c3b7ed
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/test_nn_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Test nn.transformer.
"""

from __future__ import annotations

from . import _setup_test_env # noqa
from .returnn_helpers import dummy_run_net, config_net_dict_via_serialized
import typing

if typing.TYPE_CHECKING:
from .. import nn
else:
from returnn_common import nn # noqa


def test_nn_transformer_search():
with nn.NameCtx.new_root() as name_ctx:
time_dim = nn.SpatialDim("time")
input_dim = nn.FeatureDim("input", 4)
data = nn.get_extern_data(nn.Data("data", dim_tags=[nn.batch_dim, time_dim, input_dim]))
transformer = nn.Transformer()
out, _ = transformer(data, source_spatial_axis=time_dim, search=True, beam_size=3, eos_symbol=0, name=name_ctx)
out.mark_as_default_output()

config_code = name_ctx.get_returnn_config_serialized()
config, net_dict = config_net_dict_via_serialized(config_code)
dummy_run_net(config)

0 comments on commit 5c3b7ed

Please sign in to comment.