diff --git a/examples/demo.py b/examples/demo.py index ff589d0..135017e 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -67,14 +67,13 @@ writer.export_scalars_to_json("./all_scalars.json") dataset = datasets.MNIST('mnist', train=False, download=True) -images = dataset.test_data[:100].float() -label = dataset.test_labels[:100] +images = dataset.data[:100].float() +label = dataset.targets[:100] features = images.view(100, 784) writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1)) writer.add_embedding(features, global_step=1, tag='noMetadata') -dataset = datasets.MNIST('mnist', train=True, download=True) -images_train = dataset.train_data[:100].float() -labels_train = dataset.train_labels[:100] +images_train = dataset.data[100:200].float() +labels_train = dataset.targets[100:200] features_train = images_train.view(100, 784) all_features = torch.cat((features, features_train)) @@ -87,7 +86,7 @@ metadata_header=['digit', 'dataset'], global_step=2) # VIDEO -vid_images = dataset.train_data[:16 * 48] +vid_images = dataset.data[:16 * 48] vid = vid_images.view(16, 48, 1, 28, 28) # BxTxCxHxW writer.add_video('video', vid_tensor=vid)