Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mixed precision support to deepxde #1650

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion deepxde/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def set_default_float(value):
The default floating point type is 'float32'.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default floating point type is 'float32'. Mixed precision uses the method in the paper: `J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision. Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.


Args:
value (String): 'float16', 'float32', or 'float64'.
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision in https://arxiv.org/abs/2401.16645).
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision).

"""
if value == "float16":
print("Set the default float type to float16")
Expand All @@ -85,6 +85,20 @@ def set_default_float(value):
elif value == "float64":
print("Set the default float type to float64")
real.set_float64()
elif value == "mixed":
g-w1 marked this conversation as resolved.
Show resolved Hide resolved
g-w1 marked this conversation as resolved.
Show resolved Hide resolved
print("Set the float type to mixed precision of float16 and float32")
real.set_mixed()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is confusing. Here you do real.set_mixed(), but later you do either real.set_float16() or real.set_float32(). It seems you only need a flag mixed. You can do this flag after line 42.

if backend_name == "tensorflow":
real.set_float16()
tf.keras.mixed_precision.set_global_policy("mixed_float16")
return # don't try to set it again below
if backend_name == "pytorch":
# Use float16 during the forward and backward passes, but store in float32
real.set_float32()
else:
raise ValueError(
f"{backend_name} backend does not currently support mixed precision."
)
else:
raise ValueError(f"{value} not supported in deepXDE")
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
Expand Down
7 changes: 5 additions & 2 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,12 @@ def closure():
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
return total_loss

self.opt.step(closure)
def closure_mixed():
Copy link
Owner

@lululxvi lululxvi Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why delete line 360

with torch.autocast(device_type="cuda", dtype=torch.float16):
closure()

self.opt.step(closure if not config.real.mixed else closure_mixed)
if self.lr_scheduler is not None:
self.lr_scheduler.step()

Expand Down
4 changes: 4 additions & 0 deletions deepxde/real.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Real:
def __init__(self, precision):
self.precision = None
self.reals = None
self.mixed = False
if precision == 16:
self.set_float16()
elif precision == 32:
Expand All @@ -28,3 +29,6 @@ def set_float32(self):
def set_float64(self):
self.precision = 64
self.reals = {np: np.float64, bkd.lib: bkd.float64}

def set_mixed(self):
self.mixed = True
2 changes: 2 additions & 0 deletions docs/user/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ General usage
| **A**: `#5`_
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
| **A**: `#28`_
- | **Q**: How can I use mixed precision training?
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `this paper <https://doi.org/10.1016/j.cma.2024.117093>`_ for more information.
- | **Q**: I want to set the global random seeds.
| **A**: `#353`_
- | **Q**: GPU.
Expand Down
Loading