Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to define whether search (or train flag) is enabled? #18

Open
albertz opened this issue Aug 6, 2021 · 20 comments
Open

How to define whether search (or train flag) is enabled? #18

albertz opened this issue Aug 6, 2021 · 20 comments
Assignees
Milestone

Comments

@albertz
Copy link
Member

albertz commented Aug 6, 2021

How to define whether search is enabled? Just not do any special handling on this level at all and leave it to the current RETURNN behavior? Or make it more explicit?

From what I can see RETURNN handling should be enough. Assuming using this with automation changes in config can be done elsewhere I think.

Note that this is not just about what is enough to be able to define all networks. Or not sure how you mean it.

It's about being straight forward and clear. I.e. at no time, it should be unclear to the user when search is used.

We do not have to follow exactly the behavior of RETURNN. There are also multiple ways in RETURNN. We can restrict it to one clean way. We can also change it. Or introduce a simpler variant here.

I'm tending to make it explicit. But not sure.

PyTorch also has a similar concept for the train flag (as we do have as well in RETURNN). I.e. some PyTorch modules behave differently depending if they are in train or eval mode (e.g. Dropout). We have exactly the same in RETURNN. And search is a flag like train.

The difference is how these flags are set:

  • In RETURNN, this is all globally, and for search flag, there are some additional (maybe unintuitive) ways to overwrite it. And the flags are implied automatically in RETURNN, depending e.g. on the task, and the user has not much control over it. It is quite hidden.

  • In PyTorch, there are no implicit automatic implied global flags. Every module has its own flag, and it is set explicitly (and easily recursively for all sub modules). Every module has always the train flag set initially, and you can disable it explicitly. So to the user, it's always clear how the flags are set, because the user sets them, and no automatic behavior. The user explicitly writes model.train() or model.eval().

Maybe again, here in returnn-common, we can follow the PyTorch style a bit for this, and also copy it for the search flag? Not sure...

Originally posted by @albertz in #16 (comment)

@Atticus1806
Copy link
Contributor

Maybe again, here in returnn-common, we can follow the PyTorch style a bit for this, and also copy it for the search flag? Not sure...

If I understand this correctly, each module has an internal flag which defines wether this (and only this module) is in train or search mode. I like this more than a global flag. This makes checking within the module easy to apply different versions of the network. In my Tacotron2 version right now I am passing a flag when calling the root module, but this is quite ugly and I would prefer a way that set actually part of returnn_common.

That being said maybe some questions that come to mind right now:

  • Is it only a binary flag? RETURNN has more than search and train as flags, but do we need to differentiate between them here? Without deeper understanding on how modes like forward in RETURNN work I would tend to saying yes, but depending on internal handling maybe no.
  • What would be functions the Module class would need to implement? get, set, set_recursively (to all Modules used in this), maybe also something like lock preventing external modification (via e.g. set_recursively)?
  • Would this also give the chance to change the number of inputs for the forward function, since this might in some cases change when doing search (some information that is given as data during training but which the model has to generate itself during search? I am sure there are examples)

@albertz
Copy link
Member Author

albertz commented Oct 21, 2021

On PyTorch: there is a training: bool flag on each Module, which is True by default. There is the Module.train(mode=True) function which sets self.training = mode and successively also calls module.train(mode) for all child modules. There is the Module.eval() function which just calls self.train(False). As far as I know, the only PyTorch base modules which make use of this flag are Dropout and BatchNorm. But other third-party modules might make use of it as well. It is standard that you call model.eval() first when you do evaluation.

So, that's PyTorch.

RETURNN has more flags, currently: train_flag: Union[bool, tf.Tensor], eval_flag: bool, search_flag: bool.
We do not necessarily have train_flag == not eval_flag.
eval_flag means that RETURNN calculates the losses (does evaluation).
train_flag has influence on behavior of dropout and maybe other layers (just like in PyTorch).
train_flag implies that we do training, so it implies that we need to calculate the losses, so it implies eval_flag=True.
train_flag can be dynamic (tf.Tensor, a placeholder) (such that we have a single computation graph both for training and evaluation).
search_flag enables the search (ChoiceLayer, and the search choices logic, etc).
Those flags are per network, not per layer. But we can easily extend that in RETURNN if we want to.

The question is, what do we do here in returnn-common?

Also, I think there is some discussion on PyTorch side whether the training flag is really so useful. I think most user code anyway explicitly constructs the model either for training, evaluation, or whatever else, and just doesn't use this flag for this. Because as you say, it's often not just these two cases (binary) but other cases as well. And this is hard to generalize. Different applications, models, methods might have very different need, different kind of flags or options.

@Atticus1806
Copy link
Contributor

