diff --git a/README.md b/README.md index 07f8732..502aca4 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ mode: train # Data configuration data: path: ../gns-sample/WaterDropSample/dataset/ + meta_data: metadata.json batch_size: 2 noise_std: 6.7e-4 input_sequence_length: 6 diff --git a/config.yaml b/config.yaml index 8f1e9ba..a2bb607 100644 --- a/config.yaml +++ b/config.yaml @@ -14,6 +14,7 @@ mode: train # Data configuration data: path: ../gns-sample/WaterDropSample/dataset/ + meta_data: metadata.json batch_size: 2 noise_std: 6.7e-4 input_sequence_length: 6 diff --git a/gns/learned_simulator.py b/gns/learned_simulator.py index 8ef5a63..1f5043d 100644 --- a/gns/learned_simulator.py +++ b/gns/learned_simulator.py @@ -136,7 +136,8 @@ def _encoder_preprocessor( nparticles_per_example: Number of particles per example. Default is 2 examples per batch. particle_types: Particle types with shape (nparticles). - material_property: Friction angle normalized by tan() with shape (nparticles) + material_property: Friction angle normalized by tan() with shape (nparticles, ). + Optionally, it can take multi material properties like (nparticles, n_material_properties) """ nparticles = position_sequence.shape[0] most_recent_position = position_sequence[:, -1] # (n_nodes, 2) @@ -187,10 +188,12 @@ def _encoder_preprocessor( # Material property if material_property is not None: - material_property = material_property.view(nparticles, 1) + n_material_props = 1 if len(material_property.shape) == 1 else material_property.shape[-1] + material_property = material_property.view(nparticles, n_material_props) node_features.append(material_property) - # Final node_features shape (nparticles, 31) for 2D - # 31 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + 1 material property + # Final node_features shape (nparticles, 30 + n_material_props) for 2D + # 30 + n_material_props = + # 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + n_material_props # Collect edge features. edge_features = [] @@ -267,7 +270,8 @@ def predict_positions( nparticles_per_example: Number of particles per example. Default is 2 examples per batch. particle_types: Particle types with shape (nparticles). - material_property: Friction angle normalized by tan() with shape (nparticles) + material_property: Friction angle normalized by tan() with shape (nparticles, ). + Optionally, it can take multi material properties like (nparticles, n_material_properties) Returns: next_positions (torch.tensor): Next position of particles. diff --git a/gns/particle_data_loader.py b/gns/particle_data_loader.py index e0e0c4f..9495742 100644 --- a/gns/particle_data_loader.py +++ b/gns/particle_data_loader.py @@ -74,9 +74,7 @@ def _get_sample(self, idx): n_particles_per_example = positions.shape[0] if self.material_property_as_feature: - material_property = np.full( - positions.shape[0], self.data[trajectory_idx][2], dtype=float - ) + material_property = self.data[trajectory_idx][2] features = ( positions, particle_type, @@ -95,9 +93,6 @@ def _get_trajectory(self, idx): positions, particle_type, material_property = self.data[idx] positions = np.transpose(positions, (1, 0, 2)) particle_type = np.full(positions.shape[0], particle_type, dtype=int) - material_property = np.full( - positions.shape[0], material_property, dtype=float - ) n_particles_per_example = positions.shape[0] trajectory = ( diff --git a/gns/train.py b/gns/train.py index a730e75..7d3c291 100644 --- a/gns/train.py +++ b/gns/train.py @@ -115,7 +115,7 @@ def predict(device: str, cfg: DictConfig): """ # Read metadata - metadata = reading_utils.read_metadata(cfg.data.path, "rollout") + metadata = reading_utils.read_metadata(cfg.data.path, "rollout", cfg.data.meta_data) simulator = _get_simulator( metadata, cfg.data.num_particle_types, @@ -200,8 +200,7 @@ def predict(device: str, cfg: DictConfig): if cfg.mode == "rollout": example_rollout["metadata"] = metadata example_rollout["loss"] = loss.mean() - filename = f"{cfg.output.filename}_ex{example_i}.pkl" - filename_render = f"{cfg.output.filename}_ex{example_i}" + filename_render = f"{cfg.output.filename}_ex{example_i}.pkl" filename = os.path.join(cfg.output.path, filename_render) with open(filename, "wb") as f: pickle.dump(example_rollout, f) @@ -377,7 +376,7 @@ def initialize_training(cfg, rank, world_size, device, use_dist): device: torch device type. use_dist: use torch.distribute """ - metadata = reading_utils.read_metadata(cfg.data.path, "train") + metadata = reading_utils.read_metadata(cfg.data.path, "train", cfg.data.meta_data) simulator, optimizer = setup_simulator_and_optimizer( cfg, metadata, rank, world_size, device, use_dist )