Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Pre-trained-colab #1

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,452 changes: 2,452 additions & 0 deletions Sample_pretrained_chair_3dsnet.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions auxiliary/sampling_and_meshing/O-CNN/randomizePointCloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
from os import listdir
from os.path import isfile, join
import pymesh
# import pymesh
import trimesh
import numpy as np
import copy
import joblib
Expand All @@ -15,7 +16,7 @@
def shuffle_pc(file, output_path, limit=None, count=None):
try:
if not os.path.exists(output_path + ".npy"):
mesh = pymesh.load_mesh(file)
mesh = trimesh.load_mesh(file)

vertices = copy.deepcopy(mesh.vertices)
permutation = np.random.permutation(len(vertices))
Expand Down
1 change: 0 additions & 1 deletion auxiliary/sampling_and_meshing/O-CNN/sample30kpoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
from os import listdir
from os.path import isfile, join
import pymesh
import numpy as np
import copy
import joblib
Expand Down
9 changes: 5 additions & 4 deletions auxiliary/sampling_and_meshing/Shuffle/parallel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import argparse
from os import listdir
from os.path import isfile, join
import pymesh
# import pymesh
import trimesh
import numpy as np
import copy
import joblib
Expand All @@ -13,18 +14,18 @@


def shuffle_pc(file, output_path):
mesh = pymesh.load_mesh(file)
mesh = trimesh.load_mesh(file)
vertices = copy.deepcopy(mesh.vertices)
permutation = np.random.permutation(len(vertices))
vertices = vertices[permutation]
new_mesh = pymesh.meshio.form_mesh(vertices, mesh.faces)
new_mesh = trimesh.Trimesh(vertices, faces = mesh.faces)
new_mesh.add_attribute("vertex_nx")
new_mesh.set_attribute("vertex_nx", mesh.get_vertex_attribute("vertex_nx")[permutation])
new_mesh.add_attribute("vertex_ny")
new_mesh.set_attribute("vertex_ny", mesh.get_vertex_attribute("vertex_ny")[permutation])
new_mesh.add_attribute("vertex_nz")
new_mesh.set_attribute("vertex_nz", mesh.get_vertex_attribute("vertex_nz")[permutation])
pymesh.save_mesh(output_path, new_mesh, ascii=True, anonymous=True, use_float=True, *new_mesh.get_attribute_names())
new_mesh.export(output_path)


def main():
Expand Down
9 changes: 5 additions & 4 deletions auxiliary/sampling_and_meshing/Shuffle/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import argparse
from os import listdir
from os.path import isfile, join
import pymesh
# import pymesh
import trimesh
import numpy as np
import copy

Expand All @@ -13,18 +14,18 @@ def shuffle_pc(file, output_path):
"""
Function to shuffle a point cloud produced by virtual scanner.
"""
mesh = pymesh.load_mesh(file)
mesh = trimesh.load_mesh(file)
vertices = copy.deepcopy(mesh.vertices)
permutation = np.random.permutation(len(vertices))
vertices = vertices[permutation]
new_mesh = pymesh.meshio.form_mesh(vertices, mesh.faces)
new_mesh = trimesh.Trimesh(vertices, faces = mesh.faces)
new_mesh.add_attribute("vertex_nx")
new_mesh.set_attribute("vertex_nx", mesh.get_vertex_attribute("vertex_nx")[permutation])
new_mesh.add_attribute("vertex_ny")
new_mesh.set_attribute("vertex_ny", mesh.get_vertex_attribute("vertex_ny")[permutation])
new_mesh.add_attribute("vertex_nz")
new_mesh.set_attribute("vertex_nz", mesh.get_vertex_attribute("vertex_nz")[permutation])
pymesh.save_mesh(output_path, new_mesh, ascii=True, anonymous=True, use_float=True, *new_mesh.get_attribute_names())
new_mesh.export(output_path)


def main():
Expand Down
211 changes: 211 additions & 0 deletions dataset/dataset_caesar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from copy import deepcopy
from easydict import EasyDict
import numpy as np
import os
import pickle
from termcolor import colored
import torch
import torch.utils.data as data

import auxiliary.my_utils as my_utils
import dataset.pointcloud_processor as pointcloud_processor


class Caesar(data.Dataset):
"""
Caesar Dataloader
Uses caesar dataset
Make sure to respect dataset Licence.
"""

def __init__(self, opt, unused_category, subcategory, unused_svr=False, train=True):
self.opt = opt
self.num_sample = opt.number_points if train else 2500

self.train = train
self.mode = 'training' if train else 'validation'
self.subcategory = subcategory

self.id2names = {0: 'person'}
self.names2id = {'person': 0}

# Initialize pointcloud normalization functions
self.init_normalization()

if not opt.demo or not opt.use_default_demo_samples:
if len(opt.class_choice) > 0 and len(opt.class_choice) == 2:
print('Initializing {} dataset for class {}.'.format(
self.mode, subcategory))
else:
raise ValueError('Argument class_choice must contain exactly two classes.')