So you would actually argue, that it would be either really complicated in usage (so not straight forward as we want returnn_common to be) or to simple in logic (not catching all possibilities) to use such a flag, which in both cases will most likely make the users default to writing some simpler logic for their use case? I can actually see that. But maybe it makes sense to give a simple straight forward logic for the base cases, and then leave more complex cases to the user?

@albertz
Copy link
Member Author

albertz commented Oct 21, 2021

I'm unsure.

We could adopt the training flag just as in the same way as in PyTorch.
(eval_flag is anyway not really needed separately, as we handle losses more explicitly here.)
Or not and just leave that logic to RETURNN.

And then not introduce other flags in the base Module but leave that as specific options to modules.
Or introduce a search flag with similar logic as the training flag. Although maybe this recursive setting the flag to all child modules is not so useful or necessary for this?

Remember, our goal is that it is straightforward for the user, with priority on understanding and reading existing code, but writing new code should be straightforward as well.

It's hard to say which way would be the most straightforward.

We maybe should do a bit of research and look for common usage patterns in other frameworks. How do they deal with the difference between training, search, forwarding, maybe other cases? E.g. look at Fairseq, ESPNet, others.

@albertz
Copy link
Member Author

albertz commented Jan 5, 2022

Note that the RETURNN train_flag is also used to differentiate between the standard training and cross validation. This is an important distinction. Currently RETURNN constructs one single network (computation graph) for both training including cross validation.

If we make everything completely independent from the RETURNN train_flag here, then we need to extend RETURNN to be able to define a separate network for cross validation. Maybe that's anyway a good idea, to have that explicit?

