Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualizing method #5

Open
chiehwangs opened this issue Nov 25, 2023 · 1 comment
Open

Visualizing method #5

chiehwangs opened this issue Nov 25, 2023 · 1 comment

Comments

@chiehwangs
Copy link

chiehwangs commented Nov 25, 2023

Hi Brent,

Firstly, thanks for sharing this excellent and robust work !

Is it possible to share the method of visualizing "the structure-revealing L2-norm of interpolated features" in figure 8 ? I am trying to draw inspiration from your work.

best regards,
chieh

@brentyi
Copy link
Owner

brentyi commented Nov 26, 2023

Hi @chiehwangs,

Thanks for your nice words!

For the code that generates these specific heatmaps, you can search the repo for "transform_feature_norms". What's happening here is: we compute the feature norm corresponding to each transform for each sample along the ray, and then alpha composite these norms (with the same weights we would use for compositing RGB). The final visualization is the feature norms a single transformation, which we pick by argmax-ing the norm itself.

The norms are first computed on a per-transform basis in the rendering pipeline. Note that the output shape here is (# rays, # transforms):

component_norms = functools.reduce(
jnp.add,
map(
# Each component has shape (groups, rays, samples, channels)
lambda a: reduce(
a**2,
"(g transform_count) rays samples channels -> rays samples transform_count",
reduction="sum",
transform_count=transform_count,
rays=num_rays,
),
primary_components,
),
)
assert (
component_norms.shape
== probs.p_terminates.shape + (transform_count,)
== rgb.shape[:-1] + (transform_count,)
)
transform_feature_norm = einsum(
component_norms,
probs.p_terminates,
"rays samples transform_count, rays samples -> rays transform_count",
)

These norms are then scaled to [0, 1]. We index into only one transform at a time (# rays, # transforms) => (#rays,)), then apply a colormap to convert to RGB for the final visualization:

tilted/visualize_nerf.py

Lines 188 to 199 in df4614a

if self.mode == "transform_feature_norm":
image = image - image.min()
image /= image.max()
image = image[
...,
onp.argsort(
-onp.linalg.norm(
image.reshape((-1, image.shape[-1])), axis=0
)
)[self.transform_viz_index % image.shape[-1]],
]
image = (mpl.colormaps[self.cmap](image) * 255.0).astype(onp.uint8)

Hope that's helpful, and please let me know if anything's unclear!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants