diff --git a/gns/train.py b/gns/train.py index 1ce329d..9498ba2 100644 --- a/gns/train.py +++ b/gns/train.py @@ -411,7 +411,6 @@ def setup_tensorboard(cfg, metadata): writer.add_hparams(hparam_dict, metric_dict) return writer - def prepare_data(example, device_id): features, labels = example @@ -421,13 +420,21 @@ def prepare_data(example, device_id): position, particle_type, n_particles_per_example = features material_property = None - # Convert numpy arrays to tensors - position = torch.from_numpy(position).float().to(device_id) - particle_type = torch.from_numpy(particle_type).long().to(device_id) + # Function to convert to tensor if needed and move to device + def to_tensor(x, dtype=torch.float): + if isinstance(x, np.ndarray): + return torch.from_numpy(x).to(dtype).to(device_id) + elif isinstance(x, torch.Tensor): + return x.to(dtype).to(device_id) + else: + return torch.tensor(x, dtype=dtype, device=device_id) + + position = to_tensor(position) + particle_type = to_tensor(particle_type, dtype=torch.long) if material_property is not None: - material_property = torch.from_numpy(np.array(material_property)).float().to(device_id) - n_particles_per_example = torch.tensor([n_particles_per_example], device=device_id).long() - labels = torch.from_numpy(labels).float().to(device_id) + material_property = to_tensor(material_property) + n_particles_per_example = to_tensor(n_particles_per_example, dtype=torch.long) + labels = to_tensor(labels) return position, particle_type, material_property, n_particles_per_example, labels @@ -1129,14 +1136,18 @@ def get_batch_for_material(train_dl, target_material_id, device_id): return dataset[random_idx] def get_unique_material_ids(train_dl, device_id): + """ + Get a list of unique material IDs from the dataloader. + """ unique_ids = set() for batch in train_dl: _, _, material_property, _, _ = prepare_data(batch, device_id) - if material_property.numel() == 1: - unique_ids.add(material_property.item()) - else: - unique_ids.update(material_property.cpu().numpy()) + if material_property is not None: + if material_property.numel() == 1: + unique_ids.add(material_property.item()) + else: + unique_ids.update(material_property.cpu().numpy()) return list(unique_ids)