forked from drprojects/superpoint_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paths3dis_11g.yaml
50 lines (41 loc) · 1.81 KB
/
s3dis_11g.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# @package _global_
# to execute this experiment run:
# python train.py experiment=semantic/s3dis_11g
# This configuration allows training SPT on a single 11G GPU, with a
# training procedure comparable with the default
# experiment/semantic/s3dis configuration.
# Among the multiple ways of reducing memory impact, we choose here to
# - divide the dataset into smaller tiles (facilitates preprocessing
# and inference on smaller GPUs)
# - reduce the number of samples in each batch (facilitates training
# on smaller GPUs)
# To keep the total number of training steps consistent with the default
# configuration, while keeping informative gradient despite the smaller
# batches, we use gradient accumulation and reduce the number of epochs.
# DISCLAIMER: the tiling procedure may increase the preprocessing time
# (more raw data reading steps), and slightly reduce mode performance
# (less diversity in the spherical samples)
defaults:
- override /datamodule: semantic/s3dis.yaml
- override /model: semantic/spt-2.yaml
- override /trainer: gpu.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
datamodule:
xy_tiling: 3 # split each cloud into xy_tiling^2=9 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory
sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory
callbacks:
gradient_accumulator:
scheduling:
0:
2 # accumulate gradient every 2 batches, to make up for reduced batch size
trainer:
max_epochs: 500 # to keep same nb of steps: 8x more tiles, 2-step gradient accumulation -> epochs/4
model:
optimizer:
lr: 0.1
weight_decay: 1e-2
logger:
wandb:
project: "spt_s3dis"
name: "SPT-64"