Skip to content

Commit

Permalink
Merge pull request #137 from IBM/fix/global_dtype
Browse files Browse the repository at this point in the history
Fix/global dtype
  • Loading branch information
Joao-L-S-Almeida authored May 12, 2023
2 parents b5c5608 + f276fc3 commit dfb8e8b
Show file tree
Hide file tree
Showing 16 changed files with 96 additions and 25 deletions.
11 changes: 6 additions & 5 deletions simulai/models/_pytorch_models/_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import torch

from simulai import ARRAY_DTYPE
from simulai.regression import ConvolutionalNetwork, DenseNetwork, Linear
from simulai.templates import (
NetworkTemplate,
Expand Down Expand Up @@ -508,7 +509,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray
"""

if isinstance(input_data, np.ndarray):
input_data = torch.from_numpy(input_data.astype("float32"))
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))

input_data = input_data.to(self.device)

Expand Down Expand Up @@ -995,7 +996,7 @@ def predict(
"""
if isinstance(input_data, np.ndarray):
input_data = torch.from_numpy(input_data.astype("float32"))
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))

predictions = list()
latent = self.projection(input_data=input_data)
Expand Down Expand Up @@ -1694,7 +1695,7 @@ def project(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndar
>>> projected_data = autoencoder.project(input_data=input_data)
"""
if isinstance(input_data, np.ndarray):
input_data = torch.from_numpy(input_data.astype("float32"))
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))

input_data = input_data.to(self.device)

Expand Down Expand Up @@ -1725,7 +1726,7 @@ def reconstruct(
>>> reconstructed_data = autoencoder.reconstruct(input_data=input_data)
"""
if isinstance(input_data, np.ndarray):
input_data = torch.from_numpy(input_data.astype("float32"))
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))

input_data = input_data.to(self.device)

