diff --git a/reactive-flows/cnf-combustion/gnns/ci/configs/gat_test.yaml b/reactive-flows/cnf-combustion/gnns/ci/configs/gat_test.yaml index 0b34f9d..2a3baa5 100644 --- a/reactive-flows/cnf-combustion/gnns/ci/configs/gat_test.yaml +++ b/reactive-flows/cnf-combustion/gnns/ci/configs/gat_test.yaml @@ -25,7 +25,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/ci/configs/gcn_test.yaml b/reactive-flows/cnf-combustion/gnns/ci/configs/gcn_test.yaml index 905e14f..79432bd 100644 --- a/reactive-flows/cnf-combustion/gnns/ci/configs/gcn_test.yaml +++ b/reactive-flows/cnf-combustion/gnns/ci/configs/gcn_test.yaml @@ -24,7 +24,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/ci/configs/gin_test.yaml b/reactive-flows/cnf-combustion/gnns/ci/configs/gin_test.yaml index ddbcbbb..b8bc8d8 100644 --- a/reactive-flows/cnf-combustion/gnns/ci/configs/gin_test.yaml +++ b/reactive-flows/cnf-combustion/gnns/ci/configs/gin_test.yaml @@ -23,7 +23,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/ci/configs/gunet_test.yaml b/reactive-flows/cnf-combustion/gnns/ci/configs/gunet_test.yaml index f8339cc..b9a8567 100644 --- a/reactive-flows/cnf-combustion/gnns/ci/configs/gunet_test.yaml +++ b/reactive-flows/cnf-combustion/gnns/ci/configs/gunet_test.yaml @@ -23,7 +23,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/configs/gat.yaml b/reactive-flows/cnf-combustion/gnns/configs/gat.yaml index e561da1..48ea43b 100644 --- a/reactive-flows/cnf-combustion/gnns/configs/gat.yaml +++ b/reactive-flows/cnf-combustion/gnns/configs/gat.yaml @@ -25,7 +25,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/configs/gcn.yaml b/reactive-flows/cnf-combustion/gnns/configs/gcn.yaml index 917ebe2..d0d0776 100644 --- a/reactive-flows/cnf-combustion/gnns/configs/gcn.yaml +++ b/reactive-flows/cnf-combustion/gnns/configs/gcn.yaml @@ -24,7 +24,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/configs/gin.yaml b/reactive-flows/cnf-combustion/gnns/configs/gin.yaml index 2e55621..fdee4bd 100644 --- a/reactive-flows/cnf-combustion/gnns/configs/gin.yaml +++ b/reactive-flows/cnf-combustion/gnns/configs/gin.yaml @@ -23,7 +23,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/configs/gunet.yaml b/reactive-flows/cnf-combustion/gnns/configs/gunet.yaml index caaa801..cb65c2d 100644 --- a/reactive-flows/cnf-combustion/gnns/configs/gunet.yaml +++ b/reactive-flows/cnf-combustion/gnns/configs/gunet.yaml @@ -23,7 +23,7 @@ model: lr: .0001 data: - class_path: data.LitCombustionDataModule + class_path: data.R2DataModule init_args: batch_size: 1 num_workers: 0 diff --git a/reactive-flows/cnf-combustion/gnns/data.py b/reactive-flows/cnf-combustion/gnns/data.py index dfacfae..2172cfe 100755 --- a/reactive-flows/cnf-combustion/gnns/data.py +++ b/reactive-flows/cnf-combustion/gnns/data.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os -from typing import List, Tuple +from typing import Dict, List, Tuple import h5py import lightning as pl import networkx as nx import numpy as np -import torch import torch_geometric as pyg import yaml +from torch import float as tfloat +from torch import tensor from torch.utils.data import random_split @@ -42,6 +44,7 @@ def __init__(self, root: str, y_normalizer: float = None) -> None: y_normalizer (str): normalizing value """ self.y_normalizer = y_normalizer + self.graph_topology = None super().__init__(root) @property @@ -71,14 +74,38 @@ def download(self) -> None: f"and move all files in file.tgz/DATA in {self.raw_dir}" ) + def _get_data(self, idx: int) -> Dict[str, np.array]: + """Return the dict of the feat and sigma of the corresponding data file. + + Returns: + (Dict[str, np.array]): the feat and sigma. + """ + raise NotImplementedError + def get(self, idx: int) -> pyg.data.Data: """Return the graph at the given index. Returns: (pyg.data.Data): Graph at the given index. """ - data = torch.load(os.path.join(self.processed_dir, f"data-{idx}.pt")) - return data + pyg_data = copy.copy(self.graph_topology) + data = self._get_data(idx) + pyg_data.x = tensor(data["feat"].reshape(-1, 1), dtype=tfloat) + pyg_data.y = tensor(data["sigma"].reshape(-1, 1), dtype=tfloat) + return pyg_data + + def create_graph_topo(self, grid_shape: Tuple[int, int, int]) -> None: + """Create the graph topology and store it in memory. + + Args: + grid_shape (Tuple[int, int, int]): the shape of the grid for the + z, y and x sorted dimensions. + """ + g0 = nx.grid_graph(dim=grid_shape) + self.graph_topology = pyg.utils.convert.from_networkx(g0) + coordinates = list(g0.nodes()) + coordinates.reverse() + self.graph_topology.pos = tensor(np.stack(coordinates)) def len(self) -> int: """Return the total length of the dataset @@ -106,34 +133,27 @@ def process(self) -> None: Create a graph for each volume of data, and saves each graph in a separate file index by the order in the raw file names list. """ - i = 0 - for raw_path in self.raw_paths: - with h5py.File(raw_path, "r") as file: - feat = file["/c_filt"][:] - - sigma = file["/c_grad_filt"][:] - if self.y_normalizer: - sigma /= self.y_normalizer - - x_size, y_size, z_size = feat.shape - - grid_shape = (z_size, y_size, x_size) - - g0 = nx.grid_graph(dim=grid_shape) - graph = pyg.utils.convert.from_networkx(g0) - undirected_index = graph.edge_index - coordinates = list(g0.nodes()) - coordinates.reverse() - - data = pyg.data.Data( - x=torch.tensor(feat.reshape(-1, 1), dtype=torch.float), - edge_index=undirected_index.clone().detach().type(torch.LongTensor), - pos=torch.tensor(np.stack(coordinates)), - y=torch.tensor(sigma.reshape(-1, 1), dtype=torch.float), - ) + # Create graph from first file + with h5py.File(self.raw_paths[0], "r") as file: + feat = file["/c_filt"][:] + x_size, y_size, z_size = feat.shape + grid_shape = (z_size, y_size, x_size) + self.create_graph_topo(grid_shape) + + def _get_data(self, idx: int) -> Dict[str, np.array]: + """Return the dict of the feat and sigma of the corresponding data file. + + Returns: + (Dict[str, np.array]): the feat and sigma. + """ + data = {} + with h5py.File(self.raw_paths[idx], "r") as file: + data["feat"] = file["/c_filt"][:] - torch.save(data, os.path.join(self.processed_dir, f"data-{i}.pt")) - i += 1 + data["sigma"] = file["/c_grad_filt"][:] + if self.y_normalizer: + data["sigma"] /= self.y_normalizer + return data class CnfDataset(CombustionDataset): @@ -153,33 +173,27 @@ def process(self) -> None: Create a graph for each volume of data, and saves each graph in a separate file index by the order in the raw file names list. """ - i = 0 - for raw_path in self.raw_paths: - with h5py.File(raw_path, "r") as file: - feat = file["/filt_8"][:] - - sigma = file["/filt_grad_8"][:] - if self.y_normalizer is not None: - sigma /= self.y_normalizer - - x_size, y_size, z_size = feat.shape - grid_shape = (z_size, y_size, x_size) - - g0 = nx.grid_graph(dim=grid_shape) - graph = pyg.utils.convert.from_networkx(g0) - undirected_index = graph.edge_index - coordinates = list(g0.nodes()) - coordinates.reverse() - - data = pyg.data.Data( - x=torch.tensor(feat.reshape(-1, 1), dtype=torch.float), - edge_index=undirected_index.type(torch.LongTensor), - pos=torch.tensor(np.stack(coordinates)), - y=torch.tensor(sigma.reshape(-1, 1), dtype=torch.float), - ) + # Create graph from first file + with h5py.File(self.raw_paths[0], "r") as file: + feat = file["/filt_8"][:] + x_size, y_size, z_size = feat.shape + grid_shape = (z_size, y_size, x_size) + self.create_graph_topo(grid_shape) - torch.save(data, os.path.join(self.processed_dir, f"data-{i}.pt")) - i += 1 + def _get_data(self, idx: int) -> Dict[str, np.array]: + """Return the dict of the feat and sigma of the corresponding data file. + + Returns: + (Dict[str, np.array]): the feat and sigma. + """ + data = {} + with h5py.File(self.raw_paths[idx], "r") as file: + data["feat"] = file["/filt_8"][:] + + data["sigma"] = file["/filt_grad_8"][:] + if self.y_normalizer: + data["sigma"] /= self.y_normalizer + return data class LitCombustionDataModule(pl.LightningDataModule): @@ -222,6 +236,11 @@ def __init__( self.test_dataset = None self.train_dataset = None + @property + def dataset_class(self) -> pyg.data.Dataset: + # Set here the Dataset class you want to use in the datamodule + return NotImplementedError + def prepare_data(self) -> None: """Not used.""" CombustionDataset(self.data_path, self.y_normalizer) @@ -243,7 +262,9 @@ def setup( if self.source_raw_data_path: LinkRawData(self.source_raw_data_path, self.data_path) - dataset = R2Dataset(self.data_path, y_normalizer=self.y_normalizer) + dataset = self.dataset_class( + self.data_path, y_normalizer=self.y_normalizer + ).shuffle() tr, va, te = self.splitting_ratios if (tr + va + te) != 1: @@ -351,3 +372,19 @@ def rm_old_dataset(self): os.rmdir(file_location) else: pass + + +class R2DataModule(LitCombustionDataModule): + """Data module to load use R2Dataset.""" + + @property + def dataset_class(self) -> pyg.data.Dataset: + return R2Dataset + + +class CnfDataModule(LitCombustionDataModule): + """Data module to load use R2Dataset.""" + + @property + def dataset_class(self) -> pyg.data.Dataset: + return CnfDataset diff --git a/reactive-flows/cnf-combustion/gnns/models.py b/reactive-flows/cnf-combustion/gnns/models.py index f6a0907..ca07bed 100755 --- a/reactive-flows/cnf-combustion/gnns/models.py +++ b/reactive-flows/cnf-combustion/gnns/models.py @@ -126,9 +126,9 @@ def on_test_epoch_end(self) -> None: self.y_hats = self.all_gather(y_hats) # Reshape the outputs to the original grid shape plus the batch dimension - self.ys = self.ys.squeeze().view((-1,) + self.grid_shape).detach().numpy() + self.ys = self.ys.squeeze().view((-1,) + self.grid_shape).detach().cpu().numpy() self.y_hats = ( - self.y_hats.squeeze().view((-1,) + self.grid_shape).detach().numpy() + self.y_hats.squeeze().view((-1,) + self.grid_shape).detach().cpu().numpy() ) plots_path = os.path.join(self.trainer.log_dir, "plots") diff --git a/reactive-flows/cnf-combustion/gnns/tests/test_data.py b/reactive-flows/cnf-combustion/gnns/tests/test_data.py index e975246..13d62a2 100644 --- a/reactive-flows/cnf-combustion/gnns/tests/test_data.py +++ b/reactive-flows/cnf-combustion/gnns/tests/test_data.py @@ -22,8 +22,9 @@ import numpy as np import torch import yaml +from torch import LongTensor, Tensor -from data import CnfDataset, LinkRawData, LitCombustionDataModule +from data import CnfDataModule, CnfDataset, LinkRawData class TestData(unittest.TestCase): @@ -99,10 +100,12 @@ def test_process(self): self.assertTrue(os.path.exists(os.path.join(tempdir, "data", "processed"))) - # insert +2 to have transform and filter files - self.assertEqual( - len(os.listdir(os.path.join(tempdir, "data", "processed"))), - len(self.filenames) + 2, + # Check the pyg.data.Data object has edge_index and pos + self.assertTrue( + isinstance(data_test.graph_topology.edge_index, LongTensor), + ) + self.assertTrue( + isinstance(data_test.graph_topology.pos, Tensor), ) def test_get(self): @@ -122,7 +125,7 @@ def test_setup(self): init_param = copy(self.init_param) init_param.update({"data_path": os.path.join(tempdir, "data")}) - dataset_test = LitCombustionDataModule(**init_param) + dataset_test = CnfDataModule(**init_param) with self.assertRaises(ValueError) as context: dataset_test.setup(stage=None) @@ -143,7 +146,7 @@ def test_train_dataloader(self): init_param = copy(self.init_param) init_param.update({"data_path": os.path.join(tempdir, "data")}) - dataset_test = LitCombustionDataModule(**init_param) + dataset_test = CnfDataModule(**init_param) with self.assertRaises(ValueError): _ = dataset_test.setup(stage=None) @@ -159,7 +162,7 @@ def test_val_dataloader(self): init_param = copy(self.init_param) init_param.update({"data_path": os.path.join(tempdir, "data")}) - dataset_test = LitCombustionDataModule(**init_param) + dataset_test = CnfDataModule(**init_param) with self.assertRaises(ValueError): _ = dataset_test.setup(stage=None) @@ -175,7 +178,7 @@ def test_test_dataloader(self): init_param = copy(self.init_param) init_param.update({"data_path": os.path.join(tempdir, "data")}) - dataset_test = LitCombustionDataModule(**init_param) + dataset_test = CnfDataModule(**init_param) with self.assertRaises(ValueError): _ = dataset_test.setup(stage=None)