Skip to content

Commit

Permalink
Prepare data and get mat ids
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent 901e45b commit 465c470
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ def setup_tensorboard(cfg, metadata):
writer.add_hparams(hparam_dict, metric_dict)
return writer


def prepare_data(example, device_id):
features, labels = example

Expand All @@ -421,13 +420,21 @@ def prepare_data(example, device_id):
position, particle_type, n_particles_per_example = features
material_property = None

# Convert numpy arrays to tensors
position = torch.from_numpy(position).float().to(device_id)
particle_type = torch.from_numpy(particle_type).long().to(device_id)
# Function to convert to tensor if needed and move to device
def to_tensor(x, dtype=torch.float):
if isinstance(x, np.ndarray):
return torch.from_numpy(x).to(dtype).to(device_id)
elif isinstance(x, torch.Tensor):
return x.to(dtype).to(device_id)
else:
return torch.tensor(x, dtype=dtype, device=device_id)

position = to_tensor(position)
particle_type = to_tensor(particle_type, dtype=torch.long)
if material_property is not None:
material_property = torch.from_numpy(np.array(material_property)).float().to(device_id)
n_particles_per_example = torch.tensor([n_particles_per_example], device=device_id).long()
labels = torch.from_numpy(labels).float().to(device_id)
material_property = to_tensor(material_property)
n_particles_per_example = to_tensor(n_particles_per_example, dtype=torch.long)
labels = to_tensor(labels)

return position, particle_type, material_property, n_particles_per_example, labels

Expand Down Expand Up @@ -1129,14 +1136,18 @@ def get_batch_for_material(train_dl, target_material_id, device_id):
return dataset[random_idx]

def get_unique_material_ids(train_dl, device_id):
"""
Get a list of unique material IDs from the dataloader.
"""
unique_ids = set()
for batch in train_dl:
_, _, material_property, _, _ = prepare_data(batch, device_id)

if material_property.numel() == 1:
unique_ids.add(material_property.item())
else:
unique_ids.update(material_property.cpu().numpy())
if material_property is not None:
if material_property.numel() == 1:
unique_ids.add(material_property.item())
else:
unique_ids.update(material_property.cpu().numpy())

return list(unique_ids)

Expand Down

0 comments on commit 465c470

Please sign in to comment.