Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize data paths for any data source #23

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,20 @@ export PYTHONPATH=$PWD:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
```

- Train on the unified generator on cars, motorbikes or chair (Improved generator in
Appendix):
- Train the unified generator on cars, motorbikes or chair (Improved generator in Appendix):

```bash
python train_3d.py --outdir=PATH_TO_LOG --data=PATH_TO_RENDER_IMG --camera_path PATH_TO_RENDER_CAMERA --gpus=8 --batch=32 --gamma=40 --data_camera_mode shapenet_car --dmtet_scale 1.0 --use_shapenet_split 1 --one_3d_generator 1 --fp32 0
python train_3d.py --outdir=PATH_TO_LOG --data=PATH_TO_RENDER_IMG --camera_path PATH_TO_RENDER_CAMERA --gpus=8 --batch=32 --gamma=80 --data_camera_mode shapenet_motorbike --dmtet_scale 1.0 --use_shapenet_split 1 --one_3d_generator 1 --fp32 0
python train_3d.py --outdir=PATH_TO_LOG --data=PATH_TO_RENDER_IMG --camera_path PATH_TO_RENDER_CAMERA --gpus=8 --batch=32 --gamma=400 --data_camera_mode shapenet_chair --dmtet_scale 0.8 --use_shapenet_split 1 --one_3d_generator 1 --fp32 0
python train_3d.py --outdir=./logs --data=./shapenet/img/02958343 --camera_path=./shapenet/camera --gpus=8 --batch=32 --gamma=40 --manifest_dir shapenet_car --dmtet_scale 1.0 --one_3d_generator 1 --fp32 0
python train_3d.py --outdir=./logs --data=./shapenet/img/03790512 --camera_path=./shapenet/camera --gpus=8 --batch=32 --gamma=80 --manifest_dir shapenet_motorbike --dmtet_scale 1.0 --one_3d_generator 1 --fp32 0
python train_3d.py --outdir=./logs --data=./shapenet/img/03001627 --camera_path=./shapenet/camera --gpus=8 --batch=32 --gamma=400 --manifest_dir shapenet_chair --dmtet_scale 0.8 --one_3d_generator 1 --fp32 0
```

- If want to train on seperate generators (main Figure in the paper):
- If want to train on separate generators (main Figure in the paper):

```bash
python train_3d.py --outdir=PATH_TO_LOG --data=PATH_TO_RENDER_IMG --camera_path PATH_TO_RENDER_CAMERA --gpus=8 --batch=32 --gamma=40 --data_camera_mode shapenet_car --dmtet_scale 1.0 --use_shapenet_split 1 --one_3d_generator 0
python train_3d.py --outdir=PATH_TO_LOG --data=PATH_TO_RENDER_IMG --camera_path PATH_TO_RENDER_CAMERA --gpus=8 --batch=32 --gamma=80 --data_camera_mode shapenet_motorbike --dmtet_scale 1.0 --use_shapenet_split 1 --one_3d_generator 0
python train_3d.py --outdir=PATH_TO_LOG --data=PATH_TO_RENDER_IMG --camera_path PATH_TO_RENDER_CAMERA --gpus=8 --batch=32 --gamma=3200 --data_camera_mode shapenet_chair --dmtet_scale 0.8 --use_shapenet_split 1 --one_3d_generator 0
python train_3d.py --outdir=./logs --data=./shapenet/img/02958343 --camera_path=./shapenet/camera --gpus=8 --batch=32 --gamma=40 --manifest_dir shapenet_car --dmtet_scale 1.0 --one_3d_generator 0
python train_3d.py --outdir=./logs --data=./shapenet/img/03790512 --camera_path=./shapenet/camera --gpus=8 --batch=32 --gamma=80 --manifest_dir shapenet_motorbike --dmtet_scale 1.0 --one_3d_generator 0
python train_3d.py --outdir=./logs --data=./shapenet/img/03001627 --camera_path=./shapenet/camera --gpus=8 --batch=32 --gamma=400 --manifest_dir shapenet_chair --dmtet_scale 0.8 --one_3d_generator 0
```

If want to debug the model first, reduce the number of gpus to 1 and batch size to 4 via:
Expand All @@ -123,9 +122,9 @@ If want to debug the model first, reduce the number of gpus to 1 and batch size
- Inference could operate on a single GPU with 16 GB memory.

```bash
python train_3d.py --outdir=save_inference_results/shapenet_car --gpus=1 --batch=4 --gamma=40 --data_camera_mode shapenet_car --dmtet_scale 1.0 --use_shapenet_split 1 --one_3d_generator 1 --fp32 0 --inference_vis 1 --resume_pretrain MODEL_PATH
python train_3d.py --outdir=save_inference_results/shapenet_chair --gpus=1 --batch=4 --gamma=40 --data_camera_mode shapenet_chair --dmtet_scale 0.8 --use_shapenet_split 1 --one_3d_generator 1 --fp32 0 --inference_vis 1 --resume_pretrain MODEL_PATH
python train_3d.py --outdir=save_inference_results/shapenet_motorbike --gpus=1 --batch=4 --gamma=40 --data_camera_mode shapenet_motorbike --dmtet_scale 1.0 --use_shapenet_split 1 --one_3d_generator 1 --fp32 0 --inference_vis 1 --resume_pretrain MODEL_PATH
python train_3d.py --outdir=save_inference_results/shapenet_car --gpus=1 --batch=4 --gamma=40 --manifest_dir shapenet_car --dmtet_scale 1.0 --one_3d_generator 1 --fp32 0 --inference_vis 1 --resume_pretrain MODEL_PATH
python train_3d.py --outdir=save_inference_results/shapenet_chair --gpus=1 --batch=4 --gamma=40 --manifest_dir shapenet_chair --dmtet_scale 0.8 --one_3d_generator 1 --fp32 0 --inference_vis 1 --resume_pretrain MODEL_PATH
python train_3d.py --outdir=save_inference_results/shapenet_motorbike --gpus=1 --batch=4 --gamma=40 --manifest_dir shapenet_motorbike --dmtet_scale 1.0 --one_3d_generator 1 --fp32 0 --inference_vis 1 --resume_pretrain MODEL_PATH
```

- To generate mesh with textures, add one option to the inference
Expand Down
37 changes: 13 additions & 24 deletions train_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,14 @@ def launch_training(c, desc, outdir, dry_run):
# ----------------------------------------------------------------------------
def init_dataset_kwargs(data, opt=None):
try:
if opt.use_shapenet_split:
dataset_kwargs = dnnlib.EasyDict(
class_name='training.dataset.ImageFolderDataset',
path=data, use_labels=True, max_size=None, xflip=False,
resolution=opt.img_res,
data_camera_mode=opt.data_camera_mode,
add_camera_cond=opt.add_camera_cond,
camera_path=opt.camera_path,
split='test' if opt.inference_vis else 'train',
)
else:
dataset_kwargs = dnnlib.EasyDict(
class_name='training.dataset.ImageFolderDataset',
path=data, use_labels=True, max_size=None, xflip=False, resolution=opt.img_res,
data_camera_mode=opt.data_camera_mode,
add_camera_cond=opt.add_camera_cond,
camera_path=opt.camera_path,
)
dataset_kwargs = dnnlib.EasyDict(
class_name='training.dataset.ImageFolderDataset',
path=data, use_labels=True, max_size=None, xflip=False, resolution=opt.img_res,
manifest_dir=opt.manifest_dir,
add_camera_cond=opt.add_camera_cond,
camera_path=opt.camera_path,
split='test' if opt.inference_vis else 'train',
)
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
dataset_kwargs.camera_path = opt.camera_path
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
Expand Down Expand Up @@ -168,8 +158,7 @@ def parse_comma_separated_list(s):
@click.option('--data', help='Path to the Training data Images', metavar='[DIR]', type=str, default='./tmp')
@click.option('--camera_path', help='Path to the camera root', metavar='[DIR]', type=str, default='./tmp')
@click.option('--img_res', help='The resolution of image', metavar='INT', type=click.IntRange(min=1), default=1024)
@click.option('--data_camera_mode', help='The type of dataset we are using', type=str, default='shapenet_car', show_default=True)
@click.option('--use_shapenet_split', help='whether use the training split or all the data for training', metavar='BOOL', type=bool, default=False, show_default=False)
@click.option('--manifest_dir', help='Directory containing train.txt, test.txt and val.txt', type=str, default='shapenet_car', show_default=True)
### Configs for 3D generator##########
@click.option('--use_style_mixing', help='whether use style mixing for generation during inference', metavar='BOOL', type=bool, default=True, show_default=False)
@click.option('--one_3d_generator', help='whether we detach the gradient for empty object', metavar='BOOL', type=bool, default=True, show_default=True)
Expand Down Expand Up @@ -240,8 +229,8 @@ def main(**kwargs):
c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data, opt=opts)
if opts.cond and not c.training_set_kwargs.use_labels:
raise click.ClickException('--cond=True requires labels specified in dataset.json')
c.training_set_kwargs.split = 'train' if opts.use_shapenet_split else 'all'
if opts.use_shapenet_split and opts.inference_vis:
c.training_set_kwargs.split = 'train'
if opts.inference_vis:
c.training_set_kwargs.split = 'test'
c.training_set_kwargs.use_labels = opts.cond
c.training_set_kwargs.xflip = False
Expand All @@ -260,7 +249,7 @@ def main(**kwargs):

c.G_kwargs.render_type = opts.render_type
c.G_kwargs.use_tri_plane = opts.use_tri_plane
c.D_kwargs.data_camera_mode = opts.data_camera_mode
# c.D_kwargs.manifest_dir = opts.manifest_dir
c.D_kwargs.add_camera_cond = opts.add_camera_cond

c.G_kwargs.tet_res = opts.tet_res
Expand All @@ -270,7 +259,7 @@ def main(**kwargs):
c.batch_size = opts.batch
c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus
# c.G_kwargs.geo_pos_enc = opts.geo_pos_enc
c.G_kwargs.data_camera_mode = opts.data_camera_mode
# c.G_kwargs.manifest_dir = opts.manifest_dir
c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase
c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax

Expand Down
159 changes: 56 additions & 103 deletions training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,96 +154,58 @@ def __init__(
path, # Path to directory or zip.
camera_path, # Path to camera
resolution=None, # Ensure specific resolution, None = highest available.
data_camera_mode='shapenet_car',
manifest_dir='shapenet_car',
add_camera_cond=False,
split='all',
split='train',
**super_kwargs # Additional arguments for the Dataset base class.
):
self.data_camera_mode = data_camera_mode
self.manifest_dir = manifest_dir
self._path = path
self._zipfile = None
self.root = path
self.mask_list = None
self.add_camera_cond = add_camera_cond
root = self._path
self.camera_root = camera_path
if data_camera_mode == 'shapenet_car' or data_camera_mode == 'shapenet_chair' \
or data_camera_mode == 'renderpeople' or data_camera_mode == 'shapenet_motorbike' \
or data_camera_mode == 'ts_house' \
or data_camera_mode == 'ts_animal':
print('==> use shapenet dataset')
if not os.path.exists(root):
print('==> ERROR!!!! THIS SHOULD ONLY HAPPEN WHEN USING INFERENCE')
n_img = 1234
self._raw_shape = (n_img, 3, resolution, resolution)
self.img_size = resolution
self._type = 'dir'
self._all_fnames = [None for i in range(n_img)]
self._image_fnames = self._all_fnames
name = os.path.splitext(os.path.basename(path))[0]
print(
'==> use image path: %s, num images: %d' % (
self.root, len(self._all_fnames)))
super().__init__(name=name, raw_shape=self._raw_shape, **super_kwargs)
return
folder_list = sorted(os.listdir(root))
if data_camera_mode == 'shapenet_chair' or data_camera_mode == 'shapenet_car':
if data_camera_mode == 'shapenet_car':
split_name = './3dgan_data_split/shapenet_car/%s.txt' % (split)
if split == 'all':
split_name = './3dgan_data_split/shapenet_car.txt'
elif data_camera_mode == 'shapenet_chair':
split_name = './3dgan_data_split/shapenet_chair/%s.txt' % (split)
if split == 'all':
split_name = './3dgan_data_split/shapenet_chair.txt'
valid_folder_list = []
with open(split_name, 'r') as f:
all_line = f.readlines()
for l in all_line:
valid_folder_list.append(l.strip())
valid_folder_list = set(valid_folder_list)
useful_folder_list = set(folder_list).intersection(valid_folder_list)
folder_list = sorted(list(useful_folder_list))
if data_camera_mode == 'ts_animal':
split_name = './3dgan_data_split/ts_animals/%s.txt' % (split)
print('==> use ts animal split %s' % (split))
if split != 'all':
valid_folder_list = []
with open(split_name, 'r') as f:
all_line = f.readlines()
for l in all_line:
valid_folder_list.append(l.strip())
valid_folder_list = set(valid_folder_list)
useful_folder_list = set(folder_list).intersection(valid_folder_list)
folder_list = sorted(list(useful_folder_list))
elif data_camera_mode == 'shapenet_motorbike':
split_name = './3dgan_data_split/shapenet_motorbike/%s.txt' % (split)
print('==> use ts shapenet motorbike split %s' % (split))
if split != 'all':
valid_folder_list = []
with open(split_name, 'r') as f:
all_line = f.readlines()
for l in all_line:
valid_folder_list.append(l.strip())
valid_folder_list = set(valid_folder_list)
useful_folder_list = set(folder_list).intersection(valid_folder_list)
folder_list = sorted(list(useful_folder_list))
print('==> use shapenet folder number %s' % (len(folder_list)))
folder_list = [os.path.join(root, f) for f in folder_list]
all_img_list = []
all_mask_list = []

for folder in folder_list:
rgb_list = sorted(os.listdir(folder))
rgb_file_name_list = [os.path.join(folder, n) for n in rgb_list]
all_img_list.extend(rgb_file_name_list)
all_mask_list.extend(rgb_list)

self.img_list = all_img_list
self.mask_list = all_mask_list

else:
raise NotImplementedError
if not os.path.exists(root):
print('==> WARNING: Root path does not exist. If you are running inference this is OK: %s' % root)
n_img = 1234
self._raw_shape = (n_img, 3, resolution, resolution)
self.img_size = resolution
self._type = 'dir'
self._all_fnames = [None for i in range(n_img)]
self._image_fnames = self._all_fnames
name = os.path.splitext(os.path.basename(path))[0]
print(
'==> use image path: %s, num images: %d' % (
self.root, len(self._all_fnames)))
super().__init__(name=name, raw_shape=self._raw_shape, **super_kwargs)
return
folder_list = sorted(os.listdir(root))
split_name = './3dgan_data_split/' + manifest_dir + '/%s.txt' % (split)
valid_folder_list = []
with open(split_name, 'r') as f:
all_line = f.readlines()
for l in all_line:
valid_folder_list.append(l.strip())
valid_folder_list = set(valid_folder_list)
useful_folder_list = set(folder_list).intersection(valid_folder_list)
folder_list = sorted(list(useful_folder_list))
print('==> use shapenet folder number %s' % (len(folder_list)))
folder_list = [os.path.join(root, f) for f in folder_list]
all_img_list = []
all_mask_list = []

for folder in folder_list:
rgb_list = sorted(os.listdir(folder))
rgb_file_name_list = [os.path.join(folder, n) for n in rgb_list]
all_img_list.extend(rgb_file_name_list)
all_mask_list.extend(rgb_list)

self.img_list = all_img_list
self.mask_list = all_mask_list

self.img_size = resolution
self._type = 'dir'
self._all_fnames = self.img_list
Expand Down Expand Up @@ -283,31 +245,22 @@ def __getstate__(self):

def __getitem__(self, idx):
fname = self._image_fnames[self._raw_idx[idx]]
if self.data_camera_mode == 'shapenet_car' or self.data_camera_mode == 'shapenet_chair' \
or self.data_camera_mode == 'renderpeople' \
or self.data_camera_mode == 'shapenet_motorbike' or self.data_camera_mode == 'ts_house' or self.data_camera_mode == 'ts_animal' \
:
ori_img = cv2.imread(fname, cv2.IMREAD_UNCHANGED)
img = ori_img[:, :, :3][..., ::-1]
mask = ori_img[:, :, 3:4]
condinfo = np.zeros(2)
fname_list = fname.split('/')
img_idx = int(fname_list[-1].split('.')[0])
obj_idx = fname_list[-2]
syn_idx = fname_list[-3]

if self.data_camera_mode == 'shapenet_car' or self.data_camera_mode == 'shapenet_chair' \
or self.data_camera_mode == 'renderpeople' or self.data_camera_mode == 'shapenet_motorbike' \
or self.data_camera_mode == 'ts_house' or self.data_camera_mode == 'ts_animal':
if not os.path.exists(os.path.join(self.camera_root, syn_idx, obj_idx, 'rotation.npy')):
print('==> not found camera root')
else:
rotation_camera = np.load(os.path.join(self.camera_root, syn_idx, obj_idx, 'rotation.npy'))
elevation_camera = np.load(os.path.join(self.camera_root, syn_idx, obj_idx, 'elevation.npy'))
condinfo[0] = rotation_camera[img_idx] / 180 * np.pi
condinfo[1] = (90 - elevation_camera[img_idx]) / 180.0 * np.pi
ori_img = cv2.imread(fname, cv2.IMREAD_UNCHANGED)
img = ori_img[:, :, :3][..., ::-1]
mask = ori_img[:, :, 3:4]
condinfo = np.zeros(2)
fname_list = fname.split('/')
img_idx = int(fname_list[-1].split('.')[0])
obj_idx = fname_list[-2]
syn_idx = fname_list[-3]

if not os.path.exists(os.path.join(self.camera_root, syn_idx, obj_idx, 'rotation.npy')):
print('==> not found camera root')
else:
raise NotImplementedError
rotation_camera = np.load(os.path.join(self.camera_root, syn_idx, obj_idx, 'rotation.npy'))
elevation_camera = np.load(os.path.join(self.camera_root, syn_idx, obj_idx, 'elevation.npy'))
condinfo[0] = rotation_camera[img_idx] / 180 * np.pi
condinfo[1] = (90 - elevation_camera[img_idx]) / 180.0 * np.pi

resize_img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
if not mask is None:
Expand Down
Loading