Skip to content

Commit

Permalink
add simulation start button in xpbd
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahdhn committed Oct 9, 2024
1 parent 9358a8e commit 47f7d56
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 104 deletions.
2 changes: 1 addition & 1 deletion apps/MCF/mcf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ int main(int argc, char** argv)
" -perm: Permutation method for Cholesky factorization. Default is {}\n"
" -max_cg_iter: Conjugate gradient maximum number of iterations. Default is {}\n"
" -device_id: GPU device ID. Default is {}",
Arg.obj_file_name, Arg.output_folder, (Arg.use_uniform_laplace? "true" : "false"), Arg.solver, Arg.time_step, Arg.cg_tolerance, Arg.perm_method, Arg.max_num_cg_iter, Arg.device_id);
Arg.obj_file_name, Arg.output_folder, (Arg.use_uniform_laplace? "true" : "false"), Arg.time_step, Arg.solver, Arg.cg_tolerance, Arg.perm_method, Arg.max_num_cg_iter, Arg.device_id);
// clang-format on
exit(EXIT_SUCCESS);
}
Expand Down
223 changes: 120 additions & 103 deletions apps/XPBD/xpbd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include "rxmesh/query.cuh"
#include "rxmesh/rxmesh_static.h"

#include "imgui.h"
#include "polyscope/polyscope.h"


using namespace rxmesh;