my_utils.red_print('Create Caesar Dataset...')
# Define core path array
self.dataset_path = os.path.join(opt.data_dir, 'Caesar')

# Create Cache path
self.path_dataset = os.path.join(self.dataset_path, 'cache')
if not os.path.exists(self.path_dataset):
os.makedirs(self.path_dataset)
self.path_dataset = os.path.join(self.path_dataset, '_'.join((self.opt.normalization, self.mode)))
self.cache_file = self.path_dataset + self.subcategory + '_info.pkl'

if not os.path.exists(self.cache_file):
# Compile list of pointcloud path by selected categories
dir_pointcloud = os.path.join(self.dataset_path, self.subcategory, self.mode)
list_pointcloud = sorted(os.listdir(dir_pointcloud))
print(
' subcategory '
+ colored(self.names2id[self.subcategory], 'yellow')
+ ' '
+ colored(self.subcategory, 'cyan')
+ ' Number Files: '
+ colored(str(len(list_pointcloud)), 'yellow')
)

if len(list_pointcloud) != 0:
self.datapath = []
for pointcloud in list_pointcloud:
pointcloud_path = os.path.join(dir_pointcloud, pointcloud)
self.datapath.append((pointcloud_path, pointcloud, self.subcategory))

# Preprocess and cache files
self.preprocess()

def preprocess(self):
if os.path.exists(self.cache_file):
# Reload dataset
my_utils.red_print('Reload dataset : {}'.format(self.cache_file))
with open(self.cache_file, 'rb') as fp:
self.data_metadata = pickle.load(fp)

self.data_points = torch.load(self.path_dataset + self.subcategory + '_points.pth')
else:
# Preprocess dataset and put in cache for future fast reload
my_utils.red_print('Preprocess dataset...')
self.datas = [self._getitem(i) for i in range(len(self.datapath))]

# Concatenate all processed files
self.data_points = [data[0] for data in self.datas]
# TODO(msegu): consider adding option to randomly select num_samples if we want to train with less samples
self.data_points = torch.cat(self.data_points, 0)

self.data_metadata = [{'pointcloud_path': data[1], 'name': data[2], 'subcategory': data[3]}
for data in self.datas]

# Save in cache
with open(self.cache_file, 'wb') as fp: # Pickling
pickle.dump(self.data_metadata, fp)
torch.save(self.data_points, self.path_dataset + self.subcategory + '_points.pth')

my_utils.red_print('Dataset Size: ' + str(len(self.data_metadata)))

def init_normalization(self):
if not self.opt.demo:
my_utils.red_print('Dataset normalization : ' + self.opt.normalization)

if self.opt.normalization == 'UnitBall':
self.normalization_function = pointcloud_processor.Normalization.normalize_unitL2ball_functional
elif self.opt.normalization == 'BoundingBox':
self.normalization_function = pointcloud_processor.Normalization.normalize_bounding_box_functional
else:
self.normalization_function = pointcloud_processor.Normalization.identity_functional

def _getitem(self, index):

pointcloud_path, pointcloud, subcategory = self.datapath[index]
points = self.load(pointcloud_path)['points'][0]
points[:, :3] = self.normalization_function(points[:, :3])
return points.unsqueeze(0), pointcloud_path, pointcloud, subcategory

def __getitem__(self, index):
return_dict = deepcopy(self.data_metadata[index])
# Point processing
points = self.data_points[index]
points = points.clone()
if self.opt.sample:
choice = np.random.choice(points.size(0), self.num_sample, replace=True)
points = points[choice, :]
points = points[:, :3].contiguous()

return_dict = {'points': points,
'pointcloud_path': return_dict['pointcloud_path'],
'subcategory': return_dict['subcategory']}
return return_dict

def __len__(self):
return len(self.data_metadata)

@staticmethod
def int2str(N):
if N < 10:
return '0' + str(N)
else:
return str(N)

def load(self, path):
ext = path.split('.')[-1]
if ext == 'npy' or ext == 'ply' or ext == 'obj':
return self.load_point_input(path)
else:
raise IOError("File extension .{} not supported. Must be one of '.npy', '.ply' or '.obj'.".format(ext))

def load_point_input(self, path):
ext = path.split('.')[-1]
if ext == 'npy':
points = np.load(path)
elif ext == 'ply' or ext == 'obj':
# import pymesh
import trimesh
points = trimesh.load_mesh(path).vertices
else:
print('invalid file extension')

points = torch.from_numpy(points.copy()).float()
operation = pointcloud_processor.Normalization(points, keep_track=True)
if self.opt.normalization == 'UnitBall':
operation.normalize_unitL2ball()
elif self.opt.normalization == 'BoundingBox':
operation.normalize_bounding_box()
else:
pass
return_dict = {
'points': points,
'operation': operation,
'path': path,
}
return return_dict