Expand Down Expand Up @@ -1754,7 +1755,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray
>>> reconstructed_data = autoencoder.eval(input_data=input_data)
"""
if isinstance(input_data, np.ndarray):
input_data = torch.from_numpy(input_data.astype("float32"))
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))

input_data = input_data.to(self.device)

Expand Down
3 changes: 3 additions & 0 deletions tests/PINN/test_deep_operator_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import numpy as np

from tests.config import configure_dtype
torch = configure_dtype()

from simulai.optimization import Optimizer
from simulai.residuals import SymbolicOperator

Expand Down
3 changes: 3 additions & 0 deletions tests/PINN/test_vanilla_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import matplotlib.pyplot as plt
import numpy as np

from tests.config import configure_dtype
torch = configure_dtype()

from simulai.optimization import Optimizer
from simulai.residuals import SymbolicOperator

Expand Down
40 changes: 40 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# (C) Copyright IBM Corp. 2019, 2020, 2021, 2022.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# (C) Copyright IBM Corporation 2017, 2018, 2019
# U.S. Government Users Restricted Rights: Use, duplication or disclosure restricted
# by GSA ADP Schedule Contract with IBM Corp.
#
# Author: Joao Lucas S. Almeida <[email protected]>

import os
import torch

def configure_dtype():

test_dtype_var = os.environ.get("TEST_DTYPE")

if test_dtype_var is not None:
test_dtype = getattr(torch, test_dtype_var)
else:
test_dtype = torch.float32

torch.set_default_dtype(test_dtype)

print(f"Using dtype {test_dtype} in tests.")

return torch



4 changes: 2 additions & 2 deletions tests/metrics/test_mahalanobis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import TestCase

import torch
from tests.config import configure_dtype
torch = configure_dtype()

from simulai.metrics import MahalanobisDistance

Expand Down
4 changes: 3 additions & 1 deletion tests/metrics/test_pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from unittest import TestCase

import numpy as np
import torch

from tests.config import configure_dtype
torch = configure_dtype()

from simulai.metrics import PointwiseError

Expand Down
9 changes: 6 additions & 3 deletions tests/network/test_conv_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
from unittest import TestCase

import numpy as np
import torch
from tests.config import configure_dtype
torch = configure_dtype()

from utils import configure_device

from simulai import ARRAY_DTYPE
from simulai.file import SPFile
from simulai.optimization import Optimizer

Expand All @@ -34,8 +37,8 @@ def generate_data(
input_data = np.random.rand(n_samples, n_inputs, vector_size)
output_data = np.random.rand(n_samples, n_outputs)

return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
output_data.astype("float32")
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
output_data.astype(ARRAY_DTYPE)
)


Expand Down
9 changes: 6 additions & 3 deletions tests/network/test_conv_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
from unittest import TestCase

import numpy as np
import torch
from tests.config import configure_dtype
torch = configure_dtype()

from utils import configure_device

from simulai import ARRAY_DTYPE
from simulai.file import SPFile
from simulai.optimization import Optimizer

Expand All @@ -34,8 +37,8 @@ def generate_data(
input_data = np.random.rand(n_samples, n_inputs, *image_size)
output_data = np.random.rand(n_samples, n_outputs)

return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
output_data.astype("float32")
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
output_data.astype(ARRAY_DTYPE)
)


Expand Down
3 changes: 2 additions & 1 deletion tests/network/test_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from unittest import TestCase

import numpy as np
import torch
from tests.config import configure_dtype
torch = configure_dtype()
from utils import configure_device

DEVICE = configure_device()
Expand Down
4 changes: 3 additions & 1 deletion tests/network/test_flexible_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from unittest import TestCase

import numpy as np
import torch
from tests.config import configure_dtype
torch = configure_dtype()

from utils import configure_device

DEVICE = configure_device()
Expand Down
4 changes: 3 additions & 1 deletion tests/network/test_improved_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from unittest import TestCase

import numpy as np
import torch
from tests.config import configure_dtype
torch = configure_dtype()

from utils import configure_device

DEVICE = configure_device()
Expand Down
4 changes: 3 additions & 1 deletion tests/network/test_residual_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import matplotlib.pyplot as plt
import numpy as np
import torch

from tests.config import configure_dtype
torch = configure_dtype()

torch.autograd.set_detect_anomaly(True)

Expand Down
13 changes: 8 additions & 5 deletions tests/network/test_template_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
from unittest import TestCase

import numpy as np
import torch
from tests.config import configure_dtype
torch = configure_dtype()

from utils import configure_device

DEVICE = configure_device()

from simulai import ARRAY_DTYPE

def generate_data_2d(
n_samples: int = None,
Expand All @@ -31,8 +34,8 @@ def generate_data_2d(
input_data = np.random.rand(n_samples, n_inputs, *image_size)
output_data = np.random.rand(n_samples, n_outputs)

return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
output_data.astype("float32")
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
output_data.astype(ARRAY_DTYPE)
)


Expand All @@ -45,8 +48,8 @@ def generate_data_1d(
input_data = np.random.rand(n_samples, n_inputs, vector_size)
output_data = np.random.rand(n_samples, n_outputs)

return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
output_data.astype("float32")
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
output_data.astype(ARRAY_DTYPE)
)


Expand Down
3 changes: 2 additions & 1 deletion tests/network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ def configure_device():
if not simulai_network_gpu:
device = "cpu"
else:
import torch
from tests.config import configure_dtype
torch = configure_dtype()

if not torch.cuda.is_available():
raise Exception("There is no gpu available to execute the tests.")
Expand Down
4 changes: 3 additions & 1 deletion tests/residuals/test_symbolicoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from unittest import TestCase

import numpy as np
import torch
from tests.config import configure_dtype
torch = configure_dtype()


from simulai.residuals import SymbolicOperator

Expand Down
3 changes: 3 additions & 0 deletions tests/rom/test_cnn_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import numpy as np

from tests.config import configure_dtype
torch = configure_dtype()

from simulai.file import SPFile
from simulai.optimization import Optimizer

Expand Down

0 comments on commit dfb8e8b

Please sign in to comment.