-
I have the following loop to update the heightmap of my mesh at each iteration before rendering and computing the loss. params = mi.traverse(scene)
obj_vertices = dr.unravel(mi.Vector3f, params['object.vertex_positions'])
new_height = dr.zeros(mi.Float, shape=dr.width(obj_vertices))
for k in range(num_elems): # num_elems>50
# compute the distance of each point in the x-y plane from the given center
dist_from_center_k = dr.sqr(obj_vertices.x - opt[f'center_x_{k:02d}']) \
+ dr.sqr(obj_vertices.y - opt[f'center_y_{k:02d}'])
dist_from_center_k = dr.sqrt(dist_from_center_k)
# compute new height of the mesh as a function of the distance of the point from the given center
new_height += get_mesh_height(dist_from_center_k)
# repeat this for all the given center values
# and optimizer updates the center values at each iteration
obj_vertices.z = new_height This works well and the forward operation roughly takes about 500ms per iteration (on CPU). The backward pass successfully updates the desired parameters |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Instead of using a for loop, you should be able to compute all the distances to all the centers at once using a larger wavefront size and then accumulate all the height values using a Something like this (un-tested!): obj_vertices = dr.unravel(mi.Vector3f, params['object.vertex_positions'])
num_vertices = dr.width(obj_vertices)
size = num_elems * num_vertices
centers = dr.repeat(mi.Point2f(center_x_values, center_y_values), num_vertices)
dist = dr.sqrt(dr.sum(dr.sqr(mi.Point2f(obj_vertices.x, obj_vertices.y) - centers)))
new_height = get_mesh_height(dist)
dr.scatter_reduce(dr.ReduceOp.Add, obj_vertices.z, new_height, dr.arange(mi.UInt32, size) % num_elems)
dr.eval() |
Beta Was this translation helpful? Give feedback.
-
Thank you very much. Your code provides me with some very important pointers, especially the While I can easily implement the matrix operations in numpy, I am having difficulty, following it in drjit. For instance, using your code I would really appreciate any pointers/ideas here. |
Beta Was this translation helpful? Give feedback.
Instead of using a for loop, you should be able to compute all the distances to all the centers at once using a larger wavefront size and then accumulate all the height values using a
dr.scatter_reduce
operation. This would be very fast to compile for Dr.Jit (in forward and backward) and scale better as the number of elements increases.Something like this (un-tested!):