Skip to content

Commit

Permalink
Fix prepare data
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent edc9638 commit 901e45b
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,24 +413,26 @@ def setup_tensorboard(cfg, metadata):


def prepare_data(example, device_id):
"""Prepare data for training or validation."""
position = example[0][0].to(device_id)
particle_type = example[0][1].to(device_id)

if len(example[0]) == 4: # if data loader includes material_property
material_property = example[0][2].to(device_id)
n_particles_per_example = example[0][3].to(device_id)
elif len(example[0]) == 3:
material_property = None
n_particles_per_example = example[0][2].to(device_id)
features, labels = example

if len(features) == 4: # If material property is present
position, particle_type, material_property, n_particles_per_example = features
else:
raise ValueError("Unexpected number of elements in the data loader")

labels = example[1].to(device_id)

position, particle_type, n_particles_per_example = features
material_property = None

# Convert numpy arrays to tensors
position = torch.from_numpy(position).float().to(device_id)
particle_type = torch.from_numpy(particle_type).long().to(device_id)
if material_property is not None:
material_property = torch.from_numpy(np.array(material_property)).float().to(device_id)
n_particles_per_example = torch.tensor([n_particles_per_example], device=device_id).long()
labels = torch.from_numpy(labels).float().to(device_id)

return position, particle_type, material_property, n_particles_per_example, labels



def train(rank, cfg, world_size, device, verbose, use_dist):
"""Train the model.
Expand Down

0 comments on commit 901e45b

Please sign in to comment.