Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(re)enable torch.compile in the pytorch trainer for train, predict, and eval #18569

Merged
merged 1 commit into from
Oct 15, 2023

Conversation

kiukchung
Copy link
Contributor

@kiukchung kiukchung commented Oct 6, 2023

Change Summary

  1. Updates torch to torch>=2.1.0 (which has many improvements to dynamo)
  2. Wraps the underlying function with torch.compile when jit_compile=True for train, eval, and predict
  3. Updated docs for model.fit() explaining that jit_compile="auto" defaults to eager for the torch backend (torch.compile only kicks in if the user explicitly sets jit_compile=True).
  4. Adds setUp() to clear_session() in testing.TestCase (required for dynamo)
  5. Fixes a few functions to make the codebase dynamo friendly
  6. Fix incorrect assertion in naming_test.py:test_uniquify_already_uniquified_name()
  7. Use jit_compile="auto" (versus jit_compile=True) in keras.testing.test_case.TestCase.run_layer_test.run_training_step() so that the backends are tested in their "default" jitted mode (jit for tf and jax and eager for torch).

Note On Dynamo

Currently there are two caveats to running torch backend with jit_compile=True

  1. (performance) It is slower than eager because of too many graph breaks, which is mainly due to the usage of tree in the function (dynamo will not trace through tree, see skipfiles) any_symbolic_tensors(), which in turn is called by pretty much all ops (e.g. numpy, layer, activation, etc). Therefore, no "deep-graph" can be captured and hence no opportunities for optimizations such as op-fusion. This can be fixed by not using tree.flatten in any_symbolic_tensors()

  2. (overhead) torch.core.convert_to_tensor needs to be simplified to just calling torch.as_tensor(x, dtype, device) rather than using x.to(dtype, device). This won't make things compile better but reduces frame eval overhead since convert_to_tensor is called for each op and tracing through many branches is less than ideal.

  3. (compatibility) There are cases where primitive operators can be traced by dynamo, but when a sequence of them are used as a higher order operator such as a layer (e.g. up_sampling_2d), causes guard failures on the primitive ops, which in turn makes dynamo trace with dynamic shapes via symbolic variables rather than concretized values, which can often lead to tracing failures due to "missing methods".

Testing

CI for unittests

Manual testing on examples/keras_io/vision/mnist_convnet.py by explicitly enabling jit_compile.
Observations:

  1. No significant speedup
  2. First 1,2 epochs are slow due to (re)compilation
  3. Later epochs are still slower: 19ms (eager) vs 28ms (compiled) on CPU, have not tried on GPU.
  4. (3) mostly due to recompilation / graph breaks since some functions (e.g. convert_to_tensor) are highly dynamic in the input types (python types).

@kiukchung kiukchung force-pushed the master branch 2 times, most recently from 8212c7d to 22ec236 Compare October 6, 2023 23:03
@codecov-commenter
Copy link

codecov-commenter commented Oct 6, 2023

Codecov Report

All modified lines are covered by tests ✅

Comparison is base (d026dfd) 78.11% compared to head (d65c494) 78.30%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18569      +/-   ##
==========================================
+ Coverage   78.11%   78.30%   +0.18%     
==========================================
  Files         334      334              
  Lines       32477    32484       +7     
  Branches     6339     6342       +3     
==========================================
+ Hits        25371    25438      +67     
+ Misses       5539     5482      -57     
+ Partials     1567     1564       -3     
Flag Coverage Δ
keras 78.19% <91.30%> (+0.17%) ⬆️
keras-jax 63.58% <34.78%> (+0.16%) ⬆️
keras-numpy 57.94% <30.43%> (+0.13%) ⬆️
keras-tensorflow 64.42% <34.78%> (+0.11%) ⬆️
keras-torch 65.31% <91.30%> (+0.22%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
keras/backend/common/global_state.py 96.29% <100.00%> (+0.46%) ⬆️
keras/backend/torch/core.py 91.45% <100.00%> (ø)
keras/backend/torch/numpy.py 95.73% <100.00%> (+0.01%) ⬆️
keras/backend/torch/random.py 91.30% <100.00%> (+0.09%) ⬆️
keras/backend/torch/trainer.py 90.04% <100.00%> (+1.01%) ⬆️
keras/layers/reshaping/flatten.py 100.00% <100.00%> (ø)
keras/layers/reshaping/up_sampling2d.py 95.45% <100.00%> (-0.30%) ⬇️
keras/testing/test_case.py 86.20% <100.00%> (+0.18%) ⬆️
keras/trainers/epoch_iterator.py 90.74% <ø> (ø)
keras/trainers/trainer.py 84.73% <100.00%> (-0.24%) ⬇️

... and 5 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras/backend/torch/core.py Outdated Show resolved Hide resolved
keras/backend/torch/trainer.py Outdated Show resolved Hide resolved
keras/trainers/trainer_test.py Outdated Show resolved Hide resolved
@kiukchung kiukchung force-pushed the master branch 19 times, most recently from 91ddb48 to e166beb Compare October 14, 2023 05:18
@kiukchung kiukchung marked this pull request as ready for review October 14, 2023 05:19
@fchollet
Copy link
Collaborator

which is mainly due to the usage of tree in the function (dynamo will not trace through tree, see skipfiles

This is actually very fixable. Instead of using tree we can use e.g. keras.utils.tree, and in those functions we can route between the actual tree functions (which are C++ based) or a simple pure Python implementation (which would be Dynamo compatible) depending on whether we're in a Dynamo context. What do you think?

The reason tree is not Dynamo compatible is presumably because it isn't Python based (for performance reasons -- which is good when in eager execution).

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- I think doing the tree conversion would likely unlock the performance benefits here.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 14, 2023
@@ -2,8 +2,8 @@
tf-nightly==2.15.0.dev20231009 # Pin a working nightly until rc0.

# Torch.
torch>=2.0.1
torchvision>=0.15.1
torch>=2.1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grasskin - FYI. I remember Gabriel wanting to keep the requirements as torch 2.0.1. So wanted him to take a look or be in the loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @sampathweb, @grasskin, let me know if we have a good reason to stay at 2.0.1. I'd like to update to 2.1 if possible since it has a bunch of fixes (especially to torch.compile)

@kiukchung
Copy link
Contributor Author

which is mainly due to the usage of tree in the function (dynamo will not trace through tree, see skipfiles

This is actually very fixable. Instead of using tree we can use e.g. keras.utils.tree, and in those functions we can route between the actual tree functions (which are C++ based) or a simple pure Python implementation (which would be Dynamo compatible) depending on whether we're in a Dynamo context. What do you think?

The reason tree is not Dynamo compatible is presumably because it isn't Python based (for performance reasons -- which is good when in eager execution).

which is mainly due to the usage of tree in the function (dynamo will not trace through tree, see skipfiles

This is actually very fixable. Instead of using tree we can use e.g. keras.utils.tree, and in those functions we can route between the actual tree functions (which are C++ based) or a simple pure Python implementation (which would be Dynamo compatible) depending on whether we're in a Dynamo context. What do you think?

The reason tree is not Dynamo compatible is presumably because it isn't Python based (for performance reasons -- which is good when in eager execution).

Yep I created an issue for this (#18614). I can do this in a fast-follow PR since this one is getting big and the torch backend defaults to eager right now.

@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Oct 15, 2023
@fchollet
Copy link
Collaborator

Happy to merge this now since CI is passing and we can do the rest in future PRs. If the updated torch version is an issue we can revert that part later.

@fchollet fchollet merged commit 1c0d997 into keras-team:master Oct 15, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants