From 04bcd4797d623303353b1421c2cdc1b15248cce2 Mon Sep 17 00:00:00 2001 From: Jacob G-W Date: Sun, 11 Feb 2024 17:11:35 -0500 Subject: [PATCH] add mixed precision support to deepxde --- deepxde/config.py | 15 ++++++++++++++- deepxde/model.py | 18 ++++++++++++++---- deepxde/real.py | 4 ++++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/deepxde/config.py b/deepxde/config.py index 6d87e67b6..caa816801 100644 --- a/deepxde/config.py +++ b/deepxde/config.py @@ -74,7 +74,7 @@ def set_default_float(value): The default floating point type is 'float32'. Args: - value (String): 'float16', 'float32', or 'float64'. + value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision in https://arxiv.org/abs/2401.16645). """ if value == "float16": print("Set the default float type to float16") @@ -85,6 +85,19 @@ def set_default_float(value): elif value == "float64": print("Set the default float type to float64") real.set_float64() + elif value == "mixed": + print("Set the float type to mixed precision of float16 and float32") + real.set_mixed() + if backend_name == "tensorflow": + real.set_float16() + tf.keras.mixed_precision.set_global_policy("mixed_float16") + elif backend_name == "pytorch": + # Use float16 during the forward and backward passes, but store in float32 + real.set_float32() + else: + raise ValueError( + f"{backend_name} backend does not currently support mixed precision." + ) else: raise ValueError(f"{value} not supported in deepXDE") if backend_name in ["tensorflow.compat.v1", "tensorflow"]: diff --git a/deepxde/model.py b/deepxde/model.py index 4ebdf6859..e19f6d9b7 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -353,10 +353,20 @@ def outputs_losses_test(inputs, targets, auxiliary_vars): def train_step(inputs, targets, auxiliary_vars): def closure(): - losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] - total_loss = torch.sum(losses) - self.opt.zero_grad() - total_loss.backward() + if config.real.mixed: + with torch.autocast(device_type="cuda", dtype=torch.float16): + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[ + 1 + ] + total_loss = torch.sum(losses) + # we do the backprop in float16 + self.opt.zero_grad() + total_loss.backward() + else: + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + self.opt.zero_grad() + total_loss.backward() return total_loss self.opt.step(closure) diff --git a/deepxde/real.py b/deepxde/real.py index 1ceb2fd7a..d7268de6e 100644 --- a/deepxde/real.py +++ b/deepxde/real.py @@ -7,6 +7,7 @@ class Real: def __init__(self, precision): self.precision = None self.reals = None + self.mixed = False if precision == 16: self.set_float16() elif precision == 32: @@ -28,3 +29,6 @@ def set_float32(self): def set_float64(self): self.precision = 64 self.reals = {np: np.float64, bkd.lib: bkd.float64} + + def set_mixed(self): + self.mixed = True