if __name__ == '__main__':
print('Testing SMXL dataset')
opt = EasyDict({'normalization': 'UnitBall', 'class_choice': ['cats', 'male'], 'sample': True, 'npoints': 2500,
'num_epochs': 5})
dataset_a = SMXL(opt, subcategory=opt.class_choice[0], train=False)
dataset_b = SMXL(opt, subcategory=opt.class_choice[1], train=False)

print(dataset_a[1])
a = len(dataset_a)
b = len(dataset_b)

# Check that random pairwise loading works as expected
dataloader_a = torch.utils.data.DataLoader(dataset_a, batch_size=1, shuffle=True)

dataloader_b = torch.utils.data.DataLoader(dataset_b, batch_size=1, shuffle=True)

for epoch in range(opt.num_epochs):
for i, (data_a, data_b) in enumerate(zip(dataloader_a, dataloader_b)):
if i == 2: break
data_a = EasyDict(data_a)
data_b = EasyDict(data_b)
print(data_a.pointcloud_path, data_a.pointcloud_path)










5 changes: 3 additions & 2 deletions dataset/dataset_shapenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,9 @@ def load_point_input(self, path):
if ext == 'npy':
points = np.load(path)
elif ext == 'ply' or ext == 'obj':
import pymesh
points = pymesh.load_mesh(path).vertices
# import pymesh
import trimesh
points = trimesh.load_mesh(path).vertices
else:
print('invalid file extension')

Expand Down
5 changes: 3 additions & 2 deletions dataset/dataset_smxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ def load_point_input(self, path):
if ext == 'npy':
points = np.load(path)
elif ext == 'ply' or ext == 'obj':
import pymesh
points = pymesh.load_mesh(path).vertices
# import pymesh
import trimesh
points = trimesh.load_mesh(path).vertices
else:
print('invalid file extension')

Expand Down
3 changes: 1 addition & 2 deletions dataset/download_shapenet_pointclouds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ function gdrive_download () {
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1" -O $2
rm -rf /tmp/cookies.txt
}
cd dataset
mkdir data

gdrive_download 1MMCYOqSalz77dduKahqDEQKFP9aCvUCy data/ShapeNetV1PointCloud.zip
cd data
unzip ShapeNetV1PointCloud.zip
Expand Down
3 changes: 1 addition & 2 deletions dataset/mesh_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pymesh
import numpy as np
from os.path import join, dirname

Expand Down Expand Up @@ -34,4 +33,4 @@ def save(mesh, path, colormap):
mesh.set_attribute("vertex_blue", colormap.colormap[vertex_sources][:, 2])
except:
pass
pymesh.save_mesh(path[:-3] + "ply", mesh, *mesh.get_attribute_names(), ascii=True)
mesh.export(path[:-3] + "ply")
4 changes: 4 additions & 0 deletions dataset/trainer_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import dataset.dataset_shapenet as dataset_shapenet
import dataset.dataset_smxl as dataset_smxl
import dataset.dataset_caesar as dataset_caesar
import dataset
import dataset.augmenter as augmenter
from easydict import EasyDict

Expand All @@ -16,6 +18,8 @@ def __init__(self, opt):
self.dataset_class = dataset_smxl.SMXL
elif opt.dataset == 'ShapeNet':
self.dataset_class = dataset_shapenet.ShapeNet
elif opt.dataset == 'Caesar':
self.dataset_class = dataset_shapenet.Caesar

def build_dataset(self):
"""
Expand Down
7 changes: 4 additions & 3 deletions model/atlasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def forward(self, latent_vector, train=True):

def generate_mesh(self, latent_vector):
assert latent_vector.size(0)==1, "input should have batch size 1!"
import pymesh
# import pymesh
input_points = [self.template[i].get_regular_points(self.nb_pts_in_primitive, latent_vector.device)
for i in range(self.opt.nb_primitives)]
input_points = [input_points[i] for i in range(self.opt.nb_primitives)]
Expand All @@ -70,12 +70,13 @@ def generate_mesh(self, latent_vector):
output_points = [self.decoder[i](input_points[i], latent_vector.unsqueeze(2)).squeeze() for i in
range(0, self.opt.nb_primitives)]

output_meshes = [pymesh.form_mesh(vertices=output_points[i].transpose(1, 0).contiguous().cpu().numpy(),
output_meshes = [trimesh.Trimesh(vertices=output_points[i].transpose(1, 0).contiguous().cpu().numpy(),
faces=self.template[i].mesh.faces)
for i in range(self.opt.nb_primitives)]

# Deform return the deformed pointcloud
mesh = pymesh.merge_meshes(output_meshes)
mesh = trimesh.util.concatenate(output_meshes)
# mesh = pymesh.merge_meshes(output_meshes)

return mesh

2 changes: 0 additions & 2 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from model.meshflow import NeuralMeshFlow
import torch
import torch.nn as nn
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes


class StyleNetBase(nn.Module):
Expand Down
Loading