If we want to keep the RETURNN train_flag logic, we can also access the train_flag dynamically (via a RETURNN layer) and then use the standard conditional logic (#24) to define custom behavior depending on this flag (e.g. for batch norm or dropout).

@albertz
Copy link
Member Author

albertz commented Feb 4, 2022

TransformerDecoder is one example for search.

@albertz
Copy link
Member Author

albertz commented Feb 7, 2022

Note, we have now train_flag() -> nn.Tensor, via TrainFlagLayer.

@albertz
Copy link
Member Author

albertz commented Feb 8, 2022

In case we need it at some later point, I guess we can add train_flag as an option to SubnetworkLayer, and that could be explicit via option to nn.scoped or so, or via context manager. train_flag() is already using the current network, so that should just work.

Functions like dropout() or stochastic_depth() could still maybe have a specific option like apply_alwaysor so to ignore the train flag.

So, the current solution for the train flag is probably fine, and can be extended when needed.

@albertz
Copy link
Member Author

albertz commented Feb 8, 2022

I tend to keep it as that, i.e. only having the train flag implicit, and all other aspects are explicit.

There would not be any special handling for the eval flag.

And the search flag would always be an explicit argument, e.g. to TransformerDecoder.__call__.
Technically on RETURNN side, we might always enable the search flag, and then we control it explicitly as an argument to ChoiceLayer.

albertz added a commit to rwth-i6/returnn that referenced this issue Feb 14, 2022
Use net search flag only as default fall back
when search option is not explicitly provided for ChoiceLayer.
DecideLayer and co will always apply.

ChoiceLayer, new add_to_beam_scores option
to control whether the score should be added to the beam
when not doing search.

Fix #946.

Related:
rwth-i6/returnn_common#18
@albertz
Copy link
Member Author

albertz commented Feb 14, 2022

Now with rwth-i6/returnn#947, I think it is easy to make the search flag always an explicit option for nn.choice. We will never set the net search flag here in returnn-common and not make use of the extra net logic. (Though, if the user uses config option task == "search", this will imply that the search flag is set. Which should not have an influence, except maybe for add_to_beam_scores, so we maybe should not allow NotSpecified but either True or False.)

albertz added a commit that referenced this issue Feb 15, 2022
Also related to search flag: #18
albertz added a commit that referenced this issue Feb 16, 2022
Also related to search flag: #18
@albertz
Copy link
Member Author

albertz commented Feb 17, 2022

The Transformer.__call__ looks like this now:

@nn.scoped
def __call__(self, source: nn.Tensor, *,
             source_spatial_axis: nn.Dim,
             target: Optional[nn.Tensor] = None,
             initial_state: Optional[nn.LayerState] = None,
             search: bool,
             beam_size: Optional[int] = None,
             max_seq_len: Optional[Union[nn.Tensor, int]] = None,
             ) -> Tuple[nn.Tensor, nn.LayerState]:
  """
  Forward step of Transformer
  """
  memory = self.encoder(source, axis=source_spatial_axis)
  loop = nn.Loop(max_seq_len=max_seq_len)
  loop.state = initial_state if initial_state else self.default_initial_state()
  with loop:
    prev_target_embed = self.target_embedding(loop.state.target)
    output, loop.state.decoder = self.decoder(
      prev_target_embed, axis=nn.single_step_dim,
      memory=memory, memory_spatial_axis=source_spatial_axis, state=loop.state.decoder)
    logits = self.output_projection(output)
    target = loop.unstack(target) if target is not None else None
    if search:
      loop.state.target = nn.choice(logits, input_type="logits", target=target, search=True, beam_size=beam_size)
      loop.end(loop.state.target == self.target_eos_symbol, include_eos=False)
    else:
      assert target is not None
      loop.state.target = target
    outputs = loop.stack(loop.state.target)
  return outputs, loop.state

See specifically the logic regarding the search option.

I'm still not sure if this is the best way. Note that nn.choice has more options, which might be relevant for certain cases, e.g. cheating, etc.

It would look bad though to add all the nn.choice options here to Transformer.__call__. You might argue that even right now already the beam_size and max_seq_len look a bit ugly there, as these are only relevant for the search.

@albertz
Copy link
Member Author

albertz commented Feb 17, 2022

Maybe target could be either a nn.Tensor or a callable. In case of a callable, it would call it with logits and loop.

Then the user could write this to enable search:

transformer = nn.Transformer(...)
out, _ = transformer(
  data, ...,
  target=lambda logits, loop: nn.choice(logits, input_type="logits", search=True, beam_size=12)

Or instead of callable, which is maybe too ambiguous, we can have a special Search object and interface. It also would provide max_seq_length.

albertz added a commit that referenced this issue Feb 17, 2022
@albertz
Copy link
Member Author

albertz commented Feb 17, 2022

I implemented an abstract SearchFuncInterface and a specific SearchFunc now, where SearchFunc.choice basically just calls nn.choice with search=True.

The Transformer.__call__ only has these args now:

@nn.scoped
def __call__(self, source: nn.Tensor, *,
             source_spatial_axis: nn.Dim,
             target: Optional[Union[nn.Tensor, nn.SearchFuncInterface]] = None,
             initial_state: Optional[nn.LayerState] = None,
             ) -> Tuple[nn.Tensor, nn.LayerState]:
  ...

And an example call looks like:

    transformer = nn.Transformer(...)
    out, _ = transformer(
      data, source_spatial_axis=time_dim,
      target=nn.SearchFunc(
        beam_size=3,
        max_seq_len=nn.reduce(nn.length(data, axis=time_dim), mode="max", axis=nn.batch_dim)))

@albertz
Copy link
Member Author

albertz commented Feb 17, 2022

I think then we have everything ready now. Whether the current solutions are good, we can only tell after a bit of usage. I think we can close this for now.

@albertz albertz closed this as completed Feb 17, 2022
@albertz
Copy link
Member Author

albertz commented Aug 22, 2022

While coming back to the search (currently only Transformer implements it, but should be part of the generic Decoder as well), I found this a bit confusing.

max_seq_len seems quite specific to me and at first a bit confusing to found this in nn.SearchFunc.

Transformer initial_state is inconsistent to other recurrent logic code where it is just called state (I think).

@albertz albertz reopened this Aug 22, 2022
@albertz
Copy link
Member Author

albertz commented Aug 22, 2022

I reopened this issue because I think we should improve and maybe redesign this.

@albertz
Copy link
Member Author

albertz commented Aug 22, 2022

In the Transformer code, we have this:

    if search:
      beam = search.get_beam()
      beam.name = f"{nn.NameCtx.current_ctx().get_abs_name()}/target"
      beam.dependency = beam.copy_as_prev_frame()
      for x in loop.state.deep_tensors():
        x.data.beam = beam.dependency

There are multiple problems:

  • This is unintuitive. Do you always need to put this when you write some own loop with a search var? Why is this really needed, and why not automatic?
  • I think it's not totally correct. Not all state deps might depend on the search var. Although it might be correct for this Transformer code.

@albertz
Copy link
Member Author

albertz commented Aug 22, 2022

The current nn.SearchFunc usage (Transformer) makes use of the RETURNN search logic (SearchBeam, Data.beam, etc). There was also the idea whether this is really needed or whether we can actually reimplement the logic purely on returnn-common side, cleaner, simpler and more straightforward than before. Maybe more explicit, less automatic magic.

albertz added a commit that referenced this issue Aug 22, 2022
@albertz
Copy link
Member Author

albertz commented Aug 22, 2022

I now just removed the beam logic (SearchFuncInterface.get_beam) and it seems to work without any special care for that. Actually I expected that I need to add some automatic logic somewhere, e.g. in _LoopState.assign, to adapt the Data.beam of all related tensors, but as it seems to work fine now without that, I just leave it.

@albertz
Copy link
Member Author

albertz commented Aug 22, 2022

SearchFuncInterface is also relevant for defining the generic high-level decoder interface (#49).

Transformer/TransformerDecoder (#55) probably should be renamed/adapted to fit the generic decoder interface with search.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants