Skip to content

Latest commit



198 lines (145 loc) · 7.17 KB

File metadata and controls

198 lines (145 loc) · 7.17 KB


Yet another NeRF, with extensibility and scalability. Implemented in PyTorch.

This project is still under rapid development, git commit history and API may be changed in the future.


Environment installation:

conda env create -f envrionment.yml
pre-commit install
pip install -e .

(Note: changes only work for installed package)

Or use pip to install packages:

pip install torch torchvision addict yapf pytest

Run tests:

pytest .


pip install bandit==1.7.4 black==22.3.0 flake8-docstrings==1.6.0 flake8==3.9.1 flynt==0.64 isort==5.8.0 mypy==0.902 pre-commit==2.13.0 pytest ipython
pre-commit install

Data Preparation

Download and extract the zip file to data/.

Dataset Links



CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m --nproc_per_node=8 scripts/ --config $config [--output_dir $OUTPUT_DIR] [--checkpoint $CHECKPOINT_PATH] [--device $DEVICE ("cuda" or "cpu")] [--test_only] [--debug] {--cfg_options "xxx=yyy"}


Data Config Ckpt PSNR (repoduce) PSNR (paper) Time (repoduce) Time (paper)
Lego lego.yml lego.ckpt 30.70 32.54 ~4h (on 4 RTX3090) >12h
Fern fern.yml fern.ckpt 27.94 25.17 ~2.5h (on 4 RTX3090) >12h

The Code Structure

Structure of the Codebase (click to expand)


  1. pipelines/

    • the shapes of gt_rgb & bg_rgb should both be (B, H, W, 3) (to be compatible with the chunkify function, and used in renderer) [TODO]: global_codes is coupled with through the pipeline (include pipeline, renderer, and network), but this variable is only used in network) loss computing: to be compatible with distributed evaluation: per-sample losses are returned, with a torch.mean calling in the runner.apis.
    • undefined args are handled by **kwargs (are then fed into feature_extractor).
    1. networks/

      • ray_bundle to points: (origins, directions, lengths)
      • input dim check.
      • The networks are hard to initialized, need stochastic sampling to break the bad initialization: pipeline.ray_sampler.stratified_point_sampling_training (main) & pipeline.renderer.density_noise_std_train
      • Currently, networks only take in global_codes, undefined args are handled by **kwargs
    2. renderer/

      • ray_point_finer, sample_pdf
      • background_deltas / background_opacity = 1e10, and use alpha mask to blend bg_color
      • use a dataclass to wrap the outputs from previous stage, and recursively call the render function
      • [FIXME]: the default bg_color is 0.0
      • density_noise_std, in original paper? - blend_output=False, the foreground mask is 1, but the also use the predicted background mask
    3. ray_sampler/

      • Right-hand coordinates: x-axis points to right, y-axis points to down, z-axis points to inward
      • camera: cam2world
      • tensor shape: (batch_size, *spatial, -1), spatial is [height, width] or [n_rays_per_image, 1]
      • directions are not normalized
      • The shape poses could both be (..., 4, 4) or (..., 3, 4)
      • Supports custom min/max_depth & image_width, image_height, xy_grid from image_width, image_height leverages functools.lru_cache
    4. feature_extractors/

      • takes in only keyword args from the extra args from the input of pipeline, and return a dict with keyword args (currently must return global_codes)
      • There may be multiple feature_extractors, so undefined args are handled by **kwargs.
  2. dataset/

    • the shapes of gt_rgb & bg_rgb should both be (B, H, W, 3) (to be compatible with the chunkify function)
    • the range of images should be normalized to [0, 1] to compatible with the sigmoid activation.
    • define a dataset_bundle: NamedTuple in the Dataset; in runner.apis wraps the data accordingly.
      • The keys of the arguments should be the same as those in pipeline, feature_extractor.
      • Currently, networks only take in global_codes
  3. runner/

    • Multiprocess loading is on CPU.

Strcture of

  1. models
    1. renderer
    2. networks
  2. data module
  3. trainer
    1. train / eval
    2. losses
    3. metrics
    4. opt
    5. (utils) optimizer / scheduler
    6. (utils) ckpt io
    7. (utils) visualization

Entry of implicitron

projects/implicitron_trainer/ Use logger from logging

Global args:

  • exp_dir
  • dataset_args / dataloader_args (both are non_leaf)

Running Pipeline

  • Build exp_dir
  • Get dataset & dataloader (function)
  • Build model (init_model)
    • Take responsibility for resume so also return the training stats & optimizer_state
    • Then move to devices
  • Build optimizer & scheduler from former optimizer_state
  • Training loops
    • seed all
    • Record lr from lr scheduler
    • train&val trainvalidate
    • test run_eval
    • save checkpoint
  • test_when_finish flag for final test


  • Checkpoints
  • Stats
  • Visualizations

Citation and Acknowledgement

Kudos to the authors for their amazing results:

    title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
    author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},

Also heavily refer to the following repositories:

  • nerf, nerf-pytorch (yenchenlin), nerf-pytorch (krrish94), nerf_pl, MMCV, MMDetection, Pytorch3D