Use single forward pass in shared model architectures #156
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Single forward pass
Motivation:
When applying the shared model, forward pass is called twice, once for policy and once for value. The input values for the forward call are identical, so the output value could be cached to improve performance.
!Note: Single forward pass also influences the autograd graph construction, so the significant speedup happens also during the backward pass phase.
Speed eval:
Big neural network (units: [2048, 1024, 1024, 512])
3840 steps
Running on top of Oige env simulation (constant for each run)
* Mixed precision = True
Quality eval:
We trained a policy for our task with each of the configurations multiple times. We didn’t observe any statistically significant difference in quality of the final results.
Notice: The single and double pass runs would be identical in ideal world, but because of finite double precision and different order of computation of gradient, they diverge gradually.
Note:
- this implementation is minimalistic, but it’s quite dangerous to generalise, as it requires the value forward pass always follow the policy forward pass.
To make it safer we may implement caching of input and check if the next input is the same
- a) check if they are reference to the same object
- b) compare input and cached input tensors directly. It brings some overhead in computation, but it’s negligible compared to time spared.