Skip to content

Commit

Permalink
handles multi material features
Browse files Browse the repository at this point in the history
  • Loading branch information
yjchoi1 committed Oct 23, 2024
1 parent b33f561 commit 166470c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ mode: train
# Data configuration
data:
path: ../gns-sample/WaterDropSample/dataset/
meta_data: metadata.json
batch_size: 2
noise_std: 6.7e-4
input_sequence_length: 6
Expand Down
14 changes: 9 additions & 5 deletions gns/learned_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def _encoder_preprocessor(
nparticles_per_example: Number of particles per example. Default is 2
examples per batch.
particle_types: Particle types with shape (nparticles).
material_property: Friction angle normalized by tan() with shape (nparticles)
material_property: Friction angle normalized by tan() with shape (nparticles, ).
Optionally, it can take multi material properties like (nparticles, n_material_properties)
"""
nparticles = position_sequence.shape[0]
most_recent_position = position_sequence[:, -1] # (n_nodes, 2)
Expand Down Expand Up @@ -187,10 +188,12 @@ def _encoder_preprocessor(

# Material property
if material_property is not None:
material_property = material_property.view(nparticles, 1)
n_material_props = 1 if len(material_property.shape) == 1 else material_property.shape[-1]
material_property = material_property.view(nparticles, n_material_props)
node_features.append(material_property)
# Final node_features shape (nparticles, 31) for 2D
# 31 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + 1 material property
# Final node_features shape (nparticles, 30 + n_material_props) for 2D
# 30 + n_material_props =
# 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + n_material_props

# Collect edge features.
edge_features = []
Expand Down Expand Up @@ -267,7 +270,8 @@ def predict_positions(
nparticles_per_example: Number of particles per example. Default is 2
examples per batch.
particle_types: Particle types with shape (nparticles).
material_property: Friction angle normalized by tan() with shape (nparticles)
material_property: Friction angle normalized by tan() with shape (nparticles, ).
Optionally, it can take multi material properties like (nparticles, n_material_properties)
Returns:
next_positions (torch.tensor): Next position of particles.
Expand Down
7 changes: 1 addition & 6 deletions gns/particle_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def _get_sample(self, idx):
n_particles_per_example = positions.shape[0]

if self.material_property_as_feature:
material_property = np.full(
positions.shape[0], self.data[trajectory_idx][2], dtype=float
)
material_property = self.data[trajectory_idx][2]
features = (
positions,
particle_type,
Expand All @@ -95,9 +93,6 @@ def _get_trajectory(self, idx):
positions, particle_type, material_property = self.data[idx]
positions = np.transpose(positions, (1, 0, 2))
particle_type = np.full(positions.shape[0], particle_type, dtype=int)
material_property = np.full(
positions.shape[0], material_property, dtype=float
)
n_particles_per_example = positions.shape[0]

trajectory = (
Expand Down

0 comments on commit 166470c

Please sign in to comment.