template <uint32_t blockThreads>
Expand Down Expand Up @@ -100,7 +104,7 @@ void __global__ solve_bending(const Context context,
const float bending_compliance,
const float bending_relaxation,
const float dt2)
{
{
auto solve = [&](const EdgeHandle& eh, const VertexIterator& iter) {
// iter[0] and iter[2] are the edge two vertices
// iter[1] and iter[3] are the two opposite vertices
Expand Down Expand Up @@ -272,119 +276,132 @@ int main(int argc, char** argv)
float mean(0.f);
float mean2(0.f);


// solve
bool started = false;

auto polyscope_callback = [&]() mutable {
GPUTimer timer;
timer.start();
float frame_time_left = frame_dt;
while (frame_time_left > 0.0) {
float dt0 = std::min(dt, frame_time_left);
frame_time_left -= dt0;

// applyExtForce
rx.for_each_vertex(DEVICE,
[dt0,
gravity,
invM = *invM,
v = *v,
new_x = *new_x,
x = *x] __device__(VertexHandle vh) {
if (invM(vh, 0) > 0.0) {
v(vh, 0) += gravity[0] * dt0;
v(vh, 1) += gravity[1] * dt0;
v(vh, 2) += gravity[2] * dt0;
}
new_x(vh, 0) = x(vh, 0) + v(vh, 0) * dt0;
new_x(vh, 1) = x(vh, 1) + v(vh, 1) * dt0;
new_x(vh, 2) = x(vh, 2) + v(vh, 2) * dt0;
});
if (ImGui::Button("Start Simulation") || started) {
started = true;

GPUTimer timer;
timer.start();
float frame_time_left = frame_dt;
while (frame_time_left > 0.0) {
float dt0 = std::min(dt, frame_time_left);
frame_time_left -= dt0;

// applyExtForce
rx.for_each_vertex(DEVICE,
[dt0,
gravity,
invM = *invM,
v = *v,
new_x = *new_x,
x = *x] __device__(VertexHandle vh) {
if (invM(vh, 0) > 0.0) {
v(vh, 0) += gravity[0] * dt0;
v(vh, 1) += gravity[1] * dt0;
v(vh, 2) += gravity[2] * dt0;
}
new_x(vh, 0) = x(vh, 0) + v(vh, 0) * dt0;
new_x(vh, 1) = x(vh, 1) + v(vh, 1) * dt0;
new_x(vh, 2) = x(vh, 2) + v(vh, 2) * dt0;
});

if (XPBD) {
la_s->reset(0.0, DEVICE);
la_b->reset(0.0, DEVICE);
}
if (XPBD) {
la_s->reset(0.0, DEVICE);
la_b->reset(0.0, DEVICE);
}

for (uint32_t iter = 0; iter < rest_iter; ++iter) {
// preSolve
dp->reset(0, DEVICE);

// solveStretch
solve_stretch<blockThreads>
<<<solve_stretch_lb.blocks,
solve_stretch_lb.num_threads,
solve_stretch_lb.smem_bytes_dyn>>>(
rx.get_context(),
*dp,
*la_s,
*invM,
*new_x,
*rest_len,
XPBD,
stretch_compliance,
stretch_relaxation,
dt0 * dt0);

// solveBending
solve_bending<blockThreads>
<<<solve_bending_lb.blocks,
solve_bending_lb.num_threads,
solve_bending_lb.smem_bytes_dyn>>>(
rx.get_context(),
*dp,
*la_b,
*invM,
*new_x,
XPBD,
bending_compliance,
bending_relaxation,
dt0 * dt0);

// postSolve
rx.for_each_vertex(
DEVICE,
[dp = *dp, new_x = *new_x] __device__(VertexHandle vh) {
new_x(vh, 0) += dp(vh, 0);
new_x(vh, 1) += dp(vh, 1);
new_x(vh, 2) += dp(vh, 2);
});
}

for (uint32_t iter = 0; iter < rest_iter; ++iter) {
// preSolve
dp->reset(0, DEVICE);

// solveStretch
solve_stretch<blockThreads>
<<<solve_stretch_lb.blocks,
solve_stretch_lb.num_threads,
solve_stretch_lb.smem_bytes_dyn>>>(rx.get_context(),
*dp,
*la_s,
*invM,
*new_x,
*rest_len,
XPBD,
stretch_compliance,
stretch_relaxation,
dt0 * dt0);

// solveBending
solve_bending<blockThreads>
<<<solve_bending_lb.blocks,
solve_bending_lb.num_threads,
solve_bending_lb.smem_bytes_dyn>>>(rx.get_context(),
*dp,
*la_b,
*invM,
*new_x,
XPBD,
bending_compliance,
bending_relaxation,
dt0 * dt0);

// postSolve
// update;
rx.for_each_vertex(
DEVICE,
[dp = *dp, new_x = *new_x] __device__(VertexHandle vh) {
new_x(vh, 0) += dp(vh, 0);
new_x(vh, 1) += dp(vh, 1);
new_x(vh, 2) += dp(vh, 2);
[dt0,
invM = *invM,
v = *v,
new_x = *new_x,
x = *x] __device__(VertexHandle vh) {
if (invM(vh, 0) <= 0.0) {
new_x(vh, 0) = x(vh, 0);
new_x(vh, 1) = x(vh, 1);
new_x(vh, 2) = x(vh, 2);
} else {
v(vh, 0) = (new_x(vh, 0) - x(vh, 0)) / dt0;
v(vh, 1) = (new_x(vh, 1) - x(vh, 1)) / dt0;
v(vh, 2) = (new_x(vh, 2) - x(vh, 2)) / dt0;

x(vh, 0) = new_x(vh, 0);
x(vh, 1) = new_x(vh, 1);
x(vh, 2) = new_x(vh, 2);
}
});
}

// update;
rx.for_each_vertex(
DEVICE,
[dt0, invM = *invM, v = *v, new_x = *new_x, x = *x] __device__(
VertexHandle vh) {
if (invM(vh, 0) <= 0.0) {
new_x(vh, 0) = x(vh, 0);
new_x(vh, 1) = x(vh, 1);
new_x(vh, 2) = x(vh, 2);
} else {
v(vh, 0) = (new_x(vh, 0) - x(vh, 0)) / dt0;
v(vh, 1) = (new_x(vh, 1) - x(vh, 1)) / dt0;
v(vh, 2) = (new_x(vh, 2) - x(vh, 2)) / dt0;

x(vh, 0) = new_x(vh, 0);
x(vh, 1) = new_x(vh, 1);
x(vh, 2) = new_x(vh, 2);
}
});
}

timer.stop();
RXMESH_INFO("Frame {}, time= {}(ms)", frame, timer.elapsed_millis());
timer.stop();
RXMESH_INFO(
"Frame {}, time= {}(ms)", frame, timer.elapsed_millis());
#if USE_POLYSCOPE
x->move(DEVICE, HOST);
rx.get_polyscope_mesh()->updateVertexPositions(*x);
x->move(DEVICE, HOST);
rx.get_polyscope_mesh()->updateVertexPositions(*x);
#endif
frame++;
if (test) {
if (frame == 99) {
rx.for_each_vertex(HOST, [&](VertexHandle vh) {
for (int i = 0; i < 3; ++i) {
mean += (*x)(vh, i);
mean2 += (*x)(vh, i) * (*x)(vh, i);
}
});
mean /= (3.f * rx.get_num_vertices());
mean2 /= (3.f * rx.get_num_vertices());
frame++;
if (test) {
if (frame == 99) {
rx.for_each_vertex(HOST, [&](VertexHandle vh) {
for (int i = 0; i < 3; ++i) {
mean += (*x)(vh, i);
mean2 += (*x)(vh, i) * (*x)(vh, i);
}
});
mean /= (3.f * rx.get_num_vertices());
mean2 /= (3.f * rx.get_num_vertices());
}
}
}
};
Expand Down

0 comments on commit 47f7d56

Please sign in to comment.